| 
									
										
										
										
											2021-10-13 22:45:15 +08:00
										 |  |  | package api | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2021-10-13 22:45:15 +08:00
										 |  |  | 	"net/http" | 
					
						
							|  |  |  | 	"testing" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/stretchr/testify/assert" | 
					
						
							|  |  |  | 	"github.com/stretchr/testify/require" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-10 15:56:04 +08:00
										 |  |  | 	"github.com/grafana/grafana/pkg/infra/log" | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 	"github.com/grafana/grafana/pkg/models/usertoken" | 
					
						
							|  |  |  | 	"github.com/grafana/grafana/pkg/services/authn" | 
					
						
							|  |  |  | 	"github.com/grafana/grafana/pkg/services/authn/authntest" | 
					
						
							| 
									
										
										
										
											2022-11-18 17:12:17 +08:00
										 |  |  | 	"github.com/grafana/grafana/pkg/services/secrets/fakes" | 
					
						
							| 
									
										
										
										
											2021-10-13 22:45:15 +08:00
										 |  |  | 	"github.com/grafana/grafana/pkg/setting" | 
					
						
							| 
									
										
										
										
											2025-07-02 22:45:07 +08:00
										 |  |  | 	"github.com/grafana/grafana/pkg/web/webtest" | 
					
						
							| 
									
										
										
										
											2021-10-13 22:45:15 +08:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-02 22:45:07 +08:00
										 |  |  | func setClientWithoutRedirectFollow(t *testing.T, s *webtest.Server) { | 
					
						
							| 
									
										
										
										
											2023-01-12 23:44:08 +08:00
										 |  |  | 	t.Helper() | 
					
						
							| 
									
										
										
										
											2025-07-02 22:45:07 +08:00
										 |  |  | 	s.HttpClient = &http.Client{ | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 		CheckRedirect: func(req *http.Request, via []*http.Request) error { | 
					
						
							|  |  |  | 			return http.ErrUseLastResponse | 
					
						
							|  |  |  | 		}, | 
					
						
							| 
									
										
										
										
											2023-01-12 23:44:08 +08:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2021-10-13 22:45:15 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | func TestOAuthLogin_Redirect(t *testing.T) { | 
					
						
							|  |  |  | 	type testCase struct { | 
					
						
							|  |  |  | 		desc             string | 
					
						
							|  |  |  | 		expectedErr      error | 
					
						
							|  |  |  | 		expectedCode     int | 
					
						
							|  |  |  | 		expectedRedirect *authn.Redirect | 
					
						
							| 
									
										
										
										
											2021-10-13 22:45:15 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 	tests := []testCase{ | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			desc:         "should be redirected to /login when passing un-configured provider", | 
					
						
							|  |  |  | 			expectedErr:  authn.ErrClientNotConfigured, | 
					
						
							|  |  |  | 			expectedCode: http.StatusFound, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			desc:         "should be redirected to provider", | 
					
						
							|  |  |  | 			expectedCode: http.StatusFound, | 
					
						
							|  |  |  | 			expectedRedirect: &authn.Redirect{ | 
					
						
							|  |  |  | 				URL: "https://some-provider.com", | 
					
						
							|  |  |  | 				Extra: map[string]string{ | 
					
						
							|  |  |  | 					authn.KeyOAuthState: "some-state", | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			desc:         "should set pkce cookie", | 
					
						
							|  |  |  | 			expectedCode: http.StatusFound, | 
					
						
							|  |  |  | 			expectedRedirect: &authn.Redirect{ | 
					
						
							|  |  |  | 				URL: "https://some-provider.com", | 
					
						
							|  |  |  | 				Extra: map[string]string{ | 
					
						
							|  |  |  | 					authn.KeyOAuthState: "some-state", | 
					
						
							|  |  |  | 					authn.KeyOAuthPKCE:  "pkce-", | 
					
						
							|  |  |  | 				}, | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2021-10-13 22:45:15 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 	for _, tt := range tests { | 
					
						
							|  |  |  | 		t.Run(tt.desc, func(t *testing.T) { | 
					
						
							|  |  |  | 			server := SetupAPITestServer(t, func(hs *HTTPServer) { | 
					
						
							|  |  |  | 				hs.Cfg = setting.NewCfg() | 
					
						
							|  |  |  | 				hs.SecretsService = fakes.NewFakeSecretsService() | 
					
						
							|  |  |  | 				hs.authnService = &authntest.FakeService{ | 
					
						
							|  |  |  | 					ExpectedErr:      tt.expectedErr, | 
					
						
							|  |  |  | 					ExpectedRedirect: tt.expectedRedirect, | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// we need to prevent the http.Client from following redirects
 | 
					
						
							| 
									
										
										
										
											2025-07-02 22:45:07 +08:00
										 |  |  | 			setClientWithoutRedirectFollow(t, server) | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 			res, err := server.Send(server.NewGetRequest("/login/generic_oauth")) | 
					
						
							|  |  |  | 			require.NoError(t, err) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			assert.Equal(t, http.StatusFound, res.StatusCode) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// on every error we should get redirected to /login
 | 
					
						
							|  |  |  | 			if tt.expectedErr != nil { | 
					
						
							|  |  |  | 				assert.Equal(t, "/login", res.Header.Get("Location")) | 
					
						
							|  |  |  | 			} else { | 
					
						
							|  |  |  | 				// check that we get correct redirect url
 | 
					
						
							|  |  |  | 				assert.Equal(t, tt.expectedRedirect.URL, res.Header.Get("Location")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				require.GreaterOrEqual(t, len(res.Cookies()), 1) | 
					
						
							|  |  |  | 				if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" { | 
					
						
							|  |  |  | 					require.Len(t, res.Cookies(), 2) | 
					
						
							|  |  |  | 				} else { | 
					
						
							|  |  |  | 					require.Len(t, res.Cookies(), 1) | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				require.GreaterOrEqual(t, len(res.Cookies()), 1) | 
					
						
							|  |  |  | 				stateCookie := res.Cookies()[0] | 
					
						
							|  |  |  | 				assert.Equal(t, OauthStateCookieName, stateCookie.Name) | 
					
						
							|  |  |  | 				assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthState], stateCookie.Value) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" { | 
					
						
							|  |  |  | 					require.Len(t, res.Cookies(), 2) | 
					
						
							|  |  |  | 					pkceCookie := res.Cookies()[1] | 
					
						
							|  |  |  | 					assert.Equal(t, OauthPKCECookieName, pkceCookie.Name) | 
					
						
							|  |  |  | 					assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthPKCE], pkceCookie.Value) | 
					
						
							|  |  |  | 				} else { | 
					
						
							|  |  |  | 					require.Len(t, res.Cookies(), 1) | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				require.NoError(t, res.Body.Close()) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 		}) | 
					
						
							| 
									
										
										
										
											2021-10-13 22:45:15 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-01-12 23:44:08 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | func TestOAuthLogin_AuthorizationCode(t *testing.T) { | 
					
						
							|  |  |  | 	type testCase struct { | 
					
						
							|  |  |  | 		desc             string | 
					
						
							|  |  |  | 		expectedErr      error | 
					
						
							|  |  |  | 		expectedIdentity *authn.Identity | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-01-12 23:44:08 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 	tests := []testCase{ | 
					
						
							| 
									
										
										
										
											2023-01-12 23:44:08 +08:00
										 |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 			desc:        "should redirect to /login on error", | 
					
						
							|  |  |  | 			expectedErr: errors.New("some error"), | 
					
						
							| 
									
										
										
										
											2023-01-12 23:44:08 +08:00
										 |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 			desc: "should redirect to / and set session cookie on successful authentication", | 
					
						
							|  |  |  | 			expectedIdentity: &authn.Identity{ | 
					
						
							|  |  |  | 				SessionToken: &usertoken.UserToken{UnhashedToken: "some-token"}, | 
					
						
							| 
									
										
										
										
											2023-01-12 23:44:08 +08:00
										 |  |  | 			}, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	for _, tt := range tests { | 
					
						
							|  |  |  | 		t.Run(tt.desc, func(t *testing.T) { | 
					
						
							|  |  |  | 			var cfg *setting.Cfg | 
					
						
							|  |  |  | 			server := SetupAPITestServer(t, func(hs *HTTPServer) { | 
					
						
							|  |  |  | 				cfg = setting.NewCfg() | 
					
						
							|  |  |  | 				hs.Cfg = cfg | 
					
						
							|  |  |  | 				hs.Cfg.LoginCookieName = "some_name" | 
					
						
							|  |  |  | 				hs.SecretsService = fakes.NewFakeSecretsService() | 
					
						
							|  |  |  | 				hs.authnService = &authntest.FakeService{ | 
					
						
							|  |  |  | 					ExpectedErr:      tt.expectedErr, | 
					
						
							|  |  |  | 					ExpectedIdentity: tt.expectedIdentity, | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// we need to prevent the http.Client from following redirects
 | 
					
						
							| 
									
										
										
										
											2025-07-02 22:45:07 +08:00
										 |  |  | 			setClientWithoutRedirectFollow(t, server) | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 			res, err := server.Send(server.NewGetRequest("/login/generic_oauth?code=code")) | 
					
						
							|  |  |  | 			require.NoError(t, err) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			require.GreaterOrEqual(t, len(res.Cookies()), 3) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// make sure oauth state cookie is deleted
 | 
					
						
							|  |  |  | 			assert.Equal(t, OauthStateCookieName, res.Cookies()[0].Name) | 
					
						
							|  |  |  | 			assert.Equal(t, "", res.Cookies()[0].Value) | 
					
						
							|  |  |  | 			assert.Equal(t, -1, res.Cookies()[0].MaxAge) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// make sure oauth pkce cookie is deleted
 | 
					
						
							|  |  |  | 			assert.Equal(t, OauthPKCECookieName, res.Cookies()[1].Name) | 
					
						
							|  |  |  | 			assert.Equal(t, "", res.Cookies()[1].Value) | 
					
						
							|  |  |  | 			assert.Equal(t, -1, res.Cookies()[1].MaxAge) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			if tt.expectedErr != nil { | 
					
						
							|  |  |  | 				require.Len(t, res.Cookies(), 3) | 
					
						
							|  |  |  | 				assert.Equal(t, http.StatusFound, res.StatusCode) | 
					
						
							|  |  |  | 				assert.Equal(t, "/login", res.Header.Get("Location")) | 
					
						
							|  |  |  | 				assert.Equal(t, loginErrorCookieName, res.Cookies()[2].Name) | 
					
						
							|  |  |  | 			} else { | 
					
						
							|  |  |  | 				require.Len(t, res.Cookies(), 4) | 
					
						
							|  |  |  | 				assert.Equal(t, http.StatusFound, res.StatusCode) | 
					
						
							|  |  |  | 				assert.Equal(t, "/", res.Header.Get("Location")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				// verify session expiry cookie is set
 | 
					
						
							|  |  |  | 				assert.Equal(t, cfg.LoginCookieName, res.Cookies()[2].Name) | 
					
						
							|  |  |  | 				assert.Equal(t, "grafana_session_expiry", res.Cookies()[3].Name) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			require.NoError(t, res.Body.Close()) | 
					
						
							|  |  |  | 		}) | 
					
						
							| 
									
										
										
										
											2023-01-12 23:44:08 +08:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | func TestOAuthLogin_Error(t *testing.T) { | 
					
						
							|  |  |  | 	server := SetupAPITestServer(t, func(hs *HTTPServer) { | 
					
						
							|  |  |  | 		hs.Cfg = setting.NewCfg() | 
					
						
							| 
									
										
										
										
											2023-08-10 15:56:04 +08:00
										 |  |  | 		hs.log = log.NewNopLogger() | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 		hs.SecretsService = fakes.NewFakeSecretsService() | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-02 22:45:07 +08:00
										 |  |  | 	setClientWithoutRedirectFollow(t, server) | 
					
						
							| 
									
										
										
										
											2023-08-09 14:54:52 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	res, err := server.Send(server.NewGetRequest("/login/azuread?error=someerror")) | 
					
						
							|  |  |  | 	require.NoError(t, err) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	assert.Equal(t, http.StatusFound, res.StatusCode) | 
					
						
							|  |  |  | 	assert.Equal(t, "/login", res.Header.Get("Location")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	require.Len(t, res.Cookies(), 1) | 
					
						
							|  |  |  | 	errCookie := res.Cookies()[0] | 
					
						
							|  |  |  | 	assert.Equal(t, loginErrorCookieName, errCookie.Name) | 
					
						
							|  |  |  | 	require.NoError(t, res.Body.Close()) | 
					
						
							|  |  |  | } |