2023-10-20 22:09:46 +08:00
|
|
|
package sync
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"errors"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
2024-08-13 16:18:28 +08:00
|
|
|
"github.com/grafana/authlib/claims"
|
2023-10-20 22:09:46 +08:00
|
|
|
"github.com/stretchr/testify/assert"
|
2023-10-26 00:15:41 +08:00
|
|
|
"golang.org/x/sync/singleflight"
|
2023-10-20 22:09:46 +08:00
|
|
|
|
2024-07-25 17:52:14 +08:00
|
|
|
"github.com/grafana/grafana/pkg/apimachinery/identity"
|
2023-10-20 22:09:46 +08:00
|
|
|
"github.com/grafana/grafana/pkg/infra/log"
|
2024-07-03 14:08:57 +08:00
|
|
|
"github.com/grafana/grafana/pkg/infra/tracing"
|
2023-10-20 22:09:46 +08:00
|
|
|
"github.com/grafana/grafana/pkg/login/social"
|
2023-12-08 18:20:42 +08:00
|
|
|
"github.com/grafana/grafana/pkg/login/social/socialtest"
|
2023-10-20 22:09:46 +08:00
|
|
|
"github.com/grafana/grafana/pkg/services/auth"
|
|
|
|
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
|
|
|
"github.com/grafana/grafana/pkg/services/authn"
|
|
|
|
"github.com/grafana/grafana/pkg/services/login"
|
|
|
|
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
|
|
|
type testCase struct {
|
|
|
|
desc string
|
|
|
|
identity *authn.Identity
|
|
|
|
oauthInfo *social.OAuthInfo
|
|
|
|
|
|
|
|
expectedHasEntryToken *login.UserAuth
|
|
|
|
expectHasEntryCalled bool
|
|
|
|
|
|
|
|
expectedTryRefreshErr error
|
|
|
|
expectTryRefreshTokenCalled bool
|
|
|
|
|
|
|
|
expectRevokeTokenCalled bool
|
|
|
|
expectInvalidateOauthTokensCalled bool
|
|
|
|
|
|
|
|
expectedErr error
|
|
|
|
}
|
|
|
|
|
|
|
|
tests := []testCase{
|
|
|
|
{
|
2024-02-05 23:44:25 +08:00
|
|
|
desc: "should skip sync when identity is not a user",
|
2024-08-13 16:18:28 +08:00
|
|
|
identity: &authn.Identity{ID: "1", Type: claims.TypeServiceAccount},
|
2024-02-05 23:44:25 +08:00
|
|
|
expectTryRefreshTokenCalled: false,
|
2023-10-20 22:09:46 +08:00
|
|
|
},
|
|
|
|
{
|
2024-02-05 23:44:25 +08:00
|
|
|
desc: "should skip sync when identity is a user but is not authenticated with session token",
|
2024-08-13 16:18:28 +08:00
|
|
|
identity: &authn.Identity{ID: "1", Type: claims.TypeUser},
|
2024-02-05 23:44:25 +08:00
|
|
|
expectTryRefreshTokenCalled: false,
|
2023-10-20 22:09:46 +08:00
|
|
|
},
|
|
|
|
{
|
2024-02-05 23:44:25 +08:00
|
|
|
desc: "should invalidate access token and session token if token refresh fails",
|
2024-08-13 16:18:28 +08:00
|
|
|
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
|
2023-10-20 22:09:46 +08:00
|
|
|
expectHasEntryCalled: true,
|
|
|
|
expectedTryRefreshErr: errors.New("some err"),
|
|
|
|
expectTryRefreshTokenCalled: true,
|
|
|
|
expectInvalidateOauthTokensCalled: true,
|
|
|
|
expectRevokeTokenCalled: true,
|
|
|
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
|
|
|
expectedErr: authn.ErrExpiredAccessToken,
|
|
|
|
},
|
|
|
|
{
|
2024-02-05 23:44:25 +08:00
|
|
|
desc: "should refresh the token successfully",
|
2024-08-13 16:18:28 +08:00
|
|
|
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
|
2024-02-05 23:44:25 +08:00
|
|
|
expectHasEntryCalled: false,
|
|
|
|
expectTryRefreshTokenCalled: true,
|
|
|
|
expectInvalidateOauthTokensCalled: false,
|
|
|
|
expectRevokeTokenCalled: false,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
desc: "should not invalidate the token if the token has already been refreshed by another request (singleflight)",
|
2024-08-13 16:18:28 +08:00
|
|
|
identity: &authn.Identity{ID: "1", Type: claims.TypeUser, SessionToken: &auth.UserToken{}, AuthenticatedBy: login.AzureADAuthModule},
|
2024-02-05 23:44:25 +08:00
|
|
|
expectHasEntryCalled: true,
|
|
|
|
expectTryRefreshTokenCalled: true,
|
|
|
|
expectInvalidateOauthTokensCalled: false,
|
|
|
|
expectRevokeTokenCalled: false,
|
|
|
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
|
|
|
expectedTryRefreshErr: errors.New("some err"),
|
2023-10-20 22:09:46 +08:00
|
|
|
},
|
2024-02-05 23:44:25 +08:00
|
|
|
|
|
|
|
// TODO: address coverage of oauthtoken sync
|
2023-10-20 22:09:46 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
t.Run(tt.desc, func(t *testing.T) {
|
|
|
|
var (
|
|
|
|
hasEntryCalled bool
|
|
|
|
tryRefreshCalled bool
|
|
|
|
invalidateTokensCalled bool
|
|
|
|
revokeTokenCalled bool
|
|
|
|
)
|
|
|
|
|
|
|
|
service := &oauthtokentest.MockOauthTokenService{
|
2024-07-25 17:52:14 +08:00
|
|
|
HasOAuthEntryFunc: func(ctx context.Context, usr identity.Requester) (*login.UserAuth, bool, error) {
|
2023-10-20 22:09:46 +08:00
|
|
|
hasEntryCalled = true
|
|
|
|
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
|
|
|
|
},
|
|
|
|
InvalidateOAuthTokensFunc: func(ctx context.Context, usr *login.UserAuth) error {
|
|
|
|
invalidateTokensCalled = true
|
|
|
|
return nil
|
|
|
|
},
|
2024-07-25 17:52:14 +08:00
|
|
|
TryTokenRefreshFunc: func(ctx context.Context, usr identity.Requester) error {
|
2023-10-20 22:09:46 +08:00
|
|
|
tryRefreshCalled = true
|
|
|
|
return tt.expectedTryRefreshErr
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
sessionService := &authtest.FakeUserAuthTokenService{
|
|
|
|
RevokeTokenProvider: func(ctx context.Context, token *auth.UserToken, soft bool) error {
|
|
|
|
revokeTokenCalled = true
|
|
|
|
return nil
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
if tt.oauthInfo == nil {
|
|
|
|
tt.oauthInfo = &social.OAuthInfo{
|
|
|
|
UseRefreshToken: true,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
socialService := &socialtest.FakeSocialService{
|
|
|
|
ExpectedAuthInfoProvider: tt.oauthInfo,
|
|
|
|
}
|
|
|
|
|
|
|
|
sync := &OAuthTokenSync{
|
2024-02-05 23:44:25 +08:00
|
|
|
log: log.NewNopLogger(),
|
|
|
|
service: service,
|
|
|
|
sessionService: sessionService,
|
|
|
|
socialService: socialService,
|
|
|
|
singleflightGroup: new(singleflight.Group),
|
2024-07-03 14:08:57 +08:00
|
|
|
tracer: tracing.InitializeTracerForTest(),
|
2023-10-20 22:09:46 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
err := sync.SyncOauthTokenHook(context.Background(), tt.identity, nil)
|
|
|
|
assert.ErrorIs(t, err, tt.expectedErr)
|
|
|
|
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled)
|
|
|
|
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled)
|
|
|
|
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled)
|
|
|
|
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|