diff --git a/CHANGES.rst b/CHANGES.rst index d62d0383..ce169742 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -74,6 +74,7 @@ Unreleased ``python`` shell if ``readline`` is installed. :issue:`3941` - ``helpers.total_seconds()`` is deprecated. Use ``timedelta.total_seconds()`` instead. :pr:`3962` +- Add type hinting. :pr:`3973`. Version 1.1.2 diff --git a/src/flask/app.py b/src/flask/app.py index 98437cba..7afb0a1e 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -1,11 +1,14 @@ import functools import inspect +import logging import os import sys +import typing as t import weakref from datetime import timedelta from itertools import chain from threading import Lock +from types import TracebackType from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableDict @@ -15,6 +18,7 @@ from werkzeug.exceptions import HTTPException from werkzeug.exceptions import InternalServerError from werkzeug.routing import BuildError from werkzeug.routing import Map +from werkzeug.routing import MapAdapter from werkzeug.routing import RequestRedirect from werkzeug.routing import RoutingException from werkzeug.routing import Rule @@ -53,15 +57,30 @@ from .signals import request_started from .signals import request_tearing_down from .templating import DispatchingJinjaLoader from .templating import Environment +from .typing import AfterRequestCallable +from .typing import BeforeRequestCallable +from .typing import ErrorHandlerCallable +from .typing import ResponseReturnValue +from .typing import TeardownCallable +from .typing import TemplateContextProcessorCallable +from .typing import TemplateFilterCallable +from .typing import TemplateGlobalCallable +from .typing import TemplateTestCallable +from .typing import URLDefaultCallable +from .typing import URLValuePreprocessorCallable from .wrappers import Request from .wrappers import Response +if t.TYPE_CHECKING: + from .blueprints import Blueprint + from .testing import FlaskClient + from .testing import FlaskCliRunner if sys.version_info >= (3, 8): iscoroutinefunction = inspect.iscoroutinefunction else: - def iscoroutinefunction(func): + def iscoroutinefunction(func: t.Any) -> bool: while inspect.ismethod(func): func = func.__func__ @@ -71,7 +90,7 @@ else: return inspect.iscoroutinefunction(func) -def _make_timedelta(value): +def _make_timedelta(value: t.Optional[timedelta]) -> t.Optional[timedelta]: if value is None or isinstance(value, timedelta): return value @@ -295,7 +314,7 @@ class Flask(Scaffold): #: This is a ``dict`` instead of an ``ImmutableDict`` to allow #: easier configuration. #: - jinja_options = {} + jinja_options: dict = {} #: Default configuration parameters. default_config = ImmutableDict( @@ -347,7 +366,7 @@ class Flask(Scaffold): #: the test client that is used with when `test_client` is used. #: #: .. versionadded:: 0.7 - test_client_class = None + test_client_class: t.Optional[t.Type["FlaskClient"]] = None #: The :class:`~click.testing.CliRunner` subclass, by default #: :class:`~flask.testing.FlaskCliRunner` that is used by @@ -355,7 +374,7 @@ class Flask(Scaffold): #: Flask app object as the first argument. #: #: .. versionadded:: 1.0 - test_cli_runner_class = None + test_cli_runner_class: t.Optional[t.Type["FlaskCliRunner"]] = None #: the session interface to use. By default an instance of #: :class:`~flask.sessions.SecureCookieSessionInterface` is used here. @@ -365,16 +384,16 @@ class Flask(Scaffold): def __init__( self, - import_name, - static_url_path=None, - static_folder="static", - static_host=None, - host_matching=False, - subdomain_matching=False, - template_folder="templates", - instance_path=None, - instance_relative_config=False, - root_path=None, + import_name: str, + static_url_path: t.Optional[str] = None, + static_folder: t.Optional[str] = "static", + static_host: t.Optional[str] = None, + host_matching: bool = False, + subdomain_matching: bool = False, + template_folder: t.Optional[str] = "templates", + instance_path: t.Optional[str] = None, + instance_relative_config: bool = False, + root_path: t.Optional[str] = None, ): super().__init__( import_name=import_name, @@ -409,14 +428,16 @@ class Flask(Scaffold): #: tried. #: #: .. versionadded:: 0.9 - self.url_build_error_handlers = [] + self.url_build_error_handlers: t.List[ + t.Callable[[Exception, str, dict], str] + ] = [] #: A list of functions that will be called at the beginning of the #: first request to this instance. To register a function, use the #: :meth:`before_first_request` decorator. #: #: .. versionadded:: 0.8 - self.before_first_request_funcs = [] + self.before_first_request_funcs: t.List[BeforeRequestCallable] = [] #: A list of functions that are called when the application context #: is destroyed. Since the application context is also torn down @@ -424,13 +445,13 @@ class Flask(Scaffold): #: from databases. #: #: .. versionadded:: 0.9 - self.teardown_appcontext_funcs = [] + self.teardown_appcontext_funcs: t.List[TeardownCallable] = [] #: A list of shell context processor functions that should be run #: when a shell context is created. #: #: .. versionadded:: 0.11 - self.shell_context_processors = [] + self.shell_context_processors: t.List[t.Callable[[], t.Dict[str, t.Any]]] = [] #: Maps registered blueprint names to blueprint objects. The #: dict retains the order the blueprints were registered in. @@ -438,7 +459,7 @@ class Flask(Scaffold): #: not track how often they were attached. #: #: .. versionadded:: 0.7 - self.blueprints = {} + self.blueprints: t.Dict[str, "Blueprint"] = {} #: a place where extensions can store application specific state. For #: example this is where an extension could store database engines and @@ -449,7 +470,7 @@ class Flask(Scaffold): #: ``'foo'``. #: #: .. versionadded:: 0.7 - self.extensions = {} + self.extensions: dict = {} #: The :class:`~werkzeug.routing.Map` for this instance. You can use #: this to change the routing converters after the class was created @@ -492,18 +513,18 @@ class Flask(Scaffold): f"{self.static_url_path}/", endpoint="static", host=static_host, - view_func=lambda **kw: self_ref().send_static_file(**kw), + view_func=lambda **kw: self_ref().send_static_file(**kw), # type: ignore # noqa: B950 ) # Set the name of the Click group in case someone wants to add # the app's commands to another CLI tool. self.cli.name = self.name - def _is_setup_finished(self): + def _is_setup_finished(self) -> bool: return self.debug and self._got_first_request @locked_cached_property - def name(self): + def name(self) -> str: # type: ignore """The name of the application. This is usually the import name with the difference that it's guessed from the run file if the import name is main. This name is used as a display name when @@ -520,7 +541,7 @@ class Flask(Scaffold): return self.import_name @property - def propagate_exceptions(self): + def propagate_exceptions(self) -> bool: """Returns the value of the ``PROPAGATE_EXCEPTIONS`` configuration value in case it's set, otherwise a sensible default is returned. @@ -532,7 +553,7 @@ class Flask(Scaffold): return self.testing or self.debug @property - def preserve_context_on_exception(self): + def preserve_context_on_exception(self) -> bool: """Returns the value of the ``PRESERVE_CONTEXT_ON_EXCEPTION`` configuration value in case it's set, otherwise a sensible default is returned. @@ -545,7 +566,7 @@ class Flask(Scaffold): return self.debug @locked_cached_property - def logger(self): + def logger(self) -> logging.Logger: """A standard Python :class:`~logging.Logger` for the app, with the same name as :attr:`name`. @@ -572,7 +593,7 @@ class Flask(Scaffold): return create_logger(self) @locked_cached_property - def jinja_env(self): + def jinja_env(self) -> Environment: """The Jinja environment used to load templates. The environment is created the first time this property is @@ -582,7 +603,7 @@ class Flask(Scaffold): return self.create_jinja_environment() @property - def got_first_request(self): + def got_first_request(self) -> bool: """This attribute is set to ``True`` if the application started handling the first request. @@ -590,7 +611,7 @@ class Flask(Scaffold): """ return self._got_first_request - def make_config(self, instance_relative=False): + def make_config(self, instance_relative: bool = False) -> Config: """Used to create the config attribute by the Flask constructor. The `instance_relative` parameter is passed in from the constructor of Flask (there named `instance_relative_config`) and indicates if @@ -607,7 +628,7 @@ class Flask(Scaffold): defaults["DEBUG"] = get_debug_flag() return self.config_class(root_path, defaults) - def auto_find_instance_path(self): + def auto_find_instance_path(self) -> str: """Tries to locate the instance path if it was not provided to the constructor of the application class. It will basically calculate the path to a folder named ``instance`` next to your main file or @@ -620,7 +641,7 @@ class Flask(Scaffold): return os.path.join(package_path, "instance") return os.path.join(prefix, "var", f"{self.name}-instance") - def open_instance_resource(self, resource, mode="rb"): + def open_instance_resource(self, resource: str, mode: str = "rb") -> t.IO[t.AnyStr]: """Opens a resource from the application's instance folder (:attr:`instance_path`). Otherwise works like :meth:`open_resource`. Instance resources can also be opened for @@ -633,7 +654,7 @@ class Flask(Scaffold): return open(os.path.join(self.instance_path, resource), mode) @property - def templates_auto_reload(self): + def templates_auto_reload(self) -> bool: """Reload templates when they are changed. Used by :meth:`create_jinja_environment`. @@ -648,10 +669,10 @@ class Flask(Scaffold): return rv if rv is not None else self.debug @templates_auto_reload.setter - def templates_auto_reload(self, value): + def templates_auto_reload(self, value: bool) -> None: self.config["TEMPLATES_AUTO_RELOAD"] = value - def create_jinja_environment(self): + def create_jinja_environment(self) -> Environment: """Create the Jinja environment based on :attr:`jinja_options` and the various Jinja-related methods of the app. Changing :attr:`jinja_options` after this will have no effect. Also adds @@ -683,10 +704,10 @@ class Flask(Scaffold): session=session, g=g, ) - rv.policies["json.dumps_function"] = json.dumps + rv.policies["json.dumps_function"] = json.dumps # type: ignore return rv - def create_global_jinja_loader(self): + def create_global_jinja_loader(self) -> DispatchingJinjaLoader: """Creates the loader for the Jinja2 environment. Can be used to override just the loader and keeping the rest unchanged. It's discouraged to override this function. Instead one should override @@ -699,7 +720,7 @@ class Flask(Scaffold): """ return DispatchingJinjaLoader(self) - def select_jinja_autoescape(self, filename): + def select_jinja_autoescape(self, filename: str) -> bool: """Returns ``True`` if autoescaping should be active for the given template name. If no template name is given, returns `True`. @@ -709,7 +730,7 @@ class Flask(Scaffold): return True return filename.endswith((".html", ".htm", ".xml", ".xhtml")) - def update_template_context(self, context): + def update_template_context(self, context: dict) -> None: """Update the template context with some commonly used variables. This injects request, session, config and g into the template context as well as everything template context processors want @@ -720,7 +741,9 @@ class Flask(Scaffold): :param context: the context as a dictionary that is updated in place to add extra variables. """ - funcs = self.template_context_processors[None] + funcs: t.Iterable[ + TemplateContextProcessorCallable + ] = self.template_context_processors[None] reqctx = _request_ctx_stack.top if reqctx is not None: for bp in self._request_blueprints(): @@ -734,7 +757,7 @@ class Flask(Scaffold): # existing views. context.update(orig_ctx) - def make_shell_context(self): + def make_shell_context(self) -> dict: """Returns the shell context for an interactive shell for this application. This runs all the registered shell context processors. @@ -758,7 +781,7 @@ class Flask(Scaffold): env = ConfigAttribute("ENV") @property - def debug(self): + def debug(self) -> bool: """Whether debug mode is enabled. When using ``flask run`` to start the development server, an interactive debugger will be shown for unhandled exceptions, and the server will be reloaded when code @@ -775,11 +798,18 @@ class Flask(Scaffold): return self.config["DEBUG"] @debug.setter - def debug(self, value): + def debug(self, value: bool) -> None: self.config["DEBUG"] = value self.jinja_env.auto_reload = self.templates_auto_reload - def run(self, host=None, port=None, debug=None, load_dotenv=True, **options): + def run( + self, + host: t.Optional[str] = None, + port: t.Optional[int] = None, + debug: t.Optional[bool] = None, + load_dotenv: bool = True, + **options: t.Any, + ) -> None: """Runs the application on a local development server. Do not use ``run()`` in a production setting. It is not intended to @@ -887,14 +917,14 @@ class Flask(Scaffold): from werkzeug.serving import run_simple try: - run_simple(host, port, self, **options) + run_simple(t.cast(str, host), port, self, **options) finally: # reset the first request information if the development server # reset normally. This makes it possible to restart the server # without reloader and that stuff from an interactive shell. self._got_first_request = False - def test_client(self, use_cookies=True, **kwargs): + def test_client(self, use_cookies: bool = True, **kwargs: t.Any) -> "FlaskClient": """Creates a test client for this application. For information about unit testing head over to :doc:`/testing`. @@ -947,10 +977,12 @@ class Flask(Scaffold): """ cls = self.test_client_class if cls is None: - from .testing import FlaskClient as cls - return cls(self, self.response_class, use_cookies=use_cookies, **kwargs) + from .testing import FlaskClient as cls # type: ignore + return cls( # type: ignore + self, self.response_class, use_cookies=use_cookies, **kwargs + ) - def test_cli_runner(self, **kwargs): + def test_cli_runner(self, **kwargs: t.Any) -> "FlaskCliRunner": """Create a CLI runner for testing CLI commands. See :ref:`testing-cli`. @@ -963,12 +995,12 @@ class Flask(Scaffold): cls = self.test_cli_runner_class if cls is None: - from .testing import FlaskCliRunner as cls + from .testing import FlaskCliRunner as cls # type: ignore - return cls(self, **kwargs) + return cls(self, **kwargs) # type: ignore @setupmethod - def register_blueprint(self, blueprint, **options): + def register_blueprint(self, blueprint: "Blueprint", **options: t.Any) -> None: """Register a :class:`~flask.Blueprint` on the application. Keyword arguments passed to this method will override the defaults set on the blueprint. @@ -989,7 +1021,7 @@ class Flask(Scaffold): """ blueprint.register(self, options) - def iter_blueprints(self): + def iter_blueprints(self) -> t.ValuesView["Blueprint"]: """Iterates over all blueprints by the order they were registered. .. versionadded:: 0.11 @@ -999,14 +1031,14 @@ class Flask(Scaffold): @setupmethod def add_url_rule( self, - rule, - endpoint=None, - view_func=None, - provide_automatic_options=None, - **options, - ): + rule: str, + endpoint: t.Optional[str] = None, + view_func: t.Optional[t.Callable] = None, + provide_automatic_options: t.Optional[bool] = None, + **options: t.Any, + ) -> None: if endpoint is None: - endpoint = _endpoint_from_view_func(view_func) + endpoint = _endpoint_from_view_func(view_func) # type: ignore options["endpoint"] = endpoint methods = options.pop("methods", None) @@ -1043,13 +1075,13 @@ class Flask(Scaffold): methods |= required_methods rule = self.url_rule_class(rule, methods=methods, **options) - rule.provide_automatic_options = provide_automatic_options + rule.provide_automatic_options = provide_automatic_options # type: ignore self.url_map.add(rule) if view_func is not None: old_func = self.view_functions.get(endpoint) if getattr(old_func, "_flask_sync_wrapper", False): - old_func = old_func.__wrapped__ + old_func = old_func.__wrapped__ # type: ignore if old_func is not None and old_func != view_func: raise AssertionError( "View function mapping is overwriting an existing" @@ -1058,7 +1090,7 @@ class Flask(Scaffold): self.view_functions[endpoint] = self.ensure_sync(view_func) @setupmethod - def template_filter(self, name=None): + def template_filter(self, name: t.Optional[str] = None) -> t.Callable: """A decorator that is used to register custom template filter. You can specify a name for the filter, otherwise the function name will be used. Example:: @@ -1071,14 +1103,16 @@ class Flask(Scaffold): function name will be used. """ - def decorator(f): + def decorator(f: TemplateFilterCallable) -> TemplateFilterCallable: self.add_template_filter(f, name=name) return f return decorator @setupmethod - def add_template_filter(self, f, name=None): + def add_template_filter( + self, f: TemplateFilterCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template filter. Works exactly like the :meth:`template_filter` decorator. @@ -1088,7 +1122,7 @@ class Flask(Scaffold): self.jinja_env.filters[name or f.__name__] = f @setupmethod - def template_test(self, name=None): + def template_test(self, name: t.Optional[str] = None) -> t.Callable: """A decorator that is used to register custom template test. You can specify a name for the test, otherwise the function name will be used. Example:: @@ -1108,14 +1142,16 @@ class Flask(Scaffold): function name will be used. """ - def decorator(f): + def decorator(f: TemplateTestCallable) -> TemplateTestCallable: self.add_template_test(f, name=name) return f return decorator @setupmethod - def add_template_test(self, f, name=None): + def add_template_test( + self, f: TemplateTestCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template test. Works exactly like the :meth:`template_test` decorator. @@ -1127,7 +1163,7 @@ class Flask(Scaffold): self.jinja_env.tests[name or f.__name__] = f @setupmethod - def template_global(self, name=None): + def template_global(self, name: t.Optional[str] = None) -> t.Callable: """A decorator that is used to register a custom template global function. You can specify a name for the global function, otherwise the function name will be used. Example:: @@ -1142,14 +1178,16 @@ class Flask(Scaffold): function name will be used. """ - def decorator(f): + def decorator(f: TemplateGlobalCallable) -> TemplateGlobalCallable: self.add_template_global(f, name=name) return f return decorator @setupmethod - def add_template_global(self, f, name=None): + def add_template_global( + self, f: TemplateGlobalCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template global function. Works exactly like the :meth:`template_global` decorator. @@ -1161,7 +1199,7 @@ class Flask(Scaffold): self.jinja_env.globals[name or f.__name__] = f @setupmethod - def before_first_request(self, f): + def before_first_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable: """Registers a function to be run before the first request to this instance of the application. @@ -1174,7 +1212,7 @@ class Flask(Scaffold): return f @setupmethod - def teardown_appcontext(self, f): + def teardown_appcontext(self, f: TeardownCallable) -> TeardownCallable: """Registers a function to be called when the application context ends. These functions are typically also called when the request context is popped. @@ -1207,7 +1245,7 @@ class Flask(Scaffold): return f @setupmethod - def shell_context_processor(self, f): + def shell_context_processor(self, f: t.Callable) -> t.Callable: """Registers a shell context processor function. .. versionadded:: 0.11 @@ -1215,7 +1253,7 @@ class Flask(Scaffold): self.shell_context_processors.append(f) return f - def _find_error_handler(self, e): + def _find_error_handler(self, e: Exception) -> t.Optional[ErrorHandlerCallable]: """Return a registered error handler for an exception in this order: blueprint handler for a specific code, app handler for a specific code, blueprint handler for an exception class, app handler for an exception @@ -1235,8 +1273,11 @@ class Flask(Scaffold): if handler is not None: return handler + return None - def handle_http_exception(self, e): + def handle_http_exception( + self, e: HTTPException + ) -> t.Union[HTTPException, ResponseReturnValue]: """Handles an HTTP exception. By default this will invoke the registered error handlers and fall back to returning the exception as response. @@ -1269,7 +1310,7 @@ class Flask(Scaffold): return e return handler(e) - def trap_http_exception(self, e): + def trap_http_exception(self, e: Exception) -> bool: """Checks if an HTTP exception should be trapped or not. By default this will return ``False`` for all exceptions except for a bad request key error if ``TRAP_BAD_REQUEST_ERRORS`` is set to ``True``. It @@ -1304,7 +1345,9 @@ class Flask(Scaffold): return False - def handle_user_exception(self, e): + def handle_user_exception( + self, e: Exception + ) -> t.Union[HTTPException, ResponseReturnValue]: """This method is called whenever an exception occurs that should be handled. A special case is :class:`~werkzeug .exceptions.HTTPException` which is forwarded to the @@ -1334,7 +1377,7 @@ class Flask(Scaffold): return handler(e) - def handle_exception(self, e): + def handle_exception(self, e: Exception) -> Response: """Handle an exception that did not have an error handler associated with it, or that was raised from an error handler. This always causes a 500 ``InternalServerError``. @@ -1374,6 +1417,7 @@ class Flask(Scaffold): raise e self.log_exception(exc_info) + server_error: t.Union[InternalServerError, ResponseReturnValue] server_error = InternalServerError(original_exception=e) handler = self._find_error_handler(server_error) @@ -1382,7 +1426,12 @@ class Flask(Scaffold): return self.finalize_request(server_error, from_error_handler=True) - def log_exception(self, exc_info): + def log_exception( + self, + exc_info: t.Union[ + t.Tuple[type, BaseException, TracebackType], t.Tuple[None, None, None] + ], + ) -> None: """Logs an exception. This is called by :meth:`handle_exception` if debugging is disabled and right before the handler is called. The default implementation logs the exception as error on the @@ -1394,7 +1443,7 @@ class Flask(Scaffold): f"Exception on {request.path} [{request.method}]", exc_info=exc_info ) - def raise_routing_exception(self, request): + def raise_routing_exception(self, request: Request) -> t.NoReturn: """Exceptions that are recording during routing are reraised with this method. During debug we are not reraising redirect requests for non ``GET``, ``HEAD``, or ``OPTIONS`` requests and we're raising @@ -1407,13 +1456,13 @@ class Flask(Scaffold): or not isinstance(request.routing_exception, RequestRedirect) or request.method in ("GET", "HEAD", "OPTIONS") ): - raise request.routing_exception + raise request.routing_exception # type: ignore from .debughelpers import FormDataRoutingRedirect raise FormDataRoutingRedirect(request) - def dispatch_request(self): + def dispatch_request(self) -> ResponseReturnValue: """Does the request dispatching. Matches the URL and returns the return value of the view or error handler. This does not have to be a response object. In order to convert the return value to a @@ -1437,7 +1486,7 @@ class Flask(Scaffold): # otherwise dispatch to the handler for that endpoint return self.view_functions[rule.endpoint](**req.view_args) - def full_dispatch_request(self): + def full_dispatch_request(self) -> Response: """Dispatches the request and on top of that performs request pre and postprocessing as well as HTTP exception catching and error handling. @@ -1454,7 +1503,11 @@ class Flask(Scaffold): rv = self.handle_user_exception(e) return self.finalize_request(rv) - def finalize_request(self, rv, from_error_handler=False): + def finalize_request( + self, + rv: t.Union[ResponseReturnValue, HTTPException], + from_error_handler: bool = False, + ) -> Response: """Given the return value from a view function this finalizes the request by converting it into a response and invoking the postprocessing functions. This is invoked for both normal @@ -1479,7 +1532,7 @@ class Flask(Scaffold): ) return response - def try_trigger_before_first_request_functions(self): + def try_trigger_before_first_request_functions(self) -> None: """Called before each request and will ensure that it triggers the :attr:`before_first_request_funcs` and only exactly once per application instance (which means process usually). @@ -1495,7 +1548,7 @@ class Flask(Scaffold): func() self._got_first_request = True - def make_default_options_response(self): + def make_default_options_response(self) -> Response: """This method is called to create the default ``OPTIONS`` response. This can be changed through subclassing to change the default behavior of ``OPTIONS`` responses. @@ -1508,7 +1561,7 @@ class Flask(Scaffold): rv.allow.update(methods) return rv - def should_ignore_error(self, error): + def should_ignore_error(self, error: t.Optional[BaseException]) -> bool: """This is called to figure out if an error should be ignored or not as far as the teardown system is concerned. If this function returns ``True`` then the teardown handlers will not be @@ -1518,7 +1571,7 @@ class Flask(Scaffold): """ return False - def ensure_sync(self, func): + def ensure_sync(self, func: t.Callable) -> t.Callable: """Ensure that the function is synchronous for WSGI workers. Plain ``def`` functions are returned as-is. ``async def`` functions are wrapped to run and wait for the response. @@ -1532,7 +1585,7 @@ class Flask(Scaffold): return func - def make_response(self, rv): + def make_response(self, rv: ResponseReturnValue) -> Response: """Convert the return value from a view function to an instance of :attr:`response_class`. @@ -1620,7 +1673,7 @@ class Flask(Scaffold): # evaluate a WSGI callable, or coerce a different response # class to the correct type try: - rv = self.response_class.force_type(rv, request.environ) + rv = self.response_class.force_type(rv, request.environ) # type: ignore # noqa: B950 except TypeError as e: raise TypeError( f"{e}\nThe view function did not return a valid" @@ -1636,10 +1689,11 @@ class Flask(Scaffold): f" callable, but it was a {type(rv).__name__}." ) + rv = t.cast(Response, rv) # prefer the status if it was provided if status is not None: if isinstance(status, (str, bytes, bytearray)): - rv.status = status + rv.status = status # type: ignore else: rv.status_code = status @@ -1649,7 +1703,9 @@ class Flask(Scaffold): return rv - def create_url_adapter(self, request): + def create_url_adapter( + self, request: t.Optional[Request] + ) -> t.Optional[MapAdapter]: """Creates a URL adapter for the given request. The URL adapter is created at a point where the request context is not yet set up so the request is passed explicitly. @@ -1687,21 +1743,25 @@ class Flask(Scaffold): url_scheme=self.config["PREFERRED_URL_SCHEME"], ) - def inject_url_defaults(self, endpoint, values): + return None + + def inject_url_defaults(self, endpoint: str, values: dict) -> None: """Injects the URL defaults for the given endpoint directly into the values dictionary passed. This is used internally and automatically called on URL building. .. versionadded:: 0.7 """ - funcs = self.url_default_functions[None] + funcs: t.Iterable[URLDefaultCallable] = self.url_default_functions[None] if "." in endpoint: bp = endpoint.rsplit(".", 1)[0] funcs = chain(funcs, self.url_default_functions[bp]) for func in funcs: func(endpoint, values) - def handle_url_build_error(self, error, endpoint, values): + def handle_url_build_error( + self, error: Exception, endpoint: str, values: dict + ) -> str: """Handle :class:`~werkzeug.routing.BuildError` on :meth:`url_for`. """ @@ -1722,7 +1782,7 @@ class Flask(Scaffold): raise error - def preprocess_request(self): + def preprocess_request(self) -> t.Optional[ResponseReturnValue]: """Called before the request is dispatched. Calls :attr:`url_value_preprocessors` registered with the app and the current blueprint (if any). Then calls :attr:`before_request_funcs` @@ -1733,14 +1793,16 @@ class Flask(Scaffold): further request handling is stopped. """ - funcs = self.url_value_preprocessors[None] + funcs: t.Iterable[URLValuePreprocessorCallable] = self.url_value_preprocessors[ + None + ] for bp in self._request_blueprints(): if bp in self.url_value_preprocessors: funcs = chain(funcs, self.url_value_preprocessors[bp]) for func in funcs: func(request.endpoint, request.view_args) - funcs = self.before_request_funcs[None] + funcs: t.Iterable[BeforeRequestCallable] = self.before_request_funcs[None] for bp in self._request_blueprints(): if bp in self.before_request_funcs: funcs = chain(funcs, self.before_request_funcs[bp]) @@ -1749,7 +1811,9 @@ class Flask(Scaffold): if rv is not None: return rv - def process_response(self, response): + return None + + def process_response(self, response: Response) -> Response: """Can be overridden in order to modify the response object before it's sent to the WSGI server. By default this will call all the :meth:`after_request` decorated functions. @@ -1763,7 +1827,7 @@ class Flask(Scaffold): instance of :attr:`response_class`. """ ctx = _request_ctx_stack.top - funcs = ctx._after_request_functions + funcs: t.Iterable[AfterRequestCallable] = ctx._after_request_functions for bp in self._request_blueprints(): if bp in self.after_request_funcs: funcs = chain(funcs, reversed(self.after_request_funcs[bp])) @@ -1775,7 +1839,9 @@ class Flask(Scaffold): self.session_interface.save_session(self, ctx.session, response) return response - def do_teardown_request(self, exc=_sentinel): + def do_teardown_request( + self, exc: t.Optional[BaseException] = _sentinel # type: ignore + ) -> None: """Called after the request is dispatched and the response is returned, right before the request context is popped. @@ -1798,7 +1864,9 @@ class Flask(Scaffold): """ if exc is _sentinel: exc = sys.exc_info()[1] - funcs = reversed(self.teardown_request_funcs[None]) + funcs: t.Iterable[TeardownCallable] = reversed( + self.teardown_request_funcs[None] + ) for bp in self._request_blueprints(): if bp in self.teardown_request_funcs: funcs = chain(funcs, reversed(self.teardown_request_funcs[bp])) @@ -1806,7 +1874,9 @@ class Flask(Scaffold): func(exc) request_tearing_down.send(self, exc=exc) - def do_teardown_appcontext(self, exc=_sentinel): + def do_teardown_appcontext( + self, exc: t.Optional[BaseException] = _sentinel # type: ignore + ) -> None: """Called right before the application context is popped. When handling a request, the application context is popped @@ -1827,7 +1897,7 @@ class Flask(Scaffold): func(exc) appcontext_tearing_down.send(self, exc=exc) - def app_context(self): + def app_context(self) -> AppContext: """Create an :class:`~flask.ctx.AppContext`. Use as a ``with`` block to push the context, which will make :data:`current_app` point at this application. @@ -1848,7 +1918,7 @@ class Flask(Scaffold): """ return AppContext(self) - def request_context(self, environ): + def request_context(self, environ: dict) -> RequestContext: """Create a :class:`~flask.ctx.RequestContext` representing a WSGI environment. Use a ``with`` block to push the context, which will make :data:`request` point at this request. @@ -1864,7 +1934,7 @@ class Flask(Scaffold): """ return RequestContext(self, environ) - def test_request_context(self, *args, **kwargs): + def test_request_context(self, *args: t.Any, **kwargs: t.Any) -> RequestContext: """Create a :class:`~flask.ctx.RequestContext` for a WSGI environment created from the given values. This is mostly useful during testing, where you may want to run a function that uses @@ -1920,7 +1990,7 @@ class Flask(Scaffold): finally: builder.close() - def wsgi_app(self, environ, start_response): + def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any: """The actual WSGI application. This is not implemented in :meth:`__call__` so that middlewares can be applied without losing a reference to the app object. Instead of doing this:: @@ -1946,7 +2016,7 @@ class Flask(Scaffold): start the response. """ ctx = self.request_context(environ) - error = None + error: t.Optional[BaseException] = None try: try: ctx.push() @@ -1963,14 +2033,14 @@ class Flask(Scaffold): error = None ctx.auto_pop(error) - def __call__(self, environ, start_response): + def __call__(self, environ: dict, start_response: t.Callable) -> t.Any: """The WSGI server calls the Flask application object as the WSGI application. This calls :meth:`wsgi_app`, which can be wrapped to apply middleware. """ return self.wsgi_app(environ, start_response) - def _request_blueprints(self): + def _request_blueprints(self) -> t.Iterable[str]: if _request_ctx_stack.top.request.blueprint is None: return [] else: diff --git a/src/flask/blueprints.py b/src/flask/blueprints.py index 92345cf2..a2b6c0f5 100644 --- a/src/flask/blueprints.py +++ b/src/flask/blueprints.py @@ -1,9 +1,25 @@ +import typing as t from collections import defaultdict from functools import update_wrapper from .scaffold import _endpoint_from_view_func from .scaffold import _sentinel from .scaffold import Scaffold +from .typing import AfterRequestCallable +from .typing import BeforeRequestCallable +from .typing import ErrorHandlerCallable +from .typing import TeardownCallable +from .typing import TemplateContextProcessorCallable +from .typing import TemplateFilterCallable +from .typing import TemplateGlobalCallable +from .typing import TemplateTestCallable +from .typing import URLDefaultCallable +from .typing import URLValuePreprocessorCallable + +if t.TYPE_CHECKING: + from .app import Flask + +DeferredSetupFunction = t.Callable[["BlueprintSetupState"], t.Callable] class BlueprintSetupState: @@ -13,7 +29,13 @@ class BlueprintSetupState: to all register callback functions. """ - def __init__(self, blueprint, app, options, first_registration): + def __init__( + self, + blueprint: "Blueprint", + app: "Flask", + options: t.Any, + first_registration: bool, + ) -> None: #: a reference to the current application self.app = app @@ -52,7 +74,13 @@ class BlueprintSetupState: self.url_defaults = dict(self.blueprint.url_values_defaults) self.url_defaults.update(self.options.get("url_defaults", ())) - def add_url_rule(self, rule, endpoint=None, view_func=None, **options): + def add_url_rule( + self, + rule: str, + endpoint: t.Optional[str] = None, + view_func: t.Optional[t.Callable] = None, + **options: t.Any, + ) -> None: """A helper method to register a rule (and optionally a view function) to the application. The endpoint is automatically prefixed with the blueprint's name. @@ -64,7 +92,7 @@ class BlueprintSetupState: rule = self.url_prefix options.setdefault("subdomain", self.subdomain) if endpoint is None: - endpoint = _endpoint_from_view_func(view_func) + endpoint = _endpoint_from_view_func(view_func) # type: ignore defaults = self.url_defaults if "defaults" in options: defaults = dict(defaults, **options.pop("defaults")) @@ -142,16 +170,16 @@ class Blueprint(Scaffold): def __init__( self, - name, - import_name, - static_folder=None, - static_url_path=None, - template_folder=None, - url_prefix=None, - subdomain=None, - url_defaults=None, - root_path=None, - cli_group=_sentinel, + name: str, + import_name: str, + static_folder: t.Optional[str] = None, + static_url_path: t.Optional[str] = None, + template_folder: t.Optional[str] = None, + url_prefix: t.Optional[str] = None, + subdomain: t.Optional[str] = None, + url_defaults: t.Optional[dict] = None, + root_path: t.Optional[str] = None, + cli_group: t.Optional[str] = _sentinel, # type: ignore ): super().__init__( import_name=import_name, @@ -163,19 +191,19 @@ class Blueprint(Scaffold): self.name = name self.url_prefix = url_prefix self.subdomain = subdomain - self.deferred_functions = [] + self.deferred_functions: t.List[DeferredSetupFunction] = [] if url_defaults is None: url_defaults = {} self.url_values_defaults = url_defaults self.cli_group = cli_group - self._blueprints = [] + self._blueprints: t.List[t.Tuple["Blueprint", dict]] = [] - def _is_setup_finished(self): + def _is_setup_finished(self) -> bool: return self.warn_on_modifications and self._got_registered_once - def record(self, func): + def record(self, func: t.Callable) -> None: """Registers a function that is called when the blueprint is registered on the application. This function is called with the state as argument as returned by the :meth:`make_setup_state` @@ -193,27 +221,29 @@ class Blueprint(Scaffold): ) self.deferred_functions.append(func) - def record_once(self, func): + def record_once(self, func: t.Callable) -> None: """Works like :meth:`record` but wraps the function in another function that will ensure the function is only called once. If the blueprint is registered a second time on the application, the function passed is not called. """ - def wrapper(state): + def wrapper(state: BlueprintSetupState) -> None: if state.first_registration: func(state) return self.record(update_wrapper(wrapper, func)) - def make_setup_state(self, app, options, first_registration=False): + def make_setup_state( + self, app: "Flask", options: dict, first_registration: bool = False + ) -> BlueprintSetupState: """Creates an instance of :meth:`~flask.blueprints.BlueprintSetupState` object that is later passed to the register callback functions. Subclasses can override this to return a subclass of the setup state. """ return BlueprintSetupState(self, app, options, first_registration) - def register_blueprint(self, blueprint, **options): + def register_blueprint(self, blueprint: "Blueprint", **options: t.Any) -> None: """Register a :class:`~flask.Blueprint` on this blueprint. Keyword arguments passed to this method will override the defaults set on the blueprint. @@ -222,7 +252,7 @@ class Blueprint(Scaffold): """ self._blueprints.append((blueprint, options)) - def register(self, app, options): + def register(self, app: "Flask", options: dict) -> None: """Called by :meth:`Flask.register_blueprint` to register all views and callbacks registered on the blueprint with the application. Creates a :class:`.BlueprintSetupState` and calls @@ -327,7 +357,13 @@ class Blueprint(Scaffold): bp_options["name_prefix"] = options.get("name_prefix", "") + self.name + "." blueprint.register(app, bp_options) - def add_url_rule(self, rule, endpoint=None, view_func=None, **options): + def add_url_rule( + self, + rule: str, + endpoint: t.Optional[str] = None, + view_func: t.Optional[t.Callable] = None, + **options: t.Any, + ) -> None: """Like :meth:`Flask.add_url_rule` but for a blueprint. The endpoint for the :func:`url_for` function is prefixed with the name of the blueprint. """ @@ -339,7 +375,7 @@ class Blueprint(Scaffold): ), "Blueprint view function name should not contain dots" self.record(lambda s: s.add_url_rule(rule, endpoint, view_func, **options)) - def app_template_filter(self, name=None): + def app_template_filter(self, name: t.Optional[str] = None) -> t.Callable: """Register a custom template filter, available application wide. Like :meth:`Flask.template_filter` but for a blueprint. @@ -347,13 +383,15 @@ class Blueprint(Scaffold): function name will be used. """ - def decorator(f): + def decorator(f: TemplateFilterCallable) -> TemplateFilterCallable: self.add_app_template_filter(f, name=name) return f return decorator - def add_app_template_filter(self, f, name=None): + def add_app_template_filter( + self, f: TemplateFilterCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template filter, available application wide. Like :meth:`Flask.add_template_filter` but for a blueprint. Works exactly like the :meth:`app_template_filter` decorator. @@ -362,12 +400,12 @@ class Blueprint(Scaffold): function name will be used. """ - def register_template(state): + def register_template(state: BlueprintSetupState) -> None: state.app.jinja_env.filters[name or f.__name__] = f self.record_once(register_template) - def app_template_test(self, name=None): + def app_template_test(self, name: t.Optional[str] = None) -> t.Callable: """Register a custom template test, available application wide. Like :meth:`Flask.template_test` but for a blueprint. @@ -377,13 +415,15 @@ class Blueprint(Scaffold): function name will be used. """ - def decorator(f): + def decorator(f: TemplateTestCallable) -> TemplateTestCallable: self.add_app_template_test(f, name=name) return f return decorator - def add_app_template_test(self, f, name=None): + def add_app_template_test( + self, f: TemplateTestCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template test, available application wide. Like :meth:`Flask.add_template_test` but for a blueprint. Works exactly like the :meth:`app_template_test` decorator. @@ -394,12 +434,12 @@ class Blueprint(Scaffold): function name will be used. """ - def register_template(state): + def register_template(state: BlueprintSetupState) -> None: state.app.jinja_env.tests[name or f.__name__] = f self.record_once(register_template) - def app_template_global(self, name=None): + def app_template_global(self, name: t.Optional[str] = None) -> t.Callable: """Register a custom template global, available application wide. Like :meth:`Flask.template_global` but for a blueprint. @@ -409,13 +449,15 @@ class Blueprint(Scaffold): function name will be used. """ - def decorator(f): + def decorator(f: TemplateGlobalCallable) -> TemplateGlobalCallable: self.add_app_template_global(f, name=name) return f return decorator - def add_app_template_global(self, f, name=None): + def add_app_template_global( + self, f: TemplateGlobalCallable, name: t.Optional[str] = None + ) -> None: """Register a custom template global, available application wide. Like :meth:`Flask.add_template_global` but for a blueprint. Works exactly like the :meth:`app_template_global` decorator. @@ -426,12 +468,12 @@ class Blueprint(Scaffold): function name will be used. """ - def register_template(state): + def register_template(state: BlueprintSetupState) -> None: state.app.jinja_env.globals[name or f.__name__] = f self.record_once(register_template) - def before_app_request(self, f): + def before_app_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable: """Like :meth:`Flask.before_request`. Such a function is executed before each request, even if outside of a blueprint. """ @@ -442,7 +484,9 @@ class Blueprint(Scaffold): ) return f - def before_app_first_request(self, f): + def before_app_first_request( + self, f: BeforeRequestCallable + ) -> BeforeRequestCallable: """Like :meth:`Flask.before_first_request`. Such a function is executed before the first request to the application. """ @@ -451,7 +495,7 @@ class Blueprint(Scaffold): ) return f - def after_app_request(self, f): + def after_app_request(self, f: AfterRequestCallable) -> AfterRequestCallable: """Like :meth:`Flask.after_request` but for a blueprint. Such a function is executed after each request, even if outside of the blueprint. """ @@ -462,7 +506,7 @@ class Blueprint(Scaffold): ) return f - def teardown_app_request(self, f): + def teardown_app_request(self, f: TeardownCallable) -> TeardownCallable: """Like :meth:`Flask.teardown_request` but for a blueprint. Such a function is executed when tearing down each request, even if outside of the blueprint. @@ -472,7 +516,9 @@ class Blueprint(Scaffold): ) return f - def app_context_processor(self, f): + def app_context_processor( + self, f: TemplateContextProcessorCallable + ) -> TemplateContextProcessorCallable: """Like :meth:`Flask.context_processor` but for a blueprint. Such a function is executed each request, even if outside of the blueprint. """ @@ -481,32 +527,34 @@ class Blueprint(Scaffold): ) return f - def app_errorhandler(self, code): + def app_errorhandler(self, code: t.Union[t.Type[Exception], int]) -> t.Callable: """Like :meth:`Flask.errorhandler` but for a blueprint. This handler is used for all requests, even if outside of the blueprint. """ - def decorator(f): + def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable: self.record_once(lambda s: s.app.errorhandler(code)(f)) return f return decorator - def app_url_value_preprocessor(self, f): + def app_url_value_preprocessor( + self, f: URLValuePreprocessorCallable + ) -> URLValuePreprocessorCallable: """Same as :meth:`url_value_preprocessor` but application wide.""" self.record_once( lambda s: s.app.url_value_preprocessors.setdefault(None, []).append(f) ) return f - def app_url_defaults(self, f): + def app_url_defaults(self, f: URLDefaultCallable) -> URLDefaultCallable: """Same as :meth:`url_defaults` but application wide.""" self.record_once( lambda s: s.app.url_default_functions.setdefault(None, []).append(f) ) return f - def ensure_sync(self, f): + def ensure_sync(self, f: t.Callable) -> t.Callable: """Ensure the function is synchronous. Override if you would like custom async to sync behaviour in diff --git a/src/flask/cli.py b/src/flask/cli.py index 79a9a7c4..987c95cd 100644 --- a/src/flask/cli.py +++ b/src/flask/cli.py @@ -27,7 +27,7 @@ except ImportError: try: import ssl except ImportError: - ssl = None + ssl = None # type: ignore class NoAppException(click.UsageError): @@ -860,7 +860,7 @@ def run_command( @click.command("shell", short_help="Run a shell in the app context.") @with_appcontext -def shell_command(): +def shell_command() -> None: """Run an interactive Python shell in the context of a given Flask application. The application will populate the default namespace of this shell according to its configuration. @@ -877,7 +877,7 @@ def shell_command(): f"App: {app.import_name} [{app.env}]\n" f"Instance: {app.instance_path}" ) - ctx = {} + ctx: dict = {} # Support the regular Python interpreter startup script if someone # is using it. @@ -922,7 +922,7 @@ def shell_command(): ) @click.option("--all-methods", is_flag=True, help="Show HEAD and OPTIONS methods.") @with_appcontext -def routes_command(sort, all_methods): +def routes_command(sort: str, all_methods: bool) -> None: """Show all registered routes with endpoints and methods.""" rules = list(current_app.url_map.iter_rules()) @@ -935,9 +935,12 @@ def routes_command(sort, all_methods): if sort in ("endpoint", "rule"): rules = sorted(rules, key=attrgetter(sort)) elif sort == "methods": - rules = sorted(rules, key=lambda rule: sorted(rule.methods)) + rules = sorted(rules, key=lambda rule: sorted(rule.methods)) # type: ignore - rule_methods = [", ".join(sorted(rule.methods - ignored_methods)) for rule in rules] + rule_methods = [ + ", ".join(sorted(rule.methods - ignored_methods)) # type: ignore + for rule in rules + ] headers = ("Endpoint", "Methods", "Rule") widths = ( @@ -975,7 +978,7 @@ debug mode. ) -def main(): +def main() -> None: # TODO omit sys.argv once https://github.com/pallets/click/issues/536 is fixed cli.main(args=sys.argv[1:]) diff --git a/src/flask/config.py b/src/flask/config.py index d2dfec2b..86f21dc8 100644 --- a/src/flask/config.py +++ b/src/flask/config.py @@ -1,6 +1,7 @@ import errno import os import types +import typing as t from werkzeug.utils import import_string @@ -8,11 +9,11 @@ from werkzeug.utils import import_string class ConfigAttribute: """Makes an attribute forward to the config""" - def __init__(self, name, get_converter=None): + def __init__(self, name: str, get_converter: t.Optional[t.Callable] = None) -> None: self.__name__ = name self.get_converter = get_converter - def __get__(self, obj, type=None): + def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any: if obj is None: return self rv = obj.config[self.__name__] @@ -20,7 +21,7 @@ class ConfigAttribute: rv = self.get_converter(rv) return rv - def __set__(self, obj, value): + def __set__(self, obj: t.Any, value: t.Any) -> None: obj.config[self.__name__] = value @@ -68,11 +69,11 @@ class Config(dict): :param defaults: an optional dictionary of default values """ - def __init__(self, root_path, defaults=None): + def __init__(self, root_path: str, defaults: t.Optional[dict] = None) -> None: dict.__init__(self, defaults or {}) self.root_path = root_path - def from_envvar(self, variable_name, silent=False): + def from_envvar(self, variable_name: str, silent: bool = False) -> bool: """Loads a configuration from an environment variable pointing to a configuration file. This is basically just a shortcut with nicer error messages for this line of code:: @@ -96,7 +97,7 @@ class Config(dict): ) return self.from_pyfile(rv, silent=silent) - def from_pyfile(self, filename, silent=False): + def from_pyfile(self, filename: str, silent: bool = False) -> bool: """Updates the values in the config from a Python file. This function behaves as if the file was imported as module with the :meth:`from_object` function. @@ -124,7 +125,7 @@ class Config(dict): self.from_object(d) return True - def from_object(self, obj): + def from_object(self, obj: t.Union[object, str]) -> None: """Updates the values from the given object. An object can be of one of the following two types: @@ -162,7 +163,12 @@ class Config(dict): if key.isupper(): self[key] = getattr(obj, key) - def from_file(self, filename, load, silent=False): + def from_file( + self, + filename: str, + load: t.Callable[[t.IO[t.Any]], t.Mapping], + silent: bool = False, + ) -> bool: """Update the values in the config from a file that is loaded using the ``load`` parameter. The loaded data is passed to the :meth:`from_mapping` method. @@ -196,30 +202,26 @@ class Config(dict): return self.from_mapping(obj) - def from_mapping(self, *mapping, **kwargs): + def from_mapping( + self, mapping: t.Optional[t.Mapping[str, t.Any]] = None, **kwargs: t.Any + ) -> bool: """Updates the config like :meth:`update` ignoring items with non-upper keys. .. versionadded:: 0.11 """ - mappings = [] - if len(mapping) == 1: - if hasattr(mapping[0], "items"): - mappings.append(mapping[0].items()) - else: - mappings.append(mapping[0]) - elif len(mapping) > 1: - raise TypeError( - f"expected at most 1 positional argument, got {len(mapping)}" - ) - mappings.append(kwargs.items()) - for mapping in mappings: - for (key, value) in mapping: - if key.isupper(): - self[key] = value + mappings: t.Dict[str, t.Any] = {} + if mapping is not None: + mappings.update(mapping) + mappings.update(kwargs) + for key, value in mappings.items(): + if key.isupper(): + self[key] = value return True - def get_namespace(self, namespace, lowercase=True, trim_namespace=True): + def get_namespace( + self, namespace: str, lowercase: bool = True, trim_namespace: bool = True + ) -> t.Dict[str, t.Any]: """Returns a dictionary containing a subset of configuration options that match the specified namespace/prefix. Example usage:: @@ -260,5 +262,5 @@ class Config(dict): rv[key] = v return rv - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} {dict.__repr__(self)}>" diff --git a/src/flask/ctx.py b/src/flask/ctx.py index f9cb87d2..70de8cad 100644 --- a/src/flask/ctx.py +++ b/src/flask/ctx.py @@ -1,5 +1,7 @@ import sys +import typing as t from functools import update_wrapper +from types import TracebackType from werkzeug.exceptions import HTTPException @@ -7,6 +9,12 @@ from .globals import _app_ctx_stack from .globals import _request_ctx_stack from .signals import appcontext_popped from .signals import appcontext_pushed +from .typing import AfterRequestCallable + +if t.TYPE_CHECKING: + from .app import Flask + from .sessions import SessionMixin + from .wrappers import Request # a singleton sentinel value for parameter defaults @@ -33,7 +41,7 @@ class _AppCtxGlobals: .. versionadded:: 0.10 """ - def get(self, name, default=None): + def get(self, name: str, default: t.Optional[t.Any] = None) -> t.Any: """Get an attribute by name, or a default value. Like :meth:`dict.get`. @@ -44,7 +52,7 @@ class _AppCtxGlobals: """ return self.__dict__.get(name, default) - def pop(self, name, default=_sentinel): + def pop(self, name: str, default: t.Any = _sentinel) -> t.Any: """Get and remove an attribute by name. Like :meth:`dict.pop`. :param name: Name of attribute to pop. @@ -58,7 +66,7 @@ class _AppCtxGlobals: else: return self.__dict__.pop(name, default) - def setdefault(self, name, default=None): + def setdefault(self, name: str, default: t.Any = None) -> t.Any: """Get the value of an attribute if it is present, otherwise set and return a default value. Like :meth:`dict.setdefault`. @@ -70,20 +78,20 @@ class _AppCtxGlobals: """ return self.__dict__.setdefault(name, default) - def __contains__(self, item): + def __contains__(self, item: t.Any) -> bool: return item in self.__dict__ - def __iter__(self): + def __iter__(self) -> t.Iterator: return iter(self.__dict__) - def __repr__(self): + def __repr__(self) -> str: top = _app_ctx_stack.top if top is not None: return f"" return object.__repr__(self) -def after_this_request(f): +def after_this_request(f: AfterRequestCallable) -> AfterRequestCallable: """Executes a function after this request. This is useful to modify response objects. The function is passed the response object and has to return the same or a new one. @@ -108,7 +116,7 @@ def after_this_request(f): return f -def copy_current_request_context(f): +def copy_current_request_context(f: t.Callable) -> t.Callable: """A helper function that decorates a function to retain the current request context. This is useful when working with greenlets. The moment the function is decorated a copy of the request context is created and @@ -148,7 +156,7 @@ def copy_current_request_context(f): return update_wrapper(wrapper, f) -def has_request_context(): +def has_request_context() -> bool: """If you have code that wants to test if a request context is there or not this function can be used. For instance, you may want to take advantage of request information if the request object is available, but fail @@ -180,7 +188,7 @@ def has_request_context(): return _request_ctx_stack.top is not None -def has_app_context(): +def has_app_context() -> bool: """Works like :func:`has_request_context` but for the application context. You can also just do a boolean check on the :data:`current_app` object instead. @@ -199,7 +207,7 @@ class AppContext: context. """ - def __init__(self, app): + def __init__(self, app: "Flask") -> None: self.app = app self.url_adapter = app.create_url_adapter(None) self.g = app.app_ctx_globals_class() @@ -208,13 +216,13 @@ class AppContext: # but there a basic "refcount" is enough to track them. self._refcnt = 0 - def push(self): + def push(self) -> None: """Binds the app context to the current context.""" self._refcnt += 1 _app_ctx_stack.push(self) appcontext_pushed.send(self.app) - def pop(self, exc=_sentinel): + def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None: # type: ignore """Pops the app context.""" try: self._refcnt -= 1 @@ -227,11 +235,13 @@ class AppContext: assert rv is self, f"Popped wrong app context. ({rv!r} instead of {self!r})" appcontext_popped.send(self.app) - def __enter__(self): + def __enter__(self) -> "AppContext": self.push() return self - def __exit__(self, exc_type, exc_value, tb): + def __exit__( + self, exc_type: type, exc_value: BaseException, tb: TracebackType + ) -> None: self.pop(exc_value) @@ -265,7 +275,13 @@ class RequestContext: that situation, otherwise your unittests will leak memory. """ - def __init__(self, app, environ, request=None, session=None): + def __init__( + self, + app: "Flask", + environ: dict, + request: t.Optional["Request"] = None, + session: t.Optional["SessionMixin"] = None, + ) -> None: self.app = app if request is None: request = app.request_class(environ) @@ -282,7 +298,7 @@ class RequestContext: # other request contexts. Now only if the last level is popped we # get rid of them. Additionally if an application context is missing # one is created implicitly so for each level we add this information - self._implicit_app_ctx_stack = [] + self._implicit_app_ctx_stack: t.List[t.Optional["AppContext"]] = [] # indicator if the context was preserved. Next time another context # is pushed the preserved context is popped. @@ -295,17 +311,17 @@ class RequestContext: # Functions that should be executed after the request on the response # object. These will be called before the regular "after_request" # functions. - self._after_request_functions = [] + self._after_request_functions: t.List[AfterRequestCallable] = [] @property - def g(self): + def g(self) -> AppContext: return _app_ctx_stack.top.g @g.setter - def g(self, value): + def g(self, value: AppContext) -> None: _app_ctx_stack.top.g = value - def copy(self): + def copy(self) -> "RequestContext": """Creates a copy of this request context with the same request object. This can be used to move a request context to a different greenlet. Because the actual request object is the same this cannot be used to @@ -325,17 +341,17 @@ class RequestContext: session=self.session, ) - def match_request(self): + def match_request(self) -> None: """Can be overridden by a subclass to hook into the matching of the request. """ try: - result = self.url_adapter.match(return_rule=True) - self.request.url_rule, self.request.view_args = result + result = self.url_adapter.match(return_rule=True) # type: ignore + self.request.url_rule, self.request.view_args = result # type: ignore except HTTPException as e: self.request.routing_exception = e - def push(self): + def push(self) -> None: """Binds the request context to the current context.""" # If an exception occurs in debug mode or if context preservation is # activated under exception situations exactly one context stays @@ -375,7 +391,7 @@ class RequestContext: if self.session is None: self.session = session_interface.make_null_session(self.app) - def pop(self, exc=_sentinel): + def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None: # type: ignore """Pops the request context and unbinds it by doing that. This will also trigger the execution of functions registered by the :meth:`~flask.Flask.teardown_request` decorator. @@ -414,20 +430,22 @@ class RequestContext: rv is self ), f"Popped wrong request context. ({rv!r} instead of {self!r})" - def auto_pop(self, exc): + def auto_pop(self, exc: t.Optional[BaseException]) -> None: if self.request.environ.get("flask._preserve_context") or ( exc is not None and self.app.preserve_context_on_exception ): self.preserved = True - self._preserved_exc = exc + self._preserved_exc = exc # type: ignore else: self.pop(exc) - def __enter__(self): + def __enter__(self) -> "RequestContext": self.push() return self - def __exit__(self, exc_type, exc_value, tb): + def __exit__( + self, exc_type: type, exc_value: BaseException, tb: TracebackType + ) -> None: # do not pop the request stack if we are in debug mode and an # exception happened. This will allow the debugger to still # access the request object in the interactive shell. Furthermore @@ -435,7 +453,7 @@ class RequestContext: # See flask.testing for how this works. self.auto_pop(exc_value) - def __repr__(self): + def __repr__(self) -> str: return ( f"<{type(self).__name__} {self.request.url!r}" f" [{self.request.method}] of {self.app.name}>" diff --git a/src/flask/debughelpers.py b/src/flask/debughelpers.py index 4bd85bc5..ce65c487 100644 --- a/src/flask/debughelpers.py +++ b/src/flask/debughelpers.py @@ -1,4 +1,5 @@ import os +import typing as t from warnings import warn from .app import Flask @@ -92,7 +93,7 @@ def attach_enctype_error_multidict(request): request.files.__class__ = newcls -def _dump_loader_info(loader): +def _dump_loader_info(loader) -> t.Generator: yield f"class: {type(loader).__module__}.{type(loader).__name__}" for key, value in sorted(loader.__dict__.items()): if key.startswith("_"): @@ -109,7 +110,7 @@ def _dump_loader_info(loader): yield f"{key}: {value!r}" -def explain_template_loading_attempts(app, template, attempts): +def explain_template_loading_attempts(app: Flask, template, attempts) -> None: """This should help developers understand what failed""" info = [f"Locating template {template!r}:"] total_found = 0 @@ -157,7 +158,7 @@ def explain_template_loading_attempts(app, template, attempts): app.logger.info("\n".join(info)) -def explain_ignored_app_run(): +def explain_ignored_app_run() -> None: if os.environ.get("WERKZEUG_RUN_MAIN") != "true": warn( Warning( diff --git a/src/flask/globals.py b/src/flask/globals.py index d46ccb41..5e6e8c75 100644 --- a/src/flask/globals.py +++ b/src/flask/globals.py @@ -1,8 +1,14 @@ +import typing as t from functools import partial from werkzeug.local import LocalProxy from werkzeug.local import LocalStack +if t.TYPE_CHECKING: + from .app import Flask + from .ctx import AppContext + from .sessions import SessionMixin + from .wrappers import Request _request_ctx_err_msg = """\ Working outside of request context. @@ -45,7 +51,7 @@ def _find_app(): # context locals _request_ctx_stack = LocalStack() _app_ctx_stack = LocalStack() -current_app = LocalProxy(_find_app) -request = LocalProxy(partial(_lookup_req_object, "request")) -session = LocalProxy(partial(_lookup_req_object, "session")) -g = LocalProxy(partial(_lookup_app_object, "g")) +current_app: "Flask" = LocalProxy(_find_app) # type: ignore +request: "Request" = LocalProxy(partial(_lookup_req_object, "request")) # type: ignore +session: "SessionMixin" = LocalProxy(partial(_lookup_req_object, "session")) # type: ignore # noqa: B950 +g: "AppContext" = LocalProxy(partial(_lookup_app_object, "g")) # type: ignore diff --git a/src/flask/helpers.py b/src/flask/helpers.py index 6a6bbcf1..99594fce 100644 --- a/src/flask/helpers.py +++ b/src/flask/helpers.py @@ -1,6 +1,8 @@ import os import socket +import typing as t import warnings +from datetime import timedelta from functools import update_wrapper from functools import wraps from threading import RLock @@ -18,8 +20,11 @@ from .globals import request from .globals import session from .signals import message_flashed +if t.TYPE_CHECKING: + from .wrappers import Response -def get_env(): + +def get_env() -> str: """Get the environment the app is running in, indicated by the :envvar:`FLASK_ENV` environment variable. The default is ``'production'``. @@ -27,7 +32,7 @@ def get_env(): return os.environ.get("FLASK_ENV") or "production" -def get_debug_flag(): +def get_debug_flag() -> bool: """Get whether debug mode should be enabled for the app, indicated by the :envvar:`FLASK_DEBUG` environment variable. The default is ``True`` if :func:`.get_env` returns ``'development'``, or ``False`` @@ -41,7 +46,7 @@ def get_debug_flag(): return val.lower() not in ("0", "false", "no") -def get_load_dotenv(default=True): +def get_load_dotenv(default: bool = True) -> bool: """Get whether the user has disabled loading dotenv files by setting :envvar:`FLASK_SKIP_DOTENV`. The default is ``True``, load the files. @@ -56,7 +61,9 @@ def get_load_dotenv(default=True): return val.lower() in ("0", "false", "no") -def stream_with_context(generator_or_function): +def stream_with_context( + generator_or_function: t.Union[t.Generator, t.Callable] +) -> t.Generator: """Request contexts disappear when the response is started on the server. This is done for efficiency reasons and to make it less likely to encounter memory leaks with badly written WSGI middlewares. The downside is that if @@ -91,16 +98,16 @@ def stream_with_context(generator_or_function): .. versionadded:: 0.9 """ try: - gen = iter(generator_or_function) + gen = iter(generator_or_function) # type: ignore except TypeError: - def decorator(*args, **kwargs): - gen = generator_or_function(*args, **kwargs) + def decorator(*args: t.Any, **kwargs: t.Any) -> t.Any: + gen = generator_or_function(*args, **kwargs) # type: ignore return stream_with_context(gen) - return update_wrapper(decorator, generator_or_function) + return update_wrapper(decorator, generator_or_function) # type: ignore - def generator(): + def generator() -> t.Generator: ctx = _request_ctx_stack.top if ctx is None: raise RuntimeError( @@ -120,7 +127,7 @@ def stream_with_context(generator_or_function): yield from gen finally: if hasattr(gen, "close"): - gen.close() + gen.close() # type: ignore # The trick is to start the generator. Then the code execution runs until # the first dummy None is yielded at which point the context was already @@ -131,7 +138,7 @@ def stream_with_context(generator_or_function): return wrapped_g -def make_response(*args): +def make_response(*args: t.Any) -> "Response": """Sometimes it is necessary to set additional headers in a view. Because views do not have to return response objects but can return a value that is converted into a response object by Flask itself, it becomes tricky to @@ -180,7 +187,7 @@ def make_response(*args): return current_app.make_response(args) -def url_for(endpoint, **values): +def url_for(endpoint: str, **values: t.Any) -> str: """Generates a URL to the given endpoint with the method provided. Variable arguments that are unknown to the target endpoint are appended @@ -331,7 +338,7 @@ def url_for(endpoint, **values): return rv -def get_template_attribute(template_name, attribute): +def get_template_attribute(template_name: str, attribute: str) -> t.Any: """Loads a macro (or variable) a template exports. This can be used to invoke a macro from within Python code. If you for example have a template named :file:`_cider.html` with the following contents: @@ -353,7 +360,7 @@ def get_template_attribute(template_name, attribute): return getattr(current_app.jinja_env.get_template(template_name).module, attribute) -def flash(message, category="message"): +def flash(message: str, category: str = "message") -> None: """Flashes a message to the next request. In order to remove the flashed message from the session and to display it to the user, the template has to call :func:`get_flashed_messages`. @@ -379,11 +386,15 @@ def flash(message, category="message"): flashes.append((category, message)) session["_flashes"] = flashes message_flashed.send( - current_app._get_current_object(), message=message, category=category + current_app._get_current_object(), # type: ignore + message=message, + category=category, ) -def get_flashed_messages(with_categories=False, category_filter=()): +def get_flashed_messages( + with_categories: bool = False, category_filter: t.Iterable[str] = () +) -> t.Union[t.List[str], t.List[t.Tuple[str, str]]]: """Pulls all flashed messages from the session and returns them. Further calls in the same request to the function will return the same messages. By default just the messages are returned, @@ -608,7 +619,7 @@ def send_file( ) -def safe_join(directory, *pathnames): +def safe_join(directory: str, *pathnames: str) -> str: """Safely join zero or more untrusted path components to a base directory to avoid escaping the base directory. @@ -631,7 +642,7 @@ def safe_join(directory, *pathnames): return path -def send_from_directory(directory, path, **kwargs): +def send_from_directory(directory: str, path: str, **kwargs: t.Any) -> "Response": """Send a file from within a directory using :func:`send_file`. .. code-block:: python @@ -661,7 +672,7 @@ def send_from_directory(directory, path, **kwargs): .. versionadded:: 0.5 """ - return werkzeug.utils.send_from_directory( + return werkzeug.utils.send_from_directory( # type: ignore directory, path, **_prepare_send_file_kwargs(**kwargs) ) @@ -675,27 +686,32 @@ class locked_cached_property(werkzeug.utils.cached_property): Inherits from Werkzeug's ``cached_property`` (and ``property``). """ - def __init__(self, fget, name=None, doc=None): + def __init__( + self, + fget: t.Callable[[t.Any], t.Any], + name: t.Optional[str] = None, + doc: t.Optional[str] = None, + ) -> None: super().__init__(fget, name=name, doc=doc) self.lock = RLock() - def __get__(self, obj, type=None): + def __get__(self, obj: object, type: type = None) -> t.Any: # type: ignore if obj is None: return self with self.lock: return super().__get__(obj, type=type) - def __set__(self, obj, value): + def __set__(self, obj: object, value: t.Any) -> None: with self.lock: super().__set__(obj, value) - def __delete__(self, obj): + def __delete__(self, obj: object) -> None: with self.lock: super().__delete__(obj) -def total_seconds(td): +def total_seconds(td: timedelta) -> int: """Returns the total seconds from a timedelta object. :param timedelta td: the timedelta to be converted in seconds @@ -716,7 +732,7 @@ def total_seconds(td): return td.days * 60 * 60 * 24 + td.seconds -def is_ip(value): +def is_ip(value: str) -> bool: """Determine if the given string is an IP address. :param value: value to check @@ -736,7 +752,7 @@ def is_ip(value): return False -def run_async(func): +def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]: """Return a sync function that will run the coroutine function *func*.""" try: from asgiref.sync import async_to_sync @@ -752,7 +768,7 @@ def run_async(func): ) @wraps(func) - def outer(*args, **kwargs): + def outer(*args: t.Any, **kwargs: t.Any) -> t.Any: """This function grabs the current context for the inner function. This is similar to the copy_current_xxx_context functions in the @@ -764,7 +780,7 @@ def run_async(func): ctx = _request_ctx_stack.top.copy() @wraps(func) - async def inner(*a, **k): + async def inner(*a: t.Any, **k: t.Any) -> t.Any: """This restores the context before awaiting the func. This is required as the function must be awaited within the @@ -780,5 +796,5 @@ def run_async(func): return async_to_sync(inner)(*args, **kwargs) - outer._flask_sync_wrapper = True + outer._flask_sync_wrapper = True # type: ignore return outer diff --git a/src/flask/json/__init__.py b/src/flask/json/__init__.py index 7ca0db90..5a6e4942 100644 --- a/src/flask/json/__init__.py +++ b/src/flask/json/__init__.py @@ -1,20 +1,25 @@ import io import json as _json +import typing as t import uuid import warnings from datetime import date -from jinja2.utils import htmlsafe_json_dumps as _jinja_htmlsafe_dumps +from jinja2.utils import htmlsafe_json_dumps as _jinja_htmlsafe_dumps # type: ignore from werkzeug.http import http_date from ..globals import current_app from ..globals import request +if t.TYPE_CHECKING: + from ..app import Flask + from ..wrappers import Response + try: import dataclasses except ImportError: # Python < 3.7 - dataclasses = None + dataclasses = None # type: ignore class JSONEncoder(_json.JSONEncoder): @@ -34,7 +39,7 @@ class JSONEncoder(_json.JSONEncoder): :attr:`flask.Blueprint.json_encoder` to override the default. """ - def default(self, o): + def default(self, o: t.Any) -> t.Any: """Convert ``o`` to a JSON serializable type. See :meth:`json.JSONEncoder.default`. Python does not support overriding how basic types like ``str`` or ``list`` are @@ -48,7 +53,7 @@ class JSONEncoder(_json.JSONEncoder): return dataclasses.asdict(o) if hasattr(o, "__html__"): return str(o.__html__()) - return super().default(self, o) + return super().default(o) class JSONDecoder(_json.JSONDecoder): @@ -62,14 +67,19 @@ class JSONDecoder(_json.JSONDecoder): """ -def _dump_arg_defaults(kwargs, app=None): +def _dump_arg_defaults( + kwargs: t.Dict[str, t.Any], app: t.Optional["Flask"] = None +) -> None: """Inject default arguments for dump functions.""" if app is None: app = current_app if app: - bp = app.blueprints.get(request.blueprint) if request else None - cls = bp.json_encoder if bp and bp.json_encoder else app.json_encoder + cls = app.json_encoder + bp = app.blueprints.get(request.blueprint) if request else None # type: ignore + if bp is not None and bp.json_encoder is not None: + cls = bp.json_encoder + kwargs.setdefault("cls", cls) kwargs.setdefault("ensure_ascii", app.config["JSON_AS_ASCII"]) kwargs.setdefault("sort_keys", app.config["JSON_SORT_KEYS"]) @@ -78,20 +88,25 @@ def _dump_arg_defaults(kwargs, app=None): kwargs.setdefault("cls", JSONEncoder) -def _load_arg_defaults(kwargs, app=None): +def _load_arg_defaults( + kwargs: t.Dict[str, t.Any], app: t.Optional["Flask"] = None +) -> None: """Inject default arguments for load functions.""" if app is None: app = current_app if app: - bp = app.blueprints.get(request.blueprint) if request else None - cls = bp.json_decoder if bp and bp.json_decoder else app.json_decoder + cls = app.json_decoder + bp = app.blueprints.get(request.blueprint) if request else None # type: ignore + if bp is not None and bp.json_decoder is not None: + cls = bp.json_decoder + kwargs.setdefault("cls", cls) else: kwargs.setdefault("cls", JSONDecoder) -def dumps(obj, app=None, **kwargs): +def dumps(obj: t.Any, app: t.Optional["Flask"] = None, **kwargs: t.Any) -> str: """Serialize an object to a string of JSON. Takes the same arguments as the built-in :func:`json.dumps`, with @@ -121,12 +136,14 @@ def dumps(obj, app=None, **kwargs): ) if isinstance(rv, str): - return rv.encode(encoding) + return rv.encode(encoding) # type: ignore return rv -def dump(obj, fp, app=None, **kwargs): +def dump( + obj: t.Any, fp: t.IO[str], app: t.Optional["Flask"] = None, **kwargs: t.Any +) -> None: """Serialize an object to JSON written to a file object. Takes the same arguments as the built-in :func:`json.dump`, with @@ -150,7 +167,7 @@ def dump(obj, fp, app=None, **kwargs): fp.write("") except TypeError: show_warning = True - fp = io.TextIOWrapper(fp, encoding or "utf-8") + fp = io.TextIOWrapper(fp, encoding or "utf-8") # type: ignore if show_warning: warnings.warn( @@ -163,7 +180,7 @@ def dump(obj, fp, app=None, **kwargs): _json.dump(obj, fp, **kwargs) -def loads(s, app=None, **kwargs): +def loads(s: str, app: t.Optional["Flask"] = None, **kwargs: t.Any) -> t.Any: """Deserialize an object from a string of JSON. Takes the same arguments as the built-in :func:`json.loads`, with @@ -199,7 +216,7 @@ def loads(s, app=None, **kwargs): return _json.loads(s, **kwargs) -def load(fp, app=None, **kwargs): +def load(fp: t.IO[str], app: t.Optional["Flask"] = None, **kwargs: t.Any) -> t.Any: """Deserialize an object from JSON read from a file object. Takes the same arguments as the built-in :func:`json.load`, with @@ -227,12 +244,12 @@ def load(fp, app=None, **kwargs): ) if isinstance(fp.read(0), bytes): - fp = io.TextIOWrapper(fp, encoding) + fp = io.TextIOWrapper(fp, encoding) # type: ignore return _json.load(fp, **kwargs) -def htmlsafe_dumps(obj, **kwargs): +def htmlsafe_dumps(obj: t.Any, **kwargs: t.Any) -> str: """Serialize an object to a string of JSON with :func:`dumps`, then replace HTML-unsafe characters with Unicode escapes and mark the result safe with :class:`~markupsafe.Markup`. @@ -256,7 +273,7 @@ def htmlsafe_dumps(obj, **kwargs): return _jinja_htmlsafe_dumps(obj, dumps=dumps, **kwargs) -def htmlsafe_dump(obj, fp, **kwargs): +def htmlsafe_dump(obj: t.Any, fp: t.IO[str], **kwargs: t.Any) -> None: """Serialize an object to JSON written to a file object, replacing HTML-unsafe characters with Unicode escapes. See :func:`htmlsafe_dumps` and :func:`dumps`. @@ -264,7 +281,7 @@ def htmlsafe_dump(obj, fp, **kwargs): fp.write(htmlsafe_dumps(obj, **kwargs)) -def jsonify(*args, **kwargs): +def jsonify(*args: t.Any, **kwargs: t.Any) -> "Response": """Serialize data to JSON and wrap it in a :class:`~flask.Response` with the :mimetype:`application/json` mimetype. diff --git a/src/flask/json/tag.py b/src/flask/json/tag.py index d3c29adb..97f365a9 100644 --- a/src/flask/json/tag.py +++ b/src/flask/json/tag.py @@ -40,6 +40,7 @@ be processed before ``dict``. app.session_interface.serializer.register(TagOrderedDict, index=0) """ +import typing as t from base64 import b64decode from base64 import b64encode from datetime import datetime @@ -60,27 +61,27 @@ class JSONTag: #: The tag to mark the serialized object with. If ``None``, this tag is #: only used as an intermediate step during tagging. - key = None + key: t.Optional[str] = None - def __init__(self, serializer): + def __init__(self, serializer: "TaggedJSONSerializer") -> None: """Create a tagger for the given serializer.""" self.serializer = serializer - def check(self, value): + def check(self, value: t.Any) -> bool: """Check if the given value should be tagged by this tag.""" raise NotImplementedError - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: """Convert the Python object to an object that is a valid JSON type. The tag will be added later.""" raise NotImplementedError - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: """Convert the JSON representation back to the correct type. The tag will already be removed.""" raise NotImplementedError - def tag(self, value): + def tag(self, value: t.Any) -> t.Any: """Convert the value to a valid JSON type and add the tag structure around it.""" return {self.key: self.to_json(value)} @@ -96,18 +97,18 @@ class TagDict(JSONTag): __slots__ = () key = " di" - def check(self, value): + def check(self, value: t.Any) -> bool: return ( isinstance(value, dict) and len(value) == 1 and next(iter(value)) in self.serializer.tags ) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: key = next(iter(value)) return {f"{key}__": self.serializer.tag(value[key])} - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: key = next(iter(value)) return {key[:-2]: value[key]} @@ -115,10 +116,10 @@ class TagDict(JSONTag): class PassDict(JSONTag): __slots__ = () - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, dict) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: # JSON objects may only have string keys, so don't bother tagging the # key here. return {k: self.serializer.tag(v) for k, v in value.items()} @@ -130,23 +131,23 @@ class TagTuple(JSONTag): __slots__ = () key = " t" - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, tuple) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return [self.serializer.tag(item) for item in value] - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: return tuple(value) class PassList(JSONTag): __slots__ = () - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, list) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return [self.serializer.tag(item) for item in value] tag = to_json @@ -156,13 +157,13 @@ class TagBytes(JSONTag): __slots__ = () key = " b" - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, bytes) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return b64encode(value).decode("ascii") - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: return b64decode(value) @@ -174,13 +175,13 @@ class TagMarkup(JSONTag): __slots__ = () key = " m" - def check(self, value): + def check(self, value: t.Any) -> bool: return callable(getattr(value, "__html__", None)) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return str(value.__html__()) - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: return Markup(value) @@ -188,13 +189,13 @@ class TagUUID(JSONTag): __slots__ = () key = " u" - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, UUID) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return value.hex - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: return UUID(value) @@ -202,13 +203,13 @@ class TagDateTime(JSONTag): __slots__ = () key = " d" - def check(self, value): + def check(self, value: t.Any) -> bool: return isinstance(value, datetime) - def to_json(self, value): + def to_json(self, value: t.Any) -> t.Any: return http_date(value) - def to_python(self, value): + def to_python(self, value: t.Any) -> t.Any: return parse_date(value) @@ -242,14 +243,19 @@ class TaggedJSONSerializer: TagDateTime, ] - def __init__(self): - self.tags = {} - self.order = [] + def __init__(self) -> None: + self.tags: t.Dict[str, JSONTag] = {} + self.order: t.List[JSONTag] = [] for cls in self.default_tags: self.register(cls) - def register(self, tag_class, force=False, index=None): + def register( + self, + tag_class: t.Type[JSONTag], + force: bool = False, + index: t.Optional[int] = None, + ) -> None: """Register a new tag with this serializer. :param tag_class: tag class to register. Will be instantiated with this @@ -277,7 +283,7 @@ class TaggedJSONSerializer: else: self.order.insert(index, tag) - def tag(self, value): + def tag(self, value: t.Any) -> t.Dict[str, t.Any]: """Convert a value to a tagged representation if necessary.""" for tag in self.order: if tag.check(value): @@ -285,7 +291,7 @@ class TaggedJSONSerializer: return value - def untag(self, value): + def untag(self, value: t.Dict[str, t.Any]) -> t.Any: """Convert a tagged representation back to the original type.""" if len(value) != 1: return value @@ -297,10 +303,10 @@ class TaggedJSONSerializer: return self.tags[key].to_python(value[key]) - def dumps(self, value): + def dumps(self, value: t.Any) -> str: """Tag the value and dump it to a compact JSON string.""" return dumps(self.tag(value), separators=(",", ":")) - def loads(self, value): + def loads(self, value: str) -> t.Any: """Load data from a JSON string and deserialized any tagged objects.""" return loads(value, object_hook=self.untag) diff --git a/src/flask/logging.py b/src/flask/logging.py index fe6809b2..48a5b7ff 100644 --- a/src/flask/logging.py +++ b/src/flask/logging.py @@ -1,13 +1,17 @@ import logging import sys +import typing as t from werkzeug.local import LocalProxy from .globals import request +if t.TYPE_CHECKING: + from .app import Flask + @LocalProxy -def wsgi_errors_stream(): +def wsgi_errors_stream() -> t.TextIO: """Find the most appropriate error stream for the application. If a request is active, log to ``wsgi.errors``, otherwise use ``sys.stderr``. @@ -19,7 +23,7 @@ def wsgi_errors_stream(): return request.environ["wsgi.errors"] if request else sys.stderr -def has_level_handler(logger): +def has_level_handler(logger: logging.Logger) -> bool: """Check if there is a handler in the logging chain that will handle the given logger's :meth:`effective level <~logging.Logger.getEffectiveLevel>`. """ @@ -33,20 +37,20 @@ def has_level_handler(logger): if not current.propagate: break - current = current.parent + current = current.parent # type: ignore return False #: Log messages to :func:`~flask.logging.wsgi_errors_stream` with the format #: ``[%(asctime)s] %(levelname)s in %(module)s: %(message)s``. -default_handler = logging.StreamHandler(wsgi_errors_stream) +default_handler = logging.StreamHandler(wsgi_errors_stream) # type: ignore default_handler.setFormatter( logging.Formatter("[%(asctime)s] %(levelname)s in %(module)s: %(message)s") ) -def create_logger(app): +def create_logger(app: "Flask") -> logging.Logger: """Get the Flask app's logger and configure it if needed. The logger name will be the same as diff --git a/src/flask/scaffold.py b/src/flask/scaffold.py index 44745b7d..445ac500 100644 --- a/src/flask/scaffold.py +++ b/src/flask/scaffold.py @@ -2,8 +2,11 @@ import importlib.util import os import pkgutil import sys +import typing as t from collections import defaultdict from functools import update_wrapper +from json import JSONDecoder +from json import JSONEncoder from jinja2 import FileSystemLoader from werkzeug.exceptions import default_exceptions @@ -14,17 +17,28 @@ from .globals import current_app from .helpers import locked_cached_property from .helpers import send_from_directory from .templating import _default_template_ctx_processor +from .typing import AfterRequestCallable +from .typing import AppOrBlueprintKey +from .typing import BeforeRequestCallable +from .typing import ErrorHandlerCallable +from .typing import TeardownCallable +from .typing import TemplateContextProcessorCallable +from .typing import URLDefaultCallable +from .typing import URLValuePreprocessorCallable + +if t.TYPE_CHECKING: + from .wrappers import Response # a singleton sentinel value for parameter defaults _sentinel = object() -def setupmethod(f): +def setupmethod(f: t.Callable) -> t.Callable: """Wraps a method so that it performs a check in debug mode if the first request was already handled. """ - def wrapper_func(self, *args, **kwargs): + def wrapper_func(self, *args: t.Any, **kwargs: t.Any) -> t.Any: if self._is_setup_finished(): raise AssertionError( "A setup function was called after the first request " @@ -60,24 +74,24 @@ class Scaffold: """ name: str - _static_folder = None - _static_url_path = None + _static_folder: t.Optional[str] = None + _static_url_path: t.Optional[str] = None #: JSON encoder class used by :func:`flask.json.dumps`. If a #: blueprint sets this, it will be used instead of the app's value. - json_encoder = None + json_encoder: t.Optional[t.Type[JSONEncoder]] = None #: JSON decoder class used by :func:`flask.json.loads`. If a #: blueprint sets this, it will be used instead of the app's value. - json_decoder = None + json_decoder: t.Optional[t.Type[JSONDecoder]] = None def __init__( self, - import_name, - static_folder=None, - static_url_path=None, - template_folder=None, - root_path=None, + import_name: str, + static_folder: t.Optional[str] = None, + static_url_path: t.Optional[str] = None, + template_folder: t.Optional[str] = None, + root_path: t.Optional[str] = None, ): #: The name of the package or module that this object belongs #: to. Do not change this once it is set by the constructor. @@ -110,7 +124,7 @@ class Scaffold: #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.view_functions = {} + self.view_functions: t.Dict[str, t.Callable] = {} #: A data structure of registered error handlers, in the format #: ``{scope: {code: {class: handler}}}```. The ``scope`` key is @@ -125,7 +139,10 @@ class Scaffold: #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.error_handler_spec = defaultdict(lambda: defaultdict(dict)) + self.error_handler_spec: t.Dict[ + AppOrBlueprintKey, + t.Dict[t.Optional[int], t.Dict[t.Type[Exception], ErrorHandlerCallable]], + ] = defaultdict(lambda: defaultdict(dict)) #: A data structure of functions to call at the beginning of #: each request, in the format ``{scope: [functions]}``. The @@ -137,7 +154,9 @@ class Scaffold: #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.before_request_funcs = defaultdict(list) + self.before_request_funcs: t.Dict[ + AppOrBlueprintKey, t.List[BeforeRequestCallable] + ] = defaultdict(list) #: A data structure of functions to call at the end of each #: request, in the format ``{scope: [functions]}``. The @@ -149,7 +168,9 @@ class Scaffold: #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.after_request_funcs = defaultdict(list) + self.after_request_funcs: t.Dict[ + AppOrBlueprintKey, t.List[AfterRequestCallable] + ] = defaultdict(list) #: A data structure of functions to call at the end of each #: request even if an exception is raised, in the format @@ -162,7 +183,9 @@ class Scaffold: #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.teardown_request_funcs = defaultdict(list) + self.teardown_request_funcs: t.Dict[ + AppOrBlueprintKey, t.List[TeardownCallable] + ] = defaultdict(list) #: A data structure of functions to call to pass extra context #: values when rendering templates, in the format @@ -175,9 +198,9 @@ class Scaffold: #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.template_context_processors = defaultdict( - list, {None: [_default_template_ctx_processor]} - ) + self.template_context_processors: t.Dict[ + AppOrBlueprintKey, t.List[TemplateContextProcessorCallable] + ] = defaultdict(list, {None: [_default_template_ctx_processor]}) #: A data structure of functions to call to modify the keyword #: arguments passed to the view function, in the format @@ -190,7 +213,10 @@ class Scaffold: #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.url_value_preprocessors = defaultdict(list) + self.url_value_preprocessors: t.Dict[ + AppOrBlueprintKey, + t.List[URLValuePreprocessorCallable], + ] = defaultdict(list) #: A data structure of functions to call to modify the keyword #: arguments when generating URLs, in the format @@ -203,31 +229,35 @@ class Scaffold: #: #: This data structure is internal. It should not be modified #: directly and its format may change at any time. - self.url_default_functions = defaultdict(list) + self.url_default_functions: t.Dict[ + AppOrBlueprintKey, t.List[URLDefaultCallable] + ] = defaultdict(list) - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} {self.name!r}>" - def _is_setup_finished(self): + def _is_setup_finished(self) -> bool: raise NotImplementedError @property - def static_folder(self): + def static_folder(self) -> t.Optional[str]: """The absolute path to the configured static folder. ``None`` if no static folder is set. """ if self._static_folder is not None: return os.path.join(self.root_path, self._static_folder) + else: + return None @static_folder.setter - def static_folder(self, value): + def static_folder(self, value: t.Optional[str]) -> None: if value is not None: value = os.fspath(value).rstrip(r"\/") self._static_folder = value @property - def has_static_folder(self): + def has_static_folder(self) -> bool: """``True`` if :attr:`static_folder` is set. .. versionadded:: 0.5 @@ -235,7 +265,7 @@ class Scaffold: return self.static_folder is not None @property - def static_url_path(self): + def static_url_path(self) -> t.Optional[str]: """The URL prefix that the static route will be accessible from. If it was not configured during init, it is derived from @@ -248,14 +278,16 @@ class Scaffold: basename = os.path.basename(self.static_folder) return f"/{basename}".rstrip("/") + return None + @static_url_path.setter - def static_url_path(self, value): + def static_url_path(self, value: t.Optional[str]) -> None: if value is not None: value = value.rstrip("/") self._static_url_path = value - def get_send_file_max_age(self, filename): + def get_send_file_max_age(self, filename: str) -> t.Optional[int]: """Used by :func:`send_file` to determine the ``max_age`` cache value for a given file path if it wasn't passed. @@ -276,7 +308,7 @@ class Scaffold: return int(value.total_seconds()) - def send_static_file(self, filename): + def send_static_file(self, filename: str) -> "Response": """The view function used to serve files from :attr:`static_folder`. A route is automatically registered for this view at :attr:`static_url_path` if :attr:`static_folder` is @@ -290,10 +322,12 @@ class Scaffold: # send_file only knows to call get_send_file_max_age on the app, # call it here so it works for blueprints too. max_age = self.get_send_file_max_age(filename) - return send_from_directory(self.static_folder, filename, max_age=max_age) + return send_from_directory( + t.cast(str, self.static_folder), filename, max_age=max_age + ) @locked_cached_property - def jinja_loader(self): + def jinja_loader(self) -> t.Optional[FileSystemLoader]: """The Jinja loader for this object's templates. By default this is a class :class:`jinja2.loaders.FileSystemLoader` to :attr:`template_folder` if it is set. @@ -302,8 +336,10 @@ class Scaffold: """ if self.template_folder is not None: return FileSystemLoader(os.path.join(self.root_path, self.template_folder)) + else: + return None - def open_resource(self, resource, mode="rb"): + def open_resource(self, resource: str, mode: str = "rb") -> t.IO[t.AnyStr]: """Open a resource file relative to :attr:`root_path` for reading. @@ -326,48 +362,48 @@ class Scaffold: return open(os.path.join(self.root_path, resource), mode) - def _method_route(self, method, rule, options): + def _method_route(self, method: str, rule: str, options: dict) -> t.Callable: if "methods" in options: raise TypeError("Use the 'route' decorator to use the 'methods' argument.") return self.route(rule, methods=[method], **options) - def get(self, rule, **options): + def get(self, rule: str, **options: t.Any) -> t.Callable: """Shortcut for :meth:`route` with ``methods=["GET"]``. .. versionadded:: 2.0 """ return self._method_route("GET", rule, options) - def post(self, rule, **options): + def post(self, rule: str, **options: t.Any) -> t.Callable: """Shortcut for :meth:`route` with ``methods=["POST"]``. .. versionadded:: 2.0 """ return self._method_route("POST", rule, options) - def put(self, rule, **options): + def put(self, rule: str, **options: t.Any) -> t.Callable: """Shortcut for :meth:`route` with ``methods=["PUT"]``. .. versionadded:: 2.0 """ return self._method_route("PUT", rule, options) - def delete(self, rule, **options): + def delete(self, rule: str, **options: t.Any) -> t.Callable: """Shortcut for :meth:`route` with ``methods=["DELETE"]``. .. versionadded:: 2.0 """ return self._method_route("DELETE", rule, options) - def patch(self, rule, **options): + def patch(self, rule: str, **options: t.Any) -> t.Callable: """Shortcut for :meth:`route` with ``methods=["PATCH"]``. .. versionadded:: 2.0 """ return self._method_route("PATCH", rule, options) - def route(self, rule, **options): + def route(self, rule: str, **options: t.Any) -> t.Callable: """Decorate a view function to register it with the given URL rule and options. Calls :meth:`add_url_rule`, which has more details about the implementation. @@ -391,7 +427,7 @@ class Scaffold: :class:`~werkzeug.routing.Rule` object. """ - def decorator(f): + def decorator(f: t.Callable) -> t.Callable: endpoint = options.pop("endpoint", None) self.add_url_rule(rule, endpoint, f, **options) return f @@ -401,12 +437,12 @@ class Scaffold: @setupmethod def add_url_rule( self, - rule, - endpoint=None, - view_func=None, - provide_automatic_options=None, - **options, - ): + rule: str, + endpoint: t.Optional[str] = None, + view_func: t.Optional[t.Callable] = None, + provide_automatic_options: t.Optional[bool] = None, + **options: t.Any, + ) -> t.Callable: """Register a rule for routing incoming requests and building URLs. The :meth:`route` decorator is a shortcut to call this with the ``view_func`` argument. These are equivalent: @@ -466,7 +502,7 @@ class Scaffold: """ raise NotImplementedError - def endpoint(self, endpoint): + def endpoint(self, endpoint: str) -> t.Callable: """Decorate a view function to register it for the given endpoint. Used if a rule is added without a ``view_func`` with :meth:`add_url_rule`. @@ -490,7 +526,7 @@ class Scaffold: return decorator @setupmethod - def before_request(self, f): + def before_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable: """Register a function to run before each request. For example, this can be used to open a database connection, or @@ -512,7 +548,7 @@ class Scaffold: return f @setupmethod - def after_request(self, f): + def after_request(self, f: AfterRequestCallable) -> AfterRequestCallable: """Register a function to run after each request to this object. The function is called with the response object, and must return @@ -528,7 +564,7 @@ class Scaffold: return f @setupmethod - def teardown_request(self, f): + def teardown_request(self, f: TeardownCallable) -> TeardownCallable: """Register a function to be run at the end of each request, regardless of whether there was an exception or not. These functions are executed when the request context is popped, even if not an @@ -567,13 +603,17 @@ class Scaffold: return f @setupmethod - def context_processor(self, f): + def context_processor( + self, f: TemplateContextProcessorCallable + ) -> TemplateContextProcessorCallable: """Registers a template context processor function.""" self.template_context_processors[None].append(f) return f @setupmethod - def url_value_preprocessor(self, f): + def url_value_preprocessor( + self, f: URLValuePreprocessorCallable + ) -> URLValuePreprocessorCallable: """Register a URL value preprocessor function for all view functions in the application. These functions will be called before the :meth:`before_request` functions. @@ -590,7 +630,7 @@ class Scaffold: return f @setupmethod - def url_defaults(self, f): + def url_defaults(self, f: URLDefaultCallable) -> URLDefaultCallable: """Callback function for URL defaults for all view functions of the application. It's called with the endpoint and values and should update the values passed in place. @@ -599,7 +639,9 @@ class Scaffold: return f @setupmethod - def errorhandler(self, code_or_exception): + def errorhandler( + self, code_or_exception: t.Union[t.Type[Exception], int] + ) -> t.Callable: """Register a function to handle errors by code or exception class. A decorator that is used to register a function given an @@ -629,14 +671,18 @@ class Scaffold: an arbitrary exception """ - def decorator(f): + def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable: self.register_error_handler(code_or_exception, f) return f return decorator @setupmethod - def register_error_handler(self, code_or_exception, f): + def register_error_handler( + self, + code_or_exception: t.Union[t.Type[Exception], int], + f: ErrorHandlerCallable, + ) -> None: """Alternative error attach function to the :meth:`errorhandler` decorator that is more straightforward to use for non decorator usage. @@ -662,7 +708,9 @@ class Scaffold: self.error_handler_spec[None][code][exc_class] = self.ensure_sync(f) @staticmethod - def _get_exc_class_and_code(exc_class_or_code): + def _get_exc_class_and_code( + exc_class_or_code: t.Union[t.Type[Exception], int] + ) -> t.Tuple[t.Type[Exception], t.Optional[int]]: """Get the exception class being handled. For HTTP status codes or ``HTTPException`` subclasses, return both the exception and status code. @@ -670,6 +718,7 @@ class Scaffold: :param exc_class_or_code: Any exception class, or an HTTP status code as an integer. """ + exc_class: t.Type[Exception] if isinstance(exc_class_or_code, int): exc_class = default_exceptions[exc_class_or_code] else: @@ -684,11 +733,11 @@ class Scaffold: else: return exc_class, None - def ensure_sync(self, func): + def ensure_sync(self, func: t.Callable) -> t.Callable: raise NotImplementedError() -def _endpoint_from_view_func(view_func): +def _endpoint_from_view_func(view_func: t.Callable) -> str: """Internal helper that returns the default endpoint for a given function. This always is the function name. """ @@ -696,7 +745,7 @@ def _endpoint_from_view_func(view_func): return view_func.__name__ -def get_root_path(import_name): +def get_root_path(import_name: str) -> str: """Find the root path of a package, or the path that contains a module. If it cannot be found, returns the current working directory. @@ -721,7 +770,7 @@ def get_root_path(import_name): return os.getcwd() if hasattr(loader, "get_filename"): - filepath = loader.get_filename(import_name) + filepath = loader.get_filename(import_name) # type: ignore else: # Fall back to imports. __import__(import_name) @@ -822,7 +871,7 @@ def _find_package_path(root_mod_name): return package_path -def find_package(import_name): +def find_package(import_name: str): """Find the prefix that a package is installed under, and the path that it would be imported from. diff --git a/src/flask/sessions.py b/src/flask/sessions.py index 795a922c..0e68e884 100644 --- a/src/flask/sessions.py +++ b/src/flask/sessions.py @@ -1,4 +1,5 @@ import hashlib +import typing as t import warnings from collections.abc import MutableMapping from datetime import datetime @@ -10,17 +11,21 @@ from werkzeug.datastructures import CallbackDict from .helpers import is_ip from .json.tag import TaggedJSONSerializer +if t.TYPE_CHECKING: + from .app import Flask + from .wrappers import Request, Response + class SessionMixin(MutableMapping): """Expands a basic dictionary with session attributes.""" @property - def permanent(self): + def permanent(self) -> bool: """This reflects the ``'_permanent'`` key in the dict.""" return self.get("_permanent", False) @permanent.setter - def permanent(self, value): + def permanent(self, value: bool) -> None: self["_permanent"] = bool(value) #: Some implementations can detect whether a session is newly @@ -61,22 +66,22 @@ class SecureCookieSession(CallbackDict, SessionMixin): #: different users. accessed = False - def __init__(self, initial=None): - def on_update(self): + def __init__(self, initial: t.Any = None) -> None: + def on_update(self) -> None: self.modified = True self.accessed = True super().__init__(initial, on_update) - def __getitem__(self, key): + def __getitem__(self, key: str) -> t.Any: self.accessed = True return super().__getitem__(key) - def get(self, key, default=None): + def get(self, key: str, default: t.Any = None) -> t.Any: self.accessed = True return super().get(key, default) - def setdefault(self, key, default=None): + def setdefault(self, key: str, default: t.Any = None) -> t.Any: self.accessed = True return super().setdefault(key, default) @@ -87,14 +92,14 @@ class NullSession(SecureCookieSession): but fail on setting. """ - def _fail(self, *args, **kwargs): + def _fail(self, *args: t.Any, **kwargs: t.Any) -> t.NoReturn: raise RuntimeError( "The session is unavailable because no secret " "key was set. Set the secret_key on the " "application to something unique and secret." ) - __setitem__ = __delitem__ = clear = pop = popitem = update = setdefault = _fail + __setitem__ = __delitem__ = clear = pop = popitem = update = setdefault = _fail # type: ignore # noqa: B950 del _fail @@ -141,7 +146,7 @@ class SessionInterface: #: .. versionadded:: 0.10 pickle_based = False - def make_null_session(self, app): + def make_null_session(self, app: "Flask") -> NullSession: """Creates a null session which acts as a replacement object if the real session support could not be loaded due to a configuration error. This mainly aids the user experience because the job of the @@ -153,7 +158,7 @@ class SessionInterface: """ return self.null_session_class() - def is_null_session(self, obj): + def is_null_session(self, obj: object) -> bool: """Checks if a given object is a null session. Null sessions are not asked to be saved. @@ -162,14 +167,14 @@ class SessionInterface: """ return isinstance(obj, self.null_session_class) - def get_cookie_name(self, app): + def get_cookie_name(self, app: "Flask") -> str: """Returns the name of the session cookie. Uses ``app.session_cookie_name`` which is set to ``SESSION_COOKIE_NAME`` """ return app.session_cookie_name - def get_cookie_domain(self, app): + def get_cookie_domain(self, app: "Flask") -> t.Optional[str]: """Returns the domain that should be set for the session cookie. Uses ``SESSION_COOKIE_DOMAIN`` if it is configured, otherwise @@ -227,7 +232,7 @@ class SessionInterface: app.config["SESSION_COOKIE_DOMAIN"] = rv return rv - def get_cookie_path(self, app): + def get_cookie_path(self, app: "Flask") -> str: """Returns the path for which the cookie should be valid. The default implementation uses the value from the ``SESSION_COOKIE_PATH`` config var if it's set, and falls back to ``APPLICATION_ROOT`` or @@ -235,27 +240,29 @@ class SessionInterface: """ return app.config["SESSION_COOKIE_PATH"] or app.config["APPLICATION_ROOT"] - def get_cookie_httponly(self, app): + def get_cookie_httponly(self, app: "Flask") -> bool: """Returns True if the session cookie should be httponly. This currently just returns the value of the ``SESSION_COOKIE_HTTPONLY`` config var. """ return app.config["SESSION_COOKIE_HTTPONLY"] - def get_cookie_secure(self, app): + def get_cookie_secure(self, app: "Flask") -> bool: """Returns True if the cookie should be secure. This currently just returns the value of the ``SESSION_COOKIE_SECURE`` setting. """ return app.config["SESSION_COOKIE_SECURE"] - def get_cookie_samesite(self, app): + def get_cookie_samesite(self, app: "Flask") -> str: """Return ``'Strict'`` or ``'Lax'`` if the cookie should use the ``SameSite`` attribute. This currently just returns the value of the :data:`SESSION_COOKIE_SAMESITE` setting. """ return app.config["SESSION_COOKIE_SAMESITE"] - def get_expiration_time(self, app, session): + def get_expiration_time( + self, app: "Flask", session: SessionMixin + ) -> t.Optional[datetime]: """A helper method that returns an expiration date for the session or ``None`` if the session is linked to the browser session. The default implementation returns now + the permanent session @@ -263,8 +270,9 @@ class SessionInterface: """ if session.permanent: return datetime.utcnow() + app.permanent_session_lifetime + return None - def should_set_cookie(self, app, session): + def should_set_cookie(self, app: "Flask", session: SessionMixin) -> bool: """Used by session backends to determine if a ``Set-Cookie`` header should be set for this session cookie for this response. If the session has been modified, the cookie is set. If the session is permanent and @@ -280,7 +288,9 @@ class SessionInterface: session.permanent and app.config["SESSION_REFRESH_EACH_REQUEST"] ) - def open_session(self, app, request): + def open_session( + self, app: "Flask", request: "Request" + ) -> t.Optional[SessionMixin]: """This method has to be implemented and must either return ``None`` in case the loading failed because of a configuration error or an instance of a session object which implements a dictionary like @@ -288,7 +298,9 @@ class SessionInterface: """ raise NotImplementedError() - def save_session(self, app, session, response): + def save_session( + self, app: "Flask", session: SessionMixin, response: "Response" + ) -> None: """This is called for actual sessions returned by :meth:`open_session` at the end of the request. This is still called during a request context so if you absolutely need access to the request you can do @@ -319,7 +331,9 @@ class SecureCookieSessionInterface(SessionInterface): serializer = session_json_serializer session_class = SecureCookieSession - def get_signing_serializer(self, app): + def get_signing_serializer( + self, app: "Flask" + ) -> t.Optional[URLSafeTimedSerializer]: if not app.secret_key: return None signer_kwargs = dict( @@ -332,7 +346,9 @@ class SecureCookieSessionInterface(SessionInterface): signer_kwargs=signer_kwargs, ) - def open_session(self, app, request): + def open_session( + self, app: "Flask", request: "Request" + ) -> t.Optional[SecureCookieSession]: s = self.get_signing_serializer(app) if s is None: return None @@ -346,7 +362,9 @@ class SecureCookieSessionInterface(SessionInterface): except BadSignature: return self.session_class() - def save_session(self, app, session, response): + def save_session( + self, app: "Flask", session: SessionMixin, response: "Response" + ) -> None: name = self.get_cookie_name(app) domain = self.get_cookie_domain(app) path = self.get_cookie_path(app) @@ -372,10 +390,10 @@ class SecureCookieSessionInterface(SessionInterface): httponly = self.get_cookie_httponly(app) expires = self.get_expiration_time(app, session) - val = self.get_signing_serializer(app).dumps(dict(session)) + val = self.get_signing_serializer(app).dumps(dict(session)) # type: ignore response.set_cookie( name, - val, + val, # type: ignore expires=expires, httponly=httponly, domain=domain, diff --git a/src/flask/signals.py b/src/flask/signals.py index d2179c65..63667bdb 100644 --- a/src/flask/signals.py +++ b/src/flask/signals.py @@ -1,3 +1,5 @@ +import typing as t + try: from blinker import Namespace @@ -5,8 +7,8 @@ try: except ImportError: signals_available = False - class Namespace: - def signal(self, name, doc=None): + class Namespace: # type: ignore + def signal(self, name: str, doc: t.Optional[str] = None) -> "_FakeSignal": return _FakeSignal(name, doc) class _FakeSignal: @@ -16,14 +18,14 @@ except ImportError: will just ignore the arguments and do nothing instead. """ - def __init__(self, name, doc=None): + def __init__(self, name: str, doc: t.Optional[str] = None) -> None: self.name = name self.__doc__ = doc - def send(self, *args, **kwargs): + def send(self, *args: t.Any, **kwargs: t.Any) -> t.Any: pass - def _fail(self, *args, **kwargs): + def _fail(self, *args: t.Any, **kwargs: t.Any) -> t.Any: raise RuntimeError( "Signalling support is unavailable because the blinker" " library is not installed." diff --git a/src/flask/templating.py b/src/flask/templating.py index 6eebb13d..1987d9e9 100644 --- a/src/flask/templating.py +++ b/src/flask/templating.py @@ -1,5 +1,8 @@ +import typing as t + from jinja2 import BaseLoader from jinja2 import Environment as BaseEnvironment +from jinja2 import Template from jinja2 import TemplateNotFound from .globals import _app_ctx_stack @@ -7,8 +10,12 @@ from .globals import _request_ctx_stack from .signals import before_render_template from .signals import template_rendered +if t.TYPE_CHECKING: + from .app import Flask + from .scaffold import Scaffold -def _default_template_ctx_processor(): + +def _default_template_ctx_processor() -> t.Dict[str, t.Any]: """Default template context processor. Injects `request`, `session` and `g`. """ @@ -29,7 +36,7 @@ class Environment(BaseEnvironment): name of the blueprint to referenced templates if necessary. """ - def __init__(self, app, **options): + def __init__(self, app: "Flask", **options: t.Any) -> None: if "loader" not in options: options["loader"] = app.create_global_jinja_loader() BaseEnvironment.__init__(self, **options) @@ -41,15 +48,19 @@ class DispatchingJinjaLoader(BaseLoader): the blueprint folders. """ - def __init__(self, app): + def __init__(self, app: "Flask") -> None: self.app = app - def get_source(self, environment, template): + def get_source( + self, environment: Environment, template: str + ) -> t.Tuple[str, t.Optional[str], t.Callable]: if self.app.config["EXPLAIN_TEMPLATE_LOADING"]: return self._get_source_explained(environment, template) return self._get_source_fast(environment, template) - def _get_source_explained(self, environment, template): + def _get_source_explained( + self, environment: Environment, template: str + ) -> t.Tuple[str, t.Optional[str], t.Callable]: attempts = [] trv = None @@ -70,7 +81,9 @@ class DispatchingJinjaLoader(BaseLoader): return trv raise TemplateNotFound(template) - def _get_source_fast(self, environment, template): + def _get_source_fast( + self, environment: Environment, template: str + ) -> t.Tuple[str, t.Optional[str], t.Callable]: for _srcobj, loader in self._iter_loaders(template): try: return loader.get_source(environment, template) @@ -78,7 +91,9 @@ class DispatchingJinjaLoader(BaseLoader): continue raise TemplateNotFound(template) - def _iter_loaders(self, template): + def _iter_loaders( + self, template: str + ) -> t.Generator[t.Tuple["Scaffold", BaseLoader], None, None]: loader = self.app.jinja_loader if loader is not None: yield self.app, loader @@ -88,7 +103,7 @@ class DispatchingJinjaLoader(BaseLoader): if loader is not None: yield blueprint, loader - def list_templates(self): + def list_templates(self) -> t.List[str]: result = set() loader = self.app.jinja_loader if loader is not None: @@ -103,7 +118,7 @@ class DispatchingJinjaLoader(BaseLoader): return list(result) -def _render(template, context, app): +def _render(template: Template, context: dict, app: "Flask") -> str: """Renders the template and fires the signal""" before_render_template.send(app, template=template, context=context) @@ -112,7 +127,9 @@ def _render(template, context, app): return rv -def render_template(template_name_or_list, **context): +def render_template( + template_name_or_list: t.Union[str, t.List[str]], **context: t.Any +) -> str: """Renders a template from the template folder with the given context. @@ -131,7 +148,7 @@ def render_template(template_name_or_list, **context): ) -def render_template_string(source, **context): +def render_template_string(source: str, **context: t.Any) -> str: """Renders a template from the given template source string with the given context. Template variables will be autoescaped. diff --git a/src/flask/testing.py b/src/flask/testing.py index 247e6605..fe3b846a 100644 --- a/src/flask/testing.py +++ b/src/flask/testing.py @@ -1,5 +1,7 @@ +import typing as t from contextlib import contextmanager from copy import copy +from types import TracebackType import werkzeug.test from click.testing import CliRunner @@ -10,6 +12,11 @@ from werkzeug.wrappers import Request as BaseRequest from . import _request_ctx_stack from .cli import ScriptInfo from .json import dumps as json_dumps +from .sessions import SessionMixin + +if t.TYPE_CHECKING: + from .app import Flask + from .wrappers import Response class EnvironBuilder(werkzeug.test.EnvironBuilder): @@ -36,14 +43,14 @@ class EnvironBuilder(werkzeug.test.EnvironBuilder): def __init__( self, - app, - path="/", - base_url=None, - subdomain=None, - url_scheme=None, - *args, - **kwargs, - ): + app: "Flask", + path: str = "/", + base_url: t.Optional[str] = None, + subdomain: t.Optional[str] = None, + url_scheme: t.Optional[str] = None, + *args: t.Any, + **kwargs: t.Any, + ) -> None: assert not (base_url or subdomain or url_scheme) or ( base_url is not None ) != bool( @@ -74,7 +81,7 @@ class EnvironBuilder(werkzeug.test.EnvironBuilder): self.app = app super().__init__(path, base_url, *args, **kwargs) - def json_dumps(self, obj, **kwargs): + def json_dumps(self, obj: t.Any, **kwargs: t.Any) -> str: # type: ignore """Serialize ``obj`` to a JSON-formatted string. The serialization will be configured according to the config associated @@ -99,9 +106,10 @@ class FlaskClient(Client): Basic usage is outlined in the :doc:`/testing` chapter. """ + application: "Flask" preserve_context = False - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) self.environ_base = { "REMOTE_ADDR": "127.0.0.1", @@ -109,7 +117,9 @@ class FlaskClient(Client): } @contextmanager - def session_transaction(self, *args, **kwargs): + def session_transaction( + self, *args: t.Any, **kwargs: t.Any + ) -> t.Generator[SessionMixin, None, None]: """When used in combination with a ``with`` statement this opens a session transaction. This can be used to modify the session that the test client uses. Once the ``with`` block is left the session is @@ -161,9 +171,14 @@ class FlaskClient(Client): headers = resp.get_wsgi_headers(c.request.environ) self.cookie_jar.extract_wsgi(c.request.environ, headers) - def open( - self, *args, as_tuple=False, buffered=False, follow_redirects=False, **kwargs - ): + def open( # type: ignore + self, + *args: t.Any, + as_tuple: bool = False, + buffered: bool = False, + follow_redirects: bool = False, + **kwargs: t.Any, + ) -> "Response": # Same logic as super.open, but apply environ_base and preserve_context. request = None @@ -198,20 +213,22 @@ class FlaskClient(Client): finally: builder.close() - return super().open( + return super().open( # type: ignore request, as_tuple=as_tuple, buffered=buffered, follow_redirects=follow_redirects, ) - def __enter__(self): + def __enter__(self) -> "FlaskClient": if self.preserve_context: raise RuntimeError("Cannot nest client invocations") self.preserve_context = True return self - def __exit__(self, exc_type, exc_value, tb): + def __exit__( + self, exc_type: type, exc_value: BaseException, tb: TracebackType + ) -> None: self.preserve_context = False # Normally the request context is preserved until the next @@ -233,11 +250,13 @@ class FlaskCliRunner(CliRunner): :meth:`~flask.Flask.test_cli_runner`. See :ref:`testing-cli`. """ - def __init__(self, app, **kwargs): + def __init__(self, app: "Flask", **kwargs: t.Any) -> None: self.app = app super().__init__(**kwargs) - def invoke(self, cli=None, args=None, **kwargs): + def invoke( # type: ignore + self, cli: t.Any = None, args: t.Any = None, **kwargs: t.Any + ) -> t.Any: """Invokes a CLI command in an isolated environment. See :meth:`CliRunner.invoke ` for full method documentation. See :ref:`testing-cli` for examples. diff --git a/src/flask/typing.py b/src/flask/typing.py new file mode 100644 index 00000000..9a664e41 --- /dev/null +++ b/src/flask/typing.py @@ -0,0 +1,46 @@ +import typing as t + + +if t.TYPE_CHECKING: + from werkzeug.datastructures import Headers # noqa: F401 + from wsgiref.types import WSGIApplication # noqa: F401 + from .wrappers import Response # noqa: F401 + +# The possible types that are directly convertible or are a Response object. +ResponseValue = t.Union[ + "Response", + t.AnyStr, + t.Dict[str, t.Any], # any jsonify-able dict + t.Generator[t.AnyStr, None, None], +] +StatusCode = int + +# the possible types for an individual HTTP header +HeaderName = str +HeaderValue = t.Union[str, t.List[str], t.Tuple[str, ...]] + +# the possible types for HTTP headers +HeadersValue = t.Union[ + "Headers", t.Dict[HeaderName, HeaderValue], t.List[t.Tuple[HeaderName, HeaderValue]] +] + +# The possible types returned by a route function. +ResponseReturnValue = t.Union[ + ResponseValue, + t.Tuple[ResponseValue, HeadersValue], + t.Tuple[ResponseValue, StatusCode], + t.Tuple[ResponseValue, StatusCode, HeadersValue], + "WSGIApplication", +] + +AppOrBlueprintKey = t.Optional[str] # The App key is None, whereas blueprints are named +AfterRequestCallable = t.Callable[["Response"], "Response"] +BeforeRequestCallable = t.Callable[[], None] +ErrorHandlerCallable = t.Callable[[Exception], ResponseReturnValue] +TeardownCallable = t.Callable[[t.Optional[BaseException]], "Response"] +TemplateContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]] +TemplateFilterCallable = t.Callable[[t.Any], str] +TemplateGlobalCallable = t.Callable[[], t.Any] +TemplateTestCallable = t.Callable[[t.Any], bool] +URLDefaultCallable = t.Callable[[str, dict], None] +URLValuePreprocessorCallable = t.Callable[[t.Optional[str], t.Optional[dict]], None] diff --git a/src/flask/views.py b/src/flask/views.py index 323e6118..339ffa18 100644 --- a/src/flask/views.py +++ b/src/flask/views.py @@ -1,4 +1,7 @@ +import typing as t + from .globals import request +from .typing import ResponseReturnValue http_method_funcs = frozenset( @@ -39,10 +42,10 @@ class View: """ #: A list of methods this view can handle. - methods = None + methods: t.Optional[t.List[str]] = None #: Setting this disables or force-enables the automatic options handling. - provide_automatic_options = None + provide_automatic_options: t.Optional[bool] = None #: The canonical way to decorate class-based views is to decorate the #: return value of as_view(). However since this moves parts of the @@ -53,9 +56,9 @@ class View: #: view function is created the result is automatically decorated. #: #: .. versionadded:: 0.8 - decorators = () + decorators: t.List[t.Callable] = [] - def dispatch_request(self): + def dispatch_request(self) -> ResponseReturnValue: """Subclasses have to override this method to implement the actual view function code. This method is called with all the arguments from the URL rule. @@ -63,7 +66,9 @@ class View: raise NotImplementedError() @classmethod - def as_view(cls, name, *class_args, **class_kwargs): + def as_view( + cls, name: str, *class_args: t.Any, **class_kwargs: t.Any + ) -> t.Callable: """Converts the class into an actual view function that can be used with the routing system. Internally this generates a function on the fly which will instantiate the :class:`View` on each request and call @@ -73,8 +78,8 @@ class View: constructor of the class. """ - def view(*args, **kwargs): - self = view.view_class(*class_args, **class_kwargs) + def view(*args: t.Any, **kwargs: t.Any) -> ResponseReturnValue: + self = view.view_class(*class_args, **class_kwargs) # type: ignore return self.dispatch_request(*args, **kwargs) if cls.decorators: @@ -88,12 +93,12 @@ class View: # view this thing came from, secondly it's also used for instantiating # the view class so you can actually replace it with something else # for testing purposes and debugging. - view.view_class = cls + view.view_class = cls # type: ignore view.__name__ = name view.__doc__ = cls.__doc__ view.__module__ = cls.__module__ - view.methods = cls.methods - view.provide_automatic_options = cls.provide_automatic_options + view.methods = cls.methods # type: ignore + view.provide_automatic_options = cls.provide_automatic_options # type: ignore return view @@ -140,7 +145,7 @@ class MethodView(View, metaclass=MethodViewType): app.add_url_rule('/counter', view_func=CounterAPI.as_view('counter')) """ - def dispatch_request(self, *args, **kwargs): + def dispatch_request(self, *args: t.Any, **kwargs: t.Any) -> ResponseReturnValue: meth = getattr(self, request.method.lower(), None) # If the request method is HEAD and we don't have a handler for it diff --git a/src/flask/wrappers.py b/src/flask/wrappers.py index 1d8f17d7..48fcc34b 100644 --- a/src/flask/wrappers.py +++ b/src/flask/wrappers.py @@ -1,3 +1,5 @@ +import typing as t + from werkzeug.exceptions import BadRequest from werkzeug.wrappers import Request as RequestBase from werkzeug.wrappers import Response as ResponseBase @@ -5,6 +7,9 @@ from werkzeug.wrappers import Response as ResponseBase from . import json from .globals import current_app +if t.TYPE_CHECKING: + from werkzeug.routing import Rule + class Request(RequestBase): """The request object used by default in Flask. Remembers the @@ -31,26 +36,28 @@ class Request(RequestBase): #: because the request was never internally bound. #: #: .. versionadded:: 0.6 - url_rule = None + url_rule: t.Optional["Rule"] = None #: A dict of view arguments that matched the request. If an exception #: happened when matching, this will be ``None``. - view_args = None + view_args: t.Optional[t.Dict[str, t.Any]] = None #: If matching the URL failed, this is the exception that will be #: raised / was raised as part of the request handling. This is #: usually a :exc:`~werkzeug.exceptions.NotFound` exception or #: something similar. - routing_exception = None + routing_exception: t.Optional[Exception] = None @property - def max_content_length(self): + def max_content_length(self) -> t.Optional[int]: # type: ignore """Read-only view of the ``MAX_CONTENT_LENGTH`` config key.""" if current_app: return current_app.config["MAX_CONTENT_LENGTH"] + else: + return None @property - def endpoint(self): + def endpoint(self) -> t.Optional[str]: """The endpoint that matched the request. This in combination with :attr:`view_args` can be used to reconstruct the same or a modified URL. If an exception happened when matching, this will @@ -58,14 +65,18 @@ class Request(RequestBase): """ if self.url_rule is not None: return self.url_rule.endpoint + else: + return None @property - def blueprint(self): + def blueprint(self) -> t.Optional[str]: """The name of the current blueprint""" if self.url_rule and "." in self.url_rule.endpoint: return self.url_rule.endpoint.rsplit(".", 1)[0] + else: + return None - def _load_form_data(self): + def _load_form_data(self) -> None: RequestBase._load_form_data(self) # In debug mode we're replacing the files multidict with an ad-hoc @@ -80,7 +91,7 @@ class Request(RequestBase): attach_enctype_error_multidict(self) - def on_json_loading_failed(self, e): + def on_json_loading_failed(self, e: Exception) -> t.NoReturn: if current_app and current_app.debug: raise BadRequest(f"Failed to decode JSON object: {e}") @@ -110,7 +121,7 @@ class Response(ResponseBase): json_module = json @property - def max_cookie_size(self): + def max_cookie_size(self) -> int: # type: ignore """Read-only view of the :data:`MAX_COOKIE_SIZE` config key. See :attr:`~werkzeug.wrappers.Response.max_cookie_size` in