mirror of https://github.com/grafana/grafana.git
Auth: Check id token expiry date (#69829)
* fixed: added id token expiry check to oauth token sync * use go-jose and id token in cache * Update pkg/services/authn/authnimpl/sync/oauth_token_sync.go * refactored getOAuthTokenCacheTTL and added unit tests * Small changes to oauth_token_sync * Remove unnecessary contexthandler changes --------- Co-authored-by: linoman <2051016+linoman@users.noreply.github.com> Co-authored-by: Mihaly Gyongyosi <mgyongyosi@users.noreply.github.com>
This commit is contained in:
parent
7bf3998510
commit
6d98d06f6e
|
@ -3,14 +3,18 @@ package sync
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
"github.com/grafana/grafana/pkg/util/errutil"
|
"github.com/grafana/grafana/pkg/util/errutil"
|
||||||
|
@ -64,10 +68,15 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
idTokenExpiry, err := getIDTokenExpiry(token)
|
||||||
|
if err != nil {
|
||||||
|
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
// token has no expire time configured, so we don't have to refresh it
|
// token has no expire time configured, so we don't have to refresh it
|
||||||
if token.OAuthExpiry.IsZero() {
|
if token.OAuthExpiry.IsZero() {
|
||||||
// cache the token check, so we don't perform it on every request
|
// cache the token check, so we don't perform it on every request
|
||||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry))
|
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,11 +93,19 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
expires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||||
|
|
||||||
|
hasIdTokenExpired := false
|
||||||
|
idTokenExpires := time.Time{}
|
||||||
|
|
||||||
|
if !idTokenExpiry.IsZero() {
|
||||||
|
idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||||
|
hasIdTokenExpired = idTokenExpires.Before(time.Now())
|
||||||
|
}
|
||||||
// token has not expired, so we don't have to refresh it
|
// token has not expired, so we don't have to refresh it
|
||||||
if !expires.Before(time.Now()) {
|
if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired {
|
||||||
// cache the token check, so we don't perform it on every request
|
// cache the token check, so we don't perform it on every request
|
||||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(expires))
|
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,15 +130,47 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||||
|
|
||||||
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
||||||
|
|
||||||
func getOAuthTokenCacheTTL(t time.Time) time.Duration {
|
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
|
||||||
if t.IsZero() {
|
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||||
return maxOAuthTokenCacheTTL
|
return maxOAuthTokenCacheTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
ttl := time.Until(t)
|
min := func(a, b time.Duration) time.Duration {
|
||||||
if ttl > maxOAuthTokenCacheTTL {
|
if a <= b {
|
||||||
return maxOAuthTokenCacheTTL
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
return ttl
|
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
|
||||||
|
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||||
|
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getIDTokenExpiry extracts the expiry time from the ID token
|
||||||
|
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
|
||||||
|
if token.OAuthIdToken == "" {
|
||||||
|
return time.Time{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Claims struct {
|
||||||
|
Exp int64 `json:"exp"`
|
||||||
|
}
|
||||||
|
var claims Claims
|
||||||
|
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Unix(claims.Exp, 0), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,12 +2,13 @@ package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
|
@ -18,9 +19,11 @@ import (
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
desc string
|
desc string
|
||||||
identity *authn.Identity
|
identity *authn.Identity
|
||||||
|
@ -95,6 +98,13 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||||
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
|
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "should refresh access token when ID token has expired",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectTryRefreshTokenCalled: true,
|
||||||
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -155,3 +165,93 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fakeIDToken is used to create a fake invalid token to verify expiry logic
|
||||||
|
func fakeIDToken(t *testing.T, expiryDate time.Time) string {
|
||||||
|
type Header struct {
|
||||||
|
Kid string `json:"kid"`
|
||||||
|
Alg string `json:"alg"`
|
||||||
|
}
|
||||||
|
type Payload struct {
|
||||||
|
Iss string `json:"iss"`
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
Exp int64 `json:"exp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := json.Marshal(Header{Kid: "123", Alg: "none"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
u := expiryDate.UTC().Unix()
|
||||||
|
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
fakeSignature := []byte("6ICJm")
|
||||||
|
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
|
||||||
|
defaultTime := time.Now()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accessTokenExpiry time.Time
|
||||||
|
idTokenExpiry time.Time
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
|
||||||
|
accessTokenExpiry: time.Time{},
|
||||||
|
idTokenExpiry: time.Time{},
|
||||||
|
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
|
||||||
|
accessTokenExpiry: time.Time{},
|
||||||
|
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
|
||||||
|
accessTokenExpiry: time.Time{},
|
||||||
|
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
|
||||||
|
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: time.Time{},
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
|
||||||
|
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: time.Time{},
|
||||||
|
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
|
||||||
|
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
|
||||||
|
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
|
||||||
|
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||||
|
want: maxOAuthTokenCacheTTL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue