diff --git a/pkg/services/authn/authnimpl/sync/oauth_token_sync.go b/pkg/services/authn/authnimpl/sync/oauth_token_sync.go index 7bb8dbf95ca..dc66af605e4 100644 --- a/pkg/services/authn/authnimpl/sync/oauth_token_sync.go +++ b/pkg/services/authn/authnimpl/sync/oauth_token_sync.go @@ -3,14 +3,18 @@ package sync import ( "context" "errors" + "fmt" "strings" "time" + "github.com/go-jose/go-jose/v3/jwt" + "github.com/grafana/grafana/pkg/infra/localcache" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/authn" + "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/util/errutil" @@ -64,10 +68,15 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn return nil } + idTokenExpiry, err := getIDTokenExpiry(token) + if err != nil { + s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err) + } + // token has no expire time configured, so we don't have to refresh it if token.OAuthExpiry.IsZero() { // cache the token check, so we don't perform it on every request - s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry)) + s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry)) return nil } @@ -84,11 +93,19 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn return nil } - expires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta) + accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta) + + hasIdTokenExpired := false + idTokenExpires := time.Time{} + + if !idTokenExpiry.IsZero() { + idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta) + hasIdTokenExpired = idTokenExpires.Before(time.Now()) + } // token has not expired, so we don't have to refresh it - if !expires.Before(time.Now()) { + if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired { // cache the token check, so we don't perform it on every request - s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(expires)) + s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires)) return nil } @@ -113,15 +130,47 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn const maxOAuthTokenCacheTTL = 10 * time.Minute -func getOAuthTokenCacheTTL(t time.Time) time.Duration { - if t.IsZero() { +func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration { + if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() { return maxOAuthTokenCacheTTL } - ttl := time.Until(t) - if ttl > maxOAuthTokenCacheTTL { - return maxOAuthTokenCacheTTL + min := func(a, b time.Duration) time.Duration { + if a <= b { + return a + } + return b } - return ttl + if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() { + return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL) + } + + if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() { + return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL) + } + + return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL) +} + +// getIDTokenExpiry extracts the expiry time from the ID token +func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) { + if token.OAuthIdToken == "" { + return time.Time{}, nil + } + + parsedToken, err := jwt.ParseSigned(token.OAuthIdToken) + if err != nil { + return time.Time{}, fmt.Errorf("error parsing id token: %w", err) + } + + type Claims struct { + Exp int64 `json:"exp"` + } + var claims Claims + if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil { + return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err) + } + + return time.Unix(claims.Exp, 0), nil } diff --git a/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go b/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go index 75606ab8689..fed27e854fa 100644 --- a/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go +++ b/pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go @@ -2,12 +2,13 @@ package sync import ( "context" + "encoding/base64" + "encoding/json" "errors" + "fmt" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/grafana/grafana/pkg/infra/localcache" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login/social" @@ -18,9 +19,11 @@ import ( "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest" "github.com/grafana/grafana/pkg/services/user" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) { +func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) { type testCase struct { desc string identity *authn.Identity @@ -95,6 +98,13 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) { expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)}, oauthInfo: &social.OAuthInfo{UseRefreshToken: false}, }, + { + desc: "should refresh access token when ID token has expired", + identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}}, + expectHasEntryCalled: true, + expectTryRefreshTokenCalled: true, + expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))}, + }, } for _, tt := range tests { @@ -155,3 +165,93 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) { }) } } + +// fakeIDToken is used to create a fake invalid token to verify expiry logic +func fakeIDToken(t *testing.T, expiryDate time.Time) string { + type Header struct { + Kid string `json:"kid"` + Alg string `json:"alg"` + } + type Payload struct { + Iss string `json:"iss"` + Sub string `json:"sub"` + Exp int64 `json:"exp"` + } + + header, err := json.Marshal(Header{Kid: "123", Alg: "none"}) + require.NoError(t, err) + u := expiryDate.UTC().Unix() + payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u}) + require.NoError(t, err) + + fakeSignature := []byte("6ICJm") + return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature)) +} + +func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) { + defaultTime := time.Now() + tests := []struct { + name string + accessTokenExpiry time.Time + idTokenExpiry time.Time + want time.Duration + }{ + { + name: "should return maxOAuthTokenCacheTTL when no expiry is given", + accessTokenExpiry: time.Time{}, + idTokenExpiry: time.Time{}, + + want: maxOAuthTokenCacheTTL, + }, + { + name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl", + accessTokenExpiry: time.Time{}, + idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + + want: maxOAuthTokenCacheTTL, + }, + { + name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl", + accessTokenExpiry: time.Time{}, + idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), + want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), + }, + { + name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given", + accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + idTokenExpiry: time.Time{}, + want: maxOAuthTokenCacheTTL, + }, + { + name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given", + accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), + idTokenExpiry: time.Time{}, + want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), + }, + { + name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry", + accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL), + idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)), + }, + { + name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry", + accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL), + want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)), + }, + { + name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl", + accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL), + want: maxOAuthTokenCacheTTL, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry) + + assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second)) + }) + } +}