From efeb0daec639f725ff94f56c53da9671059c61d4 Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Mon, 30 Jan 2023 12:45:04 +0100 Subject: [PATCH] AuthN: Add oauth clients and perform oauth authentication with authn.Service (#62072) * AuthN: Update signature of redirect client and RedirectURL function * OAuth: use authn.Service to perform oauth authentication and login if feature toggle is enabled * AuthN: register oauth clients * AuthN: set auth module metadata * AuthN: add logs for failed login attempts * AuthN: Don't use enable disabled setting * OAuth: only run hooks when authnService feature toggle is disabled * OAuth: Add function to handle oauth errors from authn.Service --- pkg/api/login_oauth.go | 104 +++++-- pkg/api/login_oauth_test.go | 1 + pkg/services/authn/authn.go | 27 +- pkg/services/authn/authnimpl/service.go | 24 +- pkg/services/authn/authnimpl/service_test.go | 12 +- pkg/services/authn/authntest/fake.go | 10 +- pkg/services/authn/clients/oauth.go | 237 ++++++++++++++ pkg/services/authn/clients/oauth_test.go | 305 +++++++++++++++++++ 8 files changed, 681 insertions(+), 39 deletions(-) create mode 100644 pkg/services/authn/clients/oauth.go create mode 100644 pkg/services/authn/clients/oauth_test.go diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index f444615081b..9572f2f0f59 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -17,11 +17,14 @@ import ( "github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/middleware/cookies" + "github.com/grafana/grafana/pkg/services/authn" contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" + "github.com/grafana/grafana/pkg/services/featuremgmt" loginservice "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util/errutil" "github.com/grafana/grafana/pkg/web" ) @@ -70,11 +73,59 @@ func genPKCECode() (string, string, error) { } func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { - loginInfo := loginservice.LoginInfo{ - AuthModule: "oauth", - } name := web.Params(ctx.Req)[":name"] - loginInfo.AuthModule = name + loginInfo := loginservice.LoginInfo{AuthModule: name} + + if errorParam := ctx.Query("error"); errorParam != "" { + errorDesc := ctx.Query("error_description") + oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc) + hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc) + return + } + + code := ctx.Query("code") + + if hs.Features.IsEnabled(featuremgmt.FlagAuthnService) { + req := &authn.Request{HTTPRequest: ctx.Req, Resp: ctx.Resp} + if code == "" { + redirect, err := hs.authnService.RedirectURL(ctx.Req.Context(), authn.ClientWithPrefix(name), req) + if err != nil { + hs.handleAuthnOAuthErr(ctx, "failed to generate oauth redirect url", err) + return + } + + if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" { + cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) + } + + cookies.WriteCookie(ctx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) + ctx.Redirect(redirect.URL) + return + } + + identity, err := hs.authnService.Login(ctx.Req.Context(), authn.ClientWithPrefix(name), req) + // NOTE: always delete these cookies, even if login failed + cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg) + cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg) + + if err != nil { + hs.handleAuthnOAuthErr(ctx, "failed to perform login for oauth request", err) + return + } + + metrics.MApiLoginOAuth.Inc() + cookies.WriteSessionCookie(ctx, hs.Cfg, identity.SessionToken.UnhashedToken, hs.Cfg.LoginMaxLifetime) + + redirectURL := setting.AppSubUrl + "/" + if redirectTo := ctx.GetCookie("redirect_to"); len(redirectTo) > 0 && hs.ValidateRedirectTo(redirectTo) == nil { + redirectURL = redirectTo + cookies.DeleteCookie(ctx.Resp, "redirect_to", hs.CookieOptionsFromCfg) + } + + ctx.Redirect(redirectURL) + return + } + provider := hs.SocialService.GetOAuthInfoProvider(name) if provider == nil { hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, errors.New("OAuth not enabled")) @@ -87,15 +138,6 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { return } - errorParam := ctx.Query("error") - if errorParam != "" { - errorDesc := ctx.Query("error_description") - oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc) - hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc) - return - } - - code := ctx.Query("code") if code == "" { var opts []oauth2.AuthCodeOption if provider.UsePKCE { @@ -106,6 +148,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { HttpStatus: http.StatusInternalServerError, PublicMessage: "An internal error occurred", }) + return } cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) @@ -345,6 +388,19 @@ func (hs *HTTPServer) hashStatecode(code, seed string) string { return hex.EncodeToString(hashBytes[:]) } +func (hs *HTTPServer) handleAuthnOAuthErr(c *contextmodel.ReqContext, msg string, err error) { + gfErr := &errutil.Error{} + if errors.As(err, gfErr) { + if gfErr.Public().Message != "" { + c.Handle(hs.Cfg, gfErr.Public().StatusCode, gfErr.Public().Message, err) + return + } + } + + c.Logger.Warn(msg, "err", err) + c.Redirect(hs.Cfg.AppSubURL + "/login") +} + type LoginError struct { HttpStatus int PublicMessage string @@ -354,18 +410,24 @@ type LoginError struct { func (hs *HTTPServer) handleOAuthLoginError(ctx *contextmodel.ReqContext, info loginservice.LoginInfo, err LoginError) { ctx.Handle(hs.Cfg, err.HttpStatus, err.PublicMessage, err.Err) - info.Error = err.Err - if info.Error == nil { - info.Error = errors.New(err.PublicMessage) - } - info.HTTPStatus = err.HttpStatus + // login hooks is handled by authn.Service + if !hs.Features.IsEnabled(featuremgmt.FlagAuthnService) { + info.Error = err.Err + if info.Error == nil { + info.Error = errors.New(err.PublicMessage) + } + info.HTTPStatus = err.HttpStatus - hs.HooksService.RunLoginHook(&info, ctx) + hs.HooksService.RunLoginHook(&info, ctx) + } } func (hs *HTTPServer) handleOAuthLoginErrorWithRedirect(ctx *contextmodel.ReqContext, info loginservice.LoginInfo, err error, v ...interface{}) { hs.redirectWithError(ctx, err, v...) - info.Error = err - hs.HooksService.RunLoginHook(&info, ctx) + // login hooks is handled by authn.Service + if !hs.Features.IsEnabled(featuremgmt.FlagAuthnService) { + info.Error = err + hs.HooksService.RunLoginHook(&info, ctx) + } } diff --git a/pkg/api/login_oauth_test.go b/pkg/api/login_oauth_test.go index b8b798e1deb..cbd1614a9dc 100644 --- a/pkg/api/login_oauth_test.go +++ b/pkg/api/login_oauth_test.go @@ -34,6 +34,7 @@ func setupSocialHTTPServerWithConfig(t *testing.T, cfg *setting.Cfg) *HTTPServer SocialService: social.ProvideService(cfg, featuremgmt.WithFeatures()), HooksService: hooks.ProvideService(), SecretsService: fakes.NewFakeSecretsService(), + Features: featuremgmt.WithFeatures(), } } diff --git a/pkg/services/authn/authn.go b/pkg/services/authn/authn.go index 64522ff26b5..0a762fd9c53 100644 --- a/pkg/services/authn/authn.go +++ b/pkg/services/authn/authn.go @@ -33,7 +33,7 @@ const ( MetaKeyAuthModule = "authModule" ) -// ClientParams are hints to the auth serviAuthN: Post login hooksce about how to handle the identity management +// ClientParams are hints to the auth service about how to handle the identity management // from the authenticating client. type ClientParams struct { // Update the internal representation of the entity from the identity provided @@ -42,7 +42,7 @@ type ClientParams struct { SyncTeamMembers bool // Create entity in the DB if it doesn't exist AllowSignUp bool - // EnableDisabledUsers is a hint to the auth service that it should reenable disabled users + // EnableDisabledUsers is a hint to the auth service that it should re-enable disabled users EnableDisabledUsers bool // LookUpParams are the arguments used to look up the entity in the DB. LookUpParams login.UserLookupParams @@ -63,7 +63,7 @@ type Service interface { // A lower number means higher priority. RegisterPostLoginHook(hook PostLoginHookFn, priority uint) // RedirectURL will generate url that we can use to initiate auth flow for supported clients. - RedirectURL(ctx context.Context, client string, r *Request) (string, error) + RedirectURL(ctx context.Context, client string, r *Request) (*Redirect, error) } type Client interface { @@ -83,7 +83,7 @@ type ContextAwareClient interface { type RedirectClient interface { Client - RedirectURL(ctx context.Context, r *Request) (string, error) + RedirectURL(ctx context.Context, r *Request) (*Redirect, error) } type PasswordClient interface { @@ -122,6 +122,18 @@ func (r *Request) GetMeta(k string) string { return r.metadata[k] } +const ( + KeyOAuthPKCE = "pkce" + KeyOAuthState = "state" +) + +type Redirect struct { + // Url used for redirect + URL string + // Extra contains data used for redirect, e.g. for oauth this would be state and pkce + Extra map[string]string +} + const ( NamespaceUser = "user" NamespaceAPIKey = "api-key" @@ -144,7 +156,7 @@ type Identity struct { ID string // IsAnonymous IsAnonymous bool - // Login is the short hand identifier of the entity. Should be unique. + // Login is the shorthand identifier of the entity. Should be unique. Login string // Name is the display name of the entity. It is not guaranteed to be unique. Name string @@ -283,3 +295,8 @@ func IdentityFromSignedInUser(id string, usr *user.SignedInUser, params ClientPa ClientParams: params, } } + +// ClientWithPrefix returns a client name prefixed with "auth.client." +func ClientWithPrefix(name string) string { + return fmt.Sprintf("auth.client.%s", name) +} diff --git a/pkg/services/authn/authnimpl/service.go b/pkg/services/authn/authnimpl/service.go index cba480ee297..52ddf7e43a1 100644 --- a/pkg/services/authn/authnimpl/service.go +++ b/pkg/services/authn/authnimpl/service.go @@ -11,6 +11,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/network" "github.com/grafana/grafana/pkg/infra/tracing" + "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/apikey" "github.com/grafana/grafana/pkg/services/auth" @@ -52,6 +53,7 @@ func ProvideService( loginAttempts loginattempt.Service, quotaService quota.Service, authInfoService login.AuthInfoService, renderService rendering.Service, features *featuremgmt.FeatureManager, oauthTokenService oauthtoken.OAuthTokenService, + socialService social.Service, ) *Service { s := &Service{ log: log.New("authn.service"), @@ -116,6 +118,21 @@ func ProvideService( s.RegisterClient(clients.ProvideJWT(jwtService, cfg)) } + for name := range socialService.GetOAuthProviders() { + oauthCfg := socialService.GetOAuthInfoProvider(name) + if oauthCfg != nil && oauthCfg.Enabled { + clientName := authn.ClientWithPrefix(name) + + connector, errConnector := socialService.GetConnector(name) + httpClient, errHTTPClient := socialService.GetOAuthHttpClient(name) + if errConnector != nil || errHTTPClient != nil { + s.log.Error("failed to configure oauth client", "client", clientName, "err", multierror.Append(errConnector, errHTTPClient)) + } + + s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient)) + } + } + // FIXME (jguer): move to User package userSyncService := sync.ProvideUserSync(userService, userProtectionService, authInfoService, quotaService) orgUserSyncService := sync.ProvideOrgSync(userService, orgService, accessControlService) @@ -233,6 +250,7 @@ func (s *Service) Login(ctx context.Context, client string, r *authn.Request) (i sessionToken, err := s.sessionService.CreateToken(ctx, &user.User{ID: id}, ip, r.HTTPRequest.UserAgent()) if err != nil { + s.log.FromContext(ctx).Error("failed to create session", "client", client, "userId", id, "err", err) return nil, err } @@ -244,19 +262,19 @@ func (s *Service) RegisterPostLoginHook(hook authn.PostLoginHookFn, priority uin s.postLoginHooks.insert(hook, priority) } -func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Request) (string, error) { +func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Request) (*authn.Redirect, error) { ctx, span := s.tracer.Start(ctx, "authn.RedirectURL") defer span.End() span.SetAttributes(attributeKeyClient, client, attribute.Key(attributeKeyClient).String(client)) c, ok := s.clients[client] if !ok { - return "", authn.ErrClientNotConfigured.Errorf("client not configured: %s", client) + return nil, authn.ErrClientNotConfigured.Errorf("client not configured: %s", client) } redirectClient, ok := c.(authn.RedirectClient) if !ok { - return "", authn.ErrUnsupportedClient.Errorf("client does not support generating redirect url: %s", client) + return nil, authn.ErrUnsupportedClient.Errorf("client does not support generating redirect url: %s", client) } return redirectClient.RedirectURL(ctx, r) diff --git a/pkg/services/authn/authnimpl/service_test.go b/pkg/services/authn/authnimpl/service_test.go index a778242b2c2..5e6f8631d51 100644 --- a/pkg/services/authn/authnimpl/service_test.go +++ b/pkg/services/authn/authnimpl/service_test.go @@ -243,16 +243,13 @@ func TestService_RedirectURL(t *testing.T) { type testCase struct { desc string client string - expectedURL string expectedErr error } tests := []testCase{ { - desc: "should generate url for valid redirect client", - client: "redirect", - expectedURL: "https://localhost/redirect", - expectedErr: nil, + desc: "should generate url for valid redirect client", + client: "redirect", }, { desc: "should return error on non existing client", @@ -269,13 +266,12 @@ func TestService_RedirectURL(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { service := setupTests(t, func(svc *Service) { - svc.RegisterClient(authntest.FakeRedirectClient{ExpectedName: "redirect", ExpectedURL: tt.expectedURL}) + svc.RegisterClient(authntest.FakeRedirectClient{ExpectedName: "redirect"}) svc.RegisterClient(&authntest.FakeClient{ExpectedName: "non-redirect"}) }) - u, err := service.RedirectURL(context.Background(), tt.client, nil) + _, err := service.RedirectURL(context.Background(), tt.client, nil) assert.ErrorIs(t, err, tt.expectedErr) - assert.Equal(t, tt.expectedURL, u) }) } } diff --git a/pkg/services/authn/authntest/fake.go b/pkg/services/authn/authntest/fake.go index a048bf6a091..a3e119ead9e 100644 --- a/pkg/services/authn/authntest/fake.go +++ b/pkg/services/authn/authntest/fake.go @@ -53,6 +53,8 @@ type FakeRedirectClient struct { ExpectedErr error ExpectedURL string ExpectedName string + ExpectedOK bool + ExpectedRedirect *authn.Redirect ExpectedIdentity *authn.Identity } @@ -64,6 +66,10 @@ func (f FakeRedirectClient) Authenticate(ctx context.Context, r *authn.Request) return f.ExpectedIdentity, f.ExpectedErr } -func (f FakeRedirectClient) RedirectURL(ctx context.Context, r *authn.Request) (string, error) { - return f.ExpectedURL, f.ExpectedErr +func (f FakeRedirectClient) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) { + return f.ExpectedRedirect, f.ExpectedErr +} + +func (f FakeRedirectClient) Test(ctx context.Context, r *authn.Request) bool { + return f.ExpectedOK } diff --git a/pkg/services/authn/clients/oauth.go b/pkg/services/authn/clients/oauth.go new file mode 100644 index 00000000000..0c99385ee6f --- /dev/null +++ b/pkg/services/authn/clients/oauth.go @@ -0,0 +1,237 @@ +package clients + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "net/http" + "strings" + + "golang.org/x/oauth2" + + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/services/authn" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/org" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util/errutil" +) + +const ( + hostedDomainParamName = "hd" + codeVerifierParamName = "code_verifier" + codeChallengeParamName = "code_challenge" + codeChallengeMethodParamName = "code_challenge_method" + codeChallengeMethod = "S256" + + oauthStateQueryName = "state" + oauthStateCookieName = "oauth_state" + oauthPKCECookieName = "oauth_code_verifier" +) + +var ( + errOAuthGenPKCE = errutil.NewBase(errutil.StatusInternal, "auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred")) + errOAuthMissingPKCE = errutil.NewBase(errutil.StatusBadRequest, "auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie")) + + errOAuthGenState = errutil.NewBase(errutil.StatusInternal, "auth.oauth.state.internal", errutil.WithPublicMessage("An internal error occurred")) + errOAuthMissingState = errutil.NewBase(errutil.StatusBadRequest, "auth.oauth.state.missing", errutil.WithPublicMessage("Missing saved oauth state")) + errOAuthInvalidState = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.state.invalid", errutil.WithPublicMessage("Provided state does not match stored state")) + + errOAuthTokenExchange = errutil.NewBase(errutil.StatusInternal, "auth.oauth.token.exchange", errutil.WithPublicMessage("Failed to get token from provider")) + + errOAuthMissingRequiredEmail = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.email.missing") + errOAuthEmailNotAllowed = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.email.not-allowed") +) + +var _ authn.RedirectClient = new(OAuth) + +func ProvideOAuth( + name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo, + connector social.SocialConnector, httpClient *http.Client, +) *OAuth { + return &OAuth{ + name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")), + log.New(name), cfg, oauthCfg, connector, httpClient, + } +} + +type OAuth struct { + name string + moduleName string + log log.Logger + cfg *setting.Cfg + oauthCfg *social.OAuthInfo + connector social.SocialConnector + httpClient *http.Client +} + +func (c *OAuth) Name() string { + return c.name +} + +func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) { + r.SetMeta(authn.MetaKeyAuthModule, c.moduleName) + // get hashed state stored in cookie + stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName) + if err != nil { + return nil, errOAuthMissingState.Errorf("missing state cookie") + } + + if stateCookie.Value == "" { + return nil, errOAuthMissingState.Errorf("missing state value in state cookie") + } + + // get state returned by the idp and hash it + stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, c.oauthCfg.ClientSecret) + // compare the state returned by idp against the one we stored in cookie + if stateQuery != stateCookie.Value { + return nil, errOAuthInvalidState.Errorf("provided state did not match stored state") + } + + var opts []oauth2.AuthCodeOption + // if pkce is enabled for client validate we have the cookie and set it as url param + if c.oauthCfg.UsePKCE { + pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName) + if err != nil { + return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err) + } + opts = append(opts, oauth2.SetAuthURLParam(codeVerifierParamName, pkceCookie.Value)) + } + + clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient) + // exchange auth code to a valid token + token, err := c.connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...) + if err != nil { + return nil, err + } + token.TokenType = "Bearer" + + userInfo, err := c.connector.UserInfo(c.connector.Client(clientCtx, token), token) + if err != nil { + return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err) + } + + if userInfo.Email == "" { + return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided") + } + + if !c.connector.IsEmailAllowed(userInfo.Email) { + return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed") + } + + return &authn.Identity{ + Login: userInfo.Login, + Name: userInfo.Name, + Email: userInfo.Email, + IsGrafanaAdmin: userInfo.IsGrafanaAdmin, + AuthModule: c.moduleName, + AuthID: userInfo.Id, + Groups: userInfo.Groups, + OAuthToken: token, + OrgRoles: getOAuthOrgRole(userInfo, c.cfg), + ClientParams: authn.ClientParams{ + SyncUser: true, + SyncTeamMembers: true, + AllowSignUp: c.connector.IsSignupAllowed(), + LookUpParams: login.UserLookupParams{Email: &userInfo.Email}, + }, + }, nil +} + +func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) { + var opts []oauth2.AuthCodeOption + + if c.oauthCfg.HostedDomain != "" { + opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, c.oauthCfg.HostedDomain)) + } + + var plainPKCE string + if c.oauthCfg.UsePKCE { + pkce, hashedPKCE, err := genPKCECode() + if err != nil { + return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err) + } + + plainPKCE = pkce + opts = append(opts, + oauth2.SetAuthURLParam(codeChallengeParamName, hashedPKCE), + oauth2.SetAuthURLParam(codeChallengeMethodParamName, codeChallengeMethod), + ) + } + + state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret) + if err != nil { + return nil, errOAuthGenState.Errorf("failed to generate state: %w", err) + } + + return &authn.Redirect{ + URL: c.connector.AuthCodeURL(state, opts...), + Extra: map[string]string{ + authn.KeyOAuthState: hashedSate, + authn.KeyOAuthPKCE: plainPKCE, + }, + }, nil +} + +// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest. +func genPKCECode() (string, string, error) { + // IETF RFC 7636 specifies that the code verifier should be 43-128 + // characters from a set of unreserved URI characters which is + // almost the same as the set of characters in base64url. + // https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 + // + // It doesn't hurt to generate a few more bytes here, we generate + // 96 bytes which we then encode using base64url to make sure + // they're within the set of unreserved characters. + // + // 96 is chosen because 96*8/6 = 128, which means that we'll have + // 128 characters after it has been base64 encoded. + raw := make([]byte, 96) + _, err := rand.Read(raw) + if err != nil { + return "", "", err + } + ascii := make([]byte, 128) + base64.RawURLEncoding.Encode(ascii, raw) + + shasum := sha256.Sum256(ascii) + pkce := base64.RawURLEncoding.EncodeToString(shasum[:]) + return string(ascii), pkce, nil +} + +func genOAuthState(secret, seed string) (string, string, error) { + rnd := make([]byte, 32) + if _, err := rand.Read(rnd); err != nil { + return "", "", err + } + state := base64.URLEncoding.EncodeToString(rnd) + return state, hashOAuthState(state, secret, seed), nil +} + +func hashOAuthState(state, secret, seed string) string { + hashBytes := sha256.Sum256([]byte(state + secret + seed)) + return hex.EncodeToString(hashBytes[:]) +} + +func getOAuthOrgRole(userInfo *social.BasicUserInfo, cfg *setting.Cfg) map[int64]org.RoleType { + orgRoles := make(map[int64]org.RoleType, 0) + if cfg.OAuthSkipOrgRoleUpdateSync { + return orgRoles + } + + if userInfo.Role == "" || !userInfo.Role.IsValid() { + return orgRoles + } + + orgID := int64(1) + if cfg.AutoAssignOrg && cfg.AutoAssignOrgId > 0 { + orgID = int64(cfg.AutoAssignOrgId) + } + + orgRoles[orgID] = userInfo.Role + return orgRoles +} diff --git a/pkg/services/authn/clients/oauth_test.go b/pkg/services/authn/clients/oauth_test.go new file mode 100644 index 00000000000..b4f5469d23c --- /dev/null +++ b/pkg/services/authn/clients/oauth_test.go @@ -0,0 +1,305 @@ +package clients + +import ( + "context" + "net/http" + "net/url" + "testing" + + "golang.org/x/oauth2" + + "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/services/authn" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/org" + "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOAuth_Authenticate(t *testing.T) { + type testCase struct { + desc string + req *authn.Request + oauthCfg *social.OAuthInfo + + addStateCookie bool + stateCookieValue string + + addPKCECookie bool + pkceCookieValue string + + isEmailAllowed bool + userInfo *social.BasicUserInfo + + expectedErr error + expectedIdentity *authn.Identity + } + + tests := []testCase{ + { + desc: "should return error when missing state cookie", + req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}}, + oauthCfg: &social.OAuthInfo{}, + expectedErr: errOAuthMissingState, + }, + { + desc: "should return error when state cookie is present but don't have a value", + req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}}, + oauthCfg: &social.OAuthInfo{}, + addStateCookie: true, + stateCookieValue: "", + expectedErr: errOAuthMissingState, + }, + { + desc: "should return error when state from ipd does not match stored state", + req: &authn.Request{HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-other-state"), + }, + }, + oauthCfg: &social.OAuthInfo{UsePKCE: true}, + addStateCookie: true, + stateCookieValue: "some-state", + expectedErr: errOAuthInvalidState, + }, + { + desc: "should return error when pkce is configured but the cookie is not present", + req: &authn.Request{HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-state"), + }, + }, + oauthCfg: &social.OAuthInfo{UsePKCE: true}, + addStateCookie: true, + stateCookieValue: "some-state", + expectedErr: errOAuthMissingPKCE, + }, + { + desc: "should return error when email is empty", + req: &authn.Request{HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-state"), + }, + }, + oauthCfg: &social.OAuthInfo{UsePKCE: true}, + addStateCookie: true, + stateCookieValue: "some-state", + addPKCECookie: true, + pkceCookieValue: "some-pkce-value", + userInfo: &social.BasicUserInfo{}, + expectedErr: errOAuthMissingRequiredEmail, + }, + { + desc: "should return error when email is not allowed", + req: &authn.Request{HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-state"), + }, + }, + oauthCfg: &social.OAuthInfo{UsePKCE: true}, + addStateCookie: true, + stateCookieValue: "some-state", + addPKCECookie: true, + pkceCookieValue: "some-pkce-value", + userInfo: &social.BasicUserInfo{Email: "some@email.com"}, + isEmailAllowed: false, + expectedErr: errOAuthEmailNotAllowed, + }, + { + desc: "should return identity for valid request", + req: &authn.Request{HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://grafana.com/?state=some-state"), + }, + }, + oauthCfg: &social.OAuthInfo{UsePKCE: true}, + addStateCookie: true, + stateCookieValue: "some-state", + addPKCECookie: true, + pkceCookieValue: "some-pkce-value", + isEmailAllowed: true, + userInfo: &social.BasicUserInfo{ + Id: "123", + Name: "name", + Email: "some@email.com", + Role: "Admin", + Groups: []string{"grp1", "grp2"}, + }, + expectedIdentity: &authn.Identity{ + Email: "some@email.com", + AuthModule: "oauth_azuread", + AuthID: "123", + Name: "name", + Groups: []string{"grp1", "grp2"}, + OAuthToken: &oauth2.Token{}, + OrgRoles: map[int64]org.RoleType{1: org.RoleAdmin}, + ClientParams: authn.ClientParams{ + SyncUser: true, + SyncTeamMembers: true, + AllowSignUp: true, + LookUpParams: login.UserLookupParams{Email: strPtr("some@email.com")}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + cfg := setting.NewCfg() + + if tt.addStateCookie { + v := tt.stateCookieValue + if v != "" { + v = hashOAuthState(v, cfg.SecretKey, tt.oauthCfg.ClientSecret) + } + tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthStateCookieName, Value: v}) + } + + if tt.addPKCECookie { + tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue}) + } + + c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, tt.oauthCfg, fakeConnector{ + ExpectedUserInfo: tt.userInfo, + ExpectedToken: &oauth2.Token{}, + ExpectedIsSignupAllowed: true, + ExpectedIsEmailAllowed: tt.isEmailAllowed, + }, nil) + identity, err := c.Authenticate(context.Background(), tt.req) + assert.ErrorIs(t, err, tt.expectedErr) + + if tt.expectedIdentity != nil { + assert.Equal(t, tt.expectedIdentity.Login, identity.Login) + assert.Equal(t, tt.expectedIdentity.Name, identity.Name) + assert.Equal(t, tt.expectedIdentity.Email, identity.Email) + assert.Equal(t, tt.expectedIdentity.AuthID, identity.AuthID) + assert.Equal(t, tt.expectedIdentity.AuthModule, identity.AuthModule) + assert.Equal(t, tt.expectedIdentity.Groups, identity.Groups) + + assert.Equal(t, tt.expectedIdentity.ClientParams.SyncUser, identity.ClientParams.SyncUser) + assert.Equal(t, tt.expectedIdentity.ClientParams.AllowSignUp, identity.ClientParams.AllowSignUp) + assert.Equal(t, tt.expectedIdentity.ClientParams.SyncTeamMembers, identity.ClientParams.SyncTeamMembers) + assert.Equal(t, tt.expectedIdentity.ClientParams.EnableDisabledUsers, identity.ClientParams.EnableDisabledUsers) + + assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.Email, identity.ClientParams.LookUpParams.Email) + assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.Login, identity.ClientParams.LookUpParams.Login) + assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.UserID, identity.ClientParams.LookUpParams.UserID) + } else { + assert.Nil(t, tt.expectedIdentity) + } + }) + } +} + +func TestOAuth_RedirectURL(t *testing.T) { + type testCase struct { + desc string + oauthCfg *social.OAuthInfo + expectedErr error + + numCallOptions int + authCodeUrlCalled bool + } + + tests := []testCase{ + { + desc: "should generate redirect url and state", + oauthCfg: &social.OAuthInfo{}, + authCodeUrlCalled: true, + }, + { + desc: "should generate redirect url with hosted domain option if configured", + oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com"}, + numCallOptions: 1, + authCodeUrlCalled: true, + }, + { + desc: "should generate redirect url with pkce if configured", + oauthCfg: &social.OAuthInfo{UsePKCE: true}, + numCallOptions: 2, + authCodeUrlCalled: true, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + var ( + authCodeUrlCalled = false + ) + + c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), tt.oauthCfg, mockConnector{ + AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string { + authCodeUrlCalled = true + require.Len(t, opts, tt.numCallOptions) + return "" + }, + }, nil) + + redirect, err := c.RedirectURL(context.Background(), nil) + assert.ErrorIs(t, err, tt.expectedErr) + assert.Equal(t, tt.authCodeUrlCalled, authCodeUrlCalled) + + if tt.expectedErr != nil { + return + } + + assert.NotEmpty(t, redirect.Extra[authn.KeyOAuthState]) + if tt.oauthCfg.UsePKCE { + assert.NotEmpty(t, redirect.Extra[authn.KeyOAuthPKCE]) + } + }) + } +} + +type mockConnector struct { + AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string + social.SocialConnector +} + +func (m mockConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + if m.AuthCodeURLFunc != nil { + return m.AuthCodeURLFunc(state, opts...) + } + return "" +} + +var _ social.SocialConnector = new(fakeConnector) + +type fakeConnector struct { + ExpectedUserInfo *social.BasicUserInfo + ExpectedUserInfoErr error + ExpectedIsEmailAllowed bool + ExpectedIsSignupAllowed bool + ExpectedToken *oauth2.Token + ExpectedTokenErr error + social.SocialConnector +} + +func (f fakeConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { + return f.ExpectedUserInfo, f.ExpectedUserInfoErr +} + +func (f fakeConnector) IsEmailAllowed(email string) bool { + return f.ExpectedIsEmailAllowed +} + +func (f fakeConnector) IsSignupAllowed() bool { + return f.ExpectedIsSignupAllowed +} + +func (f fakeConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return f.ExpectedToken, f.ExpectedTokenErr +} + +func (f fakeConnector) Client(ctx context.Context, t *oauth2.Token) *http.Client { + return nil +} + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +}