diff --git a/pkg/services/auth/authimpl/external_session_store.go b/pkg/services/auth/authimpl/external_session_store.go index 95e379b79b6..33a43355e59 100644 --- a/pkg/services/auth/authimpl/external_session_store.go +++ b/pkg/services/auth/authimpl/external_session_store.go @@ -56,6 +56,8 @@ func (s *store) Get(ctx context.Context, ID int64) (*auth.ExternalSession, error return externalSession, nil } +// List returns a list of external sessions that match the given query. +// If the result set contains more than one entry, the entries are sorted by ID in descending order. func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery) ([]*auth.ExternalSession, error) { ctx, span := s.tracer.Start(ctx, "externalsession.List") defer span.End() @@ -65,6 +67,10 @@ func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery) externalSession.ID = query.ID } + if query.UserID != 0 { + externalSession.UserID = query.UserID + } + hash := sha256.New() if query.SessionID != "" { @@ -80,7 +86,7 @@ func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery) queryResult := make([]*auth.ExternalSession, 0) err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error { - return sess.Find(&queryResult, externalSession) + return sess.Desc("id").Find(&queryResult, externalSession) }) if err != nil { return nil, err diff --git a/pkg/services/auth/external_session.go b/pkg/services/auth/external_session.go index e67b5a10259..32e9e1f4236 100644 --- a/pkg/services/auth/external_session.go +++ b/pkg/services/auth/external_session.go @@ -51,6 +51,7 @@ type UpdateExternalSessionCommand struct { type ListExternalSessionQuery struct { ID int64 + UserID int64 NameID string SessionID string } diff --git a/pkg/services/authn/authnimpl/sync/oauth_token_sync.go b/pkg/services/authn/authnimpl/sync/oauth_token_sync.go index a63853648a7..cb2103437ce 100644 --- a/pkg/services/authn/authnimpl/sync/oauth_token_sync.go +++ b/pkg/services/authn/authnimpl/sync/oauth_token_sync.go @@ -93,7 +93,11 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, id *authn.Ident updateCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 15*time.Second) defer cancel() - token, refreshErr := s.service.TryTokenRefresh(updateCtx, id, id.SessionToken) + token, refreshErr := s.service.TryTokenRefresh(updateCtx, id, &oauthtoken.TokenRefreshMetadata{ + ExternalSessionID: id.SessionToken.ExternalSessionId, + AuthModule: id.GetAuthenticatedBy(), + AuthID: id.GetAuthID(), + }) if refreshErr != nil { if errors.Is(refreshErr, context.Canceled) { return nil, nil @@ -107,7 +111,7 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, id *authn.Ident ctxLogger.Error("Failed to refresh OAuth access token", "id", id.ID, "error", refreshErr) // log the user out - if err := s.sessionService.RevokeToken(ctx, id.SessionToken, false); err != nil { + if err := s.sessionService.RevokeToken(ctx, id.SessionToken, false); err != nil && !errors.Is(err, auth.ErrUserTokenNotFound) { ctxLogger.Warn("Failed to revoke session token", "id", id.ID, "tokenId", id.SessionToken.Id, "error", err) } diff --git a/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go b/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go index 5f5e1303a95..3178f5390f2 100644 --- a/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go +++ b/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go @@ -25,6 +25,7 @@ import ( contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" ) @@ -77,6 +78,14 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) { expectRevokeTokenCalled: false, expectToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)}, }, + { + desc: "should not invalidate session if token refresh fails with no refresh token", + identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule}, + expectedTryRefreshErr: oauthtoken.ErrNoRefreshTokenFound, + expectTryRefreshTokenCalled: true, + expectRevokeTokenCalled: true, + expectedErr: oauthtoken.ErrNoRefreshTokenFound, + }, // TODO: address coverage of oauthtoken sync } @@ -89,7 +98,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) { ) service := &oauthtokentest.MockOauthTokenService{ - TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester, _ *auth.UserToken) (*oauth2.Token, error) { + TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester, _ *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error) { tryRefreshCalled = true return nil, tt.expectedTryRefreshErr }, diff --git a/pkg/services/authn/clients/oauth.go b/pkg/services/authn/clients/oauth.go index 0e8053a39b7..e6d4b67d122 100644 --- a/pkg/services/authn/clients/oauth.go +++ b/pkg/services/authn/clients/oauth.go @@ -297,7 +297,9 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester, sessionToke ctxLogger := c.log.FromContext(ctx).New("userID", userID) - if err := c.oauthService.InvalidateOAuthTokens(ctx, user, sessionToken); err != nil { + if err := c.oauthService.InvalidateOAuthTokens(ctx, user, &oauthtoken.TokenRefreshMetadata{ + ExternalSessionID: sessionToken.ExternalSessionId, + AuthModule: user.GetAuthenticatedBy()}); err != nil { ctxLogger.Error("Failed to invalidate tokens", "error", err) } diff --git a/pkg/services/authn/clients/oauth_test.go b/pkg/services/authn/clients/oauth_test.go index 2bd536d6544..b6a9a07c488 100644 --- a/pkg/services/authn/clients/oauth_test.go +++ b/pkg/services/authn/clients/oauth_test.go @@ -19,10 +19,12 @@ import ( "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social/socialtest" + "github.com/grafana/grafana/pkg/models/usertoken" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/setting" @@ -481,7 +483,7 @@ func TestOAuth_Logout(t *testing.T) { "id_token": "some.id.token", }) }, - InvalidateOAuthTokensFunc: func(_ context.Context, _ identity.Requester, _ *auth.UserToken) error { + InvalidateOAuthTokensFunc: func(_ context.Context, _ identity.Requester, _ *oauthtoken.TokenRefreshMetadata) error { invalidateTokenCalled = true return nil }, @@ -492,7 +494,7 @@ func TestOAuth_Logout(t *testing.T) { } c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, mockService, fakeSocialSvc, &setting.OSSImpl{Cfg: tt.cfg}, featuremgmt.WithFeatures(), tracing.InitializeTracerForTest()) - redirect, ok := c.Logout(context.Background(), &authn.Identity{ID: "1", Type: claims.TypeUser}, nil) + redirect, ok := c.Logout(context.Background(), &authn.Identity{ID: "1", Type: claims.TypeUser}, &usertoken.UserToken{}) assert.Equal(t, tt.expectedOK, ok) if tt.expectedOK { diff --git a/pkg/services/login/authinfo.go b/pkg/services/login/authinfo.go index 3e9751d0ea2..095e3390ce9 100644 --- a/pkg/services/login/authinfo.go +++ b/pkg/services/login/authinfo.go @@ -5,6 +5,7 @@ import ( "strings" ) +//go:generate mockery --name AuthInfoService --structname MockAuthInfoService --outpkg authinfotest --filename auth_info_service_mock.go --output ./authinfotest/ type AuthInfoService interface { GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) (*UserAuth, error) GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error) diff --git a/pkg/services/login/authinfotest/auth_info_service_mock.go b/pkg/services/login/authinfotest/auth_info_service_mock.go new file mode 100644 index 00000000000..42f9bd60b7f --- /dev/null +++ b/pkg/services/login/authinfotest/auth_info_service_mock.go @@ -0,0 +1,765 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package authinfotest + +import ( + "context" + + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/user" + mock "github.com/stretchr/testify/mock" +) + +// NewMockAuthInfoService creates a new instance of MockAuthInfoService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockAuthInfoService(t interface { + mock.TestingT + Cleanup(func()) +}) *MockAuthInfoService { + mock := &MockAuthInfoService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockAuthInfoService is an autogenerated mock type for the AuthInfoService type +type MockAuthInfoService struct { + mock.Mock +} + +type MockAuthInfoService_Expecter struct { + mock *mock.Mock +} + +func (_m *MockAuthInfoService) EXPECT() *MockAuthInfoService_Expecter { + return &MockAuthInfoService_Expecter{mock: &_m.Mock} +} + +// DeleteUserAuthInfo provides a mock function for the type MockAuthInfoService +func (_mock *MockAuthInfoService) DeleteUserAuthInfo(ctx context.Context, userID int64) error { + ret := _mock.Called(ctx, userID) + + if len(ret) == 0 { + panic("no return value specified for DeleteUserAuthInfo") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = returnFunc(ctx, userID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockAuthInfoService_DeleteUserAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteUserAuthInfo' +type MockAuthInfoService_DeleteUserAuthInfo_Call struct { + *mock.Call +} + +// DeleteUserAuthInfo is a helper method to define mock.On call +// - ctx context.Context +// - userID int64 +func (_e *MockAuthInfoService_Expecter) DeleteUserAuthInfo(ctx interface{}, userID interface{}) *MockAuthInfoService_DeleteUserAuthInfo_Call { + return &MockAuthInfoService_DeleteUserAuthInfo_Call{Call: _e.mock.On("DeleteUserAuthInfo", ctx, userID)} +} + +func (_c *MockAuthInfoService_DeleteUserAuthInfo_Call) Run(run func(ctx context.Context, userID int64)) *MockAuthInfoService_DeleteUserAuthInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 int64 + if args[1] != nil { + arg1 = args[1].(int64) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockAuthInfoService_DeleteUserAuthInfo_Call) Return(err error) *MockAuthInfoService_DeleteUserAuthInfo_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockAuthInfoService_DeleteUserAuthInfo_Call) RunAndReturn(run func(ctx context.Context, userID int64) error) *MockAuthInfoService_DeleteUserAuthInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetAuthInfo provides a mock function for the type MockAuthInfoService +func (_mock *MockAuthInfoService) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) { + ret := _mock.Called(ctx, query) + + if len(ret) == 0 { + panic("no return value specified for GetAuthInfo") + } + + var r0 *login.UserAuth + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) (*login.UserAuth, error)); ok { + return returnFunc(ctx, query) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) *login.UserAuth); ok { + r0 = returnFunc(ctx, query) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*login.UserAuth) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *login.GetAuthInfoQuery) error); ok { + r1 = returnFunc(ctx, query) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockAuthInfoService_GetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthInfo' +type MockAuthInfoService_GetAuthInfo_Call struct { + *mock.Call +} + +// GetAuthInfo is a helper method to define mock.On call +// - ctx context.Context +// - query *login.GetAuthInfoQuery +func (_e *MockAuthInfoService_Expecter) GetAuthInfo(ctx interface{}, query interface{}) *MockAuthInfoService_GetAuthInfo_Call { + return &MockAuthInfoService_GetAuthInfo_Call{Call: _e.mock.On("GetAuthInfo", ctx, query)} +} + +func (_c *MockAuthInfoService_GetAuthInfo_Call) Run(run func(ctx context.Context, query *login.GetAuthInfoQuery)) *MockAuthInfoService_GetAuthInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *login.GetAuthInfoQuery + if args[1] != nil { + arg1 = args[1].(*login.GetAuthInfoQuery) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockAuthInfoService_GetAuthInfo_Call) Return(userAuth *login.UserAuth, err error) *MockAuthInfoService_GetAuthInfo_Call { + _c.Call.Return(userAuth, err) + return _c +} + +func (_c *MockAuthInfoService_GetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error)) *MockAuthInfoService_GetAuthInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetUserLabels provides a mock function for the type MockAuthInfoService +func (_mock *MockAuthInfoService) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) { + ret := _mock.Called(ctx, query) + + if len(ret) == 0 { + panic("no return value specified for GetUserLabels") + } + + var r0 map[int64]string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) (map[int64]string, error)); ok { + return returnFunc(ctx, query) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) map[int64]string); ok { + r0 = returnFunc(ctx, query) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, login.GetUserLabelsQuery) error); ok { + r1 = returnFunc(ctx, query) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockAuthInfoService_GetUserLabels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserLabels' +type MockAuthInfoService_GetUserLabels_Call struct { + *mock.Call +} + +// GetUserLabels is a helper method to define mock.On call +// - ctx context.Context +// - query login.GetUserLabelsQuery +func (_e *MockAuthInfoService_Expecter) GetUserLabels(ctx interface{}, query interface{}) *MockAuthInfoService_GetUserLabels_Call { + return &MockAuthInfoService_GetUserLabels_Call{Call: _e.mock.On("GetUserLabels", ctx, query)} +} + +func (_c *MockAuthInfoService_GetUserLabels_Call) Run(run func(ctx context.Context, query login.GetUserLabelsQuery)) *MockAuthInfoService_GetUserLabels_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 login.GetUserLabelsQuery + if args[1] != nil { + arg1 = args[1].(login.GetUserLabelsQuery) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockAuthInfoService_GetUserLabels_Call) Return(int64ToString map[int64]string, err error) *MockAuthInfoService_GetUserLabels_Call { + _c.Call.Return(int64ToString, err) + return _c +} + +func (_c *MockAuthInfoService_GetUserLabels_Call) RunAndReturn(run func(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error)) *MockAuthInfoService_GetUserLabels_Call { + _c.Call.Return(run) + return _c +} + +// SetAuthInfo provides a mock function for the type MockAuthInfoService +func (_mock *MockAuthInfoService) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error { + ret := _mock.Called(ctx, cmd) + + if len(ret) == 0 { + panic("no return value specified for SetAuthInfo") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *login.SetAuthInfoCommand) error); ok { + r0 = returnFunc(ctx, cmd) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockAuthInfoService_SetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetAuthInfo' +type MockAuthInfoService_SetAuthInfo_Call struct { + *mock.Call +} + +// SetAuthInfo is a helper method to define mock.On call +// - ctx context.Context +// - cmd *login.SetAuthInfoCommand +func (_e *MockAuthInfoService_Expecter) SetAuthInfo(ctx interface{}, cmd interface{}) *MockAuthInfoService_SetAuthInfo_Call { + return &MockAuthInfoService_SetAuthInfo_Call{Call: _e.mock.On("SetAuthInfo", ctx, cmd)} +} + +func (_c *MockAuthInfoService_SetAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.SetAuthInfoCommand)) *MockAuthInfoService_SetAuthInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *login.SetAuthInfoCommand + if args[1] != nil { + arg1 = args[1].(*login.SetAuthInfoCommand) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockAuthInfoService_SetAuthInfo_Call) Return(err error) *MockAuthInfoService_SetAuthInfo_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockAuthInfoService_SetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.SetAuthInfoCommand) error) *MockAuthInfoService_SetAuthInfo_Call { + _c.Call.Return(run) + return _c +} + +// UpdateAuthInfo provides a mock function for the type MockAuthInfoService +func (_mock *MockAuthInfoService) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error { + ret := _mock.Called(ctx, cmd) + + if len(ret) == 0 { + panic("no return value specified for UpdateAuthInfo") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *login.UpdateAuthInfoCommand) error); ok { + r0 = returnFunc(ctx, cmd) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockAuthInfoService_UpdateAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAuthInfo' +type MockAuthInfoService_UpdateAuthInfo_Call struct { + *mock.Call +} + +// UpdateAuthInfo is a helper method to define mock.On call +// - ctx context.Context +// - cmd *login.UpdateAuthInfoCommand +func (_e *MockAuthInfoService_Expecter) UpdateAuthInfo(ctx interface{}, cmd interface{}) *MockAuthInfoService_UpdateAuthInfo_Call { + return &MockAuthInfoService_UpdateAuthInfo_Call{Call: _e.mock.On("UpdateAuthInfo", ctx, cmd)} +} + +func (_c *MockAuthInfoService_UpdateAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand)) *MockAuthInfoService_UpdateAuthInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *login.UpdateAuthInfoCommand + if args[1] != nil { + arg1 = args[1].(*login.UpdateAuthInfoCommand) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockAuthInfoService_UpdateAuthInfo_Call) Return(err error) *MockAuthInfoService_UpdateAuthInfo_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockAuthInfoService_UpdateAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error) *MockAuthInfoService_UpdateAuthInfo_Call { + _c.Call.Return(run) + return _c +} + +// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockStore(t interface { + mock.TestingT + Cleanup(func()) +}) *MockStore { + mock := &MockStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockStore is an autogenerated mock type for the Store type +type MockStore struct { + mock.Mock +} + +type MockStore_Expecter struct { + mock *mock.Mock +} + +func (_m *MockStore) EXPECT() *MockStore_Expecter { + return &MockStore_Expecter{mock: &_m.Mock} +} + +// DeleteUserAuthInfo provides a mock function for the type MockStore +func (_mock *MockStore) DeleteUserAuthInfo(ctx context.Context, userID int64) error { + ret := _mock.Called(ctx, userID) + + if len(ret) == 0 { + panic("no return value specified for DeleteUserAuthInfo") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = returnFunc(ctx, userID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockStore_DeleteUserAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteUserAuthInfo' +type MockStore_DeleteUserAuthInfo_Call struct { + *mock.Call +} + +// DeleteUserAuthInfo is a helper method to define mock.On call +// - ctx context.Context +// - userID int64 +func (_e *MockStore_Expecter) DeleteUserAuthInfo(ctx interface{}, userID interface{}) *MockStore_DeleteUserAuthInfo_Call { + return &MockStore_DeleteUserAuthInfo_Call{Call: _e.mock.On("DeleteUserAuthInfo", ctx, userID)} +} + +func (_c *MockStore_DeleteUserAuthInfo_Call) Run(run func(ctx context.Context, userID int64)) *MockStore_DeleteUserAuthInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 int64 + if args[1] != nil { + arg1 = args[1].(int64) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_DeleteUserAuthInfo_Call) Return(err error) *MockStore_DeleteUserAuthInfo_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockStore_DeleteUserAuthInfo_Call) RunAndReturn(run func(ctx context.Context, userID int64) error) *MockStore_DeleteUserAuthInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetAuthInfo provides a mock function for the type MockStore +func (_mock *MockStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) { + ret := _mock.Called(ctx, query) + + if len(ret) == 0 { + panic("no return value specified for GetAuthInfo") + } + + var r0 *login.UserAuth + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) (*login.UserAuth, error)); ok { + return returnFunc(ctx, query) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) *login.UserAuth); ok { + r0 = returnFunc(ctx, query) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*login.UserAuth) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *login.GetAuthInfoQuery) error); ok { + r1 = returnFunc(ctx, query) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockStore_GetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthInfo' +type MockStore_GetAuthInfo_Call struct { + *mock.Call +} + +// GetAuthInfo is a helper method to define mock.On call +// - ctx context.Context +// - query *login.GetAuthInfoQuery +func (_e *MockStore_Expecter) GetAuthInfo(ctx interface{}, query interface{}) *MockStore_GetAuthInfo_Call { + return &MockStore_GetAuthInfo_Call{Call: _e.mock.On("GetAuthInfo", ctx, query)} +} + +func (_c *MockStore_GetAuthInfo_Call) Run(run func(ctx context.Context, query *login.GetAuthInfoQuery)) *MockStore_GetAuthInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *login.GetAuthInfoQuery + if args[1] != nil { + arg1 = args[1].(*login.GetAuthInfoQuery) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_GetAuthInfo_Call) Return(userAuth *login.UserAuth, err error) *MockStore_GetAuthInfo_Call { + _c.Call.Return(userAuth, err) + return _c +} + +func (_c *MockStore_GetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error)) *MockStore_GetAuthInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetUserLabels provides a mock function for the type MockStore +func (_mock *MockStore) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) { + ret := _mock.Called(ctx, query) + + if len(ret) == 0 { + panic("no return value specified for GetUserLabels") + } + + var r0 map[int64]string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) (map[int64]string, error)); ok { + return returnFunc(ctx, query) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) map[int64]string); ok { + r0 = returnFunc(ctx, query) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, login.GetUserLabelsQuery) error); ok { + r1 = returnFunc(ctx, query) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockStore_GetUserLabels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserLabels' +type MockStore_GetUserLabels_Call struct { + *mock.Call +} + +// GetUserLabels is a helper method to define mock.On call +// - ctx context.Context +// - query login.GetUserLabelsQuery +func (_e *MockStore_Expecter) GetUserLabels(ctx interface{}, query interface{}) *MockStore_GetUserLabels_Call { + return &MockStore_GetUserLabels_Call{Call: _e.mock.On("GetUserLabels", ctx, query)} +} + +func (_c *MockStore_GetUserLabels_Call) Run(run func(ctx context.Context, query login.GetUserLabelsQuery)) *MockStore_GetUserLabels_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 login.GetUserLabelsQuery + if args[1] != nil { + arg1 = args[1].(login.GetUserLabelsQuery) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_GetUserLabels_Call) Return(int64ToString map[int64]string, err error) *MockStore_GetUserLabels_Call { + _c.Call.Return(int64ToString, err) + return _c +} + +func (_c *MockStore_GetUserLabels_Call) RunAndReturn(run func(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error)) *MockStore_GetUserLabels_Call { + _c.Call.Return(run) + return _c +} + +// SetAuthInfo provides a mock function for the type MockStore +func (_mock *MockStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error { + ret := _mock.Called(ctx, cmd) + + if len(ret) == 0 { + panic("no return value specified for SetAuthInfo") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *login.SetAuthInfoCommand) error); ok { + r0 = returnFunc(ctx, cmd) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockStore_SetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetAuthInfo' +type MockStore_SetAuthInfo_Call struct { + *mock.Call +} + +// SetAuthInfo is a helper method to define mock.On call +// - ctx context.Context +// - cmd *login.SetAuthInfoCommand +func (_e *MockStore_Expecter) SetAuthInfo(ctx interface{}, cmd interface{}) *MockStore_SetAuthInfo_Call { + return &MockStore_SetAuthInfo_Call{Call: _e.mock.On("SetAuthInfo", ctx, cmd)} +} + +func (_c *MockStore_SetAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.SetAuthInfoCommand)) *MockStore_SetAuthInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *login.SetAuthInfoCommand + if args[1] != nil { + arg1 = args[1].(*login.SetAuthInfoCommand) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_SetAuthInfo_Call) Return(err error) *MockStore_SetAuthInfo_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockStore_SetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.SetAuthInfoCommand) error) *MockStore_SetAuthInfo_Call { + _c.Call.Return(run) + return _c +} + +// UpdateAuthInfo provides a mock function for the type MockStore +func (_mock *MockStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error { + ret := _mock.Called(ctx, cmd) + + if len(ret) == 0 { + panic("no return value specified for UpdateAuthInfo") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *login.UpdateAuthInfoCommand) error); ok { + r0 = returnFunc(ctx, cmd) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockStore_UpdateAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAuthInfo' +type MockStore_UpdateAuthInfo_Call struct { + *mock.Call +} + +// UpdateAuthInfo is a helper method to define mock.On call +// - ctx context.Context +// - cmd *login.UpdateAuthInfoCommand +func (_e *MockStore_Expecter) UpdateAuthInfo(ctx interface{}, cmd interface{}) *MockStore_UpdateAuthInfo_Call { + return &MockStore_UpdateAuthInfo_Call{Call: _e.mock.On("UpdateAuthInfo", ctx, cmd)} +} + +func (_c *MockStore_UpdateAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand)) *MockStore_UpdateAuthInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *login.UpdateAuthInfoCommand + if args[1] != nil { + arg1 = args[1].(*login.UpdateAuthInfoCommand) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_UpdateAuthInfo_Call) Return(err error) *MockStore_UpdateAuthInfo_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockStore_UpdateAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error) *MockStore_UpdateAuthInfo_Call { + _c.Call.Return(run) + return _c +} + +// NewMockUserProtectionService creates a new instance of MockUserProtectionService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockUserProtectionService(t interface { + mock.TestingT + Cleanup(func()) +}) *MockUserProtectionService { + mock := &MockUserProtectionService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockUserProtectionService is an autogenerated mock type for the UserProtectionService type +type MockUserProtectionService struct { + mock.Mock +} + +type MockUserProtectionService_Expecter struct { + mock *mock.Mock +} + +func (_m *MockUserProtectionService) EXPECT() *MockUserProtectionService_Expecter { + return &MockUserProtectionService_Expecter{mock: &_m.Mock} +} + +// AllowUserMapping provides a mock function for the type MockUserProtectionService +func (_mock *MockUserProtectionService) AllowUserMapping(user1 *user.User, authModule string) error { + ret := _mock.Called(user1, authModule) + + if len(ret) == 0 { + panic("no return value specified for AllowUserMapping") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(*user.User, string) error); ok { + r0 = returnFunc(user1, authModule) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockUserProtectionService_AllowUserMapping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllowUserMapping' +type MockUserProtectionService_AllowUserMapping_Call struct { + *mock.Call +} + +// AllowUserMapping is a helper method to define mock.On call +// - user1 *user.User +// - authModule string +func (_e *MockUserProtectionService_Expecter) AllowUserMapping(user1 interface{}, authModule interface{}) *MockUserProtectionService_AllowUserMapping_Call { + return &MockUserProtectionService_AllowUserMapping_Call{Call: _e.mock.On("AllowUserMapping", user1, authModule)} +} + +func (_c *MockUserProtectionService_AllowUserMapping_Call) Run(run func(user1 *user.User, authModule string)) *MockUserProtectionService_AllowUserMapping_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *user.User + if args[0] != nil { + arg0 = args[0].(*user.User) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockUserProtectionService_AllowUserMapping_Call) Return(err error) *MockUserProtectionService_AllowUserMapping_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockUserProtectionService_AllowUserMapping_Call) RunAndReturn(run func(user1 *user.User, authModule string) error) *MockUserProtectionService_AllowUserMapping_Call { + _c.Call.Return(run) + return _c +} diff --git a/pkg/services/oauthtoken/oauth_token.go b/pkg/services/oauthtoken/oauth_token.go index 0304eca46f1..a3503e06522 100644 --- a/pkg/services/oauthtoken/oauth_token.go +++ b/pkg/services/oauthtoken/oauth_token.go @@ -11,6 +11,7 @@ import ( "github.com/go-jose/go-jose/v4/jwt" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" "golang.org/x/oauth2" @@ -57,8 +58,14 @@ var _ OAuthTokenService = (*Service)(nil) type OAuthTokenService interface { GetCurrentOAuthToken(context.Context, identity.Requester, *auth.UserToken) *oauth2.Token IsOAuthPassThruEnabled(*datasources.DataSource) bool - TryTokenRefresh(context.Context, identity.Requester, *auth.UserToken) (*oauth2.Token, error) - InvalidateOAuthTokens(context.Context, identity.Requester, *auth.UserToken) error + TryTokenRefresh(context.Context, identity.Requester, *TokenRefreshMetadata) (*oauth2.Token, error) + InvalidateOAuthTokens(context.Context, identity.Requester, *TokenRefreshMetadata) error +} + +type TokenRefreshMetadata struct { + ExternalSessionID int64 + AuthModule string + AuthID string } func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer, @@ -102,51 +109,71 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Request ctxLogger = ctxLogger.New("userID", userID) - if !strings.HasPrefix(usr.GetAuthenticatedBy(), "oauth_") { - ctxLogger.Warn("The specified user's auth provider is not oauth", - "authmodule", usr.GetAuthenticatedBy()) + tokenRefreshMetadata := &TokenRefreshMetadata{ + ExternalSessionID: 0, + } + var persistedToken *oauth2.Token + // Find the external session associated with the user and session token + // regardless of the improvedExternalSessionHandling feature toggle, + // because Grafana writes and updates both tables to make the switch + // to the new session handling smoother. + externalSession, err := o.getExternalSession(ctx, usr, userID, sessionToken) + if err != nil && !errors.Is(err, auth.ErrExternalSessionNotFound) { + ctxLogger.Error("Failed to get external session", "error", err) + return nil + } + + // If the feature toggle is enabled, an external session is required. + if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) && (externalSession == nil || errors.Is(err, auth.ErrExternalSessionNotFound)) { + ctxLogger.Error("No external session found for user", "userID", userID) + return nil + } + + // externalSession can be nil if Grafana was updated from a version where the + // external session table was not used yet (did not exist) and the user has not logged in since + // the version update (therefore no external session was created for the user yet). + if externalSession != nil { + tokenRefreshMetadata.ExternalSessionID = externalSession.ID + } + + authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{ + UserId: userID, + }) + if err != nil { + if errors.Is(err, user.ErrUserNotFound) { + ctxLogger.Warn("No AuthInfo found for user", "userID", userID) + return nil + } + + ctxLogger.Error("Failed to fetch AuthInfo for user", "userID", userID, "error", err) + return nil + } + + tokenRefreshMetadata.AuthID = authInfo.AuthId + tokenRefreshMetadata.AuthModule = authInfo.AuthModule + + if !strings.HasPrefix(tokenRefreshMetadata.AuthModule, "oauth_") { + ctxLogger.Warn("The specified user's auth provider is not oauth", + "authmodule", tokenRefreshMetadata.AuthModule) return nil } - var persistedToken *oauth2.Token if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) { - externalSession, err := o.sessionService.GetExternalSession(ctx, sessionToken.ExternalSessionId) - if err != nil { - if errors.Is(err, auth.ErrExternalSessionNotFound) { - return nil - } - ctxLogger.Error("Failed to fetch external session", "error", err) - return nil - } - persistedToken = buildOAuthTokenFromExternalSession(externalSession) - - if persistedToken.RefreshToken == "" { - return persistedToken - } } else { - authInfo, ok, _ := o.hasOAuthEntry(ctx, usr) - if !ok { - return nil - } - - if err := checkOAuthRefreshToken(authInfo); err != nil { - if errors.Is(err, ErrNoRefreshTokenFound) { - return buildOAuthTokenFromAuthInfo(authInfo) - } - - return nil - } - persistedToken = buildOAuthTokenFromAuthInfo(authInfo) } + if persistedToken.RefreshToken == "" { + return persistedToken + } + refreshNeeded := needTokenRefresh(ctx, persistedToken) if !refreshNeeded { return persistedToken } - token, err := o.TryTokenRefresh(ctx, usr, sessionToken) + token, err := o.TryTokenRefresh(ctx, usr, tokenRefreshMetadata) if err != nil { if errors.Is(err, ErrNoRefreshTokenFound) { return persistedToken @@ -214,7 +241,7 @@ func (o *Service) hasOAuthEntry(ctx context.Context, usr identity.Requester) (*l // TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful // It uses a server lock to prevent getting the Refresh Token multiple times for a given User -func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error) { +func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, tokenRefreshMetadata *TokenRefreshMetadata) (*oauth2.Token, error) { ctx, span := o.tracer.Start(ctx, "oauthtoken.TryTokenRefresh") defer span.End() @@ -239,14 +266,13 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s ctxLogger = ctxLogger.New("userID", userID) - // get the token's auth provider (f.e. azuread) - currAuthenticator := usr.GetAuthenticatedBy() - if !strings.HasPrefix(currAuthenticator, "oauth") { - ctxLogger.Warn("The specified user's auth provider is not OAuth", "authmodule", currAuthenticator) + if !strings.HasPrefix(tokenRefreshMetadata.AuthModule, "oauth_") { + ctxLogger.Warn("The specified user's auth provider is not oauth", + "authmodule", tokenRefreshMetadata.AuthModule) return nil, nil } - provider := strings.TrimPrefix(currAuthenticator, "oauth_") + provider := strings.TrimPrefix(tokenRefreshMetadata.AuthModule, "oauth_") currentOAuthInfo := o.SocialService.GetOAuthInfoProvider(provider) if currentOAuthInfo == nil { ctxLogger.Warn("OAuth provider not found", "provider", provider) @@ -261,7 +287,7 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s lockKey := fmt.Sprintf("oauth-refresh-token-%d", userID) if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) { - lockKey = fmt.Sprintf("oauth-refresh-token-%d-%d", userID, sessionToken.ExternalSessionId) + lockKey = fmt.Sprintf("oauth-refresh-token-%d-%d", userID, tokenRefreshMetadata.ExternalSessionID) } lockTimeConfig := serverlock.LockTimeConfig{ @@ -290,7 +316,7 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s var persistedToken *oauth2.Token var externalSession *auth.ExternalSession if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) { - externalSession, err = o.sessionService.GetExternalSession(ctx, sessionToken.ExternalSessionId) + externalSession, err = o.sessionService.GetExternalSession(ctx, tokenRefreshMetadata.ExternalSessionID) if err != nil { if errors.Is(err, auth.ErrExternalSessionNotFound) { ctxLogger.Error("External session was not found for user", "error", err) @@ -321,7 +347,7 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s return } - newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, persistedToken, usr, sessionToken) + newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, persistedToken, usr, tokenRefreshMetadata) }, retryOpt) if lockErr != nil { ctxLogger.Error("Failed to obtain token refresh lock", "error", lockErr) @@ -330,14 +356,14 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s // Silence ErrNoRefreshTokenFound if errors.Is(cmdErr, ErrNoRefreshTokenFound) { - return nil, nil + return nil, ErrNoRefreshTokenFound } return newToken, cmdErr } // InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero -func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) error { +func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, tokenRefreshMetadata *TokenRefreshMetadata) error { userID, err := usr.GetInternalID() if err != nil { logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err) @@ -347,7 +373,7 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Reques ctxLogger := logger.FromContext(ctx).New("userID", userID) if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) { - err := o.sessionService.UpdateExternalSession(ctx, sessionToken.ExternalSessionId, &auth.UpdateExternalSessionCommand{ + err := o.sessionService.UpdateExternalSession(ctx, tokenRefreshMetadata.ExternalSessionID, &auth.UpdateExternalSessionCommand{ Token: &oauth2.Token{}, }) if err != nil { @@ -358,8 +384,8 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Reques return o.AuthInfoService.UpdateAuthInfo(ctx, &login.UpdateAuthInfoCommand{ UserId: userID, - AuthModule: usr.GetAuthenticatedBy(), - AuthId: usr.GetAuthID(), + AuthModule: tokenRefreshMetadata.AuthModule, + AuthId: tokenRefreshMetadata.AuthID, OAuthToken: &oauth2.Token{ AccessToken: "", RefreshToken: "", @@ -368,13 +394,14 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Reques }) } -func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken *oauth2.Token, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error) { +func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken *oauth2.Token, usr identity.Requester, tokenRefreshMetadata *TokenRefreshMetadata) (*oauth2.Token, error) { ctx, span := o.tracer.Start(ctx, "oauthtoken.tryGetOrRefreshOAuthToken") defer span.End() userID, err := usr.GetInternalID() if err != nil { logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err) + span.SetStatus(codes.Error, "Failed to convert user id to int") return nil, err } @@ -382,8 +409,11 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken ctxLogger := logger.FromContext(ctx).New("userID", userID) + // tryGetOrRefreshOAuthToken assumes that the AuthModule has RefreshToken enabled + // which is checked by the caller (TryTokenRefresh) if persistedToken.RefreshToken == "" { - ctxLogger.Warn("No refresh token available", "authmodule", usr.GetAuthenticatedBy()) + ctxLogger.Error("No refresh token available", "authmodule", tokenRefreshMetadata.AuthModule) + span.SetStatus(codes.Error, ErrNoRefreshTokenFound.Error()) return nil, ErrNoRefreshTokenFound } @@ -392,50 +422,44 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken return persistedToken, nil } - authProvider := usr.GetAuthenticatedBy() - connect, err := o.SocialService.GetConnector(authProvider) + connect, err := o.SocialService.GetConnector(tokenRefreshMetadata.AuthModule) if err != nil { - ctxLogger.Error("Failed to get oauth connector", "provider", authProvider, "error", err) + ctxLogger.Error("Failed to get oauth connector", "provider", tokenRefreshMetadata.AuthModule, "error", err) + span.SetStatus(codes.Error, "Failed to get oauth connector: "+err.Error()) return nil, err } - client, err := o.SocialService.GetOAuthHttpClient(authProvider) + client, err := o.SocialService.GetOAuthHttpClient(tokenRefreshMetadata.AuthModule) if err != nil { - ctxLogger.Error("Failed to get oauth http client", "provider", authProvider, "error", err) + ctxLogger.Error("Failed to get oauth http client", "provider", tokenRefreshMetadata.AuthModule, "error", err) + span.SetStatus(codes.Error, "Failed to get oauth http client") return nil, err } ctx = context.WithValue(ctx, oauth2.HTTPClient, client) start := time.Now() // TokenSource handles refreshing the token if it has expired - token, err := connect.TokenSource(ctx, persistedToken).Token() + token, refreshErr := connect.TokenSource(ctx, persistedToken).Token() duration := time.Since(start) - o.tokenRefreshDuration.WithLabelValues(authProvider, fmt.Sprintf("%t", err == nil)).Observe(duration.Seconds()) + o.tokenRefreshDuration.WithLabelValues(tokenRefreshMetadata.AuthModule, fmt.Sprintf("%t", err == nil)).Observe(duration.Seconds()) - if err != nil { + if refreshErr != nil { span.SetAttributes(attribute.Bool("token_refreshed", false)) ctxLogger.Error("Failed to retrieve oauth access token", - "provider", usr.GetAuthenticatedBy(), "error", err) + "provider", tokenRefreshMetadata.AuthModule, "error", refreshErr) // token refresh failed, invalidate the old token - if err := o.InvalidateOAuthTokens(ctx, usr, sessionToken); err != nil { - ctxLogger.Warn("Failed to invalidate OAuth tokens", "authID", usr.GetAuthID(), "error", err) + if err := o.InvalidateOAuthTokens(ctx, usr, tokenRefreshMetadata); err != nil { + ctxLogger.Warn("Failed to invalidate OAuth tokens", "authID", tokenRefreshMetadata.AuthID, "error", err) } - return nil, err + return nil, refreshErr } span.SetAttributes(attribute.Bool("token_refreshed", true)) // If the tokens are not the same, update the entry in the DB if !tokensEq(persistedToken, token) { - updateAuthCommand := &login.UpdateAuthInfoCommand{ - UserId: userID, - AuthModule: usr.GetAuthenticatedBy(), - AuthId: usr.GetAuthID(), - OAuthToken: token, - } - if o.Cfg.Env == setting.Dev { ctxLogger.Debug("Oauth got token", "auth_module", usr.GetAuthenticatedBy(), @@ -446,17 +470,32 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken } if !o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) { + updateAuthCommand := &login.UpdateAuthInfoCommand{ + UserId: userID, + AuthModule: tokenRefreshMetadata.AuthModule, + AuthId: tokenRefreshMetadata.AuthID, + OAuthToken: token, + } if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil { - ctxLogger.Error("Failed to update auth info during token refresh", "authID", usr.GetAuthID(), "error", err) + ctxLogger.Error("Failed to update auth info during token refresh", "authID", tokenRefreshMetadata.AuthID, "error", err) + span.SetStatus(codes.Error, "Failed to update auth info during token refresh") return nil, err } } - if err := o.sessionService.UpdateExternalSession(ctx, sessionToken.ExternalSessionId, &auth.UpdateExternalSessionCommand{ - Token: token, - }); err != nil { - ctxLogger.Error("Failed to update external session during token refresh", "error", err) - return nil, err + // Update the external session with the new token if we the user has an external session, + // regardless of the feature flag state to keep the `user_external_session` table in sync. + // ExternalSessionID should always be set except for some edge cases: + // - when Grafana was updated to a version where the `improvedExternalSessionHandling` feature flag + // was enabled after the user logged in + if tokenRefreshMetadata.ExternalSessionID != 0 { + if err := o.sessionService.UpdateExternalSession(ctx, tokenRefreshMetadata.ExternalSessionID, &auth.UpdateExternalSessionCommand{ + Token: token, + }); err != nil { + ctxLogger.Error("Failed to update external session during token refresh", "error", err) + span.SetStatus(codes.Error, "Failed to update external session during token refresh") + return nil, err + } } ctxLogger.Debug("Updated oauth info for user") @@ -502,6 +541,11 @@ func needTokenRefresh(ctx context.Context, persistedToken *oauth2.Token) bool { ctxLogger := logger.FromContext(ctx) + if persistedToken.AccessToken == "" { + ctxLogger.Debug("Access token has been cleared, need to refresh") + return true + } + idTokenExp, err := GetIDTokenExpiry(persistedToken) if err != nil { ctxLogger.Warn("Could not get ID Token expiry", "error", err) @@ -552,22 +596,6 @@ func buildOAuthTokenFromExternalSession(externalSession *auth.ExternalSession) * return token } -func checkOAuthRefreshToken(authInfo *login.UserAuth) error { - if !strings.Contains(authInfo.AuthModule, "oauth") { - logger.Warn("The specified user's auth provider is not oauth", - "authmodule", authInfo.AuthModule, "userid", authInfo.UserId) - return ErrNotAnOAuthProvider - } - - if authInfo.OAuthRefreshToken == "" { - logger.Warn("No refresh token available", - "authmodule", authInfo.AuthModule, "userid", authInfo.UserId) - return ErrNoRefreshTokenFound - } - - return nil -} - // GetIDTokenExpiry extracts the expiry time from the ID token func GetIDTokenExpiry(token *oauth2.Token) (time.Time, error) { idToken, ok := token.Extra("id_token").(string) @@ -601,3 +629,28 @@ func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpi hasTokenExpired = adjustedExpiry.Before(time.Now()) return } + +// getExternalSession fetches the external session based on the user and session token. +// When using the render module, it fetches the most recent external session for the user +// since the session token ID is not available. +// For regular users, it uses the session token ID to fetch the external session. +func (o *Service) getExternalSession(ctx context.Context, usr identity.Requester, userID int64, sessionToken *auth.UserToken) (*auth.ExternalSession, error) { + if usr.GetAuthenticatedBy() == login.RenderModule { + // When using render module, we don't have the session token ID, so we need to fetch the most recent session + // entry for the user (as it is done with the old flow). + // In the future, we might want to consider passing the session token ID to the render module to make this more robust. + externalSessions, err := o.sessionService.FindExternalSessions(ctx, &auth.ListExternalSessionQuery{UserID: userID}) + if err != nil { + return nil, err + } + + if len(externalSessions) == 0 || externalSessions[0] == nil { + return nil, auth.ErrExternalSessionNotFound + } + + return externalSessions[0], nil + } + + // For regular users, we use the session token ID to fetch the external session + return o.sessionService.GetExternalSession(ctx, sessionToken.ExternalSessionId) +} diff --git a/pkg/services/oauthtoken/oauth_token_test.go b/pkg/services/oauthtoken/oauth_token_test.go index 487a4f6426e..2dfc3e6e68e 100644 --- a/pkg/services/oauthtoken/oauth_token_test.go +++ b/pkg/services/oauthtoken/oauth_token_test.go @@ -2,7 +2,6 @@ package oauthtoken import ( "context" - "errors" "testing" "time" @@ -18,7 +17,6 @@ import ( "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social/socialtest" - "github.com/grafana/grafana/pkg/models/usertoken" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/authn" @@ -38,69 +36,57 @@ func TestMain(m *testing.M) { testsuite.Run(m) } -type FakeAuthInfoStore struct { - login.Store - ExpectedError error - ExpectedOAuth *login.UserAuth -} +var ( + unexpiredTokenWithoutRefresh = &oauth2.Token{ + AccessToken: "testaccess", + Expiry: time.Now().Add(time.Hour), + TokenType: "Bearer", + } -func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) { - return f.ExpectedOAuth, f.ExpectedError -} + unexpiredTokenWithoutRefreshWithIDToken = unexpiredTokenWithoutRefresh.WithExtra(map[string]interface{}{ + "id_token": UNEXPIRED_ID_TOKEN, + }) -func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error { - return f.ExpectedError -} - -func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error { - f.ExpectedOAuth.OAuthAccessToken = cmd.OAuthToken.AccessToken - f.ExpectedOAuth.OAuthExpiry = cmd.OAuthToken.Expiry - f.ExpectedOAuth.OAuthTokenType = cmd.OAuthToken.TokenType - f.ExpectedOAuth.OAuthRefreshToken = cmd.OAuthToken.RefreshToken - return f.ExpectedError -} - -func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error { - return f.ExpectedError -} - -func TestIntegration_TryTokenRefresh(t *testing.T) { - testutil.SkipIntegrationTestInShortMode(t) - - unexpiredToken := &oauth2.Token{ + unexpiredToken = &oauth2.Token{ AccessToken: "testaccess", RefreshToken: "testrefresh", Expiry: time.Now().Add(time.Hour), TokenType: "Bearer", } - unexpiredTokenWithIDToken := unexpiredToken.WithExtra(map[string]interface{}{ + + unexpiredTokenWithIDToken = unexpiredToken.WithExtra(map[string]interface{}{ "id_token": UNEXPIRED_ID_TOKEN, }) - expiredToken := &oauth2.Token{ + expiredToken = &oauth2.Token{ AccessToken: "testaccess", RefreshToken: "testrefresh", Expiry: time.Now().Add(-time.Hour), TokenType: "Bearer", } +) - type environment struct { - sessionService *authtest.MockUserAuthTokenService - authInfoService *authinfotest.FakeService - serverLock *serverlock.ServerLockService - socialConnector *socialtest.MockSocialConnector - socialService *socialtest.FakeSocialService +type environment struct { + sessionService *authtest.MockUserAuthTokenService + authInfoService *authinfotest.MockAuthInfoService + serverLock *serverlock.ServerLockService + socialConnector *socialtest.MockSocialConnector + socialService *socialtest.FakeSocialService - store db.DB - service *Service - } + store db.DB + service *Service +} + +func TestIntegration_TryTokenRefresh(t *testing.T) { + testutil.SkipIntegrationTestInShortMode(t) type testCase struct { - desc string - identity identity.Requester - setup func(env *environment) - expectedToken *oauth2.Token - expectedErr error + desc string + identity identity.Requester + refreshMetadata *TokenRefreshMetadata + setup func(env *environment) + expectedToken *oauth2.Token + expectedErr error } userIdentity := &authn.Identity{ @@ -122,53 +108,74 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { identity: &authn.Identity{ID: "invalid", Type: claims.TypeUser}, }, { - desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned", - identity: userIdentity, - setup: func(env *environment) { - env.authInfoService.ExpectedError = errors.New("some error") - }, + desc: "should skip token refresh when no oauth provider was found", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.SAMLAuthModule}, }, { - desc: "should skip token refresh if the user doesn't have an oauth entry", - identity: userIdentity, + desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { - env.authInfoService.ExpectedUserAuth = &login.UserAuth{ - AuthModule: login.SAMLAuthModule, - } - }, - }, - { - desc: "should skip token refresh when no oauth provider was found", - identity: userIdentity, - setup: func(env *environment) { - env.authInfoService.ExpectedUserAuth = &login.UserAuth{ - AuthModule: login.GenericOAuthModule, - } - }, - }, - { - desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)", - identity: userIdentity, - setup: func(env *environment) { - env.authInfoService.ExpectedUserAuth = &login.UserAuth{ - AuthModule: login.GenericOAuthModule, - } env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: false, } }, }, { - desc: "should skip token refresh when the token is still valid and no id token is present", - identity: userIdentity, + desc: "should skip token refresh if there's an unexpected error while looking up the user auth entry, additionally, no error should be returned", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { - env.authInfoService.ExpectedUserAuth = &login.UserAuth{ + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(nil, assert.AnError).Once() + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + }, + }, + { + desc: "should skip token refresh when there is no refresh token and the provider does not require one", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: false, + } + }, + expectedToken: nil, + }, + { + desc: "should return error when there is no refresh token and provider requires one", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, + setup: func(env *environment) { + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + OAuthAccessToken: expiredToken.AccessToken, + OAuthRefreshToken: "", + OAuthExpiry: expiredToken.Expiry, + }, nil) + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + }, + expectedToken: nil, + expectedErr: ErrNoRefreshTokenFound, + }, + { + desc: "should skip token refresh when the token is still valid and no id token is present", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, + setup: func(env *environment) { + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ AuthModule: login.GenericOAuthModule, OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken, OAuthExpiry: unexpiredTokenWithIDToken.Expiry, OAuthTokenType: unexpiredTokenWithIDToken.TokenType, - } + }, nil) env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: true, @@ -177,17 +184,18 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { expectedToken: unexpiredToken, }, { - desc: "should not refresh the tokens if access token or id token have not expired yet", - identity: userIdentity, + desc: "should not refresh the tokens if access token or id token have not expired yet", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { - env.authInfoService.ExpectedUserAuth = &login.UserAuth{ + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ AuthModule: login.GenericOAuthModule, OAuthIdToken: UNEXPIRED_ID_TOKEN, OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken, OAuthExpiry: unexpiredTokenWithIDToken.Expiry, OAuthTokenType: unexpiredTokenWithIDToken.TokenType, - } + }, nil) env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: true, @@ -196,33 +204,14 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { expectedToken: unexpiredTokenWithIDToken, }, { - desc: "should skip token refresh when there is no refresh token", - identity: userIdentity, - setup: func(env *environment) { - env.authInfoService.ExpectedUserAuth = &login.UserAuth{ - AuthModule: login.GenericOAuthModule, - OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, - OAuthRefreshToken: "", - OAuthExpiry: unexpiredTokenWithIDToken.Expiry, - } - env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ - UseRefreshToken: true, - } - }, - expectedToken: &oauth2.Token{ - AccessToken: unexpiredTokenWithIDToken.AccessToken, - RefreshToken: "", - Expiry: unexpiredTokenWithIDToken.Expiry, - }, - }, - { - desc: "should do token refresh when the token is expired", - identity: userIdentity, + desc: "should do token refresh when the token is expired", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: true, } - env.authInfoService.ExpectedUserAuth = &login.UserAuth{ + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ AuthModule: login.GenericOAuthModule, AuthId: "subject", UserId: 1, @@ -231,7 +220,16 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { OAuthExpiry: expiredToken.Expiry, OAuthTokenType: expiredToken.TokenType, OAuthIdToken: EXPIRED_ID_TOKEN, - } + }, nil) + + env.authInfoService.On("UpdateAuthInfo", mock.Anything, mock.MatchedBy(func(cmd *login.UpdateAuthInfoCommand) bool { + return cmd.UserId == 1234 && cmd.AuthModule == login.GenericOAuthModule && + cmd.OAuthToken.AccessToken == unexpiredTokenWithIDToken.AccessToken && + cmd.OAuthToken.RefreshToken == unexpiredTokenWithIDToken.RefreshToken && + cmd.OAuthToken.Expiry.Equal(unexpiredTokenWithIDToken.Expiry) && + cmd.OAuthToken.TokenType == unexpiredTokenWithIDToken.TokenType + })).Return(nil).Once() + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() @@ -239,13 +237,14 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { expectedToken: unexpiredTokenWithIDToken, }, { - desc: "should refresh token when the id token is expired", - identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}, + desc: "should refresh token when the id token is expired", + identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: true, } - env.authInfoService.ExpectedUserAuth = &login.UserAuth{ + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ AuthModule: login.GenericOAuthModule, AuthId: "subject", UserId: 1, @@ -254,7 +253,16 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { OAuthExpiry: unexpiredTokenWithIDToken.Expiry, OAuthTokenType: unexpiredTokenWithIDToken.TokenType, OAuthIdToken: EXPIRED_ID_TOKEN, - } + }, nil) + + env.authInfoService.On("UpdateAuthInfo", mock.Anything, mock.MatchedBy(func(cmd *login.UpdateAuthInfoCommand) bool { + return cmd.UserId == 1234 && cmd.AuthModule == login.GenericOAuthModule && + cmd.OAuthToken.AccessToken == unexpiredTokenWithIDToken.AccessToken && + cmd.OAuthToken.RefreshToken == unexpiredTokenWithIDToken.RefreshToken && + cmd.OAuthToken.Expiry.Equal(unexpiredTokenWithIDToken.Expiry) && + cmd.OAuthToken.TokenType == unexpiredTokenWithIDToken.TokenType + })).Return(nil).Once() + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() @@ -262,22 +270,14 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { expectedToken: unexpiredTokenWithIDToken, }, { - desc: "should return ErrRetriesExhausted when lock cannot be acquired", - identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}, + desc: "should return ErrRetriesExhausted when lock cannot be acquired", + identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: true, } - env.authInfoService.ExpectedUserAuth = &login.UserAuth{ - AuthModule: login.GenericOAuthModule, - AuthId: "subject", - UserId: 1234, - OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, - OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken, - OAuthExpiry: unexpiredTokenWithIDToken.Expiry, - OAuthTokenType: unexpiredTokenWithIDToken.TokenType, - OAuthIdToken: EXPIRED_ID_TOKEN, - } + _ = env.store.WithDbSession(context.Background(), func(sess *db.Session) error { _, err := sess.Exec(`INSERT INTO server_lock (operation_uid, last_execution, version) VALUES (?, ?, ?)`, "oauth-refresh-token-1234", time.Now().Add(2*time.Second).Unix(), 0) return err @@ -285,6 +285,42 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { }, expectedErr: ErrRetriesExhausted, }, + { + desc: "should be able to refresh token when the caller is render service and the access token is expired", + identity: &authn.Identity{ + AuthenticatedBy: login.RenderModule, + ID: "1", + Type: claims.TypeUser, + }, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.MatchedBy(func(query *login.GetAuthInfoQuery) bool { + return query.UserId == 1 + })).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + AuthId: "subject", + UserId: 1, + OAuthAccessToken: expiredToken.AccessToken, + OAuthRefreshToken: expiredToken.RefreshToken, + OAuthExpiry: expiredToken.Expiry, + OAuthTokenType: expiredToken.TokenType, + OAuthIdToken: EXPIRED_ID_TOKEN, + }, nil).Once() + env.authInfoService.On("UpdateAuthInfo", mock.Anything, mock.MatchedBy(func(cmd *login.UpdateAuthInfoCommand) bool { + return cmd.UserId == 1 && cmd.AuthModule == login.GenericOAuthModule && + cmd.OAuthToken.AccessToken == unexpiredTokenWithIDToken.AccessToken && + cmd.OAuthToken.RefreshToken == unexpiredTokenWithIDToken.RefreshToken && + cmd.OAuthToken.Expiry.Equal(unexpiredTokenWithIDToken.Expiry) && + cmd.OAuthToken.TokenType == unexpiredTokenWithIDToken.TokenType + })).Return(nil).Once() + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + }, + expectedToken: unexpiredTokenWithIDToken, + }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { @@ -294,7 +330,7 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { env := environment{ sessionService: authtest.NewMockUserAuthTokenService(t), - authInfoService: &authinfotest.FakeService{}, + authInfoService: authinfotest.NewMockAuthInfoService(t), serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()), socialConnector: socialConnector, socialService: &socialtest.FakeSocialService{ @@ -319,7 +355,7 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { ) // token refresh - actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity, &usertoken.UserToken{ExternalSessionId: 1}) + actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity, tt.refreshMetadata) if tt.expectedErr != nil { assert.ErrorIs(t, err, tt.expectedErr) @@ -347,45 +383,19 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { testutil.SkipIntegrationTestInShortMode(t) - unexpiredToken := &oauth2.Token{ - AccessToken: "testaccess", - RefreshToken: "testrefresh", - Expiry: time.Now().Add(time.Hour), - TokenType: "Bearer", - } - unexpiredTokenWithIDToken := unexpiredToken.WithExtra(map[string]interface{}{ - "id_token": UNEXPIRED_ID_TOKEN, - }) - - expiredToken := &oauth2.Token{ - AccessToken: "testaccess", - RefreshToken: "testrefresh", - Expiry: time.Now().Add(-time.Hour), - TokenType: "Bearer", - } - userIdentity := &authn.Identity{ AuthenticatedBy: login.GenericOAuthModule, ID: "1234", Type: claims.TypeUser, } - type environment struct { - sessionService *authtest.MockUserAuthTokenService - serverLock *serverlock.ServerLockService - socialConnector *socialtest.MockSocialConnector - socialService *socialtest.FakeSocialService - - store db.DB - service *Service - } - type testCase struct { - desc string - identity identity.Requester - setup func(env *environment) - expectedToken *oauth2.Token - expectedErr error + desc string + identity identity.Requester + refreshMetadata *TokenRefreshMetadata + setup func(env *environment) + expectedToken *oauth2.Token + expectedErr error } tests := []testCase{ @@ -401,8 +411,14 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { identity: &authn.Identity{ID: "invalid", Type: claims.TypeUser}, }, { - desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned", - identity: userIdentity, + desc: "should skip token refresh when no oauth provider was found", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.SAMLAuthModule}, + }, + { + desc: "should skip token refresh if there's an unexpected error while looking up the external session entry, additionally, no error should be returned", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(nil, assert.AnError).Once() @@ -411,10 +427,11 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { } }, }, - // Kinda impossible to happen, can only happen after the feature is enabled and logged in users don't have their external sessions set + // Edge case, can only happen after the feature is enabled and logged in users don't have their external sessions set { - desc: "should skip token refresh if the user doesn't have an external session", - identity: userIdentity, + desc: "should skip token refresh if the user doesn't have an external session", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(nil, auth.ErrExternalSessionNotFound).Once() @@ -424,15 +441,17 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { }, }, { - desc: "should skip token refresh when no oauth provider was found", - identity: userIdentity, + desc: "should skip token refresh when no oauth provider was found", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.socialService.ExpectedAuthInfoProvider = nil }, }, { - desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)", - identity: userIdentity, + desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: false, @@ -440,8 +459,9 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { }, }, { - desc: "should skip token refresh when the token is still valid and no id token is present", - identity: userIdentity, + desc: "should skip token refresh when the token is still valid and no id token is present", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ ID: 1, @@ -458,8 +478,40 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { expectedToken: unexpiredToken, }, { - desc: "should not do token refresh if access token or id token have not expired yet", - identity: userIdentity, + desc: "should skip token refresh when there is no refresh token and the provider does not require one", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: false, + } + }, + expectedToken: nil, + }, + { + desc: "should return error when there is no refresh token and provider requires one", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, + setup: func(env *environment) { + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1, + AccessToken: expiredToken.AccessToken, + RefreshToken: "", + ExpiresAt: expiredToken.Expiry, + }, nil).Once() + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + }, + expectedToken: nil, + expectedErr: ErrNoRefreshTokenFound, + }, + { + desc: "should not do token refresh if access token or id token have not expired yet", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ ID: 1, @@ -477,42 +529,17 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { expectedToken: unexpiredTokenWithIDToken, }, { - desc: "should skip token refresh when there is no refresh token", - identity: userIdentity, - setup: func(env *environment) { - env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ - ID: 1, - UserID: 1, - AccessToken: unexpiredTokenWithIDToken.AccessToken, - RefreshToken: "", - ExpiresAt: unexpiredTokenWithIDToken.Expiry, - }, nil).Once() - - env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ - UseRefreshToken: true, - } - }, - expectedToken: &oauth2.Token{ - AccessToken: unexpiredTokenWithIDToken.AccessToken, - RefreshToken: "", - Expiry: unexpiredTokenWithIDToken.Expiry, - }, - }, - { - desc: "should refresh token when the access token is expired", - identity: &authn.Identity{ - AuthenticatedBy: login.GenericOAuthModule, - ID: "1", - Type: claims.TypeUser, - }, + desc: "should refresh token when the access token is expired", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ ID: 1, UserID: 1, AccessToken: expiredToken.AccessToken, - IDToken: UNEXPIRED_ID_TOKEN, RefreshToken: expiredToken.RefreshToken, ExpiresAt: expiredToken.Expiry, + IDToken: UNEXPIRED_ID_TOKEN, }, nil).Once() env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() @@ -526,12 +553,13 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { expectedToken: unexpiredTokenWithIDToken, }, { - desc: "should refresh token when the id token is expired", - identity: userIdentity, + desc: "should refresh token when the id token is expired", + identity: userIdentity, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ ID: 1, - UserID: 1, + UserID: 1234, AccessToken: unexpiredTokenWithIDToken.AccessToken, RefreshToken: unexpiredTokenWithIDToken.RefreshToken, ExpiresAt: unexpiredTokenWithIDToken.Expiry, @@ -549,8 +577,38 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { expectedToken: unexpiredTokenWithIDToken, }, { - desc: "should return ErrRetriesExhausted when lock cannot be acquired", - identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}, + desc: "should be able to refresh token when the caller is render service and the access token is expired", + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, + identity: &authn.Identity{ + AuthenticatedBy: login.RenderModule, + ID: "1", + Type: claims.TypeUser, + }, + setup: func(env *environment) { + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1, + AuthModule: login.RenderModule, + AccessToken: expiredToken.AccessToken, + RefreshToken: expiredToken.RefreshToken, + ExpiresAt: expiredToken.Expiry, + IDToken: UNEXPIRED_ID_TOKEN, + }, nil).Once() + + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + }, + expectedToken: unexpiredTokenWithIDToken, + }, + { + desc: "should return ErrRetriesExhausted when lock cannot be acquired", + identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule}, + refreshMetadata: &TokenRefreshMetadata{ExternalSessionID: 1, AuthModule: login.GenericOAuthModule}, setup: func(env *environment) { env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ UseRefreshToken: true, @@ -572,6 +630,7 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { env := environment{ sessionService: authtest.NewMockUserAuthTokenService(t), + authInfoService: authinfotest.NewMockAuthInfoService(t), serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()), socialConnector: socialConnector, socialService: &socialtest.FakeSocialService{ @@ -586,7 +645,7 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { env.service = ProvideService( env.socialService, - nil, + env.authInfoService, setting.NewCfg(), prometheus.NewRegistry(), env.serverLock, @@ -596,7 +655,7 @@ func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) { ) // token refresh - actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity, &usertoken.UserToken{ExternalSessionId: 1}) + actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity, tt.refreshMetadata) if tt.expectedErr != nil { assert.ErrorIs(t, err, tt.expectedErr) @@ -635,34 +694,44 @@ func verifyUpdateExternalSessionCommand(token *oauth2.Token) func(*auth.UpdateEx func TestOAuthTokenSync_needTokenRefresh(t *testing.T) { tests := []struct { name string - usr *login.UserAuth + token *oauth2.Token expectedTokenRefreshFlag bool expectedTokenDuration time.Duration }{ { - name: "should not need token refresh when token has no expiration date", - usr: &login.UserAuth{}, + name: "should not need token refresh when token has no expiration date", + token: &oauth2.Token{ + AccessToken: "some_access_token", + Expiry: time.Time{}, + }, expectedTokenRefreshFlag: false, }, { name: "should not need token refresh with an invalid jwt token that might result in an error when parsing", - usr: &login.UserAuth{ - OAuthIdToken: "invalid_jwt_format", - }, + token: (&oauth2.Token{ + AccessToken: "some_access_token", + }).WithExtra(map[string]any{"id_token": "invalid_jwt_format"}), expectedTokenRefreshFlag: false, }, { - name: "should flag token refresh with id token is expired", - usr: &login.UserAuth{ - OAuthIdToken: EXPIRED_ID_TOKEN, + name: "should flag token refresh when access token is empty", + token: &oauth2.Token{ + AccessToken: "", }, expectedTokenRefreshFlag: true, + }, + { + name: "should flag token refresh with id token is expired", + token: (&oauth2.Token{ + AccessToken: "some_access_token"}).WithExtra(map[string]any{"id_token": EXPIRED_ID_TOKEN}), + expectedTokenRefreshFlag: true, expectedTokenDuration: time.Second, }, { name: "should flag token refresh when expiry date is zero", - usr: &login.UserAuth{ - OAuthExpiry: time.Unix(0, 0), + token: &oauth2.Token{ + AccessToken: "some_access_token", + Expiry: time.Unix(0, 0), }, expectedTokenRefreshFlag: true, expectedTokenDuration: time.Second, @@ -670,10 +739,686 @@ func TestOAuthTokenSync_needTokenRefresh(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - token := buildOAuthTokenFromAuthInfo(tt.usr) - needsTokenRefresh := needTokenRefresh(context.Background(), token) + needsTokenRefresh := needTokenRefresh(context.Background(), tt.token) assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh) }) } } + +func TestIntegration_GetCurrentOAuthToken(t *testing.T) { + testutil.SkipIntegrationTestInShortMode(t) + + type testCase struct { + desc string + identity identity.Requester + sessionToken *auth.UserToken + setup func(env *environment) + expectedToken *oauth2.Token + } + + userIdentity := &authn.Identity{ + AuthenticatedBy: login.GenericOAuthModule, + ID: "1234", + Type: claims.TypeUser, + } + + tests := []testCase{ + { + desc: "should return nil when identity is nil", + identity: nil, + expectedToken: nil, + }, + { + desc: "should return nil when identity is not a user", + identity: &authn.Identity{ID: "1", Type: claims.TypeServiceAccount}, + expectedToken: nil, + }, + { + desc: "should refresh token for render service user", + identity: &authn.Identity{ID: "1", Type: claims.TypeUser, AuthenticatedBy: login.RenderModule}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + AuthId: "subject", + UserId: 1, + OAuthAccessToken: expiredToken.AccessToken, + OAuthRefreshToken: expiredToken.RefreshToken, + OAuthExpiry: expiredToken.Expiry, + OAuthTokenType: expiredToken.TokenType, + OAuthIdToken: EXPIRED_ID_TOKEN, + }, nil) + + env.sessionService.On("FindExternalSessions", mock.Anything, &auth.ListExternalSessionQuery{UserID: 1}).Return([]*auth.ExternalSession{ + { + ID: 1, + UserID: 1, + AuthModule: login.GenericOAuthModule, + AccessToken: expiredToken.AccessToken, + RefreshToken: expiredToken.RefreshToken, + ExpiresAt: expiredToken.Expiry, + IDToken: EXPIRED_ID_TOKEN, + }, + }, nil).Once() + + env.authInfoService.On("UpdateAuthInfo", mock.Anything, mock.MatchedBy(func(cmd *login.UpdateAuthInfoCommand) bool { + return cmd.UserId == 1 && cmd.AuthModule == login.GenericOAuthModule && + cmd.OAuthToken.AccessToken == unexpiredTokenWithIDToken.AccessToken && + cmd.OAuthToken.RefreshToken == unexpiredTokenWithIDToken.RefreshToken && + cmd.OAuthToken.Expiry.Equal(unexpiredTokenWithIDToken.Expiry) && + cmd.OAuthToken.TokenType == unexpiredTokenWithIDToken.TokenType + })).Return(nil).Once() + + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + }, + expectedToken: unexpiredTokenWithIDToken, + }, + { + desc: "should refresh token for render service user with multiple external sessions", + identity: &authn.Identity{ID: "1", Type: claims.TypeUser, AuthenticatedBy: login.RenderModule}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + AuthId: "subject", + UserId: 1, + OAuthAccessToken: expiredToken.AccessToken, + OAuthRefreshToken: expiredToken.RefreshToken, + OAuthExpiry: expiredToken.Expiry, + OAuthTokenType: expiredToken.TokenType, + OAuthIdToken: EXPIRED_ID_TOKEN, + }, nil) + + // Return multiple external sessions, the most recent one is returned first by the query + env.sessionService.On("FindExternalSessions", mock.Anything, &auth.ListExternalSessionQuery{UserID: 1}).Return([]*auth.ExternalSession{ + { + ID: 2, // newer session + UserID: 1, + AuthModule: login.GenericOAuthModule, + AccessToken: expiredToken.AccessToken, + RefreshToken: expiredToken.RefreshToken, + ExpiresAt: expiredToken.Expiry, + IDToken: EXPIRED_ID_TOKEN, + }, + { + ID: 1, // older session + UserID: 1, + AuthModule: login.GenericOAuthModule, + }}, nil).Once() + + env.authInfoService.On("UpdateAuthInfo", mock.Anything, mock.MatchedBy(func(cmd *login.UpdateAuthInfoCommand) bool { + return cmd.UserId == 1 && cmd.AuthModule == login.GenericOAuthModule && + cmd.OAuthToken.AccessToken == unexpiredTokenWithIDToken.AccessToken && + cmd.OAuthToken.RefreshToken == unexpiredTokenWithIDToken.RefreshToken && + cmd.OAuthToken.Expiry.Equal(unexpiredTokenWithIDToken.Expiry) && + cmd.OAuthToken.TokenType == unexpiredTokenWithIDToken.TokenType + })).Return(nil).Once() + + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(2), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + }, + expectedToken: unexpiredTokenWithIDToken, + }, + { + desc: "should skip token refresh when the token is still valid and no id token is present", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + OAuthAccessToken: unexpiredToken.AccessToken, + OAuthRefreshToken: unexpiredToken.RefreshToken, + OAuthExpiry: unexpiredToken.Expiry, + OAuthTokenType: unexpiredToken.TokenType, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1234, + AuthModule: login.GenericOAuthModule, + AccessToken: unexpiredToken.AccessToken, + RefreshToken: unexpiredToken.RefreshToken, + ExpiresAt: unexpiredToken.Expiry, + }, nil).Once() + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + }, + expectedToken: unexpiredToken, + }, + { + desc: "should not do token refresh if access token or id token have not expired yet", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + OAuthIdToken: UNEXPIRED_ID_TOKEN, + OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, + OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken, + OAuthExpiry: unexpiredTokenWithIDToken.Expiry, + OAuthTokenType: unexpiredTokenWithIDToken.TokenType, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1234, + AuthModule: login.GenericOAuthModule, + AccessToken: unexpiredToken.AccessToken, + RefreshToken: unexpiredToken.RefreshToken, + ExpiresAt: unexpiredToken.Expiry, + }, nil).Once() + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + }, + expectedToken: unexpiredTokenWithIDToken, + }, + { + desc: "should return the unexpired access and id token when token refresh is disabled", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + OAuthIdToken: UNEXPIRED_ID_TOKEN, + OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, + OAuthExpiry: unexpiredTokenWithIDToken.Expiry, + OAuthTokenType: unexpiredTokenWithIDToken.TokenType, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1234, + AuthModule: login.GenericOAuthModule, + AccessToken: unexpiredToken.AccessToken, + ExpiresAt: unexpiredToken.Expiry, + }, nil).Once() + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: false, + } + }, + expectedToken: unexpiredTokenWithoutRefreshWithIDToken, + }, + // Edge case, can only happen after the feature is enabled and logged in users don't have their external sessions set, + { + desc: "should refresh token when the access token is expired and the external session was not found", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + AuthId: "subject", + UserId: 1234, + OAuthAccessToken: expiredToken.AccessToken, + OAuthRefreshToken: expiredToken.RefreshToken, + OAuthExpiry: expiredToken.Expiry, + OAuthTokenType: expiredToken.TokenType, + OAuthIdToken: EXPIRED_ID_TOKEN, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(nil, auth.ErrExternalSessionNotFound).Once() + + env.authInfoService.On("UpdateAuthInfo", mock.Anything, mock.MatchedBy(func(cmd *login.UpdateAuthInfoCommand) bool { + return cmd.UserId == 1234 && cmd.AuthModule == login.GenericOAuthModule && + cmd.OAuthToken.AccessToken == unexpiredTokenWithIDToken.AccessToken && + cmd.OAuthToken.RefreshToken == unexpiredTokenWithIDToken.RefreshToken && + cmd.OAuthToken.Expiry.Equal(unexpiredTokenWithIDToken.Expiry) && + cmd.OAuthToken.TokenType == unexpiredTokenWithIDToken.TokenType + })).Return(nil).Once() + + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + }, + expectedToken: unexpiredTokenWithIDToken, + }, + { + desc: "should refresh token when the access token is expired", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + AuthId: "subject", + UserId: 1234, + OAuthAccessToken: expiredToken.AccessToken, + OAuthRefreshToken: expiredToken.RefreshToken, + OAuthExpiry: expiredToken.Expiry, + OAuthTokenType: expiredToken.TokenType, + OAuthIdToken: EXPIRED_ID_TOKEN, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1234, + AccessToken: expiredToken.AccessToken, + RefreshToken: expiredToken.RefreshToken, + ExpiresAt: expiredToken.Expiry, + IDToken: UNEXPIRED_ID_TOKEN, + }, nil).Once() + + env.authInfoService.On("UpdateAuthInfo", mock.Anything, mock.MatchedBy(func(cmd *login.UpdateAuthInfoCommand) bool { + return cmd.UserId == 1234 && cmd.AuthModule == login.GenericOAuthModule && + cmd.OAuthToken.AccessToken == unexpiredTokenWithIDToken.AccessToken && + cmd.OAuthToken.RefreshToken == unexpiredTokenWithIDToken.RefreshToken && + cmd.OAuthToken.Expiry.Equal(unexpiredTokenWithIDToken.Expiry) && + cmd.OAuthToken.TokenType == unexpiredTokenWithIDToken.TokenType + })).Return(nil).Once() + + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + }, + expectedToken: unexpiredTokenWithIDToken, + }, + { + desc: "should refresh token when the id token is expired", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + AuthId: "subject", + UserId: 1234, + OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken, + OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken, + OAuthExpiry: unexpiredTokenWithIDToken.Expiry, + OAuthTokenType: unexpiredTokenWithIDToken.TokenType, + OAuthIdToken: EXPIRED_ID_TOKEN, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1234, + AuthModule: login.GenericOAuthModule, + AccessToken: unexpiredToken.AccessToken, + RefreshToken: unexpiredToken.RefreshToken, + ExpiresAt: unexpiredToken.Expiry, + IDToken: EXPIRED_ID_TOKEN, + }, nil).Once() + + env.authInfoService.On("UpdateAuthInfo", mock.Anything, mock.MatchedBy(func(cmd *login.UpdateAuthInfoCommand) bool { + return cmd.UserId == 1234 && cmd.AuthModule == login.GenericOAuthModule && + cmd.OAuthToken.AccessToken == unexpiredTokenWithIDToken.AccessToken && + cmd.OAuthToken.RefreshToken == unexpiredTokenWithIDToken.RefreshToken && + cmd.OAuthToken.Expiry.Equal(unexpiredTokenWithIDToken.Expiry) && + cmd.OAuthToken.TokenType == unexpiredTokenWithIDToken.TokenType + })).Return(nil).Once() + + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + }, + expectedToken: unexpiredTokenWithIDToken, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + socialConnector := socialtest.NewMockSocialConnector(t) + store := db.InitTestDB(t) + features := featuremgmt.WithFeatures() + + env := environment{ + sessionService: authtest.NewMockUserAuthTokenService(t), + authInfoService: authinfotest.NewMockAuthInfoService(t), + serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()), + socialConnector: socialConnector, + socialService: &socialtest.FakeSocialService{ + ExpectedConnector: socialConnector, + }, + store: store, + } + + if tt.setup != nil { + tt.setup(&env) + } + + env.service = ProvideService( + env.socialService, + env.authInfoService, + setting.NewCfg(), + prometheus.NewRegistry(), + env.serverLock, + tracing.InitializeTracerForTest(), + env.sessionService, + features, + ) + + actualToken := env.service.GetCurrentOAuthToken(context.Background(), tt.identity, tt.sessionToken) + + if tt.expectedToken == nil { + assert.Nil(t, actualToken) + return + } + + assert.NotNil(t, actualToken) + assert.Equal(t, tt.expectedToken.AccessToken, actualToken.AccessToken) + assert.Equal(t, tt.expectedToken.RefreshToken, actualToken.RefreshToken) + assert.WithinDuration(t, tt.expectedToken.Expiry, actualToken.Expiry, time.Second) + assert.Equal(t, tt.expectedToken.TokenType, actualToken.TokenType) + if tt.expectedToken.Extra("id_token") != nil { + assert.Equal(t, tt.expectedToken.Extra("id_token"), actualToken.Extra("id_token")) + } else { + assert.Nil(t, actualToken.Extra("id_token")) + } + }) + } +} + +func TestIntegration_GetCurrentOAuthToken_WithExternalSessions(t *testing.T) { + testutil.SkipIntegrationTestInShortMode(t) + + type testCase struct { + desc string + identity identity.Requester + sessionToken *auth.UserToken + setup func(env *environment) + expectedToken *oauth2.Token + } + + userIdentity := &authn.Identity{ + AuthenticatedBy: login.GenericOAuthModule, + ID: "1234", + Type: claims.TypeUser, + } + + tests := []testCase{ + { + desc: "should return nil when identity is nil", + identity: nil, + expectedToken: nil, + }, + { + desc: "should return nil when identity is not a user", + identity: &authn.Identity{ID: "1", Type: claims.TypeServiceAccount}, + expectedToken: nil, + }, + { + desc: "should refresh token for render service user", + identity: &authn.Identity{ID: "1", Type: claims.TypeUser, AuthenticatedBy: login.RenderModule}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(3)).Return(&auth.ExternalSession{ + ID: 3, + UserID: 1, + AuthModule: login.GenericOAuthModule, + AccessToken: expiredToken.AccessToken, + RefreshToken: expiredToken.RefreshToken, + ExpiresAt: expiredToken.Expiry, + IDToken: EXPIRED_ID_TOKEN, + }, nil).Once() + + env.sessionService.On("FindExternalSessions", mock.Anything, &auth.ListExternalSessionQuery{UserID: 1}).Return([]*auth.ExternalSession{ + { + ID: 3, + UserID: 1, + AuthModule: login.GenericOAuthModule, + AccessToken: expiredToken.AccessToken, + RefreshToken: expiredToken.RefreshToken, + ExpiresAt: expiredToken.Expiry, + IDToken: EXPIRED_ID_TOKEN, + }, + }, nil).Once() + + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(3), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + }, + expectedToken: unexpiredTokenWithIDToken, + }, + { + desc: "should refresh token for render service user with multiple external sessions", + identity: &authn.Identity{ID: "1", Type: claims.TypeUser, AuthenticatedBy: login.RenderModule}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + AuthId: "subject", + UserId: 1, + OAuthAccessToken: expiredToken.AccessToken, + OAuthRefreshToken: expiredToken.RefreshToken, + OAuthExpiry: expiredToken.Expiry, + OAuthTokenType: expiredToken.TokenType, + OAuthIdToken: EXPIRED_ID_TOKEN, + }, nil) + + // Return multiple external sessions, the most recent one is returned first by the query + env.sessionService.On("FindExternalSessions", mock.Anything, &auth.ListExternalSessionQuery{UserID: 1}).Return([]*auth.ExternalSession{ + { + ID: 2, // newer session + UserID: 1, + AuthModule: login.GenericOAuthModule, + AccessToken: expiredToken.AccessToken, + RefreshToken: expiredToken.RefreshToken, + ExpiresAt: expiredToken.Expiry, + IDToken: EXPIRED_ID_TOKEN, + }, + { + ID: 1, // older session + UserID: 1, + AuthModule: login.GenericOAuthModule, + }}, nil).Once() + + env.sessionService.On("GetExternalSession", mock.Anything, int64(2)).Return(&auth.ExternalSession{ + ID: 2, + UserID: 1, + AuthModule: login.GenericOAuthModule, + AccessToken: expiredToken.AccessToken, + RefreshToken: expiredToken.RefreshToken, + ExpiresAt: expiredToken.Expiry, + IDToken: EXPIRED_ID_TOKEN, + }, nil).Once() + + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(2), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + }, + expectedToken: unexpiredTokenWithIDToken, + }, + { + desc: "should skip token refresh when the token is still valid and no id token is present", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1234, + AuthModule: login.GenericOAuthModule, + AccessToken: unexpiredToken.AccessToken, + RefreshToken: unexpiredToken.RefreshToken, + ExpiresAt: unexpiredToken.Expiry, + }, nil).Once() + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + }, + expectedToken: unexpiredToken, + }, + { + desc: "should return the unexpired access and id token when token refresh is disabled", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1234, + AuthModule: login.GenericOAuthModule, + AccessToken: unexpiredTokenWithIDToken.AccessToken, + ExpiresAt: unexpiredTokenWithIDToken.Expiry, + IDToken: UNEXPIRED_ID_TOKEN, + }, nil).Once() + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: false, + } + }, + expectedToken: unexpiredTokenWithoutRefreshWithIDToken, + }, + { + desc: "should not do token refresh if access token or id token have not expired yet", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1234, + AuthModule: login.GenericOAuthModule, + AccessToken: unexpiredTokenWithIDToken.AccessToken, + RefreshToken: unexpiredTokenWithIDToken.RefreshToken, + ExpiresAt: unexpiredTokenWithIDToken.Expiry, + IDToken: UNEXPIRED_ID_TOKEN, + }, nil).Once() + + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + }, + expectedToken: unexpiredTokenWithIDToken, + }, + { + desc: "should refresh token when the access token is expired", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1, + AccessToken: expiredToken.AccessToken, + RefreshToken: expiredToken.RefreshToken, + ExpiresAt: expiredToken.Expiry, + IDToken: UNEXPIRED_ID_TOKEN, + }, nil).Twice() + + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + }, + expectedToken: unexpiredTokenWithIDToken, + }, + { + desc: "should refresh token when the id token is expired", + identity: userIdentity, + sessionToken: &auth.UserToken{ExternalSessionId: 1}, + setup: func(env *environment) { + env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{ + UseRefreshToken: true, + } + env.authInfoService.On("GetAuthInfo", mock.Anything, mock.Anything).Return(&login.UserAuth{ + AuthModule: login.GenericOAuthModule, + }, nil) + + env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{ + ID: 1, + UserID: 1234, + AuthModule: login.GenericOAuthModule, + AccessToken: unexpiredToken.AccessToken, + RefreshToken: unexpiredToken.RefreshToken, + ExpiresAt: unexpiredToken.Expiry, + IDToken: EXPIRED_ID_TOKEN, + }, nil).Twice() + + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() + }, + expectedToken: unexpiredTokenWithIDToken, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + socialConnector := socialtest.NewMockSocialConnector(t) + store := db.InitTestDB(t) + features := featuremgmt.WithFeatures(featuremgmt.FlagImprovedExternalSessionHandling) + + env := environment{ + sessionService: authtest.NewMockUserAuthTokenService(t), + authInfoService: authinfotest.NewMockAuthInfoService(t), + serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()), + socialConnector: socialConnector, + socialService: &socialtest.FakeSocialService{ + ExpectedConnector: socialConnector, + }, + store: store, + } + + if tt.setup != nil { + tt.setup(&env) + } + + env.service = ProvideService( + env.socialService, + env.authInfoService, + setting.NewCfg(), + prometheus.NewRegistry(), + env.serverLock, + tracing.InitializeTracerForTest(), + env.sessionService, + features, + ) + + actualToken := env.service.GetCurrentOAuthToken(context.Background(), tt.identity, tt.sessionToken) + + if tt.expectedToken == nil { + assert.Nil(t, actualToken) + return + } + + assert.NotNil(t, actualToken) + assert.Equal(t, tt.expectedToken.AccessToken, actualToken.AccessToken) + assert.Equal(t, tt.expectedToken.RefreshToken, actualToken.RefreshToken) + assert.WithinDuration(t, tt.expectedToken.Expiry, actualToken.Expiry, time.Second) + if tt.expectedToken.Extra("id_token") != nil { + assert.Equal(t, tt.expectedToken.Extra("id_token"), actualToken.Extra("id_token")) + } else { + assert.Nil(t, actualToken.Extra("id_token")) + } + }) + } +} diff --git a/pkg/services/oauthtoken/oauthtokentest/mock.go b/pkg/services/oauthtoken/oauthtokentest/mock.go index 39e2f6d8fd9..b9319461480 100644 --- a/pkg/services/oauthtoken/oauthtokentest/mock.go +++ b/pkg/services/oauthtoken/oauthtokentest/mock.go @@ -8,13 +8,14 @@ import ( "github.com/grafana/grafana/pkg/apimachinery/identity" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/datasources" + "github.com/grafana/grafana/pkg/services/oauthtoken" ) type MockOauthTokenService struct { GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) *oauth2.Token IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool - InvalidateOAuthTokensFunc func(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) error - TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error) + InvalidateOAuthTokensFunc func(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) error + TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error) } func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) *oauth2.Token { @@ -31,16 +32,16 @@ func (m *MockOauthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSourc return false } -func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) error { +func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) error { if m.InvalidateOAuthTokensFunc != nil { - return m.InvalidateOAuthTokensFunc(ctx, usr, sessionToken) + return m.InvalidateOAuthTokensFunc(ctx, usr, metadata) } return nil } -func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error) { +func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error) { if m.TryTokenRefreshFunc != nil { - return m.TryTokenRefreshFunc(ctx, usr, sessionToken) + return m.TryTokenRefreshFunc(ctx, usr, metadata) } return nil, nil } diff --git a/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go b/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go index 8c58b43d232..a8a6eeafeae 100644 --- a/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go +++ b/pkg/services/oauthtoken/oauthtokentest/oauthtokentest.go @@ -29,10 +29,10 @@ func (s *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool { return oauthtoken.IsOAuthPassThruEnabled(ds) } -func (s *Service) TryTokenRefresh(context.Context, identity.Requester, *auth.UserToken) (*oauth2.Token, error) { +func (s *Service) TryTokenRefresh(context.Context, identity.Requester, *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error) { return s.Token, nil } -func (s *Service) InvalidateOAuthTokens(context.Context, identity.Requester, *auth.UserToken) error { +func (s *Service) InvalidateOAuthTokens(context.Context, identity.Requester, *oauthtoken.TokenRefreshMetadata) error { return nil }