| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | import re | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from fastapi import Request | 
					
						
							|  |  |  | from starlette.middleware.base import BaseHTTPMiddleware | 
					
						
							|  |  |  | from typing import Dict | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | class SecurityHeadersMiddleware(BaseHTTPMiddleware): | 
					
						
							|  |  |  |     async def dispatch(self, request: Request, call_next): | 
					
						
							|  |  |  |         response = await call_next(request) | 
					
						
							|  |  |  |         response.headers.update(set_security_headers()) | 
					
						
							|  |  |  |         return response | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | def set_security_headers() -> Dict[str, str]: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Sets security headers based on environment variables. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     This function reads specific environment variables and uses their values | 
					
						
							|  |  |  |     to set corresponding security headers. The headers that can be set are: | 
					
						
							|  |  |  |     - cache-control | 
					
						
							| 
									
										
										
										
											2024-11-07 01:16:22 +08:00
										 |  |  |     - permissions-policy | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  |     - strict-transport-security | 
					
						
							|  |  |  |     - referrer-policy | 
					
						
							|  |  |  |     - x-content-type-options | 
					
						
							|  |  |  |     - x-download-options | 
					
						
							|  |  |  |     - x-frame-options | 
					
						
							| 
									
										
										
										
											2024-09-17 09:02:55 +08:00
										 |  |  |     - x-permitted-cross-domain-policies | 
					
						
							| 
									
										
										
										
											2024-11-30 22:31:54 +08:00
										 |  |  |     - content-security-policy | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     Each environment variable is associated with a specific setter function | 
					
						
							|  |  |  |     that constructs the header. If the environment variable is set, the | 
					
						
							|  |  |  |     corresponding header is added to the options dictionary. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Returns: | 
					
						
							|  |  |  |         dict: A dictionary containing the security headers and their values. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     options = {} | 
					
						
							|  |  |  |     header_setters = { | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |         "CACHE_CONTROL": set_cache_control, | 
					
						
							|  |  |  |         "HSTS": set_hsts, | 
					
						
							| 
									
										
										
										
											2024-11-07 01:16:22 +08:00
										 |  |  |         "PERMISSIONS_POLICY": set_permissions_policy, | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |         "REFERRER_POLICY": set_referrer, | 
					
						
							|  |  |  |         "XCONTENT_TYPE": set_xcontent_type, | 
					
						
							|  |  |  |         "XDOWNLOAD_OPTIONS": set_xdownload_options, | 
					
						
							|  |  |  |         "XFRAME_OPTIONS": set_xframe, | 
					
						
							|  |  |  |         "XPERMITTED_CROSS_DOMAIN_POLICIES": set_xpermitted_cross_domain_policies, | 
					
						
							| 
									
										
										
										
											2024-11-30 22:31:54 +08:00
										 |  |  |         "CONTENT_SECURITY_POLICY": set_content_security_policy, | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for env_var, setter in header_setters.items(): | 
					
						
							|  |  |  |         value = os.environ.get(env_var, None) | 
					
						
							|  |  |  |         if value: | 
					
						
							|  |  |  |             header = setter(value) | 
					
						
							|  |  |  |             if header: | 
					
						
							|  |  |  |                 options.update(header) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return options | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | # Set HTTP Strict Transport Security(HSTS) response header | 
					
						
							|  |  |  | def set_hsts(value: str): | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |     pattern = r"^max-age=(\d+)(;includeSubDomains)?(;preload)?$" | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  |     match = re.match(pattern, value, re.IGNORECASE) | 
					
						
							|  |  |  |     if not match: | 
					
						
							| 
									
										
										
										
											2024-10-26 14:45:28 +08:00
										 |  |  |         value = "max-age=31536000;includeSubDomains" | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |     return {"Strict-Transport-Security": value} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # Set X-Frame-Options response header | 
					
						
							|  |  |  | def set_xframe(value: str): | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |     pattern = r"^(DENY|SAMEORIGIN)$" | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  |     match = re.match(pattern, value, re.IGNORECASE) | 
					
						
							|  |  |  |     if not match: | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |         value = "DENY" | 
					
						
							|  |  |  |     return {"X-Frame-Options": value} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-07 01:16:22 +08:00
										 |  |  | # Set Permissions-Policy response header | 
					
						
							|  |  |  | def set_permissions_policy(value: str): | 
					
						
							|  |  |  |     pattern = r"^(?:(accelerometer|autoplay|camera|clipboard-read|clipboard-write|fullscreen|geolocation|gyroscope|magnetometer|microphone|midi|payment|picture-in-picture|sync-xhr|usb|xr-spatial-tracking)=\((self)?\),?)*$" | 
					
						
							|  |  |  |     match = re.match(pattern, value, re.IGNORECASE) | 
					
						
							|  |  |  |     if not match: | 
					
						
							|  |  |  |         value = "none" | 
					
						
							|  |  |  |     return {"Permissions-Policy": value} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | # Set Referrer-Policy response header | 
					
						
							|  |  |  | def set_referrer(value: str): | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |     pattern = r"^(no-referrer|no-referrer-when-downgrade|origin|origin-when-cross-origin|same-origin|strict-origin|strict-origin-when-cross-origin|unsafe-url)$" | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  |     match = re.match(pattern, value, re.IGNORECASE) | 
					
						
							|  |  |  |     if not match: | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |         value = "no-referrer" | 
					
						
							|  |  |  |     return {"Referrer-Policy": value} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # Set Cache-Control response header | 
					
						
							|  |  |  | def set_cache_control(value: str): | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |     pattern = r"^(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable)(,\s*(public|private|no-cache|no-store|must-revalidate|proxy-revalidate|max-age=\d+|s-maxage=\d+|no-transform|immutable))*$" | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  |     match = re.match(pattern, value, re.IGNORECASE) | 
					
						
							|  |  |  |     if not match: | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |         value = "no-store, max-age=0" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return {"Cache-Control": value} | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Set X-Download-Options response header | 
					
						
							|  |  |  | def set_xdownload_options(value: str): | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |     if value != "noopen": | 
					
						
							|  |  |  |         value = "noopen" | 
					
						
							|  |  |  |     return {"X-Download-Options": value} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # Set X-Content-Type-Options response header | 
					
						
							|  |  |  | def set_xcontent_type(value: str): | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |     if value != "nosniff": | 
					
						
							|  |  |  |         value = "nosniff" | 
					
						
							|  |  |  |     return {"X-Content-Type-Options": value} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # Set X-Permitted-Cross-Domain-Policies response header | 
					
						
							|  |  |  | def set_xpermitted_cross_domain_policies(value: str): | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |     pattern = r"^(none|master-only|by-content-type|by-ftp-filename)$" | 
					
						
							| 
									
										
										
										
											2024-09-17 08:53:30 +08:00
										 |  |  |     match = re.match(pattern, value, re.IGNORECASE) | 
					
						
							|  |  |  |     if not match: | 
					
						
							| 
									
										
										
										
											2024-09-19 09:24:39 +08:00
										 |  |  |         value = "none" | 
					
						
							|  |  |  |     return {"X-Permitted-Cross-Domain-Policies": value} | 
					
						
							| 
									
										
										
										
											2024-11-30 22:31:54 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-01 15:36:30 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-30 22:31:54 +08:00
										 |  |  | # Set Content-Security-Policy response header | 
					
						
							|  |  |  | def set_content_security_policy(value: str): | 
					
						
							|  |  |  |     return {"Content-Security-Policy": value} |