mirror of https://github.com/grafana/grafana.git
				
				
				
			
		
			
				
	
	
		
			211 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			211 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
| package api
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"net/http"
 | |
| 	"testing"
 | |
| 
 | |
| 	"github.com/stretchr/testify/assert"
 | |
| 	"github.com/stretchr/testify/require"
 | |
| 
 | |
| 	"github.com/grafana/grafana/pkg/infra/log"
 | |
| 	"github.com/grafana/grafana/pkg/models/usertoken"
 | |
| 	"github.com/grafana/grafana/pkg/services/authn"
 | |
| 	"github.com/grafana/grafana/pkg/services/authn/authntest"
 | |
| 	"github.com/grafana/grafana/pkg/services/secrets/fakes"
 | |
| 	"github.com/grafana/grafana/pkg/setting"
 | |
| 	"github.com/grafana/grafana/pkg/web/webtest"
 | |
| )
 | |
| 
 | |
| func setClientWithoutRedirectFollow(t *testing.T, s *webtest.Server) {
 | |
| 	t.Helper()
 | |
| 	s.HttpClient = &http.Client{
 | |
| 		CheckRedirect: func(req *http.Request, via []*http.Request) error {
 | |
| 			return http.ErrUseLastResponse
 | |
| 		},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestOAuthLogin_Redirect(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		desc             string
 | |
| 		expectedErr      error
 | |
| 		expectedCode     int
 | |
| 		expectedRedirect *authn.Redirect
 | |
| 	}
 | |
| 
 | |
| 	tests := []testCase{
 | |
| 		{
 | |
| 			desc:         "should be redirected to /login when passing un-configured provider",
 | |
| 			expectedErr:  authn.ErrClientNotConfigured,
 | |
| 			expectedCode: http.StatusFound,
 | |
| 		},
 | |
| 		{
 | |
| 			desc:         "should be redirected to provider",
 | |
| 			expectedCode: http.StatusFound,
 | |
| 			expectedRedirect: &authn.Redirect{
 | |
| 				URL: "https://some-provider.com",
 | |
| 				Extra: map[string]string{
 | |
| 					authn.KeyOAuthState: "some-state",
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			desc:         "should set pkce cookie",
 | |
| 			expectedCode: http.StatusFound,
 | |
| 			expectedRedirect: &authn.Redirect{
 | |
| 				URL: "https://some-provider.com",
 | |
| 				Extra: map[string]string{
 | |
| 					authn.KeyOAuthState: "some-state",
 | |
| 					authn.KeyOAuthPKCE:  "pkce-",
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tt := range tests {
 | |
| 		t.Run(tt.desc, func(t *testing.T) {
 | |
| 			server := SetupAPITestServer(t, func(hs *HTTPServer) {
 | |
| 				hs.Cfg = setting.NewCfg()
 | |
| 				hs.SecretsService = fakes.NewFakeSecretsService()
 | |
| 				hs.authnService = &authntest.FakeService{
 | |
| 					ExpectedErr:      tt.expectedErr,
 | |
| 					ExpectedRedirect: tt.expectedRedirect,
 | |
| 				}
 | |
| 			})
 | |
| 
 | |
| 			// we need to prevent the http.Client from following redirects
 | |
| 			setClientWithoutRedirectFollow(t, server)
 | |
| 
 | |
| 			res, err := server.Send(server.NewGetRequest("/login/generic_oauth"))
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			assert.Equal(t, http.StatusFound, res.StatusCode)
 | |
| 
 | |
| 			// on every error we should get redirected to /login
 | |
| 			if tt.expectedErr != nil {
 | |
| 				assert.Equal(t, "/login", res.Header.Get("Location"))
 | |
| 			} else {
 | |
| 				// check that we get correct redirect url
 | |
| 				assert.Equal(t, tt.expectedRedirect.URL, res.Header.Get("Location"))
 | |
| 
 | |
| 				require.GreaterOrEqual(t, len(res.Cookies()), 1)
 | |
| 				if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" {
 | |
| 					require.Len(t, res.Cookies(), 2)
 | |
| 				} else {
 | |
| 					require.Len(t, res.Cookies(), 1)
 | |
| 				}
 | |
| 
 | |
| 				require.GreaterOrEqual(t, len(res.Cookies()), 1)
 | |
| 				stateCookie := res.Cookies()[0]
 | |
| 				assert.Equal(t, OauthStateCookieName, stateCookie.Name)
 | |
| 				assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthState], stateCookie.Value)
 | |
| 
 | |
| 				if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" {
 | |
| 					require.Len(t, res.Cookies(), 2)
 | |
| 					pkceCookie := res.Cookies()[1]
 | |
| 					assert.Equal(t, OauthPKCECookieName, pkceCookie.Name)
 | |
| 					assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthPKCE], pkceCookie.Value)
 | |
| 				} else {
 | |
| 					require.Len(t, res.Cookies(), 1)
 | |
| 				}
 | |
| 
 | |
| 				require.NoError(t, res.Body.Close())
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestOAuthLogin_AuthorizationCode(t *testing.T) {
 | |
| 	type testCase struct {
 | |
| 		desc             string
 | |
| 		expectedErr      error
 | |
| 		expectedIdentity *authn.Identity
 | |
| 	}
 | |
| 
 | |
| 	tests := []testCase{
 | |
| 		{
 | |
| 			desc:        "should redirect to /login on error",
 | |
| 			expectedErr: errors.New("some error"),
 | |
| 		},
 | |
| 		{
 | |
| 			desc: "should redirect to / and set session cookie on successful authentication",
 | |
| 			expectedIdentity: &authn.Identity{
 | |
| 				SessionToken: &usertoken.UserToken{UnhashedToken: "some-token"},
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tt := range tests {
 | |
| 		t.Run(tt.desc, func(t *testing.T) {
 | |
| 			var cfg *setting.Cfg
 | |
| 			server := SetupAPITestServer(t, func(hs *HTTPServer) {
 | |
| 				cfg = setting.NewCfg()
 | |
| 				hs.Cfg = cfg
 | |
| 				hs.Cfg.LoginCookieName = "some_name"
 | |
| 				hs.SecretsService = fakes.NewFakeSecretsService()
 | |
| 				hs.authnService = &authntest.FakeService{
 | |
| 					ExpectedErr:      tt.expectedErr,
 | |
| 					ExpectedIdentity: tt.expectedIdentity,
 | |
| 				}
 | |
| 			})
 | |
| 
 | |
| 			// we need to prevent the http.Client from following redirects
 | |
| 			setClientWithoutRedirectFollow(t, server)
 | |
| 
 | |
| 			res, err := server.Send(server.NewGetRequest("/login/generic_oauth?code=code"))
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			require.GreaterOrEqual(t, len(res.Cookies()), 3)
 | |
| 
 | |
| 			// make sure oauth state cookie is deleted
 | |
| 			assert.Equal(t, OauthStateCookieName, res.Cookies()[0].Name)
 | |
| 			assert.Equal(t, "", res.Cookies()[0].Value)
 | |
| 			assert.Equal(t, -1, res.Cookies()[0].MaxAge)
 | |
| 
 | |
| 			// make sure oauth pkce cookie is deleted
 | |
| 			assert.Equal(t, OauthPKCECookieName, res.Cookies()[1].Name)
 | |
| 			assert.Equal(t, "", res.Cookies()[1].Value)
 | |
| 			assert.Equal(t, -1, res.Cookies()[1].MaxAge)
 | |
| 
 | |
| 			if tt.expectedErr != nil {
 | |
| 				require.Len(t, res.Cookies(), 3)
 | |
| 				assert.Equal(t, http.StatusFound, res.StatusCode)
 | |
| 				assert.Equal(t, "/login", res.Header.Get("Location"))
 | |
| 				assert.Equal(t, loginErrorCookieName, res.Cookies()[2].Name)
 | |
| 			} else {
 | |
| 				require.Len(t, res.Cookies(), 4)
 | |
| 				assert.Equal(t, http.StatusFound, res.StatusCode)
 | |
| 				assert.Equal(t, "/", res.Header.Get("Location"))
 | |
| 
 | |
| 				// verify session expiry cookie is set
 | |
| 				assert.Equal(t, cfg.LoginCookieName, res.Cookies()[2].Name)
 | |
| 				assert.Equal(t, "grafana_session_expiry", res.Cookies()[3].Name)
 | |
| 			}
 | |
| 
 | |
| 			require.NoError(t, res.Body.Close())
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestOAuthLogin_Error(t *testing.T) {
 | |
| 	server := SetupAPITestServer(t, func(hs *HTTPServer) {
 | |
| 		hs.Cfg = setting.NewCfg()
 | |
| 		hs.log = log.NewNopLogger()
 | |
| 		hs.SecretsService = fakes.NewFakeSecretsService()
 | |
| 	})
 | |
| 
 | |
| 	setClientWithoutRedirectFollow(t, server)
 | |
| 
 | |
| 	res, err := server.Send(server.NewGetRequest("/login/azuread?error=someerror"))
 | |
| 	require.NoError(t, err)
 | |
| 
 | |
| 	assert.Equal(t, http.StatusFound, res.StatusCode)
 | |
| 	assert.Equal(t, "/login", res.Header.Get("Location"))
 | |
| 
 | |
| 	require.Len(t, res.Cookies(), 1)
 | |
| 	errCookie := res.Cookies()[0]
 | |
| 	assert.Equal(t, loginErrorCookieName, errCookie.Name)
 | |
| 	require.NoError(t, res.Body.Close())
 | |
| }
 |