mirror of https://github.com/grafana/grafana.git
211 lines
6.1 KiB
Go
211 lines
6.1 KiB
Go
package api
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/grafana/grafana/pkg/infra/log"
|
|
"github.com/grafana/grafana/pkg/models/usertoken"
|
|
"github.com/grafana/grafana/pkg/services/authn"
|
|
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
|
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
|
"github.com/grafana/grafana/pkg/setting"
|
|
"github.com/grafana/grafana/pkg/web/webtest"
|
|
)
|
|
|
|
func setClientWithoutRedirectFollow(t *testing.T, s *webtest.Server) {
|
|
t.Helper()
|
|
s.HttpClient = &http.Client{
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestOAuthLogin_Redirect(t *testing.T) {
|
|
type testCase struct {
|
|
desc string
|
|
expectedErr error
|
|
expectedCode int
|
|
expectedRedirect *authn.Redirect
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
desc: "should be redirected to /login when passing un-configured provider",
|
|
expectedErr: authn.ErrClientNotConfigured,
|
|
expectedCode: http.StatusFound,
|
|
},
|
|
{
|
|
desc: "should be redirected to provider",
|
|
expectedCode: http.StatusFound,
|
|
expectedRedirect: &authn.Redirect{
|
|
URL: "https://some-provider.com",
|
|
Extra: map[string]string{
|
|
authn.KeyOAuthState: "some-state",
|
|
},
|
|
},
|
|
},
|
|
{
|
|
desc: "should set pkce cookie",
|
|
expectedCode: http.StatusFound,
|
|
expectedRedirect: &authn.Redirect{
|
|
URL: "https://some-provider.com",
|
|
Extra: map[string]string{
|
|
authn.KeyOAuthState: "some-state",
|
|
authn.KeyOAuthPKCE: "pkce-",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.desc, func(t *testing.T) {
|
|
server := SetupAPITestServer(t, func(hs *HTTPServer) {
|
|
hs.Cfg = setting.NewCfg()
|
|
hs.SecretsService = fakes.NewFakeSecretsService()
|
|
hs.authnService = &authntest.FakeService{
|
|
ExpectedErr: tt.expectedErr,
|
|
ExpectedRedirect: tt.expectedRedirect,
|
|
}
|
|
})
|
|
|
|
// we need to prevent the http.Client from following redirects
|
|
setClientWithoutRedirectFollow(t, server)
|
|
|
|
res, err := server.Send(server.NewGetRequest("/login/generic_oauth"))
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, http.StatusFound, res.StatusCode)
|
|
|
|
// on every error we should get redirected to /login
|
|
if tt.expectedErr != nil {
|
|
assert.Equal(t, "/login", res.Header.Get("Location"))
|
|
} else {
|
|
// check that we get correct redirect url
|
|
assert.Equal(t, tt.expectedRedirect.URL, res.Header.Get("Location"))
|
|
|
|
require.GreaterOrEqual(t, len(res.Cookies()), 1)
|
|
if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" {
|
|
require.Len(t, res.Cookies(), 2)
|
|
} else {
|
|
require.Len(t, res.Cookies(), 1)
|
|
}
|
|
|
|
require.GreaterOrEqual(t, len(res.Cookies()), 1)
|
|
stateCookie := res.Cookies()[0]
|
|
assert.Equal(t, OauthStateCookieName, stateCookie.Name)
|
|
assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthState], stateCookie.Value)
|
|
|
|
if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" {
|
|
require.Len(t, res.Cookies(), 2)
|
|
pkceCookie := res.Cookies()[1]
|
|
assert.Equal(t, OauthPKCECookieName, pkceCookie.Name)
|
|
assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthPKCE], pkceCookie.Value)
|
|
} else {
|
|
require.Len(t, res.Cookies(), 1)
|
|
}
|
|
|
|
require.NoError(t, res.Body.Close())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOAuthLogin_AuthorizationCode(t *testing.T) {
|
|
type testCase struct {
|
|
desc string
|
|
expectedErr error
|
|
expectedIdentity *authn.Identity
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
desc: "should redirect to /login on error",
|
|
expectedErr: errors.New("some error"),
|
|
},
|
|
{
|
|
desc: "should redirect to / and set session cookie on successful authentication",
|
|
expectedIdentity: &authn.Identity{
|
|
SessionToken: &usertoken.UserToken{UnhashedToken: "some-token"},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.desc, func(t *testing.T) {
|
|
var cfg *setting.Cfg
|
|
server := SetupAPITestServer(t, func(hs *HTTPServer) {
|
|
cfg = setting.NewCfg()
|
|
hs.Cfg = cfg
|
|
hs.Cfg.LoginCookieName = "some_name"
|
|
hs.SecretsService = fakes.NewFakeSecretsService()
|
|
hs.authnService = &authntest.FakeService{
|
|
ExpectedErr: tt.expectedErr,
|
|
ExpectedIdentity: tt.expectedIdentity,
|
|
}
|
|
})
|
|
|
|
// we need to prevent the http.Client from following redirects
|
|
setClientWithoutRedirectFollow(t, server)
|
|
|
|
res, err := server.Send(server.NewGetRequest("/login/generic_oauth?code=code"))
|
|
require.NoError(t, err)
|
|
|
|
require.GreaterOrEqual(t, len(res.Cookies()), 3)
|
|
|
|
// make sure oauth state cookie is deleted
|
|
assert.Equal(t, OauthStateCookieName, res.Cookies()[0].Name)
|
|
assert.Equal(t, "", res.Cookies()[0].Value)
|
|
assert.Equal(t, -1, res.Cookies()[0].MaxAge)
|
|
|
|
// make sure oauth pkce cookie is deleted
|
|
assert.Equal(t, OauthPKCECookieName, res.Cookies()[1].Name)
|
|
assert.Equal(t, "", res.Cookies()[1].Value)
|
|
assert.Equal(t, -1, res.Cookies()[1].MaxAge)
|
|
|
|
if tt.expectedErr != nil {
|
|
require.Len(t, res.Cookies(), 3)
|
|
assert.Equal(t, http.StatusFound, res.StatusCode)
|
|
assert.Equal(t, "/login", res.Header.Get("Location"))
|
|
assert.Equal(t, loginErrorCookieName, res.Cookies()[2].Name)
|
|
} else {
|
|
require.Len(t, res.Cookies(), 4)
|
|
assert.Equal(t, http.StatusFound, res.StatusCode)
|
|
assert.Equal(t, "/", res.Header.Get("Location"))
|
|
|
|
// verify session expiry cookie is set
|
|
assert.Equal(t, cfg.LoginCookieName, res.Cookies()[2].Name)
|
|
assert.Equal(t, "grafana_session_expiry", res.Cookies()[3].Name)
|
|
}
|
|
|
|
require.NoError(t, res.Body.Close())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOAuthLogin_Error(t *testing.T) {
|
|
server := SetupAPITestServer(t, func(hs *HTTPServer) {
|
|
hs.Cfg = setting.NewCfg()
|
|
hs.log = log.NewNopLogger()
|
|
hs.SecretsService = fakes.NewFakeSecretsService()
|
|
})
|
|
|
|
setClientWithoutRedirectFollow(t, server)
|
|
|
|
res, err := server.Send(server.NewGetRequest("/login/azuread?error=someerror"))
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, http.StatusFound, res.StatusCode)
|
|
assert.Equal(t, "/login", res.Header.Get("Location"))
|
|
|
|
require.Len(t, res.Cookies(), 1)
|
|
errCookie := res.Cookies()[0]
|
|
assert.Equal(t, loginErrorCookieName, errCookie.Name)
|
|
require.NoError(t, res.Body.Close())
|
|
}
|