refactor: lean on rs/cors

This commit is contained in:
maxts0gt 2025-07-25 19:02:31 +09:00
parent bce7b80f50
commit 6940e0121e
2 changed files with 33 additions and 65 deletions

View File

@ -646,28 +646,7 @@ func registerAPIRouter(router *mux.Router) {
apiRouter.MethodNotAllowedHandler = collectAPIStats("methodnotallowed", httpTraceAll(methodNotAllowedHandler("S3")))
}
// corsCredentialsWrapper wraps http.ResponseWriter to post-process CORS headers
type corsCredentialsWrapper struct {
http.ResponseWriter
headerWritten bool
}
func (w *corsCredentialsWrapper) WriteHeader(code int) {
if !w.headerWritten {
w.fixCORSCredentialsViolation()
w.headerWritten = true
}
w.ResponseWriter.WriteHeader(code)
}
func (w *corsCredentialsWrapper) Write(b []byte) (int, error) {
if !w.headerWritten {
w.fixCORSCredentialsViolation()
w.headerWritten = true
}
return w.ResponseWriter.Write(b)
}
// configHasWildcard checks if any configured CORS origin is a wildcard
func configHasWildcard() bool {
for _, o := range globalAPIConfig.getCorsAllowOrigins() {
if o == "*" {
@ -677,20 +656,6 @@ func configHasWildcard() bool {
return false
}
func (w *corsCredentialsWrapper) fixCORSCredentialsViolation() {
hdr := w.Header()
// Only run if CORS actually ran
if hdr.Get("Access-Control-Allow-Origin") == "" {
return
}
// CORS spec compliance: never allow credentials with wildcard origins
if hdr.Get("Access-Control-Allow-Origin") == "*" || configHasWildcard() {
hdr.Del("Access-Control-Allow-Credentials")
}
}
// corsHandler handler for CORS (Cross Origin Resource Sharing)
func corsHandler(handler http.Handler) http.Handler {
commonS3Headers := []string{
@ -714,34 +679,37 @@ func corsHandler(handler http.Handler) http.Handler {
"x-amz*",
"*",
}
opts := cors.Options{
AllowOriginFunc: func(origin string) bool {
for _, allowedOrigin := range globalAPIConfig.getCorsAllowOrigins() {
if wildcard.MatchSimple(allowedOrigin, origin) {
return true
}
}
return false
},
AllowedMethods: []string{
http.MethodGet,
http.MethodPut,
http.MethodHead,
http.MethodPost,
http.MethodDelete,
http.MethodOptions,
http.MethodPatch,
},
AllowedHeaders: commonS3Headers,
ExposedHeaders: commonS3Headers,
AllowCredentials: true,
}
corsMiddleware := cors.New(opts)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Wrap response writer to post-process CORS headers
wrapper := &corsCredentialsWrapper{ResponseWriter: w}
corsMiddleware.Handler(handler).ServeHTTP(wrapper, r)
// Configure CORS dynamically based on current settings
// This ensures we handle configuration changes and wildcard security properly
hasWildcard := configHasWildcard()
opts := cors.Options{
AllowOriginFunc: func(origin string) bool {
for _, allowedOrigin := range globalAPIConfig.getCorsAllowOrigins() {
if wildcard.MatchSimple(allowedOrigin, origin) {
return true
}
}
return false
},
AllowedMethods: []string{
http.MethodGet,
http.MethodPut,
http.MethodHead,
http.MethodPost,
http.MethodDelete,
http.MethodOptions,
http.MethodPatch,
},
AllowedHeaders: commonS3Headers,
ExposedHeaders: commonS3Headers,
// CORS spec compliance: disable credentials when wildcard origins are configured
// This prevents the security vulnerability where any website can make credentialed requests
AllowCredentials: !hasWildcard,
}
// Use rs/cors directly without custom wrapper to avoid interface issues
cors.New(opts).Handler(handler).ServeHTTP(w, r)
})
}

View File

@ -168,7 +168,7 @@ func TestCORSUnauthorizedOrigin(t *testing.T) {
// Test preflight request from unauthorized origin
req := httptest.NewRequest("OPTIONS", "/", nil)
req.Header.Set("Origin", "https://example.com")
req.Header.Set("Origin", "https://example.org") // This origin is NOT in the allowed list
req.Header.Set("Access-Control-Request-Method", "GET")
rr := httptest.NewRecorder()