2022-06-02 21:52:30 +08:00
package csrf
2022-03-01 02:58:56 +08:00
import (
2022-07-14 02:28:59 +08:00
"errors"
2022-03-01 02:58:56 +08:00
"net/http"
"net/http/httptest"
2022-07-14 02:28:59 +08:00
"strings"
2022-03-01 02:58:56 +08:00
"testing"
2022-07-14 02:28:59 +08:00
"github.com/stretchr/testify/assert"
2022-03-01 02:58:56 +08:00
"github.com/stretchr/testify/require"
2022-06-02 21:52:30 +08:00
"github.com/grafana/grafana/pkg/setting"
2022-03-01 02:58:56 +08:00
)
func TestMiddlewareCSRF ( t * testing . T ) {
tests := [ ] struct {
name string
cookieName string
method string
origin string
host string
code int
} {
{
name : "mismatched origin and host is forbidden" ,
cookieName : "foo" ,
method : "GET" ,
origin : "http://notLocalhost" ,
host : "localhost" ,
code : http . StatusForbidden ,
} ,
{
name : "mismatched origin and host is NOT forbidden with a 'Safe Method'" ,
cookieName : "foo" ,
method : "TRACE" ,
origin : "http://notLocalhost" ,
host : "localhost" ,
code : http . StatusOK ,
} ,
{
name : "mismatched origin and host is NOT forbidden without a cookie" ,
cookieName : "" ,
method : "GET" ,
origin : "http://notLocalhost" ,
host : "localhost" ,
code : http . StatusOK ,
} ,
{
name : "malformed host is a bad request" ,
cookieName : "foo" ,
method : "GET" ,
host : "localhost:80:80" ,
code : http . StatusBadRequest ,
} ,
{
name : "host works without port" ,
cookieName : "foo" ,
method : "GET" ,
host : "localhost" ,
origin : "http://localhost" ,
code : http . StatusOK ,
} ,
{
name : "port does not have to match" ,
cookieName : "foo" ,
method : "GET" ,
host : "localhost:80" ,
origin : "http://localhost:3000" ,
code : http . StatusOK ,
} ,
{
name : "IPv6 host works with port" ,
cookieName : "foo" ,
method : "GET" ,
host : "[::1]:3000" ,
origin : "http://[::1]:3000" ,
code : http . StatusOK ,
} ,
{
name : "IPv6 host (with longer address) works with port" ,
cookieName : "foo" ,
method : "GET" ,
host : "[2001:db8::1]:3000" ,
origin : "http://[2001:db8::1]:3000" ,
code : http . StatusOK ,
} ,
{
name : "IPv6 host (with longer address) works without port" ,
cookieName : "foo" ,
method : "GET" ,
host : "[2001:db8::1]" ,
origin : "http://[2001:db8::1]" ,
code : http . StatusOK ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
rr := csrfScenario ( t , tt . cookieName , tt . method , tt . origin , tt . host )
require . Equal ( t , tt . code , rr . Code )
} )
}
}
2022-07-14 02:28:59 +08:00
func TestCSRF_Check ( t * testing . T ) {
tests := [ ] struct {
name string
request * http . Request
2023-04-24 21:11:08 +08:00
getCfg func ( ) * setting . Cfg
2022-07-14 02:28:59 +08:00
addtHeader map [ string ] struct { }
trustedOrigins map [ string ] struct { }
safeEndpoints map [ string ] struct { }
expectedOK bool
expectedStatus int
} {
{
2023-04-24 21:11:08 +08:00
name : "base case" ,
getCfg : func ( ) * setting . Cfg {
return setting . NewCfg ( )
} ,
request : postRequest ( t , "" , nil , true ) ,
2022-07-14 02:28:59 +08:00
expectedOK : true ,
} ,
{
2023-04-24 21:11:08 +08:00
name : "base with null origin header" ,
getCfg : func ( ) * setting . Cfg {
return setting . NewCfg ( )
} ,
request : postRequest ( t , "" , map [ string ] string { "Origin" : "null" } , true ) ,
2022-07-14 02:28:59 +08:00
expectedStatus : http . StatusForbidden ,
} ,
{
2023-04-24 21:11:08 +08:00
name : "grafana.org" ,
getCfg : func ( ) * setting . Cfg {
return setting . NewCfg ( )
} ,
request : postRequest ( t , "grafana.org" , map [ string ] string { "Origin" : "https://grafana.org" } , true ) ,
2022-07-14 02:28:59 +08:00
expectedOK : true ,
} ,
{
2023-04-24 21:11:08 +08:00
name : "grafana.org with X-Forwarded-Host" ,
getCfg : func ( ) * setting . Cfg {
return setting . NewCfg ( )
} ,
request : postRequest ( t , "grafana.localhost" , map [ string ] string { "X-Forwarded-Host" : "grafana.org" , "Origin" : "https://grafana.org" } , true ) ,
2022-07-14 02:28:59 +08:00
expectedStatus : http . StatusForbidden ,
} ,
{
2023-04-24 21:11:08 +08:00
name : "grafana.org with X-Forwarded-Host and header trusted" ,
getCfg : func ( ) * setting . Cfg {
return setting . NewCfg ( )
} ,
request : postRequest ( t , "grafana.localhost" , map [ string ] string { "X-Forwarded-Host" : "grafana.org" , "Origin" : "https://grafana.org" } , true ) ,
2022-07-14 02:28:59 +08:00
addtHeader : map [ string ] struct { } { "X-Forwarded-Host" : { } } ,
expectedOK : true ,
} ,
{
2023-04-24 21:11:08 +08:00
name : "grafana.org from grafana.com" ,
getCfg : func ( ) * setting . Cfg {
return setting . NewCfg ( )
} ,
request : postRequest ( t , "grafana.org" , map [ string ] string { "Origin" : "https://grafana.com" } , true ) ,
2022-07-14 02:28:59 +08:00
expectedStatus : http . StatusForbidden ,
} ,
{
2023-04-24 21:11:08 +08:00
name : "grafana.org from grafana.com explicit trust for grafana.com" ,
getCfg : func ( ) * setting . Cfg {
return setting . NewCfg ( )
} ,
request : postRequest ( t , "grafana.org" , map [ string ] string { "Origin" : "https://grafana.com" } , true ) ,
2022-07-14 02:28:59 +08:00
trustedOrigins : map [ string ] struct { } { "grafana.com" : { } } ,
expectedOK : true ,
} ,
{
2023-04-24 21:11:08 +08:00
name : "grafana.org from grafana.com with X-Forwarded-Host and header trusted" ,
getCfg : func ( ) * setting . Cfg {
return setting . NewCfg ( )
} ,
request : postRequest ( t , "grafana.localhost" , map [ string ] string { "X-Forwarded-Host" : "grafana.org" , "Origin" : "https://grafana.com" } , true ) ,
2022-07-14 02:28:59 +08:00
addtHeader : map [ string ] struct { } { "X-Forwarded-Host" : { } } ,
trustedOrigins : map [ string ] struct { } { "grafana.com" : { } } ,
expectedOK : true ,
} ,
{
2023-04-24 21:11:08 +08:00
name : "safe endpoint" ,
getCfg : func ( ) * setting . Cfg {
return setting . NewCfg ( )
} ,
request : postRequest ( t , "example.org/foo/bar" , map [ string ] string { "Origin" : "null" } , true ) ,
2022-07-14 02:28:59 +08:00
safeEndpoints : map [ string ] struct { } { "foo/bar" : { } } ,
expectedOK : true ,
} ,
2023-04-24 21:11:08 +08:00
{
name : "grafana.org with X-Forwarded-Host; will skip csrf check if login cookie is not present; without login cookie, should return nil because login cookie is not present" ,
getCfg : func ( ) * setting . Cfg {
cfg := setting . NewCfg ( )
cfg . SectionWithEnvOverrides ( "security" ) . Key ( "csrf_always_check" ) . SetValue ( "false" )
return cfg
} ,
request : postRequest ( t , "grafana.localhost" , map [ string ] string { "X-Forwarded-Host" : "grafana.org" , "Origin" : "https://grafana.org" } , false ) ,
expectedOK : true ,
} ,
{
name : "grafana.org with X-Forwarded-Host; will perform csrf check even if login cookie is not present, should return error because host name does not match origin" ,
getCfg : func ( ) * setting . Cfg {
cfg := setting . NewCfg ( )
cfg . SectionWithEnvOverrides ( "security" ) . Key ( "csrf_always_check" ) . SetValue ( "true" )
return cfg
} ,
request : postRequest ( t , "grafana.localhost" , map [ string ] string { "X-Forwarded-Host" : "grafana.org" , "Origin" : "https://grafana.org" } , false ) ,
expectedStatus : http . StatusForbidden ,
} ,
2022-07-14 02:28:59 +08:00
}
for _ , tc := range tests {
tc := tc
t . Run ( tc . name , func ( t * testing . T ) {
2023-04-24 21:11:08 +08:00
csrf := ProvideCSRFFilter ( tc . getCfg ( ) )
csrf . trustedOrigins = tc . trustedOrigins
csrf . headers = tc . addtHeader
csrf . safeEndpoints = tc . safeEndpoints
csrf . cfg . LoginCookieName = "LoginCookie"
err := csrf . check ( tc . request )
2022-07-14 02:28:59 +08:00
if tc . expectedOK {
require . NoError ( t , err )
} else {
require . Error ( t , err )
var actual * errorWithStatus
require . True ( t , errors . As ( err , & actual ) )
assert . EqualValues ( t , tc . expectedStatus , actual . HTTPStatus )
}
} )
}
}
2023-04-24 21:11:08 +08:00
func postRequest ( t testing . TB , hostname string , headers map [ string ] string , withLoginCookie bool ) * http . Request {
2022-07-14 02:28:59 +08:00
t . Helper ( )
urlParts := strings . SplitN ( hostname , "/" , 2 )
path := "/"
if len ( urlParts ) == 2 {
path = urlParts [ 1 ]
}
r , err := http . NewRequest ( http . MethodPost , path , nil )
require . NoError ( t , err )
r . Host = urlParts [ 0 ]
2023-04-24 21:11:08 +08:00
if withLoginCookie {
r . AddCookie ( & http . Cookie {
Name : "LoginCookie" ,
Value : "this should not be important" ,
} )
}
2022-07-14 02:28:59 +08:00
for k , v := range headers {
r . Header . Set ( k , v )
}
return r
}
2022-03-01 02:58:56 +08:00
func csrfScenario ( t * testing . T , cookieName , method , origin , host string ) * httptest . ResponseRecorder {
req , err := http . NewRequest ( method , "/" , nil )
if err != nil {
t . Fatal ( err )
}
req . AddCookie ( & http . Cookie {
Name : cookieName ,
} )
// Note: Not sure where host header populates req.Host, or how that works.
req . Host = host
req . Header . Set ( "HOST" , host )
req . Header . Set ( "ORIGIN" , origin )
testHandler := http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
} )
rr := httptest . NewRecorder ( )
2022-06-02 21:52:30 +08:00
cfg := setting . NewCfg ( )
cfg . LoginCookieName = cookieName
service := ProvideCSRFFilter ( cfg )
2022-07-14 02:28:59 +08:00
handler := service . Middleware ( ) ( testHandler )
2022-03-01 02:58:56 +08:00
handler . ServeHTTP ( rr , req )
return rr
}
2023-04-24 21:11:08 +08:00
func TestProvideCSRFFilter ( t * testing . T ) {
t . Parallel ( )
tests := [ ] struct {
getInput func ( ) * setting . Cfg
expectedAlwaysCheck bool
} {
{
getInput : func ( ) * setting . Cfg {
return setting . NewCfg ( )
} ,
// Should default to false when config value is not set.
expectedAlwaysCheck : false ,
} ,
{
getInput : func ( ) * setting . Cfg {
cfg := setting . NewCfg ( )
cfg . SectionWithEnvOverrides ( "security" ) . Key ( "csrf_always_check" ) . SetValue ( "false" )
return cfg
} ,
// Should be false when config value is set to false.
expectedAlwaysCheck : false ,
} ,
{
getInput : func ( ) * setting . Cfg {
cfg := setting . NewCfg ( )
cfg . SectionWithEnvOverrides ( "security" ) . Key ( "csrf_always_check" ) . SetValue ( "true" )
return cfg
} ,
// Should be true when config value is set to true.
expectedAlwaysCheck : true ,
} ,
}
for _ , tc := range tests {
csrf := ProvideCSRFFilter ( tc . getInput ( ) )
assert . Equal ( t , tc . expectedAlwaysCheck , csrf . alwaysCheck )
}
}