mirror of https://github.com/grafana/grafana.git
Auth: Fix render user OAuth passthrough (#111636)
* devenv: fix volumes section when sources don't contain one * wip * Working correctly with improvedExternalSessionHandling on * Remove not needed lines * Working with the old flow, tests * Handle compatibility with the feature toggle, tests wip * Tests * Cleanup * Address feedback * Align tests * Add comment * Fix issue with session removal after the invalidation of tokens * Remove commented out code * clean up
This commit is contained in:
parent
3fea6e65f7
commit
53f4803e98
|
@ -56,6 +56,8 @@ func (s *store) Get(ctx context.Context, ID int64) (*auth.ExternalSession, error
|
|||
return externalSession, nil
|
||||
}
|
||||
|
||||
// List returns a list of external sessions that match the given query.
|
||||
// If the result set contains more than one entry, the entries are sorted by ID in descending order.
|
||||
func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery) ([]*auth.ExternalSession, error) {
|
||||
ctx, span := s.tracer.Start(ctx, "externalsession.List")
|
||||
defer span.End()
|
||||
|
@ -65,6 +67,10 @@ func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery)
|
|||
externalSession.ID = query.ID
|
||||
}
|
||||
|
||||
if query.UserID != 0 {
|
||||
externalSession.UserID = query.UserID
|
||||
}
|
||||
|
||||
hash := sha256.New()
|
||||
|
||||
if query.SessionID != "" {
|
||||
|
@ -80,7 +86,7 @@ func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery)
|
|||
|
||||
queryResult := make([]*auth.ExternalSession, 0)
|
||||
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
|
||||
return sess.Find(&queryResult, externalSession)
|
||||
return sess.Desc("id").Find(&queryResult, externalSession)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -51,6 +51,7 @@ type UpdateExternalSessionCommand struct {
|
|||
|
||||
type ListExternalSessionQuery struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
NameID string
|
||||
SessionID string
|
||||
}
|
||||
|
|
|
@ -93,7 +93,11 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, id *authn.Ident
|
|||
updateCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
token, refreshErr := s.service.TryTokenRefresh(updateCtx, id, id.SessionToken)
|
||||
token, refreshErr := s.service.TryTokenRefresh(updateCtx, id, &oauthtoken.TokenRefreshMetadata{
|
||||
ExternalSessionID: id.SessionToken.ExternalSessionId,
|
||||
AuthModule: id.GetAuthenticatedBy(),
|
||||
AuthID: id.GetAuthID(),
|
||||
})
|
||||
if refreshErr != nil {
|
||||
if errors.Is(refreshErr, context.Canceled) {
|
||||
return nil, nil
|
||||
|
@ -107,7 +111,7 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, id *authn.Ident
|
|||
ctxLogger.Error("Failed to refresh OAuth access token", "id", id.ID, "error", refreshErr)
|
||||
|
||||
// log the user out
|
||||
if err := s.sessionService.RevokeToken(ctx, id.SessionToken, false); err != nil {
|
||||
if err := s.sessionService.RevokeToken(ctx, id.SessionToken, false); err != nil && !errors.Is(err, auth.ErrUserTokenNotFound) {
|
||||
ctxLogger.Warn("Failed to revoke session token", "id", id.ID, "tokenId", id.SessionToken.Id, "error", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||
)
|
||||
|
||||
|
@ -77,6 +78,14 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
|||
expectRevokeTokenCalled: false,
|
||||
expectToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||
},
|
||||
{
|
||||
desc: "should not invalidate session if token refresh fails with no refresh token",
|
||||
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
|
||||
expectedTryRefreshErr: oauthtoken.ErrNoRefreshTokenFound,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectRevokeTokenCalled: true,
|
||||
expectedErr: oauthtoken.ErrNoRefreshTokenFound,
|
||||
},
|
||||
|
||||
// TODO: address coverage of oauthtoken sync
|
||||
}
|
||||
|
@ -89,7 +98,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
|||
)
|
||||
|
||||
service := &oauthtokentest.MockOauthTokenService{
|
||||
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester, _ *auth.UserToken) (*oauth2.Token, error) {
|
||||
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester, _ *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error) {
|
||||
tryRefreshCalled = true
|
||||
return nil, tt.expectedTryRefreshErr
|
||||
},
|
||||
|
|
|
@ -297,7 +297,9 @@ func (c *OAuth) Logout(ctx context.Context, user identity.Requester, sessionToke
|
|||
|
||||
ctxLogger := c.log.FromContext(ctx).New("userID", userID)
|
||||
|
||||
if err := c.oauthService.InvalidateOAuthTokens(ctx, user, sessionToken); err != nil {
|
||||
if err := c.oauthService.InvalidateOAuthTokens(ctx, user, &oauthtoken.TokenRefreshMetadata{
|
||||
ExternalSessionID: sessionToken.ExternalSessionId,
|
||||
AuthModule: user.GetAuthenticatedBy()}); err != nil {
|
||||
ctxLogger.Error("Failed to invalidate tokens", "error", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -19,10 +19,12 @@ import (
|
|||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/login/social/socialtest"
|
||||
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
|
@ -481,7 +483,7 @@ func TestOAuth_Logout(t *testing.T) {
|
|||
"id_token": "some.id.token",
|
||||
})
|
||||
},
|
||||
InvalidateOAuthTokensFunc: func(_ context.Context, _ identity.Requester, _ *auth.UserToken) error {
|
||||
InvalidateOAuthTokensFunc: func(_ context.Context, _ identity.Requester, _ *oauthtoken.TokenRefreshMetadata) error {
|
||||
invalidateTokenCalled = true
|
||||
return nil
|
||||
},
|
||||
|
@ -492,7 +494,7 @@ func TestOAuth_Logout(t *testing.T) {
|
|||
}
|
||||
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), tt.cfg, mockService, fakeSocialSvc, &setting.OSSImpl{Cfg: tt.cfg}, featuremgmt.WithFeatures(), tracing.InitializeTracerForTest())
|
||||
|
||||
redirect, ok := c.Logout(context.Background(), &authn.Identity{ID: "1", Type: claims.TypeUser}, nil)
|
||||
redirect, ok := c.Logout(context.Background(), &authn.Identity{ID: "1", Type: claims.TypeUser}, &usertoken.UserToken{})
|
||||
|
||||
assert.Equal(t, tt.expectedOK, ok)
|
||||
if tt.expectedOK {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
//go:generate mockery --name AuthInfoService --structname MockAuthInfoService --outpkg authinfotest --filename auth_info_service_mock.go --output ./authinfotest/
|
||||
type AuthInfoService interface {
|
||||
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) (*UserAuth, error)
|
||||
GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error)
|
||||
|
|
|
@ -0,0 +1,765 @@
|
|||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package authinfotest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewMockAuthInfoService creates a new instance of MockAuthInfoService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockAuthInfoService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockAuthInfoService {
|
||||
mock := &MockAuthInfoService{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// MockAuthInfoService is an autogenerated mock type for the AuthInfoService type
|
||||
type MockAuthInfoService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockAuthInfoService_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockAuthInfoService) EXPECT() *MockAuthInfoService_Expecter {
|
||||
return &MockAuthInfoService_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// DeleteUserAuthInfo provides a mock function for the type MockAuthInfoService
|
||||
func (_mock *MockAuthInfoService) DeleteUserAuthInfo(ctx context.Context, userID int64) error {
|
||||
ret := _mock.Called(ctx, userID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for DeleteUserAuthInfo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, int64) error); ok {
|
||||
r0 = returnFunc(ctx, userID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockAuthInfoService_DeleteUserAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteUserAuthInfo'
|
||||
type MockAuthInfoService_DeleteUserAuthInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// DeleteUserAuthInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID int64
|
||||
func (_e *MockAuthInfoService_Expecter) DeleteUserAuthInfo(ctx interface{}, userID interface{}) *MockAuthInfoService_DeleteUserAuthInfo_Call {
|
||||
return &MockAuthInfoService_DeleteUserAuthInfo_Call{Call: _e.mock.On("DeleteUserAuthInfo", ctx, userID)}
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_DeleteUserAuthInfo_Call) Run(run func(ctx context.Context, userID int64)) *MockAuthInfoService_DeleteUserAuthInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 int64
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(int64)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_DeleteUserAuthInfo_Call) Return(err error) *MockAuthInfoService_DeleteUserAuthInfo_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_DeleteUserAuthInfo_Call) RunAndReturn(run func(ctx context.Context, userID int64) error) *MockAuthInfoService_DeleteUserAuthInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetAuthInfo provides a mock function for the type MockAuthInfoService
|
||||
func (_mock *MockAuthInfoService) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
||||
ret := _mock.Called(ctx, query)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetAuthInfo")
|
||||
}
|
||||
|
||||
var r0 *login.UserAuth
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) (*login.UserAuth, error)); ok {
|
||||
return returnFunc(ctx, query)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) *login.UserAuth); ok {
|
||||
r0 = returnFunc(ctx, query)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*login.UserAuth)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *login.GetAuthInfoQuery) error); ok {
|
||||
r1 = returnFunc(ctx, query)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockAuthInfoService_GetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthInfo'
|
||||
type MockAuthInfoService_GetAuthInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetAuthInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - query *login.GetAuthInfoQuery
|
||||
func (_e *MockAuthInfoService_Expecter) GetAuthInfo(ctx interface{}, query interface{}) *MockAuthInfoService_GetAuthInfo_Call {
|
||||
return &MockAuthInfoService_GetAuthInfo_Call{Call: _e.mock.On("GetAuthInfo", ctx, query)}
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_GetAuthInfo_Call) Run(run func(ctx context.Context, query *login.GetAuthInfoQuery)) *MockAuthInfoService_GetAuthInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *login.GetAuthInfoQuery
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*login.GetAuthInfoQuery)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_GetAuthInfo_Call) Return(userAuth *login.UserAuth, err error) *MockAuthInfoService_GetAuthInfo_Call {
|
||||
_c.Call.Return(userAuth, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_GetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error)) *MockAuthInfoService_GetAuthInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetUserLabels provides a mock function for the type MockAuthInfoService
|
||||
func (_mock *MockAuthInfoService) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) {
|
||||
ret := _mock.Called(ctx, query)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetUserLabels")
|
||||
}
|
||||
|
||||
var r0 map[int64]string
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) (map[int64]string, error)); ok {
|
||||
return returnFunc(ctx, query)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) map[int64]string); ok {
|
||||
r0 = returnFunc(ctx, query)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[int64]string)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, login.GetUserLabelsQuery) error); ok {
|
||||
r1 = returnFunc(ctx, query)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockAuthInfoService_GetUserLabels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserLabels'
|
||||
type MockAuthInfoService_GetUserLabels_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetUserLabels is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - query login.GetUserLabelsQuery
|
||||
func (_e *MockAuthInfoService_Expecter) GetUserLabels(ctx interface{}, query interface{}) *MockAuthInfoService_GetUserLabels_Call {
|
||||
return &MockAuthInfoService_GetUserLabels_Call{Call: _e.mock.On("GetUserLabels", ctx, query)}
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_GetUserLabels_Call) Run(run func(ctx context.Context, query login.GetUserLabelsQuery)) *MockAuthInfoService_GetUserLabels_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 login.GetUserLabelsQuery
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(login.GetUserLabelsQuery)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_GetUserLabels_Call) Return(int64ToString map[int64]string, err error) *MockAuthInfoService_GetUserLabels_Call {
|
||||
_c.Call.Return(int64ToString, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_GetUserLabels_Call) RunAndReturn(run func(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error)) *MockAuthInfoService_GetUserLabels_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAuthInfo provides a mock function for the type MockAuthInfoService
|
||||
func (_mock *MockAuthInfoService) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
|
||||
ret := _mock.Called(ctx, cmd)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SetAuthInfo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.SetAuthInfoCommand) error); ok {
|
||||
r0 = returnFunc(ctx, cmd)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockAuthInfoService_SetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetAuthInfo'
|
||||
type MockAuthInfoService_SetAuthInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SetAuthInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cmd *login.SetAuthInfoCommand
|
||||
func (_e *MockAuthInfoService_Expecter) SetAuthInfo(ctx interface{}, cmd interface{}) *MockAuthInfoService_SetAuthInfo_Call {
|
||||
return &MockAuthInfoService_SetAuthInfo_Call{Call: _e.mock.On("SetAuthInfo", ctx, cmd)}
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_SetAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.SetAuthInfoCommand)) *MockAuthInfoService_SetAuthInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *login.SetAuthInfoCommand
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*login.SetAuthInfoCommand)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_SetAuthInfo_Call) Return(err error) *MockAuthInfoService_SetAuthInfo_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_SetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.SetAuthInfoCommand) error) *MockAuthInfoService_SetAuthInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateAuthInfo provides a mock function for the type MockAuthInfoService
|
||||
func (_mock *MockAuthInfoService) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
|
||||
ret := _mock.Called(ctx, cmd)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateAuthInfo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.UpdateAuthInfoCommand) error); ok {
|
||||
r0 = returnFunc(ctx, cmd)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockAuthInfoService_UpdateAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAuthInfo'
|
||||
type MockAuthInfoService_UpdateAuthInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateAuthInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cmd *login.UpdateAuthInfoCommand
|
||||
func (_e *MockAuthInfoService_Expecter) UpdateAuthInfo(ctx interface{}, cmd interface{}) *MockAuthInfoService_UpdateAuthInfo_Call {
|
||||
return &MockAuthInfoService_UpdateAuthInfo_Call{Call: _e.mock.On("UpdateAuthInfo", ctx, cmd)}
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_UpdateAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand)) *MockAuthInfoService_UpdateAuthInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *login.UpdateAuthInfoCommand
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*login.UpdateAuthInfoCommand)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_UpdateAuthInfo_Call) Return(err error) *MockAuthInfoService_UpdateAuthInfo_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockAuthInfoService_UpdateAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error) *MockAuthInfoService_UpdateAuthInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockStore creates a new instance of MockStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockStore(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockStore {
|
||||
mock := &MockStore{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// MockStore is an autogenerated mock type for the Store type
|
||||
type MockStore struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockStore_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockStore) EXPECT() *MockStore_Expecter {
|
||||
return &MockStore_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// DeleteUserAuthInfo provides a mock function for the type MockStore
|
||||
func (_mock *MockStore) DeleteUserAuthInfo(ctx context.Context, userID int64) error {
|
||||
ret := _mock.Called(ctx, userID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for DeleteUserAuthInfo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, int64) error); ok {
|
||||
r0 = returnFunc(ctx, userID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockStore_DeleteUserAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteUserAuthInfo'
|
||||
type MockStore_DeleteUserAuthInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// DeleteUserAuthInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID int64
|
||||
func (_e *MockStore_Expecter) DeleteUserAuthInfo(ctx interface{}, userID interface{}) *MockStore_DeleteUserAuthInfo_Call {
|
||||
return &MockStore_DeleteUserAuthInfo_Call{Call: _e.mock.On("DeleteUserAuthInfo", ctx, userID)}
|
||||
}
|
||||
|
||||
func (_c *MockStore_DeleteUserAuthInfo_Call) Run(run func(ctx context.Context, userID int64)) *MockStore_DeleteUserAuthInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 int64
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(int64)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStore_DeleteUserAuthInfo_Call) Return(err error) *MockStore_DeleteUserAuthInfo_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStore_DeleteUserAuthInfo_Call) RunAndReturn(run func(ctx context.Context, userID int64) error) *MockStore_DeleteUserAuthInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetAuthInfo provides a mock function for the type MockStore
|
||||
func (_mock *MockStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
||||
ret := _mock.Called(ctx, query)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetAuthInfo")
|
||||
}
|
||||
|
||||
var r0 *login.UserAuth
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) (*login.UserAuth, error)); ok {
|
||||
return returnFunc(ctx, query)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.GetAuthInfoQuery) *login.UserAuth); ok {
|
||||
r0 = returnFunc(ctx, query)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*login.UserAuth)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *login.GetAuthInfoQuery) error); ok {
|
||||
r1 = returnFunc(ctx, query)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockStore_GetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthInfo'
|
||||
type MockStore_GetAuthInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetAuthInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - query *login.GetAuthInfoQuery
|
||||
func (_e *MockStore_Expecter) GetAuthInfo(ctx interface{}, query interface{}) *MockStore_GetAuthInfo_Call {
|
||||
return &MockStore_GetAuthInfo_Call{Call: _e.mock.On("GetAuthInfo", ctx, query)}
|
||||
}
|
||||
|
||||
func (_c *MockStore_GetAuthInfo_Call) Run(run func(ctx context.Context, query *login.GetAuthInfoQuery)) *MockStore_GetAuthInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *login.GetAuthInfoQuery
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*login.GetAuthInfoQuery)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStore_GetAuthInfo_Call) Return(userAuth *login.UserAuth, err error) *MockStore_GetAuthInfo_Call {
|
||||
_c.Call.Return(userAuth, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStore_GetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error)) *MockStore_GetAuthInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetUserLabels provides a mock function for the type MockStore
|
||||
func (_mock *MockStore) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) {
|
||||
ret := _mock.Called(ctx, query)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetUserLabels")
|
||||
}
|
||||
|
||||
var r0 map[int64]string
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) (map[int64]string, error)); ok {
|
||||
return returnFunc(ctx, query)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, login.GetUserLabelsQuery) map[int64]string); ok {
|
||||
r0 = returnFunc(ctx, query)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[int64]string)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, login.GetUserLabelsQuery) error); ok {
|
||||
r1 = returnFunc(ctx, query)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockStore_GetUserLabels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserLabels'
|
||||
type MockStore_GetUserLabels_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetUserLabels is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - query login.GetUserLabelsQuery
|
||||
func (_e *MockStore_Expecter) GetUserLabels(ctx interface{}, query interface{}) *MockStore_GetUserLabels_Call {
|
||||
return &MockStore_GetUserLabels_Call{Call: _e.mock.On("GetUserLabels", ctx, query)}
|
||||
}
|
||||
|
||||
func (_c *MockStore_GetUserLabels_Call) Run(run func(ctx context.Context, query login.GetUserLabelsQuery)) *MockStore_GetUserLabels_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 login.GetUserLabelsQuery
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(login.GetUserLabelsQuery)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStore_GetUserLabels_Call) Return(int64ToString map[int64]string, err error) *MockStore_GetUserLabels_Call {
|
||||
_c.Call.Return(int64ToString, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStore_GetUserLabels_Call) RunAndReturn(run func(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error)) *MockStore_GetUserLabels_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAuthInfo provides a mock function for the type MockStore
|
||||
func (_mock *MockStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
|
||||
ret := _mock.Called(ctx, cmd)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SetAuthInfo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.SetAuthInfoCommand) error); ok {
|
||||
r0 = returnFunc(ctx, cmd)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockStore_SetAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetAuthInfo'
|
||||
type MockStore_SetAuthInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SetAuthInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cmd *login.SetAuthInfoCommand
|
||||
func (_e *MockStore_Expecter) SetAuthInfo(ctx interface{}, cmd interface{}) *MockStore_SetAuthInfo_Call {
|
||||
return &MockStore_SetAuthInfo_Call{Call: _e.mock.On("SetAuthInfo", ctx, cmd)}
|
||||
}
|
||||
|
||||
func (_c *MockStore_SetAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.SetAuthInfoCommand)) *MockStore_SetAuthInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *login.SetAuthInfoCommand
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*login.SetAuthInfoCommand)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStore_SetAuthInfo_Call) Return(err error) *MockStore_SetAuthInfo_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStore_SetAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.SetAuthInfoCommand) error) *MockStore_SetAuthInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateAuthInfo provides a mock function for the type MockStore
|
||||
func (_mock *MockStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
|
||||
ret := _mock.Called(ctx, cmd)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateAuthInfo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *login.UpdateAuthInfoCommand) error); ok {
|
||||
r0 = returnFunc(ctx, cmd)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockStore_UpdateAuthInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAuthInfo'
|
||||
type MockStore_UpdateAuthInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateAuthInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cmd *login.UpdateAuthInfoCommand
|
||||
func (_e *MockStore_Expecter) UpdateAuthInfo(ctx interface{}, cmd interface{}) *MockStore_UpdateAuthInfo_Call {
|
||||
return &MockStore_UpdateAuthInfo_Call{Call: _e.mock.On("UpdateAuthInfo", ctx, cmd)}
|
||||
}
|
||||
|
||||
func (_c *MockStore_UpdateAuthInfo_Call) Run(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand)) *MockStore_UpdateAuthInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *login.UpdateAuthInfoCommand
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*login.UpdateAuthInfoCommand)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStore_UpdateAuthInfo_Call) Return(err error) *MockStore_UpdateAuthInfo_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockStore_UpdateAuthInfo_Call) RunAndReturn(run func(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error) *MockStore_UpdateAuthInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockUserProtectionService creates a new instance of MockUserProtectionService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockUserProtectionService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockUserProtectionService {
|
||||
mock := &MockUserProtectionService{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// MockUserProtectionService is an autogenerated mock type for the UserProtectionService type
|
||||
type MockUserProtectionService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockUserProtectionService_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockUserProtectionService) EXPECT() *MockUserProtectionService_Expecter {
|
||||
return &MockUserProtectionService_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// AllowUserMapping provides a mock function for the type MockUserProtectionService
|
||||
func (_mock *MockUserProtectionService) AllowUserMapping(user1 *user.User, authModule string) error {
|
||||
ret := _mock.Called(user1, authModule)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AllowUserMapping")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(*user.User, string) error); ok {
|
||||
r0 = returnFunc(user1, authModule)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockUserProtectionService_AllowUserMapping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllowUserMapping'
|
||||
type MockUserProtectionService_AllowUserMapping_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AllowUserMapping is a helper method to define mock.On call
|
||||
// - user1 *user.User
|
||||
// - authModule string
|
||||
func (_e *MockUserProtectionService_Expecter) AllowUserMapping(user1 interface{}, authModule interface{}) *MockUserProtectionService_AllowUserMapping_Call {
|
||||
return &MockUserProtectionService_AllowUserMapping_Call{Call: _e.mock.On("AllowUserMapping", user1, authModule)}
|
||||
}
|
||||
|
||||
func (_c *MockUserProtectionService_AllowUserMapping_Call) Run(run func(user1 *user.User, authModule string)) *MockUserProtectionService_AllowUserMapping_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 *user.User
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(*user.User)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockUserProtectionService_AllowUserMapping_Call) Return(err error) *MockUserProtectionService_AllowUserMapping_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockUserProtectionService_AllowUserMapping_Call) RunAndReturn(run func(user1 *user.User, authModule string) error) *MockUserProtectionService_AllowUserMapping_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
|
@ -57,8 +58,14 @@ var _ OAuthTokenService = (*Service)(nil)
|
|||
type OAuthTokenService interface {
|
||||
GetCurrentOAuthToken(context.Context, identity.Requester, *auth.UserToken) *oauth2.Token
|
||||
IsOAuthPassThruEnabled(*datasources.DataSource) bool
|
||||
TryTokenRefresh(context.Context, identity.Requester, *auth.UserToken) (*oauth2.Token, error)
|
||||
InvalidateOAuthTokens(context.Context, identity.Requester, *auth.UserToken) error
|
||||
TryTokenRefresh(context.Context, identity.Requester, *TokenRefreshMetadata) (*oauth2.Token, error)
|
||||
InvalidateOAuthTokens(context.Context, identity.Requester, *TokenRefreshMetadata) error
|
||||
}
|
||||
|
||||
type TokenRefreshMetadata struct {
|
||||
ExternalSessionID int64
|
||||
AuthModule string
|
||||
AuthID string
|
||||
}
|
||||
|
||||
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg, registerer prometheus.Registerer,
|
||||
|
@ -102,51 +109,71 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr identity.Request
|
|||
|
||||
ctxLogger = ctxLogger.New("userID", userID)
|
||||
|
||||
if !strings.HasPrefix(usr.GetAuthenticatedBy(), "oauth_") {
|
||||
ctxLogger.Warn("The specified user's auth provider is not oauth",
|
||||
"authmodule", usr.GetAuthenticatedBy())
|
||||
tokenRefreshMetadata := &TokenRefreshMetadata{
|
||||
ExternalSessionID: 0,
|
||||
}
|
||||
var persistedToken *oauth2.Token
|
||||
// Find the external session associated with the user and session token
|
||||
// regardless of the improvedExternalSessionHandling feature toggle,
|
||||
// because Grafana writes and updates both tables to make the switch
|
||||
// to the new session handling smoother.
|
||||
externalSession, err := o.getExternalSession(ctx, usr, userID, sessionToken)
|
||||
if err != nil && !errors.Is(err, auth.ErrExternalSessionNotFound) {
|
||||
ctxLogger.Error("Failed to get external session", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the feature toggle is enabled, an external session is required.
|
||||
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) && (externalSession == nil || errors.Is(err, auth.ErrExternalSessionNotFound)) {
|
||||
ctxLogger.Error("No external session found for user", "userID", userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// externalSession can be nil if Grafana was updated from a version where the
|
||||
// external session table was not used yet (did not exist) and the user has not logged in since
|
||||
// the version update (therefore no external session was created for the user yet).
|
||||
if externalSession != nil {
|
||||
tokenRefreshMetadata.ExternalSessionID = externalSession.ID
|
||||
}
|
||||
|
||||
authInfo, err := o.AuthInfoService.GetAuthInfo(ctx, &login.GetAuthInfoQuery{
|
||||
UserId: userID,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
ctxLogger.Warn("No AuthInfo found for user", "userID", userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
ctxLogger.Error("Failed to fetch AuthInfo for user", "userID", userID, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
tokenRefreshMetadata.AuthID = authInfo.AuthId
|
||||
tokenRefreshMetadata.AuthModule = authInfo.AuthModule
|
||||
|
||||
if !strings.HasPrefix(tokenRefreshMetadata.AuthModule, "oauth_") {
|
||||
ctxLogger.Warn("The specified user's auth provider is not oauth",
|
||||
"authmodule", tokenRefreshMetadata.AuthModule)
|
||||
return nil
|
||||
}
|
||||
|
||||
var persistedToken *oauth2.Token
|
||||
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
|
||||
externalSession, err := o.sessionService.GetExternalSession(ctx, sessionToken.ExternalSessionId)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrExternalSessionNotFound) {
|
||||
return nil
|
||||
}
|
||||
ctxLogger.Error("Failed to fetch external session", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
persistedToken = buildOAuthTokenFromExternalSession(externalSession)
|
||||
|
||||
if persistedToken.RefreshToken == "" {
|
||||
return persistedToken
|
||||
}
|
||||
} else {
|
||||
authInfo, ok, _ := o.hasOAuthEntry(ctx, usr)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := checkOAuthRefreshToken(authInfo); err != nil {
|
||||
if errors.Is(err, ErrNoRefreshTokenFound) {
|
||||
return buildOAuthTokenFromAuthInfo(authInfo)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
persistedToken = buildOAuthTokenFromAuthInfo(authInfo)
|
||||
}
|
||||
|
||||
if persistedToken.RefreshToken == "" {
|
||||
return persistedToken
|
||||
}
|
||||
|
||||
refreshNeeded := needTokenRefresh(ctx, persistedToken)
|
||||
if !refreshNeeded {
|
||||
return persistedToken
|
||||
}
|
||||
|
||||
token, err := o.TryTokenRefresh(ctx, usr, sessionToken)
|
||||
token, err := o.TryTokenRefresh(ctx, usr, tokenRefreshMetadata)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNoRefreshTokenFound) {
|
||||
return persistedToken
|
||||
|
@ -214,7 +241,7 @@ func (o *Service) hasOAuthEntry(ctx context.Context, usr identity.Requester) (*l
|
|||
|
||||
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
|
||||
// It uses a server lock to prevent getting the Refresh Token multiple times for a given User
|
||||
func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error) {
|
||||
func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, tokenRefreshMetadata *TokenRefreshMetadata) (*oauth2.Token, error) {
|
||||
ctx, span := o.tracer.Start(ctx, "oauthtoken.TryTokenRefresh")
|
||||
defer span.End()
|
||||
|
||||
|
@ -239,14 +266,13 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s
|
|||
|
||||
ctxLogger = ctxLogger.New("userID", userID)
|
||||
|
||||
// get the token's auth provider (f.e. azuread)
|
||||
currAuthenticator := usr.GetAuthenticatedBy()
|
||||
if !strings.HasPrefix(currAuthenticator, "oauth") {
|
||||
ctxLogger.Warn("The specified user's auth provider is not OAuth", "authmodule", currAuthenticator)
|
||||
if !strings.HasPrefix(tokenRefreshMetadata.AuthModule, "oauth_") {
|
||||
ctxLogger.Warn("The specified user's auth provider is not oauth",
|
||||
"authmodule", tokenRefreshMetadata.AuthModule)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
provider := strings.TrimPrefix(currAuthenticator, "oauth_")
|
||||
provider := strings.TrimPrefix(tokenRefreshMetadata.AuthModule, "oauth_")
|
||||
currentOAuthInfo := o.SocialService.GetOAuthInfoProvider(provider)
|
||||
if currentOAuthInfo == nil {
|
||||
ctxLogger.Warn("OAuth provider not found", "provider", provider)
|
||||
|
@ -261,7 +287,7 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s
|
|||
|
||||
lockKey := fmt.Sprintf("oauth-refresh-token-%d", userID)
|
||||
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
|
||||
lockKey = fmt.Sprintf("oauth-refresh-token-%d-%d", userID, sessionToken.ExternalSessionId)
|
||||
lockKey = fmt.Sprintf("oauth-refresh-token-%d-%d", userID, tokenRefreshMetadata.ExternalSessionID)
|
||||
}
|
||||
|
||||
lockTimeConfig := serverlock.LockTimeConfig{
|
||||
|
@ -290,7 +316,7 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s
|
|||
var persistedToken *oauth2.Token
|
||||
var externalSession *auth.ExternalSession
|
||||
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
|
||||
externalSession, err = o.sessionService.GetExternalSession(ctx, sessionToken.ExternalSessionId)
|
||||
externalSession, err = o.sessionService.GetExternalSession(ctx, tokenRefreshMetadata.ExternalSessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrExternalSessionNotFound) {
|
||||
ctxLogger.Error("External session was not found for user", "error", err)
|
||||
|
@ -321,7 +347,7 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s
|
|||
return
|
||||
}
|
||||
|
||||
newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, persistedToken, usr, sessionToken)
|
||||
newToken, cmdErr = o.tryGetOrRefreshOAuthToken(ctx, persistedToken, usr, tokenRefreshMetadata)
|
||||
}, retryOpt)
|
||||
if lockErr != nil {
|
||||
ctxLogger.Error("Failed to obtain token refresh lock", "error", lockErr)
|
||||
|
@ -330,14 +356,14 @@ func (o *Service) TryTokenRefresh(ctx context.Context, usr identity.Requester, s
|
|||
|
||||
// Silence ErrNoRefreshTokenFound
|
||||
if errors.Is(cmdErr, ErrNoRefreshTokenFound) {
|
||||
return nil, nil
|
||||
return nil, ErrNoRefreshTokenFound
|
||||
}
|
||||
|
||||
return newToken, cmdErr
|
||||
}
|
||||
|
||||
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
|
||||
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) error {
|
||||
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, tokenRefreshMetadata *TokenRefreshMetadata) error {
|
||||
userID, err := usr.GetInternalID()
|
||||
if err != nil {
|
||||
logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err)
|
||||
|
@ -347,7 +373,7 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Reques
|
|||
ctxLogger := logger.FromContext(ctx).New("userID", userID)
|
||||
|
||||
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
|
||||
err := o.sessionService.UpdateExternalSession(ctx, sessionToken.ExternalSessionId, &auth.UpdateExternalSessionCommand{
|
||||
err := o.sessionService.UpdateExternalSession(ctx, tokenRefreshMetadata.ExternalSessionID, &auth.UpdateExternalSessionCommand{
|
||||
Token: &oauth2.Token{},
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -358,8 +384,8 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Reques
|
|||
|
||||
return o.AuthInfoService.UpdateAuthInfo(ctx, &login.UpdateAuthInfoCommand{
|
||||
UserId: userID,
|
||||
AuthModule: usr.GetAuthenticatedBy(),
|
||||
AuthId: usr.GetAuthID(),
|
||||
AuthModule: tokenRefreshMetadata.AuthModule,
|
||||
AuthId: tokenRefreshMetadata.AuthID,
|
||||
OAuthToken: &oauth2.Token{
|
||||
AccessToken: "",
|
||||
RefreshToken: "",
|
||||
|
@ -368,13 +394,14 @@ func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr identity.Reques
|
|||
})
|
||||
}
|
||||
|
||||
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken *oauth2.Token, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error) {
|
||||
func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken *oauth2.Token, usr identity.Requester, tokenRefreshMetadata *TokenRefreshMetadata) (*oauth2.Token, error) {
|
||||
ctx, span := o.tracer.Start(ctx, "oauthtoken.tryGetOrRefreshOAuthToken")
|
||||
defer span.End()
|
||||
|
||||
userID, err := usr.GetInternalID()
|
||||
if err != nil {
|
||||
logger.Error("Failed to convert user id to int", "id", usr.GetID(), "error", err)
|
||||
span.SetStatus(codes.Error, "Failed to convert user id to int")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -382,8 +409,11 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken
|
|||
|
||||
ctxLogger := logger.FromContext(ctx).New("userID", userID)
|
||||
|
||||
// tryGetOrRefreshOAuthToken assumes that the AuthModule has RefreshToken enabled
|
||||
// which is checked by the caller (TryTokenRefresh)
|
||||
if persistedToken.RefreshToken == "" {
|
||||
ctxLogger.Warn("No refresh token available", "authmodule", usr.GetAuthenticatedBy())
|
||||
ctxLogger.Error("No refresh token available", "authmodule", tokenRefreshMetadata.AuthModule)
|
||||
span.SetStatus(codes.Error, ErrNoRefreshTokenFound.Error())
|
||||
return nil, ErrNoRefreshTokenFound
|
||||
}
|
||||
|
||||
|
@ -392,50 +422,44 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken
|
|||
return persistedToken, nil
|
||||
}
|
||||
|
||||
authProvider := usr.GetAuthenticatedBy()
|
||||
connect, err := o.SocialService.GetConnector(authProvider)
|
||||
connect, err := o.SocialService.GetConnector(tokenRefreshMetadata.AuthModule)
|
||||
if err != nil {
|
||||
ctxLogger.Error("Failed to get oauth connector", "provider", authProvider, "error", err)
|
||||
ctxLogger.Error("Failed to get oauth connector", "provider", tokenRefreshMetadata.AuthModule, "error", err)
|
||||
span.SetStatus(codes.Error, "Failed to get oauth connector: "+err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
|
||||
client, err := o.SocialService.GetOAuthHttpClient(tokenRefreshMetadata.AuthModule)
|
||||
if err != nil {
|
||||
ctxLogger.Error("Failed to get oauth http client", "provider", authProvider, "error", err)
|
||||
ctxLogger.Error("Failed to get oauth http client", "provider", tokenRefreshMetadata.AuthModule, "error", err)
|
||||
span.SetStatus(codes.Error, "Failed to get oauth http client")
|
||||
return nil, err
|
||||
}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
|
||||
start := time.Now()
|
||||
// TokenSource handles refreshing the token if it has expired
|
||||
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
||||
token, refreshErr := connect.TokenSource(ctx, persistedToken).Token()
|
||||
duration := time.Since(start)
|
||||
o.tokenRefreshDuration.WithLabelValues(authProvider, fmt.Sprintf("%t", err == nil)).Observe(duration.Seconds())
|
||||
o.tokenRefreshDuration.WithLabelValues(tokenRefreshMetadata.AuthModule, fmt.Sprintf("%t", err == nil)).Observe(duration.Seconds())
|
||||
|
||||
if err != nil {
|
||||
if refreshErr != nil {
|
||||
span.SetAttributes(attribute.Bool("token_refreshed", false))
|
||||
ctxLogger.Error("Failed to retrieve oauth access token",
|
||||
"provider", usr.GetAuthenticatedBy(), "error", err)
|
||||
"provider", tokenRefreshMetadata.AuthModule, "error", refreshErr)
|
||||
|
||||
// token refresh failed, invalidate the old token
|
||||
if err := o.InvalidateOAuthTokens(ctx, usr, sessionToken); err != nil {
|
||||
ctxLogger.Warn("Failed to invalidate OAuth tokens", "authID", usr.GetAuthID(), "error", err)
|
||||
if err := o.InvalidateOAuthTokens(ctx, usr, tokenRefreshMetadata); err != nil {
|
||||
ctxLogger.Warn("Failed to invalidate OAuth tokens", "authID", tokenRefreshMetadata.AuthID, "error", err)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
return nil, refreshErr
|
||||
}
|
||||
|
||||
span.SetAttributes(attribute.Bool("token_refreshed", true))
|
||||
|
||||
// If the tokens are not the same, update the entry in the DB
|
||||
if !tokensEq(persistedToken, token) {
|
||||
updateAuthCommand := &login.UpdateAuthInfoCommand{
|
||||
UserId: userID,
|
||||
AuthModule: usr.GetAuthenticatedBy(),
|
||||
AuthId: usr.GetAuthID(),
|
||||
OAuthToken: token,
|
||||
}
|
||||
|
||||
if o.Cfg.Env == setting.Dev {
|
||||
ctxLogger.Debug("Oauth got token",
|
||||
"auth_module", usr.GetAuthenticatedBy(),
|
||||
|
@ -446,17 +470,32 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken
|
|||
}
|
||||
|
||||
if !o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
|
||||
updateAuthCommand := &login.UpdateAuthInfoCommand{
|
||||
UserId: userID,
|
||||
AuthModule: tokenRefreshMetadata.AuthModule,
|
||||
AuthId: tokenRefreshMetadata.AuthID,
|
||||
OAuthToken: token,
|
||||
}
|
||||
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
|
||||
ctxLogger.Error("Failed to update auth info during token refresh", "authID", usr.GetAuthID(), "error", err)
|
||||
ctxLogger.Error("Failed to update auth info during token refresh", "authID", tokenRefreshMetadata.AuthID, "error", err)
|
||||
span.SetStatus(codes.Error, "Failed to update auth info during token refresh")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := o.sessionService.UpdateExternalSession(ctx, sessionToken.ExternalSessionId, &auth.UpdateExternalSessionCommand{
|
||||
Token: token,
|
||||
}); err != nil {
|
||||
ctxLogger.Error("Failed to update external session during token refresh", "error", err)
|
||||
return nil, err
|
||||
// Update the external session with the new token if we the user has an external session,
|
||||
// regardless of the feature flag state to keep the `user_external_session` table in sync.
|
||||
// ExternalSessionID should always be set except for some edge cases:
|
||||
// - when Grafana was updated to a version where the `improvedExternalSessionHandling` feature flag
|
||||
// was enabled after the user logged in
|
||||
if tokenRefreshMetadata.ExternalSessionID != 0 {
|
||||
if err := o.sessionService.UpdateExternalSession(ctx, tokenRefreshMetadata.ExternalSessionID, &auth.UpdateExternalSessionCommand{
|
||||
Token: token,
|
||||
}); err != nil {
|
||||
ctxLogger.Error("Failed to update external session during token refresh", "error", err)
|
||||
span.SetStatus(codes.Error, "Failed to update external session during token refresh")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ctxLogger.Debug("Updated oauth info for user")
|
||||
|
@ -502,6 +541,11 @@ func needTokenRefresh(ctx context.Context, persistedToken *oauth2.Token) bool {
|
|||
|
||||
ctxLogger := logger.FromContext(ctx)
|
||||
|
||||
if persistedToken.AccessToken == "" {
|
||||
ctxLogger.Debug("Access token has been cleared, need to refresh")
|
||||
return true
|
||||
}
|
||||
|
||||
idTokenExp, err := GetIDTokenExpiry(persistedToken)
|
||||
if err != nil {
|
||||
ctxLogger.Warn("Could not get ID Token expiry", "error", err)
|
||||
|
@ -552,22 +596,6 @@ func buildOAuthTokenFromExternalSession(externalSession *auth.ExternalSession) *
|
|||
return token
|
||||
}
|
||||
|
||||
func checkOAuthRefreshToken(authInfo *login.UserAuth) error {
|
||||
if !strings.Contains(authInfo.AuthModule, "oauth") {
|
||||
logger.Warn("The specified user's auth provider is not oauth",
|
||||
"authmodule", authInfo.AuthModule, "userid", authInfo.UserId)
|
||||
return ErrNotAnOAuthProvider
|
||||
}
|
||||
|
||||
if authInfo.OAuthRefreshToken == "" {
|
||||
logger.Warn("No refresh token available",
|
||||
"authmodule", authInfo.AuthModule, "userid", authInfo.UserId)
|
||||
return ErrNoRefreshTokenFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIDTokenExpiry extracts the expiry time from the ID token
|
||||
func GetIDTokenExpiry(token *oauth2.Token) (time.Time, error) {
|
||||
idToken, ok := token.Extra("id_token").(string)
|
||||
|
@ -601,3 +629,28 @@ func getExpiryWithSkew(expiry time.Time) (adjustedExpiry time.Time, hasTokenExpi
|
|||
hasTokenExpired = adjustedExpiry.Before(time.Now())
|
||||
return
|
||||
}
|
||||
|
||||
// getExternalSession fetches the external session based on the user and session token.
|
||||
// When using the render module, it fetches the most recent external session for the user
|
||||
// since the session token ID is not available.
|
||||
// For regular users, it uses the session token ID to fetch the external session.
|
||||
func (o *Service) getExternalSession(ctx context.Context, usr identity.Requester, userID int64, sessionToken *auth.UserToken) (*auth.ExternalSession, error) {
|
||||
if usr.GetAuthenticatedBy() == login.RenderModule {
|
||||
// When using render module, we don't have the session token ID, so we need to fetch the most recent session
|
||||
// entry for the user (as it is done with the old flow).
|
||||
// In the future, we might want to consider passing the session token ID to the render module to make this more robust.
|
||||
externalSessions, err := o.sessionService.FindExternalSessions(ctx, &auth.ListExternalSessionQuery{UserID: userID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(externalSessions) == 0 || externalSessions[0] == nil {
|
||||
return nil, auth.ErrExternalSessionNotFound
|
||||
}
|
||||
|
||||
return externalSessions[0], nil
|
||||
}
|
||||
|
||||
// For regular users, we use the session token ID to fetch the external session
|
||||
return o.sessionService.GetExternalSession(ctx, sessionToken.ExternalSessionId)
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -8,13 +8,14 @@ import (
|
|||
"github.com/grafana/grafana/pkg/apimachinery/identity"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
)
|
||||
|
||||
type MockOauthTokenService struct {
|
||||
GetCurrentOauthTokenFunc func(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) *oauth2.Token
|
||||
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
|
||||
InvalidateOAuthTokensFunc func(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) error
|
||||
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error)
|
||||
InvalidateOAuthTokensFunc func(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) error
|
||||
TryTokenRefreshFunc func(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) *oauth2.Token {
|
||||
|
@ -31,16 +32,16 @@ func (m *MockOauthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSourc
|
|||
return false
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) error {
|
||||
func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) error {
|
||||
if m.InvalidateOAuthTokensFunc != nil {
|
||||
return m.InvalidateOAuthTokensFunc(ctx, usr, sessionToken)
|
||||
return m.InvalidateOAuthTokensFunc(ctx, usr, metadata)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester, sessionToken *auth.UserToken) (*oauth2.Token, error) {
|
||||
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr identity.Requester, metadata *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error) {
|
||||
if m.TryTokenRefreshFunc != nil {
|
||||
return m.TryTokenRefreshFunc(ctx, usr, sessionToken)
|
||||
return m.TryTokenRefreshFunc(ctx, usr, metadata)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -29,10 +29,10 @@ func (s *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
|||
return oauthtoken.IsOAuthPassThruEnabled(ds)
|
||||
}
|
||||
|
||||
func (s *Service) TryTokenRefresh(context.Context, identity.Requester, *auth.UserToken) (*oauth2.Token, error) {
|
||||
func (s *Service) TryTokenRefresh(context.Context, identity.Requester, *oauthtoken.TokenRefreshMetadata) (*oauth2.Token, error) {
|
||||
return s.Token, nil
|
||||
}
|
||||
|
||||
func (s *Service) InvalidateOAuthTokens(context.Context, identity.Requester, *auth.UserToken) error {
|
||||
func (s *Service) InvalidateOAuthTokens(context.Context, identity.Requester, *oauthtoken.TokenRefreshMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue