diff --git a/pkg/login/social/generic_oauth.go b/pkg/login/social/generic_oauth.go index 4dc503d0dce..efcc7472f32 100644 --- a/pkg/login/social/generic_oauth.go +++ b/pkg/login/social/generic_oauth.go @@ -1,17 +1,19 @@ package social import ( + "bytes" + "compress/zlib" "encoding/base64" "encoding/json" "errors" "fmt" + "io/ioutil" "net/http" "net/mail" "regexp" - "github.com/grafana/grafana/pkg/util/errutil" - "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/util/errutil" "golang.org/x/oauth2" ) @@ -209,6 +211,37 @@ func (s *SocialGenericOAuth) extractFromToken(token *oauth2.Token) *UserInfoJson return nil } + headerBytes, err := base64.RawURLEncoding.DecodeString(matched[1]) + if err != nil { + s.log.Error("Error base64 decoding header", "header", matched[1], "error", err) + return nil + } + + var header map[string]string + if err := json.Unmarshal(headerBytes, &header); err != nil { + s.log.Error("Error deserializing header", "error", err) + return nil + } + + if compression, ok := header["zip"]; ok { + if compression != "DEF" { + s.log.Warn("Unknown compression algorithm", "algorithm", compression) + return nil + } + + fr, err := zlib.NewReader(bytes.NewReader(rawJSON)) + if err != nil { + s.log.Error("Error creating zlib reader", "error", err) + return nil + } + defer fr.Close() + rawJSON, err = ioutil.ReadAll(fr) + if err != nil { + s.log.Error("Error decompressing payload", "error", err) + return nil + } + } + var data UserInfoJson if err := json.Unmarshal(rawJSON, &data); err != nil { s.log.Error("Error decoding id_token JSON", "raw_json", string(data.rawJSON), "error", err) diff --git a/pkg/login/social/generic_oauth_test.go b/pkg/login/social/generic_oauth_test.go index c832f2c9bbe..d246188513b 100644 --- a/pkg/login/social/generic_oauth_test.go +++ b/pkg/login/social/generic_oauth_test.go @@ -427,3 +427,64 @@ func TestUserInfoSearchesForLogin(t *testing.T) { } }) } + +func TestPayloadCompression(t *testing.T) { + provider := SocialGenericOAuth{ + SocialBase: &SocialBase{ + log: log.NewWithLevel("generic_oauth_test", log15.LvlDebug), + }, + emailAttributePath: "email", + } + + tests := []struct { + Name string + OAuth2Extra interface{} + ExpectedEmail string + }{ + { + Name: "Given a valid DEFLATE compressed id_token, return userInfo", + OAuth2Extra: map[string]interface{}{ + // { "role": "Admin", "email": "john.doe@example.com" } + "id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsInppcCI6IkRFRiJ9.eJyrVkrNTczMUbJSysrPyNNLyU91SK1IzC3ISdVLzs9V0lEqys9JBco6puRm5inVAgCFRw_6.XrV4ZKhw19dTcnviXanBD8lwjeALCYtDiESMmGzC-ho", + }, + ExpectedEmail: "john.doe@example.com", + }, + { + Name: "Given an invalid DEFLATE compressed id_token, return nil", + OAuth2Extra: map[string]interface{}{ + // { "role": "Admin", "email": "john.doe@example.com" } + "id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsInppcCI6IkRFRiJ9.00eJyrVkrNTczMUbJSysrPyNNLyU91SK1IzC3ISdVLzs9V0lEqys9JBco6puRm5inVAgCFRw_6.XrV4ZKhw19dTcnviXanBD8lwjeALCYtDiESMmGzC-ho", + }, + ExpectedEmail: "", + }, + { + Name: "Given an unsupported GZIP compressed id_token, return nil", + OAuth2Extra: map[string]interface{}{ + // { "role": "Admin", "email": "john.doe@example.com" } + "id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsInppcCI6IkdaSVAifQ.H4sIAAAAAAAAAKtWSs1NzMxRslLKys_I00vJT3VIrUjMLchJ1UvOz1XSUSrKz0kFyjqm5GbmKdUCANotxTkvAAAA.85AXm3JOF5qflEA0goDFvlbZl2q3eFvqVcehz860W-o", + }, + ExpectedEmail: "", + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + staticToken := oauth2.Token{ + AccessToken: "", + TokenType: "", + RefreshToken: "", + Expiry: time.Now(), + } + + token := staticToken.WithExtra(test.OAuth2Extra) + userInfo := provider.extractFromToken(token) + + if test.ExpectedEmail == "" { + require.Nil(t, userInfo, "Testing case %q", test.Name) + } else { + require.NotNil(t, userInfo, "Testing case %q", test.Name) + require.Equal(t, test.ExpectedEmail, userInfo.Email) + } + }) + } +}