mirror of https://github.com/grafana/grafana.git
iam: Refresh live connection when ID tokens expire (#107209)
* iam: refresh live connection when ID tokens expire * add coverage for the handler functions * reinstate inadvertently broken unit test
This commit is contained in:
parent
8d8b824f73
commit
4f66c4a2a1
|
@ -1363,6 +1363,7 @@ github.com/grafana/grafana/pkg/build v0.0.0-20250220114259-be81314e2118/go.mod h
|
|||
github.com/grafana/grafana/pkg/build v0.0.0-20250227105625-8f465f124924/go.mod h1:Vw0LdoMma64VgIMVpRY3i0D156jddgUGjTQBOcyeF3k=
|
||||
github.com/grafana/grafana/pkg/build v0.0.0-20250227163402-d78c646f93bb/go.mod h1:Vw0LdoMma64VgIMVpRY3i0D156jddgUGjTQBOcyeF3k=
|
||||
github.com/grafana/grafana/pkg/build v0.0.0-20250403075254-4918d8720c61/go.mod h1:LGVnSwdrS0ZnJ2WXEl5acgDoYPm74EUSFavca1NKHI8=
|
||||
github.com/grafana/grafana/pkg/build v0.0.0-20250625151647-35f89a456cc6/go.mod h1:dIu5dZy00k2TBdpVBXkvSbxHNj5H7lW/sOTpJTtKIXg=
|
||||
github.com/grafana/grafana/pkg/semconv v0.0.0-20250121113133-e747350fee2d/go.mod h1:tfLnBpPYgwrBMRz4EXqPCZJyCjEG4Ev37FSlXnocJ2c=
|
||||
github.com/grafana/grafana/pkg/semconv v0.0.0-20250627191313-2f1a6ae1712b/go.mod h1:mu3yl0GxB0eQZV1q7Kka0pkF3Th9x7W04WrjR9wqBlc=
|
||||
github.com/grafana/grafana/pkg/storage/unified/apistore v0.0.0-20250121113133-e747350fee2d/go.mod h1:CXpwZ3Mkw6xVlGKc0SqUxqXCP3Uv182q6qAQnLaLxRg=
|
||||
|
|
|
@ -3,6 +3,7 @@ module github.com/grafana/grafana/pkg/apimachinery
|
|||
go 1.24.4
|
||||
|
||||
require (
|
||||
github.com/go-jose/go-jose/v3 v3.0.4 // @grafana/identity-access-team
|
||||
github.com/grafana/authlib v0.0.0-20250618124654-54543efcfeed // @grafana/identity-access-team
|
||||
github.com/grafana/authlib/types v0.0.0-20250325095148-d6da9c164a7d // @grafana/identity-access-team
|
||||
github.com/stretchr/testify v1.10.0
|
||||
|
@ -15,7 +16,6 @@ require (
|
|||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
|
||||
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
|
||||
github.com/go-jose/go-jose/v3 v3.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.21.0 // indirect
|
||||
|
|
|
@ -3,7 +3,9 @@ package identity
|
|||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"k8s.io/apiserver/pkg/authentication/user"
|
||||
|
||||
claims "github.com/grafana/authlib/types"
|
||||
|
@ -125,3 +127,31 @@ func intIdentifier(typ claims.IdentityType, id string, expected ...claims.Identi
|
|||
|
||||
return 0, ErrNotIntIdentifier
|
||||
}
|
||||
|
||||
// IsIDTokenExpired returns true if the ID token is expired.
|
||||
// If no ID token exists, returns false.
|
||||
func IsIDTokenExpired(requester Requester) bool {
|
||||
idToken := requester.GetIDToken()
|
||||
if idToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
parsed, err := jwt.ParseSigned(idToken)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Expiry *jwt.NumericDate `json:"exp"`
|
||||
}
|
||||
if err := parsed.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if claims.Expiry != nil {
|
||||
expiryTime := claims.Expiry.Time()
|
||||
return time.Now().After(expiryTime)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
package identity_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/apimachinery/identity"
|
||||
)
|
||||
|
||||
func TestIsIDTokenExpired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token func(t *testing.T) string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "should return false when ID token is not set",
|
||||
token: func(t *testing.T) string {
|
||||
return ""
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "should return false when ID token is not expired",
|
||||
token: func(t *testing.T) string {
|
||||
expiration := time.Now().Add(time.Hour)
|
||||
return createToken(t, &expiration)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "should return true when ID token is expired",
|
||||
token: func(t *testing.T) string {
|
||||
expiration := time.Now().Add(-time.Hour)
|
||||
return createToken(t, &expiration)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "should return false when ID token has no expiry claim",
|
||||
token: func(t *testing.T) string {
|
||||
return createToken(t, nil)
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "should return false when ID token is malformed",
|
||||
token: func(t *testing.T) string {
|
||||
return "invalid.jwt.token"
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "should handle token that expires exactly now",
|
||||
token: func(t *testing.T) string {
|
||||
expiration := time.Now().Add(-time.Millisecond)
|
||||
return createToken(t, &expiration)
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token := tt.token(t)
|
||||
requester := &identity.StaticRequester{IDToken: token}
|
||||
|
||||
result := identity.IsIDTokenExpired(requester)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createToken(t *testing.T, exp *time.Time) string {
|
||||
key := []byte("test-secret-key")
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims := struct {
|
||||
jwt.Claims
|
||||
}{
|
||||
Claims: jwt.Claims{
|
||||
Subject: "test-user",
|
||||
},
|
||||
}
|
||||
|
||||
if exp != nil {
|
||||
claims.Expiry = jwt.NewNumericDate(*exp)
|
||||
}
|
||||
|
||||
token, err := jwt.Signed(signer).Claims(claims).CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
return token
|
||||
}
|
|
@ -641,6 +641,20 @@ func runConcurrentlyIfNeeded(ctx context.Context, semaphore chan struct{}, fn fu
|
|||
return nil
|
||||
}
|
||||
|
||||
func (g *GrafanaLive) checkIDTokenExpirationAndRefresh(user identity.Requester, client *centrifuge.Client) bool {
|
||||
if !identity.IsIDTokenExpired(user) {
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Debug("ID token expired, triggering refresh", "user", client.UserID(), "client", client.ID())
|
||||
err := g.node.Refresh(client.UserID(), centrifuge.WithRefreshExpired(true))
|
||||
if err != nil {
|
||||
logger.Error("Failed to refresh expired ID token", "user", client.UserID(), "client", client.ID(), "error", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (g *GrafanaLive) HandleDatasourceDelete(orgID int64, dsUID string) {
|
||||
if g.runStreamManager == nil {
|
||||
return
|
||||
|
@ -676,6 +690,12 @@ func (g *GrafanaLive) handleOnRPC(clientContextWithSpan context.Context, client
|
|||
logger.Error("No user found in context", "user", client.UserID(), "client", client.ID(), "method", e.Method)
|
||||
return centrifuge.RPCReply{}, centrifuge.ErrorInternal
|
||||
}
|
||||
|
||||
// Check if ID token is expired and trigger refresh if needed
|
||||
if expired := g.checkIDTokenExpirationAndRefresh(user, client); expired {
|
||||
return centrifuge.RPCReply{}, centrifuge.ErrorExpired
|
||||
}
|
||||
|
||||
var req dtos.MetricRequest
|
||||
err := json.Unmarshal(e.Data, &req)
|
||||
if err != nil {
|
||||
|
@ -712,6 +732,11 @@ func (g *GrafanaLive) handleOnSubscribe(clientContextWithSpan context.Context, c
|
|||
return centrifuge.SubscribeReply{}, centrifuge.ErrorInternal
|
||||
}
|
||||
|
||||
// Check if ID token is expired and trigger refresh if needed
|
||||
if expired := g.checkIDTokenExpirationAndRefresh(user, client); expired {
|
||||
return centrifuge.SubscribeReply{}, centrifuge.ErrorExpired
|
||||
}
|
||||
|
||||
// See a detailed comment for StripOrgID about orgID management in Live.
|
||||
orgID, channel, err := orgchannel.StripOrgID(e.Channel)
|
||||
if err != nil {
|
||||
|
@ -813,6 +838,11 @@ func (g *GrafanaLive) handleOnPublish(clientCtxWithSpan context.Context, client
|
|||
return centrifuge.PublishReply{}, centrifuge.ErrorInternal
|
||||
}
|
||||
|
||||
// Check if ID token is expired and trigger refresh if needed
|
||||
if expired := g.checkIDTokenExpirationAndRefresh(user, client); expired {
|
||||
return centrifuge.PublishReply{}, centrifuge.ErrorExpired
|
||||
}
|
||||
|
||||
// See a detailed comment for StripOrgID about orgID management in Live.
|
||||
orgID, channel, err := orgchannel.StripOrgID(e.Channel)
|
||||
if err != nil {
|
||||
|
|
|
@ -7,15 +7,20 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/centrifugal/centrifuge"
|
||||
"github.com/grafana/grafana/pkg/api/routing"
|
||||
"github.com/grafana/grafana/pkg/apimachinery/identity"
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
|
||||
"github.com/grafana/grafana/pkg/services/annotations/annotationstest"
|
||||
"github.com/grafana/grafana/pkg/services/dashboards"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/live/livecontext"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/tests/testsuite"
|
||||
)
|
||||
|
@ -29,20 +34,9 @@ func TestIntegration_provideLiveService_RedisUnavailable(t *testing.T) {
|
|||
|
||||
cfg.LiveHAEngine = "testredisunavailable"
|
||||
|
||||
_, err := ProvideService(nil, cfg,
|
||||
routing.NewRouteRegister(),
|
||||
nil, nil, nil, nil,
|
||||
db.InitTestDB(t),
|
||||
nil,
|
||||
&usagestats.UsageStatsMock{T: t},
|
||||
nil,
|
||||
featuremgmt.WithFeatures(),
|
||||
acimpl.ProvideAccessControl(featuremgmt.WithFeatures()),
|
||||
&dashboards.FakeDashboardService{},
|
||||
annotationstest.NewFakeAnnotationsRepo(),
|
||||
nil, nil)
|
||||
_, err := setupLiveService(cfg, t)
|
||||
|
||||
// Proceeds without live HA if redis is unavaialble
|
||||
// Proceeds without live HA if redis is unavailable
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
|
@ -233,3 +227,173 @@ func Test_getHistogramMetric(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_handleOnPublish_IDTokenExpiration(t *testing.T) {
|
||||
g, err := setupLiveService(nil, t)
|
||||
require.NoError(t, err)
|
||||
|
||||
client, _, err := centrifuge.NewClient(context.Background(), g.node, newDummyTransport("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("expired token", func(t *testing.T) {
|
||||
expiration := time.Now().Add(-time.Hour)
|
||||
token := createToken(t, &expiration)
|
||||
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
|
||||
reply, err := g.handleOnPublish(ctx, client, centrifuge.PublishEvent{
|
||||
Channel: "test",
|
||||
Data: []byte("test"),
|
||||
})
|
||||
require.ErrorIs(t, err, centrifuge.ErrorExpired)
|
||||
require.Empty(t, reply)
|
||||
})
|
||||
|
||||
t.Run("unexpired token", func(t *testing.T) {
|
||||
expiration := time.Now().Add(time.Hour)
|
||||
token := createToken(t, &expiration)
|
||||
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
|
||||
reply, err := g.handleOnPublish(ctx, client, centrifuge.PublishEvent{
|
||||
Channel: "test",
|
||||
Data: []byte("test"),
|
||||
})
|
||||
|
||||
// Another error is returned if the token is not expired but the refresh fails.
|
||||
// That happens because we're providing an invalid orgID as the channel.
|
||||
require.NotErrorIs(t, err, centrifuge.ErrorExpired)
|
||||
require.Empty(t, reply)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_handleOnRPC_IDTokenExpiration(t *testing.T) {
|
||||
g, err := setupLiveService(nil, t)
|
||||
require.NoError(t, err)
|
||||
|
||||
client, _, err := centrifuge.NewClient(context.Background(), g.node, newDummyTransport("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("expired token", func(t *testing.T) {
|
||||
expiration := time.Now().Add(-time.Hour)
|
||||
token := createToken(t, &expiration)
|
||||
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
|
||||
reply, err := g.handleOnRPC(ctx, client, centrifuge.RPCEvent{
|
||||
Method: "grafana.query",
|
||||
Data: []byte("test"),
|
||||
})
|
||||
require.ErrorIs(t, err, centrifuge.ErrorExpired)
|
||||
require.Empty(t, reply)
|
||||
})
|
||||
|
||||
t.Run("unexpired token", func(t *testing.T) {
|
||||
expiration := time.Now().Add(time.Hour)
|
||||
token := createToken(t, &expiration)
|
||||
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
|
||||
reply, err := g.handleOnRPC(ctx, client, centrifuge.RPCEvent{
|
||||
Method: "grafana.query",
|
||||
Data: []byte("test"),
|
||||
})
|
||||
|
||||
// Another error is returned if the token is not expired but the refresh fails.
|
||||
// That happens because we're providing an invalid orgID as the channel.
|
||||
require.NotErrorIs(t, err, centrifuge.ErrorExpired)
|
||||
require.Empty(t, reply)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_handleOnSubscribe_IDTokenExpiration(t *testing.T) {
|
||||
g, err := setupLiveService(nil, t)
|
||||
require.NoError(t, err)
|
||||
|
||||
client, _, err := centrifuge.NewClient(context.Background(), g.node, newDummyTransport("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("expired token", func(t *testing.T) {
|
||||
expiration := time.Now().Add(-time.Hour)
|
||||
token := createToken(t, &expiration)
|
||||
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
|
||||
reply, err := g.handleOnSubscribe(ctx, client, centrifuge.SubscribeEvent{
|
||||
Channel: "test",
|
||||
})
|
||||
require.ErrorIs(t, err, centrifuge.ErrorExpired)
|
||||
require.Empty(t, reply)
|
||||
})
|
||||
|
||||
t.Run("unexpired token", func(t *testing.T) {
|
||||
expiration := time.Now().Add(time.Hour)
|
||||
token := createToken(t, &expiration)
|
||||
ctx := livecontext.SetContextSignedUser(context.Background(), &identity.StaticRequester{IDToken: token})
|
||||
reply, err := g.handleOnSubscribe(ctx, client, centrifuge.SubscribeEvent{
|
||||
Channel: "test",
|
||||
})
|
||||
|
||||
// Another error is returned if the token is not expired but the refresh fails.
|
||||
// That happens because we're providing an invalid orgID as the channel.
|
||||
require.NotErrorIs(t, err, centrifuge.ErrorExpired)
|
||||
require.Empty(t, reply)
|
||||
})
|
||||
}
|
||||
|
||||
func setupLiveService(cfg *setting.Cfg, t *testing.T) (*GrafanaLive, error) {
|
||||
if cfg == nil {
|
||||
cfg = setting.NewCfg()
|
||||
}
|
||||
|
||||
return ProvideService(nil,
|
||||
cfg,
|
||||
routing.NewRouteRegister(),
|
||||
nil, nil, nil, nil,
|
||||
db.InitTestDB(t),
|
||||
nil,
|
||||
&usagestats.UsageStatsMock{T: t},
|
||||
nil,
|
||||
featuremgmt.WithFeatures(),
|
||||
acimpl.ProvideAccessControl(featuremgmt.WithFeatures()),
|
||||
&dashboards.FakeDashboardService{},
|
||||
annotationstest.NewFakeAnnotationsRepo(),
|
||||
nil, nil)
|
||||
}
|
||||
|
||||
type dummyTransport struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (t *dummyTransport) Name() string { return t.name }
|
||||
func (t *dummyTransport) Protocol() centrifuge.ProtocolType { return centrifuge.ProtocolTypeJSON }
|
||||
func (t *dummyTransport) ProtocolVersion() centrifuge.ProtocolVersion {
|
||||
return centrifuge.ProtocolVersion2
|
||||
}
|
||||
func (t *dummyTransport) Emulation() bool { return false }
|
||||
func (t *dummyTransport) Unidirectional() bool { return false }
|
||||
func (t *dummyTransport) DisabledPushFlags() uint64 { return 0 }
|
||||
func (t *dummyTransport) PingPongConfig() centrifuge.PingPongConfig {
|
||||
return centrifuge.PingPongConfig{}
|
||||
}
|
||||
func (t *dummyTransport) Write(data []byte) error { return nil }
|
||||
func (t *dummyTransport) WriteMany(d ...[]byte) error { return nil }
|
||||
func (t *dummyTransport) Close(disconnect centrifuge.Disconnect) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newDummyTransport(name string) *dummyTransport {
|
||||
return &dummyTransport{name: name}
|
||||
}
|
||||
|
||||
func createToken(t *testing.T, exp *time.Time) string {
|
||||
key := []byte("test-secret-key")
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims := struct {
|
||||
jwt.Claims
|
||||
}{
|
||||
Claims: jwt.Claims{
|
||||
Subject: "test-user",
|
||||
},
|
||||
}
|
||||
|
||||
if exp != nil {
|
||||
claims.Expiry = jwt.NewNumericDate(*exp)
|
||||
}
|
||||
|
||||
token, err := jwt.Signed(signer).Claims(claims).CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
return token
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue