Auth: Fix render user OAuth passthrough (#111636)

* devenv: fix volumes section when sources don't contain one

* wip

* Working correctly with improvedExternalSessionHandling on

* Remove not needed lines

* Working with the old flow, tests

* Handle compatibility with the feature toggle, tests wip

* Tests

* Cleanup

* Address feedback

* Align tests

* Add comment

* Fix issue with session removal after the invalidation of tokens

* Remove commented out code

* clean up
This commit is contained in:
Misi 2025-10-07 10:52:43 +02:00 committed by GitHub
parent 3fea6e65f7
commit 53f4803e98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1912 additions and 323 deletions

View File

@ -56,6 +56,8 @@ func (s *store) Get(ctx context.Context, ID int64) (*auth.ExternalSession, error
return externalSession, nil
}
// List returns a list of external sessions that match the given query.
// If the result set contains more than one entry, the entries are sorted by ID in descending order.
func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery) ([]*auth.ExternalSession, error) {
ctx, span := s.tracer.Start(ctx, "externalsession.List")
defer span.End()
@ -65,6 +67,10 @@ func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery)
externalSession.ID = query.ID
}
if query.UserID != 0 {
externalSession.UserID = query.UserID
}
hash := sha256.New()
if query.SessionID != "" {
@ -80,7 +86,7 @@ func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery)
queryResult := make([]*auth.ExternalSession, 0)
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
return sess.Find(&queryResult, externalSession)
return sess.Desc("id").Find(&queryResult, externalSession)
})
if err != nil {
return nil, err

View File

@ -51,6 +51,7 @@ type UpdateExternalSessionCommand struct {
type ListExternalSessionQuery struct {
ID int64
UserID int64
NameID string
SessionID string
}

View File

@ -93,7 +93,11 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, id *authn.Ident
updateCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 15*time.Second)
defer cancel()
token, refreshErr := s.service.TryTokenRefresh(updateCtx, id, id.SessionToken)
token, refreshErr := s.service.TryTokenRefresh(updateCtx, id, &oauthtoken.TokenRefreshMetadata{
ExternalSessionID: id.SessionToken.ExternalSessionId,
AuthModule: id.GetAuthenticatedBy(),
AuthID: id.GetAuthID(),
})
if refreshErr != nil {
if errors.Is(refreshErr, context.Canceled) {
return nil, nil
@ -107,7 +111,7 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, id *authn.Ident
ctxLogger.Error("Failed to refresh OAuth access token", "id", id.ID, "error", refreshErr)
// log the user out
if err := s.sessionService.RevokeToken(ctx, id.SessionToken, false); err != nil {
if err := s.sessionService.RevokeToken(ctx, id.SessionToken, false); err != nil && !errors.Is(err, auth.ErrUserTokenNotFound) {
ctxLogger.Warn("Failed to revoke session token", "id", id.ID, "tokenId", id.SessionToken.Id, "error", err)
}

View File

@ -25,6 +25,7 @@ import (
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
)
@ -77,6 +78,14 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
expectRevokeTokenCalled: false,
expectToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
},
{
desc: "should not invalidate session if token refresh fails with no refresh token",
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
expectedTryRefreshErr: oauthtoken.ErrNoRefreshTokenFound,
expectTryRefreshTokenCalled: true,
expectRevokeTokenCalled: true,
expectedErr: oauthtoken.ErrNoRefreshTokenFound,
},
// TODO: address coverage of oauthtoken sync
}
@ -89,7 +98,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
)
service := &oauthtokentest.MockOauthTokenService{
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester, _ *auth.UserToken) (*oauth2.Token, error) {
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester, _ *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error) {
tryRefreshCalled = true
return nil, tt.expectedTryRefreshErr
},

View File

@ -297,7 +297,9 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester, sessionToke
ctxLogger := c.log.FromContext(ctx).New("userID", userID)
if err := c.oauthService.InvalidateOAuthTokens(ctx, user, sessionToken); err != nil {
if err := c.oauthService.InvalidateOAuthTokens(ctx, user, &oauthtoken.TokenRefreshMetadata{
ExternalSessionID: sessionToken.ExternalSessionId,
AuthModule: user.GetAuthenticatedBy()}); err != nil {
ctxLogger.Error("Failed to invalidate tokens", "error", err)
}

View File

@ -19,10 +19,12 @@ import (
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/social/socialtest"
"github.com/grafana/grafana/pkg/models/usertoken"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
@ -481,7 +483,7 @@ func TestOAuth_Logout(t *testing.T) {
"id_token": "some.id.token",
})
},
InvalidateOAuthTokensFunc: func(_ context.Context, _ identity.Requester, _ *auth.UserToken) error {
InvalidateOAuthTokensFunc: func(_ context.Context, _ identity.Requester, _ *oauthtoken.TokenRefreshMetadata) error {
invalidateTokenCalled = true
return nil
},
@ -492,7 +494,7 @@ func TestOAuth_Logout(t *testing.T) {
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, mockService, fakeSocialSvc, &setting.OSSImpl{Cfg: tt.cfg}, featuremgmt.WithFeatures(), tracing.InitializeTracerForTest())
redirect, ok := c.Logout(context.Background(), &authn.Identity{ID: "1", Type: claims.TypeUser}, nil)
redirect, ok := c.Logout(context.Background(), &authn.Identity{ID: "1", Type: claims.TypeUser}, &usertoken.UserToken{})
assert.Equal(t, tt.expectedOK, ok)
if tt.expectedOK {

View File

@ -5,6 +5,7 @@ import (
"strings"
)
//go:generate mockery --name AuthInfoService --structname MockAuthInfoService --outpkg authinfotest --filename auth_info_service_mock.go --output ./authinfotest/
type AuthInfoService interface {
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) (*UserAuth, error)
GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error)

View File

@ -0,0 +1,765 @@
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package authinfotest
import (
"context"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/user"
mock "github.com/stretchr/testify/mock"
)
// NewMockAuthInfoService creates a new instance of MockAuthInfoService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockAuthInfoService(t interface {
mock.TestingT
Cleanup(func())
}) *MockAuthInfoService {
mock := &MockAuthInfoService{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// MockAuthInfoService is an autogenerated mock type for the AuthInfoService type
type MockAuthInfoService struct {
mock.Mock
}
type MockAuthInfoService_Expecter struct {
mock *mock.Mock
}
func (_m *MockAuthInfoService) EXPECT() *MockAuthInfoService_Expecter {
return &MockAuthInfoService_Expecter{mock: &_m.Mock}
}
// DeleteUserAuthInfo provides a mock function for the type MockAuthInfoService
func (_mock *MockAuthInfoService) DeleteUserAuthInfo(ctx context.Context, userID int64) error {
ret := _mock.Called(ctx, userID)
if len(ret) == 0 {
panic("no return value specified for DeleteUserAuthInfo")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, int64) error); ok {
r0 = returnFunc(ctx, userID)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockAuthInfoService_DeleteUserAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteUserAuthInfo'
type MockAuthInfoService_DeleteUserAuthInfo_Call struct {
*mock.Call
}
// DeleteUserAuthInfo is a helper method to define mock.On call
// - ctx context.Context
// - userID int64
func (_e *MockAuthInfoService_Expecter) DeleteUserAuthInfo(ctx interface{}, userID interface{}) *MockAuthInfoService_DeleteUserAuthInfo_Call {
return &MockAuthInfoService_DeleteUserAuthInfo_Call{Call: _e.mock.On("DeleteUserAuthInfo", ctx, userID)}
}
func (_c *MockAuthInfoService_DeleteUserAuthInfo_Call) Run(run func(ctx context.Context, userID int64)) *MockAuthInfoService_DeleteUserAuthInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 int64
if args[1] != nil {
arg1 = args[1].(int64)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockAuthInfoService_DeleteUserAuthInfo_Call) Return(err error) *MockAuthInfoService_DeleteUserAuthInfo_Call {
_c.Call.Return(err)
return _c
}
func (_c *MockAuthInfoService_DeleteUserAuthInfo_Call) RunAndReturn(run func(ctx context.Context, userID int64) error) *MockAuthInfoService_DeleteUserAuthInfo_Call {
_c.Call.Return(run)
return _c
}
// GetAuthInfo provides a mock function for the type MockAuthInfoService
func (_mock *MockAuthInfoService) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
ret := _mock.Called(ctx, query)
if len(ret) == 0 {
panic("no return value specified for GetAuthInfo")
}
var r0 *login.UserAuth
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) (*login.UserAuth, error)); ok {
return returnFunc(ctx, query)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) *login.UserAuth); ok {
r0 = returnFunc(ctx, query)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*login.UserAuth)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *login.GetAuthInfoQuery) error); ok {
r1 = returnFunc(ctx, query)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockAuthInfoService_GetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthInfo'
type MockAuthInfoService_GetAuthInfo_Call struct {
*mock.Call
}
// GetAuthInfo is a helper method to define mock.On call
// - ctx context.Context
// - query *login.GetAuthInfoQuery
func (_e *MockAuthInfoService_Expecter) GetAuthInfo(ctx interface{}, query interface{}) *MockAuthInfoService_GetAuthInfo_Call {
return &MockAuthInfoService_GetAuthInfo_Call{Call: _e.mock.On("GetAuthInfo", ctx, query)}
}
func (_c *MockAuthInfoService_GetAuthInfo_Call) Run(run func(ctx context.Context, query *login.GetAuthInfoQuery)) *MockAuthInfoService_GetAuthInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *login.GetAuthInfoQuery
if args[1] != nil {
arg1 = args[1].(*login.GetAuthInfoQuery)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockAuthInfoService_GetAuthInfo_Call) Return(userAuth *login.UserAuth, err error) *MockAuthInfoService_GetAuthInfo_Call {
_c.Call.Return(userAuth, err)
return _c
}
func (_c *MockAuthInfoService_GetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error)) *MockAuthInfoService_GetAuthInfo_Call {
_c.Call.Return(run)
return _c
}
// GetUserLabels provides a mock function for the type MockAuthInfoService
func (_mock *MockAuthInfoService) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) {
ret := _mock.Called(ctx, query)
if len(ret) == 0 {
panic("no return value specified for GetUserLabels")
}
var r0 map[int64]string
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) (map[int64]string, error)); ok {
return returnFunc(ctx, query)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) map[int64]string); ok {
r0 = returnFunc(ctx, query)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]string)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, login.GetUserLabelsQuery) error); ok {
r1 = returnFunc(ctx, query)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockAuthInfoService_GetUserLabels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserLabels'
type MockAuthInfoService_GetUserLabels_Call struct {
*mock.Call
}
// GetUserLabels is a helper method to define mock.On call
// - ctx context.Context
// - query login.GetUserLabelsQuery
func (_e *MockAuthInfoService_Expecter) GetUserLabels(ctx interface{}, query interface{}) *MockAuthInfoService_GetUserLabels_Call {
return &MockAuthInfoService_GetUserLabels_Call{Call: _e.mock.On("GetUserLabels", ctx, query)}
}
func (_c *MockAuthInfoService_GetUserLabels_Call) Run(run func(ctx context.Context, query login.GetUserLabelsQuery)) *MockAuthInfoService_GetUserLabels_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 login.GetUserLabelsQuery
if args[1] != nil {
arg1 = args[1].(login.GetUserLabelsQuery)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockAuthInfoService_GetUserLabels_Call) Return(int64ToString map[int64]string, err error) *MockAuthInfoService_GetUserLabels_Call {
_c.Call.Return(int64ToString, err)
return _c
}
func (_c *MockAuthInfoService_GetUserLabels_Call) RunAndReturn(run func(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error)) *MockAuthInfoService_GetUserLabels_Call {
_c.Call.Return(run)
return _c
}
// SetAuthInfo provides a mock function for the type MockAuthInfoService
func (_mock *MockAuthInfoService) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
ret := _mock.Called(ctx, cmd)
if len(ret) == 0 {
panic("no return value specified for SetAuthInfo")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.SetAuthInfoCommand) error); ok {
r0 = returnFunc(ctx, cmd)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockAuthInfoService_SetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetAuthInfo'
type MockAuthInfoService_SetAuthInfo_Call struct {
*mock.Call
}
// SetAuthInfo is a helper method to define mock.On call
// - ctx context.Context
// - cmd *login.SetAuthInfoCommand
func (_e *MockAuthInfoService_Expecter) SetAuthInfo(ctx interface{}, cmd interface{}) *MockAuthInfoService_SetAuthInfo_Call {
return &MockAuthInfoService_SetAuthInfo_Call{Call: _e.mock.On("SetAuthInfo", ctx, cmd)}
}
func (_c *MockAuthInfoService_SetAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.SetAuthInfoCommand)) *MockAuthInfoService_SetAuthInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *login.SetAuthInfoCommand
if args[1] != nil {
arg1 = args[1].(*login.SetAuthInfoCommand)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockAuthInfoService_SetAuthInfo_Call) Return(err error) *MockAuthInfoService_SetAuthInfo_Call {
_c.Call.Return(err)
return _c
}
func (_c *MockAuthInfoService_SetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.SetAuthInfoCommand) error) *MockAuthInfoService_SetAuthInfo_Call {
_c.Call.Return(run)
return _c
}
// UpdateAuthInfo provides a mock function for the type MockAuthInfoService
func (_mock *MockAuthInfoService) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
ret := _mock.Called(ctx, cmd)
if len(ret) == 0 {
panic("no return value specified for UpdateAuthInfo")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.UpdateAuthInfoCommand) error); ok {
r0 = returnFunc(ctx, cmd)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockAuthInfoService_UpdateAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAuthInfo'
type MockAuthInfoService_UpdateAuthInfo_Call struct {
*mock.Call
}
// UpdateAuthInfo is a helper method to define mock.On call
// - ctx context.Context
// - cmd *login.UpdateAuthInfoCommand
func (_e *MockAuthInfoService_Expecter) UpdateAuthInfo(ctx interface{}, cmd interface{}) *MockAuthInfoService_UpdateAuthInfo_Call {
return &MockAuthInfoService_UpdateAuthInfo_Call{Call: _e.mock.On("UpdateAuthInfo", ctx, cmd)}
}
func (_c *MockAuthInfoService_UpdateAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand)) *MockAuthInfoService_UpdateAuthInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *login.UpdateAuthInfoCommand
if args[1] != nil {
arg1 = args[1].(*login.UpdateAuthInfoCommand)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockAuthInfoService_UpdateAuthInfo_Call) Return(err error) *MockAuthInfoService_UpdateAuthInfo_Call {
_c.Call.Return(err)
return _c
}
func (_c *MockAuthInfoService_UpdateAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error) *MockAuthInfoService_UpdateAuthInfo_Call {
_c.Call.Return(run)
return _c
}
// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockStore(t interface {
mock.TestingT
Cleanup(func())
}) *MockStore {
mock := &MockStore{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// MockStore is an autogenerated mock type for the Store type
type MockStore struct {
mock.Mock
}
type MockStore_Expecter struct {
mock *mock.Mock
}
func (_m *MockStore) EXPECT() *MockStore_Expecter {
return &MockStore_Expecter{mock: &_m.Mock}
}
// DeleteUserAuthInfo provides a mock function for the type MockStore
func (_mock *MockStore) DeleteUserAuthInfo(ctx context.Context, userID int64) error {
ret := _mock.Called(ctx, userID)
if len(ret) == 0 {
panic("no return value specified for DeleteUserAuthInfo")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, int64) error); ok {
r0 = returnFunc(ctx, userID)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockStore_DeleteUserAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteUserAuthInfo'
type MockStore_DeleteUserAuthInfo_Call struct {
*mock.Call
}
// DeleteUserAuthInfo is a helper method to define mock.On call
// - ctx context.Context
// - userID int64
func (_e *MockStore_Expecter) DeleteUserAuthInfo(ctx interface{}, userID interface{}) *MockStore_DeleteUserAuthInfo_Call {
return &MockStore_DeleteUserAuthInfo_Call{Call: _e.mock.On("DeleteUserAuthInfo", ctx, userID)}
}
func (_c *MockStore_DeleteUserAuthInfo_Call) Run(run func(ctx context.Context, userID int64)) *MockStore_DeleteUserAuthInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 int64
if args[1] != nil {
arg1 = args[1].(int64)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockStore_DeleteUserAuthInfo_Call) Return(err error) *MockStore_DeleteUserAuthInfo_Call {
_c.Call.Return(err)
return _c
}
func (_c *MockStore_DeleteUserAuthInfo_Call) RunAndReturn(run func(ctx context.Context, userID int64) error) *MockStore_DeleteUserAuthInfo_Call {
_c.Call.Return(run)
return _c
}
// GetAuthInfo provides a mock function for the type MockStore
func (_mock *MockStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
ret := _mock.Called(ctx, query)
if len(ret) == 0 {
panic("no return value specified for GetAuthInfo")
}
var r0 *login.UserAuth
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) (*login.UserAuth, error)); ok {
return returnFunc(ctx, query)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) *login.UserAuth); ok {
r0 = returnFunc(ctx, query)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*login.UserAuth)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *login.GetAuthInfoQuery) error); ok {
r1 = returnFunc(ctx, query)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockStore_GetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthInfo'
type MockStore_GetAuthInfo_Call struct {
*mock.Call
}
// GetAuthInfo is a helper method to define mock.On call
// - ctx context.Context
// - query *login.GetAuthInfoQuery
func (_e *MockStore_Expecter) GetAuthInfo(ctx interface{}, query interface{}) *MockStore_GetAuthInfo_Call {
return &MockStore_GetAuthInfo_Call{Call: _e.mock.On("GetAuthInfo", ctx, query)}
}
func (_c *MockStore_GetAuthInfo_Call) Run(run func(ctx context.Context, query *login.GetAuthInfoQuery)) *MockStore_GetAuthInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *login.GetAuthInfoQuery
if args[1] != nil {
arg1 = args[1].(*login.GetAuthInfoQuery)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockStore_GetAuthInfo_Call) Return(userAuth *login.UserAuth, err error) *MockStore_GetAuthInfo_Call {
_c.Call.Return(userAuth, err)
return _c
}
func (_c *MockStore_GetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error)) *MockStore_GetAuthInfo_Call {
_c.Call.Return(run)
return _c
}
// GetUserLabels provides a mock function for the type MockStore
func (_mock *MockStore) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) {
ret := _mock.Called(ctx, query)
if len(ret) == 0 {
panic("no return value specified for GetUserLabels")
}
var r0 map[int64]string
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) (map[int64]string, error)); ok {
return returnFunc(ctx, query)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) map[int64]string); ok {
r0 = returnFunc(ctx, query)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]string)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, login.GetUserLabelsQuery) error); ok {
r1 = returnFunc(ctx, query)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockStore_GetUserLabels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserLabels'
type MockStore_GetUserLabels_Call struct {
*mock.Call
}
// GetUserLabels is a helper method to define mock.On call
// - ctx context.Context
// - query login.GetUserLabelsQuery
func (_e *MockStore_Expecter) GetUserLabels(ctx interface{}, query interface{}) *MockStore_GetUserLabels_Call {
return &MockStore_GetUserLabels_Call{Call: _e.mock.On("GetUserLabels", ctx, query)}
}
func (_c *MockStore_GetUserLabels_Call) Run(run func(ctx context.Context, query login.GetUserLabelsQuery)) *MockStore_GetUserLabels_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 login.GetUserLabelsQuery
if args[1] != nil {
arg1 = args[1].(login.GetUserLabelsQuery)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockStore_GetUserLabels_Call) Return(int64ToString map[int64]string, err error) *MockStore_GetUserLabels_Call {
_c.Call.Return(int64ToString, err)
return _c
}
func (_c *MockStore_GetUserLabels_Call) RunAndReturn(run func(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error)) *MockStore_GetUserLabels_Call {
_c.Call.Return(run)
return _c
}
// SetAuthInfo provides a mock function for the type MockStore
func (_mock *MockStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
ret := _mock.Called(ctx, cmd)
if len(ret) == 0 {
panic("no return value specified for SetAuthInfo")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.SetAuthInfoCommand) error); ok {
r0 = returnFunc(ctx, cmd)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockStore_SetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetAuthInfo'
type MockStore_SetAuthInfo_Call struct {
*mock.Call
}
// SetAuthInfo is a helper method to define mock.On call
// - ctx context.Context
// - cmd *login.SetAuthInfoCommand
func (_e *MockStore_Expecter) SetAuthInfo(ctx interface{}, cmd interface{}) *MockStore_SetAuthInfo_Call {
return &MockStore_SetAuthInfo_Call{Call: _e.mock.On("SetAuthInfo", ctx, cmd)}
}
func (_c *MockStore_SetAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.SetAuthInfoCommand)) *MockStore_SetAuthInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *login.SetAuthInfoCommand
if args[1] != nil {
arg1 = args[1].(*login.SetAuthInfoCommand)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockStore_SetAuthInfo_Call) Return(err error) *MockStore_SetAuthInfo_Call {
_c.Call.Return(err)
return _c
}
func (_c *MockStore_SetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.SetAuthInfoCommand) error) *MockStore_SetAuthInfo_Call {
_c.Call.Return(run)
return _c
}
// UpdateAuthInfo provides a mock function for the type MockStore
func (_mock *MockStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
ret := _mock.Called(ctx, cmd)
if len(ret) == 0 {
panic("no return value specified for UpdateAuthInfo")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.UpdateAuthInfoCommand) error); ok {
r0 = returnFunc(ctx, cmd)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockStore_UpdateAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAuthInfo'
type MockStore_UpdateAuthInfo_Call struct {
*mock.Call
}
// UpdateAuthInfo is a helper method to define mock.On call
// - ctx context.Context
// - cmd *login.UpdateAuthInfoCommand
func (_e *MockStore_Expecter) UpdateAuthInfo(ctx interface{}, cmd interface{}) *MockStore_UpdateAuthInfo_Call {
return &MockStore_UpdateAuthInfo_Call{Call: _e.mock.On("UpdateAuthInfo", ctx, cmd)}
}
func (_c *MockStore_UpdateAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand)) *MockStore_UpdateAuthInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *login.UpdateAuthInfoCommand
if args[1] != nil {
arg1 = args[1].(*login.UpdateAuthInfoCommand)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockStore_UpdateAuthInfo_Call) Return(err error) *MockStore_UpdateAuthInfo_Call {
_c.Call.Return(err)
return _c
}
func (_c *MockStore_UpdateAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error) *MockStore_UpdateAuthInfo_Call {
_c.Call.Return(run)
return _c
}
// NewMockUserProtectionService creates a new instance of MockUserProtectionService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockUserProtectionService(t interface {
mock.TestingT
Cleanup(func())
}) *MockUserProtectionService {
mock := &MockUserProtectionService{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// MockUserProtectionService is an autogenerated mock type for the UserProtectionService type
type MockUserProtectionService struct {
mock.Mock
}
type MockUserProtectionService_Expecter struct {
mock *mock.Mock
}
func (_m *MockUserProtectionService) EXPECT() *MockUserProtectionService_Expecter {
return &MockUserProtectionService_Expecter{mock: &_m.Mock}
}
// AllowUserMapping provides a mock function for the type MockUserProtectionService
func (_mock *MockUserProtectionService) AllowUserMapping(user1 *user.User, authModule string) error {
ret := _mock.Called(user1, authModule)
if len(ret) == 0 {
panic("no return value specified for AllowUserMapping")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(*user.User, string) error); ok {
r0 = returnFunc(user1, authModule)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockUserProtectionService_AllowUserMapping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllowUserMapping'
type MockUserProtectionService_AllowUserMapping_Call struct {
*mock.Call
}
// AllowUserMapping is a helper method to define mock.On call
// - user1 *user.User
// - authModule string
func (_e *MockUserProtectionService_Expecter) AllowUserMapping(user1 interface{}, authModule interface{}) *MockUserProtectionService_AllowUserMapping_Call {
return &MockUserProtectionService_AllowUserMapping_Call{Call: _e.mock.On("AllowUserMapping", user1, authModule)}
}
func (_c *MockUserProtectionService_AllowUserMapping_Call) Run(run func(user1 *user.User, authModule string)) *MockUserProtectionService_AllowUserMapping_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 *user.User
if args[0] != nil {
arg0 = args[0].(*user.User)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockUserProtectionService_AllowUserMapping_Call) Return(err error) *MockUserProtectionService_AllowUserMapping_Call {
_c.Call.Return(err)
return _c
}
func (_c *MockUserProtectionService_AllowUserMapping_Call) RunAndReturn(run func(user1 *user.User, authModule string) error) *MockUserProtectionService_AllowUserMapping_Call {
_c.Call.Return(run)
return _c
}

View File

@ -11,6 +11,7 @@ import (
"github.com/go-jose/go-jose/v4/jwt"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
@ -57,8 +58,14 @@ var _ OAuthTokenService = (*Service)(nil)
type OAuthTokenService interface {
GetCurrentOAuthToken(context.Context, identity.Requester, *auth.UserToken) *oauth2.Token
IsOAuthPassThruEnabled(*datasources.DataSource) bool
TryTokenRefresh(context.Context, identity.Requester, *auth.UserToken) (*oauth2.Token, error)
InvalidateOAuthTokens(context.Context, identity.Requester, *auth.UserToken) error
TryTokenRefresh(context.Context, identity.Requester, *TokenRefreshMetadata) (*oauth2.Token, error)
InvalidateOAuthTokens(context.Context, identity.Requester, *TokenRefreshMetadata) error
}
type TokenRefreshMetadata struct {
ExternalSessionID int64
AuthModule string
AuthID string
}
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer,
@ -102,51 +109,71 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Request
ctxLogger = ctxLogger.New("userID", userID)
if !strings.HasPrefix(usr.GetAuthenticatedBy(), "oauth_") {
ctxLogger.Warn("The specified user's auth provider is not oauth",
"authmodule", usr.GetAuthenticatedBy())
tokenRefreshMetadata := &TokenRefreshMetadata{
ExternalSessionID: 0,
}
var persistedToken *oauth2.Token
// Find the external session associated with the user and session token
// regardless of the improvedExternalSessionHandling feature toggle,
// because Grafana writes and updates both tables to make the switch
// to the new session handling smoother.
externalSession, err := o.getExternalSession(ctx, usr, userID, sessionToken)
if err != nil && !errors.Is(err, auth.ErrExternalSessionNotFound) {
ctxLogger.Error("Failed to get external session", "error", err)
return nil
}
// If the feature toggle is enabled, an external session is required.
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) && (externalSession == nil || errors.Is(err, auth.ErrExternalSessionNotFound)) {
ctxLogger.Error("No external session found for user", "userID", userID)
return nil
}
// externalSession can be nil if Grafana was updated from a version where the
// external session table was not used yet (did not exist) and the user has not logged in since
// the version update (therefore no external session was created for the user yet).
if externalSession != nil {
tokenRefreshMetadata.ExternalSessionID = externalSession.ID
}
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{
UserId: userID,
})
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
ctxLogger.Warn("No AuthInfo found for user", "userID", userID)
return nil
}
ctxLogger.Error("Failed to fetch AuthInfo for user", "userID", userID, "error", err)
return nil
}
tokenRefreshMetadata.AuthID = authInfo.AuthId
tokenRefreshMetadata.AuthModule = authInfo.AuthModule
if !strings.HasPrefix(tokenRefreshMetadata.AuthModule, "oauth_") {
ctxLogger.Warn("The specified user's auth provider is not oauth",
"authmodule", tokenRefreshMetadata.AuthModule)
return nil
}
var persistedToken *oauth2.Token
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
externalSession, err := o.sessionService.GetExternalSession(ctx, sessionToken.ExternalSessionId)
if err != nil {
if errors.Is(err, auth.ErrExternalSessionNotFound) {
return nil
}
ctxLogger.Error("Failed to fetch external session", "error", err)
return nil
}
persistedToken = buildOAuthTokenFromExternalSession(externalSession)
if persistedToken.RefreshToken == "" {
return persistedToken
}
} else {
authInfo, ok, _ := o.hasOAuthEntry(ctx, usr)
if !ok {
return nil
}
if err := checkOAuthRefreshToken(authInfo); err != nil {
if errors.Is(err, ErrNoRefreshTokenFound) {
return buildOAuthTokenFromAuthInfo(authInfo)
}
return nil
}
persistedToken = buildOAuthTokenFromAuthInfo(authInfo)
}
if persistedToken.RefreshToken == "" {
return persistedToken
}
refreshNeeded := needTokenRefresh(ctx, persistedToken)
if !refreshNeeded {
return persistedToken
}
token, err := o.TryTokenRefresh(ctx, usr, sessionToken)
token, err := o.TryTokenRefresh(ctx, usr, tokenRefreshMetadata)
if err != nil {
if errors.Is(err, ErrNoRefreshTokenFound) {
return persistedToken
@ -214,7 +241,7 @@ func (o *Service) hasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
// It uses a server lock to prevent getting the Refresh Token multiple times for a given User
func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error) {
func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, tokenRefreshMetadata *TokenRefreshMetadata) (*oauth2.Token, error) {
ctx, span := o.tracer.Start(ctx, "oauthtoken.TryTokenRefresh")
defer span.End()
@ -239,14 +266,13 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s
ctxLogger = ctxLogger.New("userID", userID)
// get the token's auth provider (f.e. azuread)
currAuthenticator := usr.GetAuthenticatedBy()
if !strings.HasPrefix(currAuthenticator, "oauth") {
ctxLogger.Warn("The specified user's auth provider is not OAuth", "authmodule", currAuthenticator)
if !strings.HasPrefix(tokenRefreshMetadata.AuthModule, "oauth_") {
ctxLogger.Warn("The specified user's auth provider is not oauth",
"authmodule", tokenRefreshMetadata.AuthModule)
return nil, nil
}
provider := strings.TrimPrefix(currAuthenticator, "oauth_")
provider := strings.TrimPrefix(tokenRefreshMetadata.AuthModule, "oauth_")
currentOAuthInfo := o.SocialService.GetOAuthInfoProvider(provider)
if currentOAuthInfo == nil {
ctxLogger.Warn("OAuth provider not found", "provider", provider)
@ -261,7 +287,7 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s
lockKey := fmt.Sprintf("oauth-refresh-token-%d", userID)
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
lockKey = fmt.Sprintf("oauth-refresh-token-%d-%d", userID, sessionToken.ExternalSessionId)
lockKey = fmt.Sprintf("oauth-refresh-token-%d-%d", userID, tokenRefreshMetadata.ExternalSessionID)
}
lockTimeConfig := serverlock.LockTimeConfig{
@ -290,7 +316,7 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s
var persistedToken *oauth2.Token
var externalSession *auth.ExternalSession
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
externalSession, err = o.sessionService.GetExternalSession(ctx, sessionToken.ExternalSessionId)
externalSession, err = o.sessionService.GetExternalSession(ctx, tokenRefreshMetadata.ExternalSessionID)
if err != nil {
if errors.Is(err, auth.ErrExternalSessionNotFound) {
ctxLogger.Error("External session was not found for user", "error", err)
@ -321,7 +347,7 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s
return
}
newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, persistedToken, usr, sessionToken)
newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, persistedToken, usr, tokenRefreshMetadata)
}, retryOpt)
if lockErr != nil {
ctxLogger.Error("Failed to obtain token refresh lock", "error", lockErr)
@ -330,14 +356,14 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s
// Silence ErrNoRefreshTokenFound
if errors.Is(cmdErr, ErrNoRefreshTokenFound) {
return nil, nil
return nil, ErrNoRefreshTokenFound
}
return newToken, cmdErr
}
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) error {
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, tokenRefreshMetadata *TokenRefreshMetadata) error {
userID, err := usr.GetInternalID()
if err != nil {
logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err)
@ -347,7 +373,7 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Reques
ctxLogger := logger.FromContext(ctx).New("userID", userID)
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
err := o.sessionService.UpdateExternalSession(ctx, sessionToken.ExternalSessionId, &auth.UpdateExternalSessionCommand{
err := o.sessionService.UpdateExternalSession(ctx, tokenRefreshMetadata.ExternalSessionID, &auth.UpdateExternalSessionCommand{
Token: &oauth2.Token{},
})
if err != nil {
@ -358,8 +384,8 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Reques
return o.AuthInfoService.UpdateAuthInfo(ctx, &login.UpdateAuthInfoCommand{
UserId: userID,
AuthModule: usr.GetAuthenticatedBy(),
AuthId: usr.GetAuthID(),
AuthModule: tokenRefreshMetadata.AuthModule,
AuthId: tokenRefreshMetadata.AuthID,
OAuthToken: &oauth2.Token{
AccessToken: "",
RefreshToken: "",
@ -368,13 +394,14 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Reques
})
}
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken *oauth2.Token, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error) {
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken *oauth2.Token, usr identity.Requester, tokenRefreshMetadata *TokenRefreshMetadata) (*oauth2.Token, error) {
ctx, span := o.tracer.Start(ctx, "oauthtoken.tryGetOrRefreshOAuthToken")
defer span.End()
userID, err := usr.GetInternalID()
if err != nil {
logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err)
span.SetStatus(codes.Error, "Failed to convert user id to int")
return nil, err
}
@ -382,8 +409,11 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken
ctxLogger := logger.FromContext(ctx).New("userID", userID)
// tryGetOrRefreshOAuthToken assumes that the AuthModule has RefreshToken enabled
// which is checked by the caller (TryTokenRefresh)
if persistedToken.RefreshToken == "" {
ctxLogger.Warn("No refresh token available", "authmodule", usr.GetAuthenticatedBy())
ctxLogger.Error("No refresh token available", "authmodule", tokenRefreshMetadata.AuthModule)
span.SetStatus(codes.Error, ErrNoRefreshTokenFound.Error())
return nil, ErrNoRefreshTokenFound
}
@ -392,50 +422,44 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken
return persistedToken, nil
}
authProvider := usr.GetAuthenticatedBy()
connect, err := o.SocialService.GetConnector(authProvider)
connect, err := o.SocialService.GetConnector(tokenRefreshMetadata.AuthModule)
if err != nil {
ctxLogger.Error("Failed to get oauth connector", "provider", authProvider, "error", err)
ctxLogger.Error("Failed to get oauth connector", "provider", tokenRefreshMetadata.AuthModule, "error", err)
span.SetStatus(codes.Error, "Failed to get oauth connector: "+err.Error())
return nil, err
}
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
client, err := o.SocialService.GetOAuthHttpClient(tokenRefreshMetadata.AuthModule)
if err != nil {
ctxLogger.Error("Failed to get oauth http client", "provider", authProvider, "error", err)
ctxLogger.Error("Failed to get oauth http client", "provider", tokenRefreshMetadata.AuthModule, "error", err)
span.SetStatus(codes.Error, "Failed to get oauth http client")
return nil, err
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
start := time.Now()
// TokenSource handles refreshing the token if it has expired
token, err := connect.TokenSource(ctx, persistedToken).Token()
token, refreshErr := connect.TokenSource(ctx, persistedToken).Token()
duration := time.Since(start)
o.tokenRefreshDuration.WithLabelValues(authProvider, fmt.Sprintf("%t", err == nil)).Observe(duration.Seconds())
o.tokenRefreshDuration.WithLabelValues(tokenRefreshMetadata.AuthModule, fmt.Sprintf("%t", err == nil)).Observe(duration.Seconds())
if err != nil {
if refreshErr != nil {
span.SetAttributes(attribute.Bool("token_refreshed", false))
ctxLogger.Error("Failed to retrieve oauth access token",
"provider", usr.GetAuthenticatedBy(), "error", err)
"provider", tokenRefreshMetadata.AuthModule, "error", refreshErr)
// token refresh failed, invalidate the old token
if err := o.InvalidateOAuthTokens(ctx, usr, sessionToken); err != nil {
ctxLogger.Warn("Failed to invalidate OAuth tokens", "authID", usr.GetAuthID(), "error", err)
if err := o.InvalidateOAuthTokens(ctx, usr, tokenRefreshMetadata); err != nil {
ctxLogger.Warn("Failed to invalidate OAuth tokens", "authID", tokenRefreshMetadata.AuthID, "error", err)
}
return nil, err
return nil, refreshErr
}
span.SetAttributes(attribute.Bool("token_refreshed", true))
// If the tokens are not the same, update the entry in the DB
if !tokensEq(persistedToken, token) {
updateAuthCommand := &login.UpdateAuthInfoCommand{
UserId: userID,
AuthModule: usr.GetAuthenticatedBy(),
AuthId: usr.GetAuthID(),
OAuthToken: token,
}
if o.Cfg.Env == setting.Dev {
ctxLogger.Debug("Oauth got token",
"auth_module", usr.GetAuthenticatedBy(),
@ -446,17 +470,32 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken
}
if !o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
updateAuthCommand := &login.UpdateAuthInfoCommand{
UserId: userID,
AuthModule: tokenRefreshMetadata.AuthModule,
AuthId: tokenRefreshMetadata.AuthID,
OAuthToken: token,
}
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
ctxLogger.Error("Failed to update auth info during token refresh", "authID", usr.GetAuthID(), "error", err)
ctxLogger.Error("Failed to update auth info during token refresh", "authID", tokenRefreshMetadata.AuthID, "error", err)
span.SetStatus(codes.Error, "Failed to update auth info during token refresh")
return nil, err
}
}
if err := o.sessionService.UpdateExternalSession(ctx, sessionToken.ExternalSessionId, &auth.UpdateExternalSessionCommand{
Token: token,
}); err != nil {
ctxLogger.Error("Failed to update external session during token refresh", "error", err)
return nil, err
// Update the external session with the new token if we the user has an external session,
// regardless of the feature flag state to keep the `user_external_session` table in sync.
// ExternalSessionID should always be set except for some edge cases:
// - when Grafana was updated to a version where the `improvedExternalSessionHandling` feature flag
// was enabled after the user logged in
if tokenRefreshMetadata.ExternalSessionID != 0 {
if err := o.sessionService.UpdateExternalSession(ctx, tokenRefreshMetadata.ExternalSessionID, &auth.UpdateExternalSessionCommand{
Token: token,
}); err != nil {
ctxLogger.Error("Failed to update external session during token refresh", "error", err)
span.SetStatus(codes.Error, "Failed to update external session during token refresh")
return nil, err
}
}
ctxLogger.Debug("Updated oauth info for user")
@ -502,6 +541,11 @@ func needTokenRefresh(ctx context.Context, persistedToken *oauth2.Token) bool {
ctxLogger := logger.FromContext(ctx)
if persistedToken.AccessToken == "" {
ctxLogger.Debug("Access token has been cleared, need to refresh")
return true
}
idTokenExp, err := GetIDTokenExpiry(persistedToken)
if err != nil {
ctxLogger.Warn("Could not get ID Token expiry", "error", err)
@ -552,22 +596,6 @@ func buildOAuthTokenFromExternalSession(externalSession *auth.ExternalSession) *
return token
}
func checkOAuthRefreshToken(authInfo *login.UserAuth) error {
if !strings.Contains(authInfo.AuthModule, "oauth") {
logger.Warn("The specified user's auth provider is not oauth",
"authmodule", authInfo.AuthModule, "userid", authInfo.UserId)
return ErrNotAnOAuthProvider
}
if authInfo.OAuthRefreshToken == "" {
logger.Warn("No refresh token available",
"authmodule", authInfo.AuthModule, "userid", authInfo.UserId)
return ErrNoRefreshTokenFound
}
return nil
}
// GetIDTokenExpiry extracts the expiry time from the ID token
func GetIDTokenExpiry(token *oauth2.Token) (time.Time, error) {
idToken, ok := token.Extra("id_token").(string)
@ -601,3 +629,28 @@ func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpi
hasTokenExpired = adjustedExpiry.Before(time.Now())
return
}
// getExternalSession fetches the external session based on the user and session token.
// When using the render module, it fetches the most recent external session for the user
// since the session token ID is not available.
// For regular users, it uses the session token ID to fetch the external session.
func (o *Service) getExternalSession(ctx context.Context, usr identity.Requester, userID int64, sessionToken *auth.UserToken) (*auth.ExternalSession, error) {
if usr.GetAuthenticatedBy() == login.RenderModule {
// When using render module, we don't have the session token ID, so we need to fetch the most recent session
// entry for the user (as it is done with the old flow).
// In the future, we might want to consider passing the session token ID to the render module to make this more robust.
externalSessions, err := o.sessionService.FindExternalSessions(ctx, &auth.ListExternalSessionQuery{UserID: userID})
if err != nil {
return nil, err
}
if len(externalSessions) == 0 || externalSessions[0] == nil {
return nil, auth.ErrExternalSessionNotFound
}
return externalSessions[0], nil
}
// For regular users, we use the session token ID to fetch the external session
return o.sessionService.GetExternalSession(ctx, sessionToken.ExternalSessionId)
}

File diff suppressed because it is too large Load Diff

View File

@ -8,13 +8,14 @@ import (
"github.com/grafana/grafana/pkg/apimachinery/identity"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/oauthtoken"
)
type MockOauthTokenService struct {
GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) *oauth2.Token
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
InvalidateOAuthTokensFunc func(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) error
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error)
InvalidateOAuthTokensFunc func(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) error
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error)
}
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) *oauth2.Token {
@ -31,16 +32,16 @@ func (m *MockOauthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSourc
return false
}
func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) error {
func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) error {
if m.InvalidateOAuthTokensFunc != nil {
return m.InvalidateOAuthTokensFunc(ctx, usr, sessionToken)
return m.InvalidateOAuthTokensFunc(ctx, usr, metadata)
}
return nil
}
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error) {
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error) {
if m.TryTokenRefreshFunc != nil {
return m.TryTokenRefreshFunc(ctx, usr, sessionToken)
return m.TryTokenRefreshFunc(ctx, usr, metadata)
}
return nil, nil
}

View File

@ -29,10 +29,10 @@ func (s *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
return oauthtoken.IsOAuthPassThruEnabled(ds)
}
func (s *Service) TryTokenRefresh(context.Context, identity.Requester, *auth.UserToken) (*oauth2.Token, error) {
func (s *Service) TryTokenRefresh(context.Context, identity.Requester, *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error) {
return s.Token, nil
}
func (s *Service) InvalidateOAuthTokens(context.Context, identity.Requester, *auth.UserToken) error {
func (s *Service) InvalidateOAuthTokens(context.Context, identity.Requester, *oauthtoken.TokenRefreshMetadata) error {
return nil
}