Add initial type hints

This should make it easier for users to correctly use Flask. The hints
are from Quart.
This commit is contained in:
pgjones 2021-04-24 12:22:26 +01:00
parent f405c6f19e
commit 77237093da
20 changed files with 820 additions and 461 deletions

View File

@ -74,6 +74,7 @@ Unreleased
``python`` shell if ``readline`` is installed. :issue:`3941`
- ``helpers.total_seconds()`` is deprecated. Use
``timedelta.total_seconds()`` instead. :pr:`3962`
- Add type hinting. :pr:`3973`.
Version 1.1.2

View File

@ -1,11 +1,14 @@
import functools
import inspect
import logging
import os
import sys
import typing as t
import weakref
from datetime import timedelta
from itertools import chain
from threading import Lock
from types import TracebackType
from werkzeug.datastructures import Headers
from werkzeug.datastructures import ImmutableDict
@ -15,6 +18,7 @@ from werkzeug.exceptions import HTTPException
from werkzeug.exceptions import InternalServerError
from werkzeug.routing import BuildError
from werkzeug.routing import Map
from werkzeug.routing import MapAdapter
from werkzeug.routing import RequestRedirect
from werkzeug.routing import RoutingException
from werkzeug.routing import Rule
@ -53,15 +57,30 @@ from .signals import request_started
from .signals import request_tearing_down
from .templating import DispatchingJinjaLoader
from .templating import Environment
from .typing import AfterRequestCallable
from .typing import BeforeRequestCallable
from .typing import ErrorHandlerCallable
from .typing import ResponseReturnValue
from .typing import TeardownCallable
from .typing import TemplateContextProcessorCallable
from .typing import TemplateFilterCallable
from .typing import TemplateGlobalCallable
from .typing import TemplateTestCallable
from .typing import URLDefaultCallable
from .typing import URLValuePreprocessorCallable
from .wrappers import Request
from .wrappers import Response
if t.TYPE_CHECKING:
from .blueprints import Blueprint
from .testing import FlaskClient
from .testing import FlaskCliRunner
if sys.version_info >= (3, 8):
iscoroutinefunction = inspect.iscoroutinefunction
else:
def iscoroutinefunction(func):
def iscoroutinefunction(func: t.Any) -> bool:
while inspect.ismethod(func):
func = func.__func__
@ -71,7 +90,7 @@ else:
return inspect.iscoroutinefunction(func)
def _make_timedelta(value):
def _make_timedelta(value: t.Optional[timedelta]) -> t.Optional[timedelta]:
if value is None or isinstance(value, timedelta):
return value
@ -295,7 +314,7 @@ class Flask(Scaffold):
#: This is a ``dict`` instead of an ``ImmutableDict`` to allow
#: easier configuration.
#:
jinja_options = {}
jinja_options: dict = {}
#: Default configuration parameters.
default_config = ImmutableDict(
@ -347,7 +366,7 @@ class Flask(Scaffold):
#: the test client that is used with when `test_client` is used.
#:
#: .. versionadded:: 0.7
test_client_class = None
test_client_class: t.Optional[t.Type["FlaskClient"]] = None
#: The :class:`~click.testing.CliRunner` subclass, by default
#: :class:`~flask.testing.FlaskCliRunner` that is used by
@ -355,7 +374,7 @@ class Flask(Scaffold):
#: Flask app object as the first argument.
#:
#: .. versionadded:: 1.0
test_cli_runner_class = None
test_cli_runner_class: t.Optional[t.Type["FlaskCliRunner"]] = None
#: the session interface to use. By default an instance of
#: :class:`~flask.sessions.SecureCookieSessionInterface` is used here.
@ -365,16 +384,16 @@ class Flask(Scaffold):
def __init__(
self,
import_name,
static_url_path=None,
static_folder="static",
static_host=None,
host_matching=False,
subdomain_matching=False,
template_folder="templates",
instance_path=None,
instance_relative_config=False,
root_path=None,
import_name: str,
static_url_path: t.Optional[str] = None,
static_folder: t.Optional[str] = "static",
static_host: t.Optional[str] = None,
host_matching: bool = False,
subdomain_matching: bool = False,
template_folder: t.Optional[str] = "templates",
instance_path: t.Optional[str] = None,
instance_relative_config: bool = False,
root_path: t.Optional[str] = None,
):
super().__init__(
import_name=import_name,
@ -409,14 +428,16 @@ class Flask(Scaffold):
#: tried.
#:
#: .. versionadded:: 0.9
self.url_build_error_handlers = []
self.url_build_error_handlers: t.List[
t.Callable[[Exception, str, dict], str]
] = []
#: A list of functions that will be called at the beginning of the
#: first request to this instance. To register a function, use the
#: :meth:`before_first_request` decorator.
#:
#: .. versionadded:: 0.8
self.before_first_request_funcs = []
self.before_first_request_funcs: t.List[BeforeRequestCallable] = []
#: A list of functions that are called when the application context
#: is destroyed. Since the application context is also torn down
@ -424,13 +445,13 @@ class Flask(Scaffold):
#: from databases.
#:
#: .. versionadded:: 0.9
self.teardown_appcontext_funcs = []
self.teardown_appcontext_funcs: t.List[TeardownCallable] = []
#: A list of shell context processor functions that should be run
#: when a shell context is created.
#:
#: .. versionadded:: 0.11
self.shell_context_processors = []
self.shell_context_processors: t.List[t.Callable[[], t.Dict[str, t.Any]]] = []
#: Maps registered blueprint names to blueprint objects. The
#: dict retains the order the blueprints were registered in.
@ -438,7 +459,7 @@ class Flask(Scaffold):
#: not track how often they were attached.
#:
#: .. versionadded:: 0.7
self.blueprints = {}
self.blueprints: t.Dict[str, "Blueprint"] = {}
#: a place where extensions can store application specific state. For
#: example this is where an extension could store database engines and
@ -449,7 +470,7 @@ class Flask(Scaffold):
#: ``'foo'``.
#:
#: .. versionadded:: 0.7
self.extensions = {}
self.extensions: dict = {}
#: The :class:`~werkzeug.routing.Map` for this instance. You can use
#: this to change the routing converters after the class was created
@ -492,18 +513,18 @@ class Flask(Scaffold):
f"{self.static_url_path}/<path:filename>",
endpoint="static",
host=static_host,
view_func=lambda **kw: self_ref().send_static_file(**kw),
view_func=lambda **kw: self_ref().send_static_file(**kw), # type: ignore # noqa: B950
)
# Set the name of the Click group in case someone wants to add
# the app's commands to another CLI tool.
self.cli.name = self.name
def _is_setup_finished(self):
def _is_setup_finished(self) -> bool:
return self.debug and self._got_first_request
@locked_cached_property
def name(self):
def name(self) -> str: # type: ignore
"""The name of the application. This is usually the import name
with the difference that it's guessed from the run file if the
import name is main. This name is used as a display name when
@ -520,7 +541,7 @@ class Flask(Scaffold):
return self.import_name
@property
def propagate_exceptions(self):
def propagate_exceptions(self) -> bool:
"""Returns the value of the ``PROPAGATE_EXCEPTIONS`` configuration
value in case it's set, otherwise a sensible default is returned.
@ -532,7 +553,7 @@ class Flask(Scaffold):
return self.testing or self.debug
@property
def preserve_context_on_exception(self):
def preserve_context_on_exception(self) -> bool:
"""Returns the value of the ``PRESERVE_CONTEXT_ON_EXCEPTION``
configuration value in case it's set, otherwise a sensible default
is returned.
@ -545,7 +566,7 @@ class Flask(Scaffold):
return self.debug
@locked_cached_property
def logger(self):
def logger(self) -> logging.Logger:
"""A standard Python :class:`~logging.Logger` for the app, with
the same name as :attr:`name`.
@ -572,7 +593,7 @@ class Flask(Scaffold):
return create_logger(self)
@locked_cached_property
def jinja_env(self):
def jinja_env(self) -> Environment:
"""The Jinja environment used to load templates.
The environment is created the first time this property is
@ -582,7 +603,7 @@ class Flask(Scaffold):
return self.create_jinja_environment()
@property
def got_first_request(self):
def got_first_request(self) -> bool:
"""This attribute is set to ``True`` if the application started
handling the first request.
@ -590,7 +611,7 @@ class Flask(Scaffold):
"""
return self._got_first_request
def make_config(self, instance_relative=False):
def make_config(self, instance_relative: bool = False) -> Config:
"""Used to create the config attribute by the Flask constructor.
The `instance_relative` parameter is passed in from the constructor
of Flask (there named `instance_relative_config`) and indicates if
@ -607,7 +628,7 @@ class Flask(Scaffold):
defaults["DEBUG"] = get_debug_flag()
return self.config_class(root_path, defaults)
def auto_find_instance_path(self):
def auto_find_instance_path(self) -> str:
"""Tries to locate the instance path if it was not provided to the
constructor of the application class. It will basically calculate
the path to a folder named ``instance`` next to your main file or
@ -620,7 +641,7 @@ class Flask(Scaffold):
return os.path.join(package_path, "instance")
return os.path.join(prefix, "var", f"{self.name}-instance")
def open_instance_resource(self, resource, mode="rb"):
def open_instance_resource(self, resource: str, mode: str = "rb") -> t.IO[t.AnyStr]:
"""Opens a resource from the application's instance folder
(:attr:`instance_path`). Otherwise works like
:meth:`open_resource`. Instance resources can also be opened for
@ -633,7 +654,7 @@ class Flask(Scaffold):
return open(os.path.join(self.instance_path, resource), mode)
@property
def templates_auto_reload(self):
def templates_auto_reload(self) -> bool:
"""Reload templates when they are changed. Used by
:meth:`create_jinja_environment`.
@ -648,10 +669,10 @@ class Flask(Scaffold):
return rv if rv is not None else self.debug
@templates_auto_reload.setter
def templates_auto_reload(self, value):
def templates_auto_reload(self, value: bool) -> None:
self.config["TEMPLATES_AUTO_RELOAD"] = value
def create_jinja_environment(self):
def create_jinja_environment(self) -> Environment:
"""Create the Jinja environment based on :attr:`jinja_options`
and the various Jinja-related methods of the app. Changing
:attr:`jinja_options` after this will have no effect. Also adds
@ -683,10 +704,10 @@ class Flask(Scaffold):
session=session,
g=g,
)
rv.policies["json.dumps_function"] = json.dumps
rv.policies["json.dumps_function"] = json.dumps # type: ignore
return rv
def create_global_jinja_loader(self):
def create_global_jinja_loader(self) -> DispatchingJinjaLoader:
"""Creates the loader for the Jinja2 environment. Can be used to
override just the loader and keeping the rest unchanged. It's
discouraged to override this function. Instead one should override
@ -699,7 +720,7 @@ class Flask(Scaffold):
"""
return DispatchingJinjaLoader(self)
def select_jinja_autoescape(self, filename):
def select_jinja_autoescape(self, filename: str) -> bool:
"""Returns ``True`` if autoescaping should be active for the given
template name. If no template name is given, returns `True`.
@ -709,7 +730,7 @@ class Flask(Scaffold):
return True
return filename.endswith((".html", ".htm", ".xml", ".xhtml"))
def update_template_context(self, context):
def update_template_context(self, context: dict) -> None:
"""Update the template context with some commonly used variables.
This injects request, session, config and g into the template
context as well as everything template context processors want
@ -720,7 +741,9 @@ class Flask(Scaffold):
:param context: the context as a dictionary that is updated in place
to add extra variables.
"""
funcs = self.template_context_processors[None]
funcs: t.Iterable[
TemplateContextProcessorCallable
] = self.template_context_processors[None]
reqctx = _request_ctx_stack.top
if reqctx is not None:
for bp in self._request_blueprints():
@ -734,7 +757,7 @@ class Flask(Scaffold):
# existing views.
context.update(orig_ctx)
def make_shell_context(self):
def make_shell_context(self) -> dict:
"""Returns the shell context for an interactive shell for this
application. This runs all the registered shell context
processors.
@ -758,7 +781,7 @@ class Flask(Scaffold):
env = ConfigAttribute("ENV")
@property
def debug(self):
def debug(self) -> bool:
"""Whether debug mode is enabled. When using ``flask run`` to start
the development server, an interactive debugger will be shown for
unhandled exceptions, and the server will be reloaded when code
@ -775,11 +798,18 @@ class Flask(Scaffold):
return self.config["DEBUG"]
@debug.setter
def debug(self, value):
def debug(self, value: bool) -> None:
self.config["DEBUG"] = value
self.jinja_env.auto_reload = self.templates_auto_reload
def run(self, host=None, port=None, debug=None, load_dotenv=True, **options):
def run(
self,
host: t.Optional[str] = None,
port: t.Optional[int] = None,
debug: t.Optional[bool] = None,
load_dotenv: bool = True,
**options: t.Any,
) -> None:
"""Runs the application on a local development server.
Do not use ``run()`` in a production setting. It is not intended to
@ -887,14 +917,14 @@ class Flask(Scaffold):
from werkzeug.serving import run_simple
try:
run_simple(host, port, self, **options)
run_simple(t.cast(str, host), port, self, **options)
finally:
# reset the first request information if the development server
# reset normally. This makes it possible to restart the server
# without reloader and that stuff from an interactive shell.
self._got_first_request = False
def test_client(self, use_cookies=True, **kwargs):
def test_client(self, use_cookies: bool = True, **kwargs: t.Any) -> "FlaskClient":
"""Creates a test client for this application. For information
about unit testing head over to :doc:`/testing`.
@ -947,10 +977,12 @@ class Flask(Scaffold):
"""
cls = self.test_client_class
if cls is None:
from .testing import FlaskClient as cls
return cls(self, self.response_class, use_cookies=use_cookies, **kwargs)
from .testing import FlaskClient as cls # type: ignore
return cls( # type: ignore
self, self.response_class, use_cookies=use_cookies, **kwargs
)
def test_cli_runner(self, **kwargs):
def test_cli_runner(self, **kwargs: t.Any) -> "FlaskCliRunner":
"""Create a CLI runner for testing CLI commands.
See :ref:`testing-cli`.
@ -963,12 +995,12 @@ class Flask(Scaffold):
cls = self.test_cli_runner_class
if cls is None:
from .testing import FlaskCliRunner as cls
from .testing import FlaskCliRunner as cls # type: ignore
return cls(self, **kwargs)
return cls(self, **kwargs) # type: ignore
@setupmethod
def register_blueprint(self, blueprint, **options):
def register_blueprint(self, blueprint: "Blueprint", **options: t.Any) -> None:
"""Register a :class:`~flask.Blueprint` on the application. Keyword
arguments passed to this method will override the defaults set on the
blueprint.
@ -989,7 +1021,7 @@ class Flask(Scaffold):
"""
blueprint.register(self, options)
def iter_blueprints(self):
def iter_blueprints(self) -> t.ValuesView["Blueprint"]:
"""Iterates over all blueprints by the order they were registered.
.. versionadded:: 0.11
@ -999,14 +1031,14 @@ class Flask(Scaffold):
@setupmethod
def add_url_rule(
self,
rule,
endpoint=None,
view_func=None,
provide_automatic_options=None,
**options,
):
rule: str,
endpoint: t.Optional[str] = None,
view_func: t.Optional[t.Callable] = None,
provide_automatic_options: t.Optional[bool] = None,
**options: t.Any,
) -> None:
if endpoint is None:
endpoint = _endpoint_from_view_func(view_func)
endpoint = _endpoint_from_view_func(view_func) # type: ignore
options["endpoint"] = endpoint
methods = options.pop("methods", None)
@ -1043,13 +1075,13 @@ class Flask(Scaffold):
methods |= required_methods
rule = self.url_rule_class(rule, methods=methods, **options)
rule.provide_automatic_options = provide_automatic_options
rule.provide_automatic_options = provide_automatic_options # type: ignore
self.url_map.add(rule)
if view_func is not None:
old_func = self.view_functions.get(endpoint)
if getattr(old_func, "_flask_sync_wrapper", False):
old_func = old_func.__wrapped__
old_func = old_func.__wrapped__ # type: ignore
if old_func is not None and old_func != view_func:
raise AssertionError(
"View function mapping is overwriting an existing"
@ -1058,7 +1090,7 @@ class Flask(Scaffold):
self.view_functions[endpoint] = self.ensure_sync(view_func)
@setupmethod
def template_filter(self, name=None):
def template_filter(self, name: t.Optional[str] = None) -> t.Callable:
"""A decorator that is used to register custom template filter.
You can specify a name for the filter, otherwise the function
name will be used. Example::
@ -1071,14 +1103,16 @@ class Flask(Scaffold):
function name will be used.
"""
def decorator(f):
def decorator(f: TemplateFilterCallable) -> TemplateFilterCallable:
self.add_template_filter(f, name=name)
return f
return decorator
@setupmethod
def add_template_filter(self, f, name=None):
def add_template_filter(
self, f: TemplateFilterCallable, name: t.Optional[str] = None
) -> None:
"""Register a custom template filter. Works exactly like the
:meth:`template_filter` decorator.
@ -1088,7 +1122,7 @@ class Flask(Scaffold):
self.jinja_env.filters[name or f.__name__] = f
@setupmethod
def template_test(self, name=None):
def template_test(self, name: t.Optional[str] = None) -> t.Callable:
"""A decorator that is used to register custom template test.
You can specify a name for the test, otherwise the function
name will be used. Example::
@ -1108,14 +1142,16 @@ class Flask(Scaffold):
function name will be used.
"""
def decorator(f):
def decorator(f: TemplateTestCallable) -> TemplateTestCallable:
self.add_template_test(f, name=name)
return f
return decorator
@setupmethod
def add_template_test(self, f, name=None):
def add_template_test(
self, f: TemplateTestCallable, name: t.Optional[str] = None
) -> None:
"""Register a custom template test. Works exactly like the
:meth:`template_test` decorator.
@ -1127,7 +1163,7 @@ class Flask(Scaffold):
self.jinja_env.tests[name or f.__name__] = f
@setupmethod
def template_global(self, name=None):
def template_global(self, name: t.Optional[str] = None) -> t.Callable:
"""A decorator that is used to register a custom template global function.
You can specify a name for the global function, otherwise the function
name will be used. Example::
@ -1142,14 +1178,16 @@ class Flask(Scaffold):
function name will be used.
"""
def decorator(f):
def decorator(f: TemplateGlobalCallable) -> TemplateGlobalCallable:
self.add_template_global(f, name=name)
return f
return decorator
@setupmethod
def add_template_global(self, f, name=None):
def add_template_global(
self, f: TemplateGlobalCallable, name: t.Optional[str] = None
) -> None:
"""Register a custom template global function. Works exactly like the
:meth:`template_global` decorator.
@ -1161,7 +1199,7 @@ class Flask(Scaffold):
self.jinja_env.globals[name or f.__name__] = f
@setupmethod
def before_first_request(self, f):
def before_first_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable:
"""Registers a function to be run before the first request to this
instance of the application.
@ -1174,7 +1212,7 @@ class Flask(Scaffold):
return f
@setupmethod
def teardown_appcontext(self, f):
def teardown_appcontext(self, f: TeardownCallable) -> TeardownCallable:
"""Registers a function to be called when the application context
ends. These functions are typically also called when the request
context is popped.
@ -1207,7 +1245,7 @@ class Flask(Scaffold):
return f
@setupmethod
def shell_context_processor(self, f):
def shell_context_processor(self, f: t.Callable) -> t.Callable:
"""Registers a shell context processor function.
.. versionadded:: 0.11
@ -1215,7 +1253,7 @@ class Flask(Scaffold):
self.shell_context_processors.append(f)
return f
def _find_error_handler(self, e):
def _find_error_handler(self, e: Exception) -> t.Optional[ErrorHandlerCallable]:
"""Return a registered error handler for an exception in this order:
blueprint handler for a specific code, app handler for a specific code,
blueprint handler for an exception class, app handler for an exception
@ -1235,8 +1273,11 @@ class Flask(Scaffold):
if handler is not None:
return handler
return None
def handle_http_exception(self, e):
def handle_http_exception(
self, e: HTTPException
) -> t.Union[HTTPException, ResponseReturnValue]:
"""Handles an HTTP exception. By default this will invoke the
registered error handlers and fall back to returning the
exception as response.
@ -1269,7 +1310,7 @@ class Flask(Scaffold):
return e
return handler(e)
def trap_http_exception(self, e):
def trap_http_exception(self, e: Exception) -> bool:
"""Checks if an HTTP exception should be trapped or not. By default
this will return ``False`` for all exceptions except for a bad request
key error if ``TRAP_BAD_REQUEST_ERRORS`` is set to ``True``. It
@ -1304,7 +1345,9 @@ class Flask(Scaffold):
return False
def handle_user_exception(self, e):
def handle_user_exception(
self, e: Exception
) -> t.Union[HTTPException, ResponseReturnValue]:
"""This method is called whenever an exception occurs that
should be handled. A special case is :class:`~werkzeug
.exceptions.HTTPException` which is forwarded to the
@ -1334,7 +1377,7 @@ class Flask(Scaffold):
return handler(e)
def handle_exception(self, e):
def handle_exception(self, e: Exception) -> Response:
"""Handle an exception that did not have an error handler
associated with it, or that was raised from an error handler.
This always causes a 500 ``InternalServerError``.
@ -1374,6 +1417,7 @@ class Flask(Scaffold):
raise e
self.log_exception(exc_info)
server_error: t.Union[InternalServerError, ResponseReturnValue]
server_error = InternalServerError(original_exception=e)
handler = self._find_error_handler(server_error)
@ -1382,7 +1426,12 @@ class Flask(Scaffold):
return self.finalize_request(server_error, from_error_handler=True)
def log_exception(self, exc_info):
def log_exception(
self,
exc_info: t.Union[
t.Tuple[type, BaseException, TracebackType], t.Tuple[None, None, None]
],
) -> None:
"""Logs an exception. This is called by :meth:`handle_exception`
if debugging is disabled and right before the handler is called.
The default implementation logs the exception as error on the
@ -1394,7 +1443,7 @@ class Flask(Scaffold):
f"Exception on {request.path} [{request.method}]", exc_info=exc_info
)
def raise_routing_exception(self, request):
def raise_routing_exception(self, request: Request) -> t.NoReturn:
"""Exceptions that are recording during routing are reraised with
this method. During debug we are not reraising redirect requests
for non ``GET``, ``HEAD``, or ``OPTIONS`` requests and we're raising
@ -1407,13 +1456,13 @@ class Flask(Scaffold):
or not isinstance(request.routing_exception, RequestRedirect)
or request.method in ("GET", "HEAD", "OPTIONS")
):
raise request.routing_exception
raise request.routing_exception # type: ignore
from .debughelpers import FormDataRoutingRedirect
raise FormDataRoutingRedirect(request)
def dispatch_request(self):
def dispatch_request(self) -> ResponseReturnValue:
"""Does the request dispatching. Matches the URL and returns the
return value of the view or error handler. This does not have to
be a response object. In order to convert the return value to a
@ -1437,7 +1486,7 @@ class Flask(Scaffold):
# otherwise dispatch to the handler for that endpoint
return self.view_functions[rule.endpoint](**req.view_args)
def full_dispatch_request(self):
def full_dispatch_request(self) -> Response:
"""Dispatches the request and on top of that performs request
pre and postprocessing as well as HTTP exception catching and
error handling.
@ -1454,7 +1503,11 @@ class Flask(Scaffold):
rv = self.handle_user_exception(e)
return self.finalize_request(rv)
def finalize_request(self, rv, from_error_handler=False):
def finalize_request(
self,
rv: t.Union[ResponseReturnValue, HTTPException],
from_error_handler: bool = False,
) -> Response:
"""Given the return value from a view function this finalizes
the request by converting it into a response and invoking the
postprocessing functions. This is invoked for both normal
@ -1479,7 +1532,7 @@ class Flask(Scaffold):
)
return response
def try_trigger_before_first_request_functions(self):
def try_trigger_before_first_request_functions(self) -> None:
"""Called before each request and will ensure that it triggers
the :attr:`before_first_request_funcs` and only exactly once per
application instance (which means process usually).
@ -1495,7 +1548,7 @@ class Flask(Scaffold):
func()
self._got_first_request = True
def make_default_options_response(self):
def make_default_options_response(self) -> Response:
"""This method is called to create the default ``OPTIONS`` response.
This can be changed through subclassing to change the default
behavior of ``OPTIONS`` responses.
@ -1508,7 +1561,7 @@ class Flask(Scaffold):
rv.allow.update(methods)
return rv
def should_ignore_error(self, error):
def should_ignore_error(self, error: t.Optional[BaseException]) -> bool:
"""This is called to figure out if an error should be ignored
or not as far as the teardown system is concerned. If this
function returns ``True`` then the teardown handlers will not be
@ -1518,7 +1571,7 @@ class Flask(Scaffold):
"""
return False
def ensure_sync(self, func):
def ensure_sync(self, func: t.Callable) -> t.Callable:
"""Ensure that the function is synchronous for WSGI workers.
Plain ``def`` functions are returned as-is. ``async def``
functions are wrapped to run and wait for the response.
@ -1532,7 +1585,7 @@ class Flask(Scaffold):
return func
def make_response(self, rv):
def make_response(self, rv: ResponseReturnValue) -> Response:
"""Convert the return value from a view function to an instance of
:attr:`response_class`.
@ -1620,7 +1673,7 @@ class Flask(Scaffold):
# evaluate a WSGI callable, or coerce a different response
# class to the correct type
try:
rv = self.response_class.force_type(rv, request.environ)
rv = self.response_class.force_type(rv, request.environ) # type: ignore # noqa: B950
except TypeError as e:
raise TypeError(
f"{e}\nThe view function did not return a valid"
@ -1636,10 +1689,11 @@ class Flask(Scaffold):
f" callable, but it was a {type(rv).__name__}."
)
rv = t.cast(Response, rv)
# prefer the status if it was provided
if status is not None:
if isinstance(status, (str, bytes, bytearray)):
rv.status = status
rv.status = status # type: ignore
else:
rv.status_code = status
@ -1649,7 +1703,9 @@ class Flask(Scaffold):
return rv
def create_url_adapter(self, request):
def create_url_adapter(
self, request: t.Optional[Request]
) -> t.Optional[MapAdapter]:
"""Creates a URL adapter for the given request. The URL adapter
is created at a point where the request context is not yet set
up so the request is passed explicitly.
@ -1687,21 +1743,25 @@ class Flask(Scaffold):
url_scheme=self.config["PREFERRED_URL_SCHEME"],
)
def inject_url_defaults(self, endpoint, values):
return None
def inject_url_defaults(self, endpoint: str, values: dict) -> None:
"""Injects the URL defaults for the given endpoint directly into
the values dictionary passed. This is used internally and
automatically called on URL building.
.. versionadded:: 0.7
"""
funcs = self.url_default_functions[None]
funcs: t.Iterable[URLDefaultCallable] = self.url_default_functions[None]
if "." in endpoint:
bp = endpoint.rsplit(".", 1)[0]
funcs = chain(funcs, self.url_default_functions[bp])
for func in funcs:
func(endpoint, values)
def handle_url_build_error(self, error, endpoint, values):
def handle_url_build_error(
self, error: Exception, endpoint: str, values: dict
) -> str:
"""Handle :class:`~werkzeug.routing.BuildError` on
:meth:`url_for`.
"""
@ -1722,7 +1782,7 @@ class Flask(Scaffold):
raise error
def preprocess_request(self):
def preprocess_request(self) -> t.Optional[ResponseReturnValue]:
"""Called before the request is dispatched. Calls
:attr:`url_value_preprocessors` registered with the app and the
current blueprint (if any). Then calls :attr:`before_request_funcs`
@ -1733,14 +1793,16 @@ class Flask(Scaffold):
further request handling is stopped.
"""
funcs = self.url_value_preprocessors[None]
funcs: t.Iterable[URLValuePreprocessorCallable] = self.url_value_preprocessors[
None
]
for bp in self._request_blueprints():
if bp in self.url_value_preprocessors:
funcs = chain(funcs, self.url_value_preprocessors[bp])
for func in funcs:
func(request.endpoint, request.view_args)
funcs = self.before_request_funcs[None]
funcs: t.Iterable[BeforeRequestCallable] = self.before_request_funcs[None]
for bp in self._request_blueprints():
if bp in self.before_request_funcs:
funcs = chain(funcs, self.before_request_funcs[bp])
@ -1749,7 +1811,9 @@ class Flask(Scaffold):
if rv is not None:
return rv
def process_response(self, response):
return None
def process_response(self, response: Response) -> Response:
"""Can be overridden in order to modify the response object
before it's sent to the WSGI server. By default this will
call all the :meth:`after_request` decorated functions.
@ -1763,7 +1827,7 @@ class Flask(Scaffold):
instance of :attr:`response_class`.
"""
ctx = _request_ctx_stack.top
funcs = ctx._after_request_functions
funcs: t.Iterable[AfterRequestCallable] = ctx._after_request_functions
for bp in self._request_blueprints():
if bp in self.after_request_funcs:
funcs = chain(funcs, reversed(self.after_request_funcs[bp]))
@ -1775,7 +1839,9 @@ class Flask(Scaffold):
self.session_interface.save_session(self, ctx.session, response)
return response
def do_teardown_request(self, exc=_sentinel):
def do_teardown_request(
self, exc: t.Optional[BaseException] = _sentinel # type: ignore
) -> None:
"""Called after the request is dispatched and the response is
returned, right before the request context is popped.
@ -1798,7 +1864,9 @@ class Flask(Scaffold):
"""
if exc is _sentinel:
exc = sys.exc_info()[1]
funcs = reversed(self.teardown_request_funcs[None])
funcs: t.Iterable[TeardownCallable] = reversed(
self.teardown_request_funcs[None]
)
for bp in self._request_blueprints():
if bp in self.teardown_request_funcs:
funcs = chain(funcs, reversed(self.teardown_request_funcs[bp]))
@ -1806,7 +1874,9 @@ class Flask(Scaffold):
func(exc)
request_tearing_down.send(self, exc=exc)
def do_teardown_appcontext(self, exc=_sentinel):
def do_teardown_appcontext(
self, exc: t.Optional[BaseException] = _sentinel # type: ignore
) -> None:
"""Called right before the application context is popped.
When handling a request, the application context is popped
@ -1827,7 +1897,7 @@ class Flask(Scaffold):
func(exc)
appcontext_tearing_down.send(self, exc=exc)
def app_context(self):
def app_context(self) -> AppContext:
"""Create an :class:`~flask.ctx.AppContext`. Use as a ``with``
block to push the context, which will make :data:`current_app`
point at this application.
@ -1848,7 +1918,7 @@ class Flask(Scaffold):
"""
return AppContext(self)
def request_context(self, environ):
def request_context(self, environ: dict) -> RequestContext:
"""Create a :class:`~flask.ctx.RequestContext` representing a
WSGI environment. Use a ``with`` block to push the context,
which will make :data:`request` point at this request.
@ -1864,7 +1934,7 @@ class Flask(Scaffold):
"""
return RequestContext(self, environ)
def test_request_context(self, *args, **kwargs):
def test_request_context(self, *args: t.Any, **kwargs: t.Any) -> RequestContext:
"""Create a :class:`~flask.ctx.RequestContext` for a WSGI
environment created from the given values. This is mostly useful
during testing, where you may want to run a function that uses
@ -1920,7 +1990,7 @@ class Flask(Scaffold):
finally:
builder.close()
def wsgi_app(self, environ, start_response):
def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any:
"""The actual WSGI application. This is not implemented in
:meth:`__call__` so that middlewares can be applied without
losing a reference to the app object. Instead of doing this::
@ -1946,7 +2016,7 @@ class Flask(Scaffold):
start the response.
"""
ctx = self.request_context(environ)
error = None
error: t.Optional[BaseException] = None
try:
try:
ctx.push()
@ -1963,14 +2033,14 @@ class Flask(Scaffold):
error = None
ctx.auto_pop(error)
def __call__(self, environ, start_response):
def __call__(self, environ: dict, start_response: t.Callable) -> t.Any:
"""The WSGI server calls the Flask application object as the
WSGI application. This calls :meth:`wsgi_app`, which can be
wrapped to apply middleware.
"""
return self.wsgi_app(environ, start_response)
def _request_blueprints(self):
def _request_blueprints(self) -> t.Iterable[str]:
if _request_ctx_stack.top.request.blueprint is None:
return []
else:

View File

@ -1,9 +1,25 @@
import typing as t
from collections import defaultdict
from functools import update_wrapper
from .scaffold import _endpoint_from_view_func
from .scaffold import _sentinel
from .scaffold import Scaffold
from .typing import AfterRequestCallable
from .typing import BeforeRequestCallable
from .typing import ErrorHandlerCallable
from .typing import TeardownCallable
from .typing import TemplateContextProcessorCallable
from .typing import TemplateFilterCallable
from .typing import TemplateGlobalCallable
from .typing import TemplateTestCallable
from .typing import URLDefaultCallable
from .typing import URLValuePreprocessorCallable
if t.TYPE_CHECKING:
from .app import Flask
DeferredSetupFunction = t.Callable[["BlueprintSetupState"], t.Callable]
class BlueprintSetupState:
@ -13,7 +29,13 @@ class BlueprintSetupState:
to all register callback functions.
"""
def __init__(self, blueprint, app, options, first_registration):
def __init__(
self,
blueprint: "Blueprint",
app: "Flask",
options: t.Any,
first_registration: bool,
) -> None:
#: a reference to the current application
self.app = app
@ -52,7 +74,13 @@ class BlueprintSetupState:
self.url_defaults = dict(self.blueprint.url_values_defaults)
self.url_defaults.update(self.options.get("url_defaults", ()))
def add_url_rule(self, rule, endpoint=None, view_func=None, **options):
def add_url_rule(
self,
rule: str,
endpoint: t.Optional[str] = None,
view_func: t.Optional[t.Callable] = None,
**options: t.Any,
) -> None:
"""A helper method to register a rule (and optionally a view function)
to the application. The endpoint is automatically prefixed with the
blueprint's name.
@ -64,7 +92,7 @@ class BlueprintSetupState:
rule = self.url_prefix
options.setdefault("subdomain", self.subdomain)
if endpoint is None:
endpoint = _endpoint_from_view_func(view_func)
endpoint = _endpoint_from_view_func(view_func) # type: ignore
defaults = self.url_defaults
if "defaults" in options:
defaults = dict(defaults, **options.pop("defaults"))
@ -142,16 +170,16 @@ class Blueprint(Scaffold):
def __init__(
self,
name,
import_name,
static_folder=None,
static_url_path=None,
template_folder=None,
url_prefix=None,
subdomain=None,
url_defaults=None,
root_path=None,
cli_group=_sentinel,
name: str,
import_name: str,
static_folder: t.Optional[str] = None,
static_url_path: t.Optional[str] = None,
template_folder: t.Optional[str] = None,
url_prefix: t.Optional[str] = None,
subdomain: t.Optional[str] = None,
url_defaults: t.Optional[dict] = None,
root_path: t.Optional[str] = None,
cli_group: t.Optional[str] = _sentinel, # type: ignore
):
super().__init__(
import_name=import_name,
@ -163,19 +191,19 @@ class Blueprint(Scaffold):
self.name = name
self.url_prefix = url_prefix
self.subdomain = subdomain
self.deferred_functions = []
self.deferred_functions: t.List[DeferredSetupFunction] = []
if url_defaults is None:
url_defaults = {}
self.url_values_defaults = url_defaults
self.cli_group = cli_group
self._blueprints = []
self._blueprints: t.List[t.Tuple["Blueprint", dict]] = []
def _is_setup_finished(self):
def _is_setup_finished(self) -> bool:
return self.warn_on_modifications and self._got_registered_once
def record(self, func):
def record(self, func: t.Callable) -> None:
"""Registers a function that is called when the blueprint is
registered on the application. This function is called with the
state as argument as returned by the :meth:`make_setup_state`
@ -193,27 +221,29 @@ class Blueprint(Scaffold):
)
self.deferred_functions.append(func)
def record_once(self, func):
def record_once(self, func: t.Callable) -> None:
"""Works like :meth:`record` but wraps the function in another
function that will ensure the function is only called once. If the
blueprint is registered a second time on the application, the
function passed is not called.
"""
def wrapper(state):
def wrapper(state: BlueprintSetupState) -> None:
if state.first_registration:
func(state)
return self.record(update_wrapper(wrapper, func))
def make_setup_state(self, app, options, first_registration=False):
def make_setup_state(
self, app: "Flask", options: dict, first_registration: bool = False
) -> BlueprintSetupState:
"""Creates an instance of :meth:`~flask.blueprints.BlueprintSetupState`
object that is later passed to the register callback functions.
Subclasses can override this to return a subclass of the setup state.
"""
return BlueprintSetupState(self, app, options, first_registration)
def register_blueprint(self, blueprint, **options):
def register_blueprint(self, blueprint: "Blueprint", **options: t.Any) -> None:
"""Register a :class:`~flask.Blueprint` on this blueprint. Keyword
arguments passed to this method will override the defaults set
on the blueprint.
@ -222,7 +252,7 @@ class Blueprint(Scaffold):
"""
self._blueprints.append((blueprint, options))
def register(self, app, options):
def register(self, app: "Flask", options: dict) -> None:
"""Called by :meth:`Flask.register_blueprint` to register all
views and callbacks registered on the blueprint with the
application. Creates a :class:`.BlueprintSetupState` and calls
@ -327,7 +357,13 @@ class Blueprint(Scaffold):
bp_options["name_prefix"] = options.get("name_prefix", "") + self.name + "."
blueprint.register(app, bp_options)
def add_url_rule(self, rule, endpoint=None, view_func=None, **options):
def add_url_rule(
self,
rule: str,
endpoint: t.Optional[str] = None,
view_func: t.Optional[t.Callable] = None,
**options: t.Any,
) -> None:
"""Like :meth:`Flask.add_url_rule` but for a blueprint. The endpoint for
the :func:`url_for` function is prefixed with the name of the blueprint.
"""
@ -339,7 +375,7 @@ class Blueprint(Scaffold):
), "Blueprint view function name should not contain dots"
self.record(lambda s: s.add_url_rule(rule, endpoint, view_func, **options))
def app_template_filter(self, name=None):
def app_template_filter(self, name: t.Optional[str] = None) -> t.Callable:
"""Register a custom template filter, available application wide. Like
:meth:`Flask.template_filter` but for a blueprint.
@ -347,13 +383,15 @@ class Blueprint(Scaffold):
function name will be used.
"""
def decorator(f):
def decorator(f: TemplateFilterCallable) -> TemplateFilterCallable:
self.add_app_template_filter(f, name=name)
return f
return decorator
def add_app_template_filter(self, f, name=None):
def add_app_template_filter(
self, f: TemplateFilterCallable, name: t.Optional[str] = None
) -> None:
"""Register a custom template filter, available application wide. Like
:meth:`Flask.add_template_filter` but for a blueprint. Works exactly
like the :meth:`app_template_filter` decorator.
@ -362,12 +400,12 @@ class Blueprint(Scaffold):
function name will be used.
"""
def register_template(state):
def register_template(state: BlueprintSetupState) -> None:
state.app.jinja_env.filters[name or f.__name__] = f
self.record_once(register_template)
def app_template_test(self, name=None):
def app_template_test(self, name: t.Optional[str] = None) -> t.Callable:
"""Register a custom template test, available application wide. Like
:meth:`Flask.template_test` but for a blueprint.
@ -377,13 +415,15 @@ class Blueprint(Scaffold):
function name will be used.
"""
def decorator(f):
def decorator(f: TemplateTestCallable) -> TemplateTestCallable:
self.add_app_template_test(f, name=name)
return f
return decorator
def add_app_template_test(self, f, name=None):
def add_app_template_test(
self, f: TemplateTestCallable, name: t.Optional[str] = None
) -> None:
"""Register a custom template test, available application wide. Like
:meth:`Flask.add_template_test` but for a blueprint. Works exactly
like the :meth:`app_template_test` decorator.
@ -394,12 +434,12 @@ class Blueprint(Scaffold):
function name will be used.
"""
def register_template(state):
def register_template(state: BlueprintSetupState) -> None:
state.app.jinja_env.tests[name or f.__name__] = f
self.record_once(register_template)
def app_template_global(self, name=None):
def app_template_global(self, name: t.Optional[str] = None) -> t.Callable:
"""Register a custom template global, available application wide. Like
:meth:`Flask.template_global` but for a blueprint.
@ -409,13 +449,15 @@ class Blueprint(Scaffold):
function name will be used.
"""
def decorator(f):
def decorator(f: TemplateGlobalCallable) -> TemplateGlobalCallable:
self.add_app_template_global(f, name=name)
return f
return decorator
def add_app_template_global(self, f, name=None):
def add_app_template_global(
self, f: TemplateGlobalCallable, name: t.Optional[str] = None
) -> None:
"""Register a custom template global, available application wide. Like
:meth:`Flask.add_template_global` but for a blueprint. Works exactly
like the :meth:`app_template_global` decorator.
@ -426,12 +468,12 @@ class Blueprint(Scaffold):
function name will be used.
"""
def register_template(state):
def register_template(state: BlueprintSetupState) -> None:
state.app.jinja_env.globals[name or f.__name__] = f
self.record_once(register_template)
def before_app_request(self, f):
def before_app_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable:
"""Like :meth:`Flask.before_request`. Such a function is executed
before each request, even if outside of a blueprint.
"""
@ -442,7 +484,9 @@ class Blueprint(Scaffold):
)
return f
def before_app_first_request(self, f):
def before_app_first_request(
self, f: BeforeRequestCallable
) -> BeforeRequestCallable:
"""Like :meth:`Flask.before_first_request`. Such a function is
executed before the first request to the application.
"""
@ -451,7 +495,7 @@ class Blueprint(Scaffold):
)
return f
def after_app_request(self, f):
def after_app_request(self, f: AfterRequestCallable) -> AfterRequestCallable:
"""Like :meth:`Flask.after_request` but for a blueprint. Such a function
is executed after each request, even if outside of the blueprint.
"""
@ -462,7 +506,7 @@ class Blueprint(Scaffold):
)
return f
def teardown_app_request(self, f):
def teardown_app_request(self, f: TeardownCallable) -> TeardownCallable:
"""Like :meth:`Flask.teardown_request` but for a blueprint. Such a
function is executed when tearing down each request, even if outside of
the blueprint.
@ -472,7 +516,9 @@ class Blueprint(Scaffold):
)
return f
def app_context_processor(self, f):
def app_context_processor(
self, f: TemplateContextProcessorCallable
) -> TemplateContextProcessorCallable:
"""Like :meth:`Flask.context_processor` but for a blueprint. Such a
function is executed each request, even if outside of the blueprint.
"""
@ -481,32 +527,34 @@ class Blueprint(Scaffold):
)
return f
def app_errorhandler(self, code):
def app_errorhandler(self, code: t.Union[t.Type[Exception], int]) -> t.Callable:
"""Like :meth:`Flask.errorhandler` but for a blueprint. This
handler is used for all requests, even if outside of the blueprint.
"""
def decorator(f):
def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable:
self.record_once(lambda s: s.app.errorhandler(code)(f))
return f
return decorator
def app_url_value_preprocessor(self, f):
def app_url_value_preprocessor(
self, f: URLValuePreprocessorCallable
) -> URLValuePreprocessorCallable:
"""Same as :meth:`url_value_preprocessor` but application wide."""
self.record_once(
lambda s: s.app.url_value_preprocessors.setdefault(None, []).append(f)
)
return f
def app_url_defaults(self, f):
def app_url_defaults(self, f: URLDefaultCallable) -> URLDefaultCallable:
"""Same as :meth:`url_defaults` but application wide."""
self.record_once(
lambda s: s.app.url_default_functions.setdefault(None, []).append(f)
)
return f
def ensure_sync(self, f):
def ensure_sync(self, f: t.Callable) -> t.Callable:
"""Ensure the function is synchronous.
Override if you would like custom async to sync behaviour in

View File

@ -27,7 +27,7 @@ except ImportError:
try:
import ssl
except ImportError:
ssl = None
ssl = None # type: ignore
class NoAppException(click.UsageError):
@ -860,7 +860,7 @@ def run_command(
@click.command("shell", short_help="Run a shell in the app context.")
@with_appcontext
def shell_command():
def shell_command() -> None:
"""Run an interactive Python shell in the context of a given
Flask application. The application will populate the default
namespace of this shell according to its configuration.
@ -877,7 +877,7 @@ def shell_command():
f"App: {app.import_name} [{app.env}]\n"
f"Instance: {app.instance_path}"
)
ctx = {}
ctx: dict = {}
# Support the regular Python interpreter startup script if someone
# is using it.
@ -922,7 +922,7 @@ def shell_command():
)
@click.option("--all-methods", is_flag=True, help="Show HEAD and OPTIONS methods.")
@with_appcontext
def routes_command(sort, all_methods):
def routes_command(sort: str, all_methods: bool) -> None:
"""Show all registered routes with endpoints and methods."""
rules = list(current_app.url_map.iter_rules())
@ -935,9 +935,12 @@ def routes_command(sort, all_methods):
if sort in ("endpoint", "rule"):
rules = sorted(rules, key=attrgetter(sort))
elif sort == "methods":
rules = sorted(rules, key=lambda rule: sorted(rule.methods))
rules = sorted(rules, key=lambda rule: sorted(rule.methods)) # type: ignore
rule_methods = [", ".join(sorted(rule.methods - ignored_methods)) for rule in rules]
rule_methods = [
", ".join(sorted(rule.methods - ignored_methods)) # type: ignore
for rule in rules
]
headers = ("Endpoint", "Methods", "Rule")
widths = (
@ -975,7 +978,7 @@ debug mode.
)
def main():
def main() -> None:
# TODO omit sys.argv once https://github.com/pallets/click/issues/536 is fixed
cli.main(args=sys.argv[1:])

View File

@ -1,6 +1,7 @@
import errno
import os
import types
import typing as t
from werkzeug.utils import import_string
@ -8,11 +9,11 @@ from werkzeug.utils import import_string
class ConfigAttribute:
"""Makes an attribute forward to the config"""
def __init__(self, name, get_converter=None):
def __init__(self, name: str, get_converter: t.Optional[t.Callable] = None) -> None:
self.__name__ = name
self.get_converter = get_converter
def __get__(self, obj, type=None):
def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
if obj is None:
return self
rv = obj.config[self.__name__]
@ -20,7 +21,7 @@ class ConfigAttribute:
rv = self.get_converter(rv)
return rv
def __set__(self, obj, value):
def __set__(self, obj: t.Any, value: t.Any) -> None:
obj.config[self.__name__] = value
@ -68,11 +69,11 @@ class Config(dict):
:param defaults: an optional dictionary of default values
"""
def __init__(self, root_path, defaults=None):
def __init__(self, root_path: str, defaults: t.Optional[dict] = None) -> None:
dict.__init__(self, defaults or {})
self.root_path = root_path
def from_envvar(self, variable_name, silent=False):
def from_envvar(self, variable_name: str, silent: bool = False) -> bool:
"""Loads a configuration from an environment variable pointing to
a configuration file. This is basically just a shortcut with nicer
error messages for this line of code::
@ -96,7 +97,7 @@ class Config(dict):
)
return self.from_pyfile(rv, silent=silent)
def from_pyfile(self, filename, silent=False):
def from_pyfile(self, filename: str, silent: bool = False) -> bool:
"""Updates the values in the config from a Python file. This function
behaves as if the file was imported as module with the
:meth:`from_object` function.
@ -124,7 +125,7 @@ class Config(dict):
self.from_object(d)
return True
def from_object(self, obj):
def from_object(self, obj: t.Union[object, str]) -> None:
"""Updates the values from the given object. An object can be of one
of the following two types:
@ -162,7 +163,12 @@ class Config(dict):
if key.isupper():
self[key] = getattr(obj, key)
def from_file(self, filename, load, silent=False):
def from_file(
self,
filename: str,
load: t.Callable[[t.IO[t.Any]], t.Mapping],
silent: bool = False,
) -> bool:
"""Update the values in the config from a file that is loaded
using the ``load`` parameter. The loaded data is passed to the
:meth:`from_mapping` method.
@ -196,30 +202,26 @@ class Config(dict):
return self.from_mapping(obj)
def from_mapping(self, *mapping, **kwargs):
def from_mapping(
self, mapping: t.Optional[t.Mapping[str, t.Any]] = None, **kwargs: t.Any
) -> bool:
"""Updates the config like :meth:`update` ignoring items with non-upper
keys.
.. versionadded:: 0.11
"""
mappings = []
if len(mapping) == 1:
if hasattr(mapping[0], "items"):
mappings.append(mapping[0].items())
else:
mappings.append(mapping[0])
elif len(mapping) > 1:
raise TypeError(
f"expected at most 1 positional argument, got {len(mapping)}"
)
mappings.append(kwargs.items())
for mapping in mappings:
for (key, value) in mapping:
mappings: t.Dict[str, t.Any] = {}
if mapping is not None:
mappings.update(mapping)
mappings.update(kwargs)
for key, value in mappings.items():
if key.isupper():
self[key] = value
return True
def get_namespace(self, namespace, lowercase=True, trim_namespace=True):
def get_namespace(
self, namespace: str, lowercase: bool = True, trim_namespace: bool = True
) -> t.Dict[str, t.Any]:
"""Returns a dictionary containing a subset of configuration options
that match the specified namespace/prefix. Example usage::
@ -260,5 +262,5 @@ class Config(dict):
rv[key] = v
return rv
def __repr__(self):
def __repr__(self) -> str:
return f"<{type(self).__name__} {dict.__repr__(self)}>"

View File

@ -1,5 +1,7 @@
import sys
import typing as t
from functools import update_wrapper
from types import TracebackType
from werkzeug.exceptions import HTTPException
@ -7,6 +9,12 @@ from .globals import _app_ctx_stack
from .globals import _request_ctx_stack
from .signals import appcontext_popped
from .signals import appcontext_pushed
from .typing import AfterRequestCallable
if t.TYPE_CHECKING:
from .app import Flask
from .sessions import SessionMixin
from .wrappers import Request
# a singleton sentinel value for parameter defaults
@ -33,7 +41,7 @@ class _AppCtxGlobals:
.. versionadded:: 0.10
"""
def get(self, name, default=None):
def get(self, name: str, default: t.Optional[t.Any] = None) -> t.Any:
"""Get an attribute by name, or a default value. Like
:meth:`dict.get`.
@ -44,7 +52,7 @@ class _AppCtxGlobals:
"""
return self.__dict__.get(name, default)
def pop(self, name, default=_sentinel):
def pop(self, name: str, default: t.Any = _sentinel) -> t.Any:
"""Get and remove an attribute by name. Like :meth:`dict.pop`.
:param name: Name of attribute to pop.
@ -58,7 +66,7 @@ class _AppCtxGlobals:
else:
return self.__dict__.pop(name, default)
def setdefault(self, name, default=None):
def setdefault(self, name: str, default: t.Any = None) -> t.Any:
"""Get the value of an attribute if it is present, otherwise
set and return a default value. Like :meth:`dict.setdefault`.
@ -70,20 +78,20 @@ class _AppCtxGlobals:
"""
return self.__dict__.setdefault(name, default)
def __contains__(self, item):
def __contains__(self, item: t.Any) -> bool:
return item in self.__dict__
def __iter__(self):
def __iter__(self) -> t.Iterator:
return iter(self.__dict__)
def __repr__(self):
def __repr__(self) -> str:
top = _app_ctx_stack.top
if top is not None:
return f"<flask.g of {top.app.name!r}>"
return object.__repr__(self)
def after_this_request(f):
def after_this_request(f: AfterRequestCallable) -> AfterRequestCallable:
"""Executes a function after this request. This is useful to modify
response objects. The function is passed the response object and has
to return the same or a new one.
@ -108,7 +116,7 @@ def after_this_request(f):
return f
def copy_current_request_context(f):
def copy_current_request_context(f: t.Callable) -> t.Callable:
"""A helper function that decorates a function to retain the current
request context. This is useful when working with greenlets. The moment
the function is decorated a copy of the request context is created and
@ -148,7 +156,7 @@ def copy_current_request_context(f):
return update_wrapper(wrapper, f)
def has_request_context():
def has_request_context() -> bool:
"""If you have code that wants to test if a request context is there or
not this function can be used. For instance, you may want to take advantage
of request information if the request object is available, but fail
@ -180,7 +188,7 @@ def has_request_context():
return _request_ctx_stack.top is not None
def has_app_context():
def has_app_context() -> bool:
"""Works like :func:`has_request_context` but for the application
context. You can also just do a boolean check on the
:data:`current_app` object instead.
@ -199,7 +207,7 @@ class AppContext:
context.
"""
def __init__(self, app):
def __init__(self, app: "Flask") -> None:
self.app = app
self.url_adapter = app.create_url_adapter(None)
self.g = app.app_ctx_globals_class()
@ -208,13 +216,13 @@ class AppContext:
# but there a basic "refcount" is enough to track them.
self._refcnt = 0
def push(self):
def push(self) -> None:
"""Binds the app context to the current context."""
self._refcnt += 1
_app_ctx_stack.push(self)
appcontext_pushed.send(self.app)
def pop(self, exc=_sentinel):
def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None: # type: ignore
"""Pops the app context."""
try:
self._refcnt -= 1
@ -227,11 +235,13 @@ class AppContext:
assert rv is self, f"Popped wrong app context. ({rv!r} instead of {self!r})"
appcontext_popped.send(self.app)
def __enter__(self):
def __enter__(self) -> "AppContext":
self.push()
return self
def __exit__(self, exc_type, exc_value, tb):
def __exit__(
self, exc_type: type, exc_value: BaseException, tb: TracebackType
) -> None:
self.pop(exc_value)
@ -265,7 +275,13 @@ class RequestContext:
that situation, otherwise your unittests will leak memory.
"""
def __init__(self, app, environ, request=None, session=None):
def __init__(
self,
app: "Flask",
environ: dict,
request: t.Optional["Request"] = None,
session: t.Optional["SessionMixin"] = None,
) -> None:
self.app = app
if request is None:
request = app.request_class(environ)
@ -282,7 +298,7 @@ class RequestContext:
# other request contexts. Now only if the last level is popped we
# get rid of them. Additionally if an application context is missing
# one is created implicitly so for each level we add this information
self._implicit_app_ctx_stack = []
self._implicit_app_ctx_stack: t.List[t.Optional["AppContext"]] = []
# indicator if the context was preserved. Next time another context
# is pushed the preserved context is popped.
@ -295,17 +311,17 @@ class RequestContext:
# Functions that should be executed after the request on the response
# object. These will be called before the regular "after_request"
# functions.
self._after_request_functions = []
self._after_request_functions: t.List[AfterRequestCallable] = []
@property
def g(self):
def g(self) -> AppContext:
return _app_ctx_stack.top.g
@g.setter
def g(self, value):
def g(self, value: AppContext) -> None:
_app_ctx_stack.top.g = value
def copy(self):
def copy(self) -> "RequestContext":
"""Creates a copy of this request context with the same request object.
This can be used to move a request context to a different greenlet.
Because the actual request object is the same this cannot be used to
@ -325,17 +341,17 @@ class RequestContext:
session=self.session,
)
def match_request(self):
def match_request(self) -> None:
"""Can be overridden by a subclass to hook into the matching
of the request.
"""
try:
result = self.url_adapter.match(return_rule=True)
self.request.url_rule, self.request.view_args = result
result = self.url_adapter.match(return_rule=True) # type: ignore
self.request.url_rule, self.request.view_args = result # type: ignore
except HTTPException as e:
self.request.routing_exception = e
def push(self):
def push(self) -> None:
"""Binds the request context to the current context."""
# If an exception occurs in debug mode or if context preservation is
# activated under exception situations exactly one context stays
@ -375,7 +391,7 @@ class RequestContext:
if self.session is None:
self.session = session_interface.make_null_session(self.app)
def pop(self, exc=_sentinel):
def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None: # type: ignore
"""Pops the request context and unbinds it by doing that. This will
also trigger the execution of functions registered by the
:meth:`~flask.Flask.teardown_request` decorator.
@ -414,20 +430,22 @@ class RequestContext:
rv is self
), f"Popped wrong request context. ({rv!r} instead of {self!r})"
def auto_pop(self, exc):
def auto_pop(self, exc: t.Optional[BaseException]) -> None:
if self.request.environ.get("flask._preserve_context") or (
exc is not None and self.app.preserve_context_on_exception
):
self.preserved = True
self._preserved_exc = exc
self._preserved_exc = exc # type: ignore
else:
self.pop(exc)
def __enter__(self):
def __enter__(self) -> "RequestContext":
self.push()
return self
def __exit__(self, exc_type, exc_value, tb):
def __exit__(
self, exc_type: type, exc_value: BaseException, tb: TracebackType
) -> None:
# do not pop the request stack if we are in debug mode and an
# exception happened. This will allow the debugger to still
# access the request object in the interactive shell. Furthermore
@ -435,7 +453,7 @@ class RequestContext:
# See flask.testing for how this works.
self.auto_pop(exc_value)
def __repr__(self):
def __repr__(self) -> str:
return (
f"<{type(self).__name__} {self.request.url!r}"
f" [{self.request.method}] of {self.app.name}>"

View File

@ -1,4 +1,5 @@
import os
import typing as t
from warnings import warn
from .app import Flask
@ -92,7 +93,7 @@ def attach_enctype_error_multidict(request):
request.files.__class__ = newcls
def _dump_loader_info(loader):
def _dump_loader_info(loader) -> t.Generator:
yield f"class: {type(loader).__module__}.{type(loader).__name__}"
for key, value in sorted(loader.__dict__.items()):
if key.startswith("_"):
@ -109,7 +110,7 @@ def _dump_loader_info(loader):
yield f"{key}: {value!r}"
def explain_template_loading_attempts(app, template, attempts):
def explain_template_loading_attempts(app: Flask, template, attempts) -> None:
"""This should help developers understand what failed"""
info = [f"Locating template {template!r}:"]
total_found = 0
@ -157,7 +158,7 @@ def explain_template_loading_attempts(app, template, attempts):
app.logger.info("\n".join(info))
def explain_ignored_app_run():
def explain_ignored_app_run() -> None:
if os.environ.get("WERKZEUG_RUN_MAIN") != "true":
warn(
Warning(

View File

@ -1,8 +1,14 @@
import typing as t
from functools import partial
from werkzeug.local import LocalProxy
from werkzeug.local import LocalStack
if t.TYPE_CHECKING:
from .app import Flask
from .ctx import AppContext
from .sessions import SessionMixin
from .wrappers import Request
_request_ctx_err_msg = """\
Working outside of request context.
@ -45,7 +51,7 @@ def _find_app():
# context locals
_request_ctx_stack = LocalStack()
_app_ctx_stack = LocalStack()
current_app = LocalProxy(_find_app)
request = LocalProxy(partial(_lookup_req_object, "request"))
session = LocalProxy(partial(_lookup_req_object, "session"))
g = LocalProxy(partial(_lookup_app_object, "g"))
current_app: "Flask" = LocalProxy(_find_app) # type: ignore
request: "Request" = LocalProxy(partial(_lookup_req_object, "request")) # type: ignore
session: "SessionMixin" = LocalProxy(partial(_lookup_req_object, "session")) # type: ignore # noqa: B950
g: "AppContext" = LocalProxy(partial(_lookup_app_object, "g")) # type: ignore

View File

@ -1,6 +1,8 @@
import os
import socket
import typing as t
import warnings
from datetime import timedelta
from functools import update_wrapper
from functools import wraps
from threading import RLock
@ -18,8 +20,11 @@ from .globals import request
from .globals import session
from .signals import message_flashed
if t.TYPE_CHECKING:
from .wrappers import Response
def get_env():
def get_env() -> str:
"""Get the environment the app is running in, indicated by the
:envvar:`FLASK_ENV` environment variable. The default is
``'production'``.
@ -27,7 +32,7 @@ def get_env():
return os.environ.get("FLASK_ENV") or "production"
def get_debug_flag():
def get_debug_flag() -> bool:
"""Get whether debug mode should be enabled for the app, indicated
by the :envvar:`FLASK_DEBUG` environment variable. The default is
``True`` if :func:`.get_env` returns ``'development'``, or ``False``
@ -41,7 +46,7 @@ def get_debug_flag():
return val.lower() not in ("0", "false", "no")
def get_load_dotenv(default=True):
def get_load_dotenv(default: bool = True) -> bool:
"""Get whether the user has disabled loading dotenv files by setting
:envvar:`FLASK_SKIP_DOTENV`. The default is ``True``, load the
files.
@ -56,7 +61,9 @@ def get_load_dotenv(default=True):
return val.lower() in ("0", "false", "no")
def stream_with_context(generator_or_function):
def stream_with_context(
generator_or_function: t.Union[t.Generator, t.Callable]
) -> t.Generator:
"""Request contexts disappear when the response is started on the server.
This is done for efficiency reasons and to make it less likely to encounter
memory leaks with badly written WSGI middlewares. The downside is that if
@ -91,16 +98,16 @@ def stream_with_context(generator_or_function):
.. versionadded:: 0.9
"""
try:
gen = iter(generator_or_function)
gen = iter(generator_or_function) # type: ignore
except TypeError:
def decorator(*args, **kwargs):
gen = generator_or_function(*args, **kwargs)
def decorator(*args: t.Any, **kwargs: t.Any) -> t.Any:
gen = generator_or_function(*args, **kwargs) # type: ignore
return stream_with_context(gen)
return update_wrapper(decorator, generator_or_function)
return update_wrapper(decorator, generator_or_function) # type: ignore
def generator():
def generator() -> t.Generator:
ctx = _request_ctx_stack.top
if ctx is None:
raise RuntimeError(
@ -120,7 +127,7 @@ def stream_with_context(generator_or_function):
yield from gen
finally:
if hasattr(gen, "close"):
gen.close()
gen.close() # type: ignore
# The trick is to start the generator. Then the code execution runs until
# the first dummy None is yielded at which point the context was already
@ -131,7 +138,7 @@ def stream_with_context(generator_or_function):
return wrapped_g
def make_response(*args):
def make_response(*args: t.Any) -> "Response":
"""Sometimes it is necessary to set additional headers in a view. Because
views do not have to return response objects but can return a value that
is converted into a response object by Flask itself, it becomes tricky to
@ -180,7 +187,7 @@ def make_response(*args):
return current_app.make_response(args)
def url_for(endpoint, **values):
def url_for(endpoint: str, **values: t.Any) -> str:
"""Generates a URL to the given endpoint with the method provided.
Variable arguments that are unknown to the target endpoint are appended
@ -331,7 +338,7 @@ def url_for(endpoint, **values):
return rv
def get_template_attribute(template_name, attribute):
def get_template_attribute(template_name: str, attribute: str) -> t.Any:
"""Loads a macro (or variable) a template exports. This can be used to
invoke a macro from within Python code. If you for example have a
template named :file:`_cider.html` with the following contents:
@ -353,7 +360,7 @@ def get_template_attribute(template_name, attribute):
return getattr(current_app.jinja_env.get_template(template_name).module, attribute)
def flash(message, category="message"):
def flash(message: str, category: str = "message") -> None:
"""Flashes a message to the next request. In order to remove the
flashed message from the session and to display it to the user,
the template has to call :func:`get_flashed_messages`.
@ -379,11 +386,15 @@ def flash(message, category="message"):
flashes.append((category, message))
session["_flashes"] = flashes
message_flashed.send(
current_app._get_current_object(), message=message, category=category
current_app._get_current_object(), # type: ignore
message=message,
category=category,
)
def get_flashed_messages(with_categories=False, category_filter=()):
def get_flashed_messages(
with_categories: bool = False, category_filter: t.Iterable[str] = ()
) -> t.Union[t.List[str], t.List[t.Tuple[str, str]]]:
"""Pulls all flashed messages from the session and returns them.
Further calls in the same request to the function will return
the same messages. By default just the messages are returned,
@ -608,7 +619,7 @@ def send_file(
)
def safe_join(directory, *pathnames):
def safe_join(directory: str, *pathnames: str) -> str:
"""Safely join zero or more untrusted path components to a base
directory to avoid escaping the base directory.
@ -631,7 +642,7 @@ def safe_join(directory, *pathnames):
return path
def send_from_directory(directory, path, **kwargs):
def send_from_directory(directory: str, path: str, **kwargs: t.Any) -> "Response":
"""Send a file from within a directory using :func:`send_file`.
.. code-block:: python
@ -661,7 +672,7 @@ def send_from_directory(directory, path, **kwargs):
.. versionadded:: 0.5
"""
return werkzeug.utils.send_from_directory(
return werkzeug.utils.send_from_directory( # type: ignore
directory, path, **_prepare_send_file_kwargs(**kwargs)
)
@ -675,27 +686,32 @@ class locked_cached_property(werkzeug.utils.cached_property):
Inherits from Werkzeug's ``cached_property`` (and ``property``).
"""
def __init__(self, fget, name=None, doc=None):
def __init__(
self,
fget: t.Callable[[t.Any], t.Any],
name: t.Optional[str] = None,
doc: t.Optional[str] = None,
) -> None:
super().__init__(fget, name=name, doc=doc)
self.lock = RLock()
def __get__(self, obj, type=None):
def __get__(self, obj: object, type: type = None) -> t.Any: # type: ignore
if obj is None:
return self
with self.lock:
return super().__get__(obj, type=type)
def __set__(self, obj, value):
def __set__(self, obj: object, value: t.Any) -> None:
with self.lock:
super().__set__(obj, value)
def __delete__(self, obj):
def __delete__(self, obj: object) -> None:
with self.lock:
super().__delete__(obj)
def total_seconds(td):
def total_seconds(td: timedelta) -> int:
"""Returns the total seconds from a timedelta object.
:param timedelta td: the timedelta to be converted in seconds
@ -716,7 +732,7 @@ def total_seconds(td):
return td.days * 60 * 60 * 24 + td.seconds
def is_ip(value):
def is_ip(value: str) -> bool:
"""Determine if the given string is an IP address.
:param value: value to check
@ -736,7 +752,7 @@ def is_ip(value):
return False
def run_async(func):
def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function *func*."""
try:
from asgiref.sync import async_to_sync
@ -752,7 +768,7 @@ def run_async(func):
)
@wraps(func)
def outer(*args, **kwargs):
def outer(*args: t.Any, **kwargs: t.Any) -> t.Any:
"""This function grabs the current context for the inner function.
This is similar to the copy_current_xxx_context functions in the
@ -764,7 +780,7 @@ def run_async(func):
ctx = _request_ctx_stack.top.copy()
@wraps(func)
async def inner(*a, **k):
async def inner(*a: t.Any, **k: t.Any) -> t.Any:
"""This restores the context before awaiting the func.
This is required as the function must be awaited within the
@ -780,5 +796,5 @@ def run_async(func):
return async_to_sync(inner)(*args, **kwargs)
outer._flask_sync_wrapper = True
outer._flask_sync_wrapper = True # type: ignore
return outer

View File

@ -1,20 +1,25 @@
import io
import json as _json
import typing as t
import uuid
import warnings
from datetime import date
from jinja2.utils import htmlsafe_json_dumps as _jinja_htmlsafe_dumps
from jinja2.utils import htmlsafe_json_dumps as _jinja_htmlsafe_dumps # type: ignore
from werkzeug.http import http_date
from ..globals import current_app
from ..globals import request
if t.TYPE_CHECKING:
from ..app import Flask
from ..wrappers import Response
try:
import dataclasses
except ImportError:
# Python < 3.7
dataclasses = None
dataclasses = None # type: ignore
class JSONEncoder(_json.JSONEncoder):
@ -34,7 +39,7 @@ class JSONEncoder(_json.JSONEncoder):
:attr:`flask.Blueprint.json_encoder` to override the default.
"""
def default(self, o):
def default(self, o: t.Any) -> t.Any:
"""Convert ``o`` to a JSON serializable type. See
:meth:`json.JSONEncoder.default`. Python does not support
overriding how basic types like ``str`` or ``list`` are
@ -48,7 +53,7 @@ class JSONEncoder(_json.JSONEncoder):
return dataclasses.asdict(o)
if hasattr(o, "__html__"):
return str(o.__html__())
return super().default(self, o)
return super().default(o)
class JSONDecoder(_json.JSONDecoder):
@ -62,14 +67,19 @@ class JSONDecoder(_json.JSONDecoder):
"""
def _dump_arg_defaults(kwargs, app=None):
def _dump_arg_defaults(
kwargs: t.Dict[str, t.Any], app: t.Optional["Flask"] = None
) -> None:
"""Inject default arguments for dump functions."""
if app is None:
app = current_app
if app:
bp = app.blueprints.get(request.blueprint) if request else None
cls = bp.json_encoder if bp and bp.json_encoder else app.json_encoder
cls = app.json_encoder
bp = app.blueprints.get(request.blueprint) if request else None # type: ignore
if bp is not None and bp.json_encoder is not None:
cls = bp.json_encoder
kwargs.setdefault("cls", cls)
kwargs.setdefault("ensure_ascii", app.config["JSON_AS_ASCII"])
kwargs.setdefault("sort_keys", app.config["JSON_SORT_KEYS"])
@ -78,20 +88,25 @@ def _dump_arg_defaults(kwargs, app=None):
kwargs.setdefault("cls", JSONEncoder)
def _load_arg_defaults(kwargs, app=None):
def _load_arg_defaults(
kwargs: t.Dict[str, t.Any], app: t.Optional["Flask"] = None
) -> None:
"""Inject default arguments for load functions."""
if app is None:
app = current_app
if app:
bp = app.blueprints.get(request.blueprint) if request else None
cls = bp.json_decoder if bp and bp.json_decoder else app.json_decoder
cls = app.json_decoder
bp = app.blueprints.get(request.blueprint) if request else None # type: ignore
if bp is not None and bp.json_decoder is not None:
cls = bp.json_decoder
kwargs.setdefault("cls", cls)
else:
kwargs.setdefault("cls", JSONDecoder)
def dumps(obj, app=None, **kwargs):
def dumps(obj: t.Any, app: t.Optional["Flask"] = None, **kwargs: t.Any) -> str:
"""Serialize an object to a string of JSON.
Takes the same arguments as the built-in :func:`json.dumps`, with
@ -121,12 +136,14 @@ def dumps(obj, app=None, **kwargs):
)
if isinstance(rv, str):
return rv.encode(encoding)
return rv.encode(encoding) # type: ignore
return rv
def dump(obj, fp, app=None, **kwargs):
def dump(
obj: t.Any, fp: t.IO[str], app: t.Optional["Flask"] = None, **kwargs: t.Any
) -> None:
"""Serialize an object to JSON written to a file object.
Takes the same arguments as the built-in :func:`json.dump`, with
@ -150,7 +167,7 @@ def dump(obj, fp, app=None, **kwargs):
fp.write("")
except TypeError:
show_warning = True
fp = io.TextIOWrapper(fp, encoding or "utf-8")
fp = io.TextIOWrapper(fp, encoding or "utf-8") # type: ignore
if show_warning:
warnings.warn(
@ -163,7 +180,7 @@ def dump(obj, fp, app=None, **kwargs):
_json.dump(obj, fp, **kwargs)
def loads(s, app=None, **kwargs):
def loads(s: str, app: t.Optional["Flask"] = None, **kwargs: t.Any) -> t.Any:
"""Deserialize an object from a string of JSON.
Takes the same arguments as the built-in :func:`json.loads`, with
@ -199,7 +216,7 @@ def loads(s, app=None, **kwargs):
return _json.loads(s, **kwargs)
def load(fp, app=None, **kwargs):
def load(fp: t.IO[str], app: t.Optional["Flask"] = None, **kwargs: t.Any) -> t.Any:
"""Deserialize an object from JSON read from a file object.
Takes the same arguments as the built-in :func:`json.load`, with
@ -227,12 +244,12 @@ def load(fp, app=None, **kwargs):
)
if isinstance(fp.read(0), bytes):
fp = io.TextIOWrapper(fp, encoding)
fp = io.TextIOWrapper(fp, encoding) # type: ignore
return _json.load(fp, **kwargs)
def htmlsafe_dumps(obj, **kwargs):
def htmlsafe_dumps(obj: t.Any, **kwargs: t.Any) -> str:
"""Serialize an object to a string of JSON with :func:`dumps`, then
replace HTML-unsafe characters with Unicode escapes and mark the
result safe with :class:`~markupsafe.Markup`.
@ -256,7 +273,7 @@ def htmlsafe_dumps(obj, **kwargs):
return _jinja_htmlsafe_dumps(obj, dumps=dumps, **kwargs)
def htmlsafe_dump(obj, fp, **kwargs):
def htmlsafe_dump(obj: t.Any, fp: t.IO[str], **kwargs: t.Any) -> None:
"""Serialize an object to JSON written to a file object, replacing
HTML-unsafe characters with Unicode escapes. See
:func:`htmlsafe_dumps` and :func:`dumps`.
@ -264,7 +281,7 @@ def htmlsafe_dump(obj, fp, **kwargs):
fp.write(htmlsafe_dumps(obj, **kwargs))
def jsonify(*args, **kwargs):
def jsonify(*args: t.Any, **kwargs: t.Any) -> "Response":
"""Serialize data to JSON and wrap it in a :class:`~flask.Response`
with the :mimetype:`application/json` mimetype.

View File

@ -40,6 +40,7 @@ be processed before ``dict``.
app.session_interface.serializer.register(TagOrderedDict, index=0)
"""
import typing as t
from base64 import b64decode
from base64 import b64encode
from datetime import datetime
@ -60,27 +61,27 @@ class JSONTag:
#: The tag to mark the serialized object with. If ``None``, this tag is
#: only used as an intermediate step during tagging.
key = None
key: t.Optional[str] = None
def __init__(self, serializer):
def __init__(self, serializer: "TaggedJSONSerializer") -> None:
"""Create a tagger for the given serializer."""
self.serializer = serializer
def check(self, value):
def check(self, value: t.Any) -> bool:
"""Check if the given value should be tagged by this tag."""
raise NotImplementedError
def to_json(self, value):
def to_json(self, value: t.Any) -> t.Any:
"""Convert the Python object to an object that is a valid JSON type.
The tag will be added later."""
raise NotImplementedError
def to_python(self, value):
def to_python(self, value: t.Any) -> t.Any:
"""Convert the JSON representation back to the correct type. The tag
will already be removed."""
raise NotImplementedError
def tag(self, value):
def tag(self, value: t.Any) -> t.Any:
"""Convert the value to a valid JSON type and add the tag structure
around it."""
return {self.key: self.to_json(value)}
@ -96,18 +97,18 @@ class TagDict(JSONTag):
__slots__ = ()
key = " di"
def check(self, value):
def check(self, value: t.Any) -> bool:
return (
isinstance(value, dict)
and len(value) == 1
and next(iter(value)) in self.serializer.tags
)
def to_json(self, value):
def to_json(self, value: t.Any) -> t.Any:
key = next(iter(value))
return {f"{key}__": self.serializer.tag(value[key])}
def to_python(self, value):
def to_python(self, value: t.Any) -> t.Any:
key = next(iter(value))
return {key[:-2]: value[key]}
@ -115,10 +116,10 @@ class TagDict(JSONTag):
class PassDict(JSONTag):
__slots__ = ()
def check(self, value):
def check(self, value: t.Any) -> bool:
return isinstance(value, dict)
def to_json(self, value):
def to_json(self, value: t.Any) -> t.Any:
# JSON objects may only have string keys, so don't bother tagging the
# key here.
return {k: self.serializer.tag(v) for k, v in value.items()}
@ -130,23 +131,23 @@ class TagTuple(JSONTag):
__slots__ = ()
key = " t"
def check(self, value):
def check(self, value: t.Any) -> bool:
return isinstance(value, tuple)
def to_json(self, value):
def to_json(self, value: t.Any) -> t.Any:
return [self.serializer.tag(item) for item in value]
def to_python(self, value):
def to_python(self, value: t.Any) -> t.Any:
return tuple(value)
class PassList(JSONTag):
__slots__ = ()
def check(self, value):
def check(self, value: t.Any) -> bool:
return isinstance(value, list)
def to_json(self, value):
def to_json(self, value: t.Any) -> t.Any:
return [self.serializer.tag(item) for item in value]
tag = to_json
@ -156,13 +157,13 @@ class TagBytes(JSONTag):
__slots__ = ()
key = " b"
def check(self, value):
def check(self, value: t.Any) -> bool:
return isinstance(value, bytes)
def to_json(self, value):
def to_json(self, value: t.Any) -> t.Any:
return b64encode(value).decode("ascii")
def to_python(self, value):
def to_python(self, value: t.Any) -> t.Any:
return b64decode(value)
@ -174,13 +175,13 @@ class TagMarkup(JSONTag):
__slots__ = ()
key = " m"
def check(self, value):
def check(self, value: t.Any) -> bool:
return callable(getattr(value, "__html__", None))
def to_json(self, value):
def to_json(self, value: t.Any) -> t.Any:
return str(value.__html__())
def to_python(self, value):
def to_python(self, value: t.Any) -> t.Any:
return Markup(value)
@ -188,13 +189,13 @@ class TagUUID(JSONTag):
__slots__ = ()
key = " u"
def check(self, value):
def check(self, value: t.Any) -> bool:
return isinstance(value, UUID)
def to_json(self, value):
def to_json(self, value: t.Any) -> t.Any:
return value.hex
def to_python(self, value):
def to_python(self, value: t.Any) -> t.Any:
return UUID(value)
@ -202,13 +203,13 @@ class TagDateTime(JSONTag):
__slots__ = ()
key = " d"
def check(self, value):
def check(self, value: t.Any) -> bool:
return isinstance(value, datetime)
def to_json(self, value):
def to_json(self, value: t.Any) -> t.Any:
return http_date(value)
def to_python(self, value):
def to_python(self, value: t.Any) -> t.Any:
return parse_date(value)
@ -242,14 +243,19 @@ class TaggedJSONSerializer:
TagDateTime,
]
def __init__(self):
self.tags = {}
self.order = []
def __init__(self) -> None:
self.tags: t.Dict[str, JSONTag] = {}
self.order: t.List[JSONTag] = []
for cls in self.default_tags:
self.register(cls)
def register(self, tag_class, force=False, index=None):
def register(
self,
tag_class: t.Type[JSONTag],
force: bool = False,
index: t.Optional[int] = None,
) -> None:
"""Register a new tag with this serializer.
:param tag_class: tag class to register. Will be instantiated with this
@ -277,7 +283,7 @@ class TaggedJSONSerializer:
else:
self.order.insert(index, tag)
def tag(self, value):
def tag(self, value: t.Any) -> t.Dict[str, t.Any]:
"""Convert a value to a tagged representation if necessary."""
for tag in self.order:
if tag.check(value):
@ -285,7 +291,7 @@ class TaggedJSONSerializer:
return value
def untag(self, value):
def untag(self, value: t.Dict[str, t.Any]) -> t.Any:
"""Convert a tagged representation back to the original type."""
if len(value) != 1:
return value
@ -297,10 +303,10 @@ class TaggedJSONSerializer:
return self.tags[key].to_python(value[key])
def dumps(self, value):
def dumps(self, value: t.Any) -> str:
"""Tag the value and dump it to a compact JSON string."""
return dumps(self.tag(value), separators=(",", ":"))
def loads(self, value):
def loads(self, value: str) -> t.Any:
"""Load data from a JSON string and deserialized any tagged objects."""
return loads(value, object_hook=self.untag)

View File

@ -1,13 +1,17 @@
import logging
import sys
import typing as t
from werkzeug.local import LocalProxy
from .globals import request
if t.TYPE_CHECKING:
from .app import Flask
@LocalProxy
def wsgi_errors_stream():
def wsgi_errors_stream() -> t.TextIO:
"""Find the most appropriate error stream for the application. If a request
is active, log to ``wsgi.errors``, otherwise use ``sys.stderr``.
@ -19,7 +23,7 @@ def wsgi_errors_stream():
return request.environ["wsgi.errors"] if request else sys.stderr
def has_level_handler(logger):
def has_level_handler(logger: logging.Logger) -> bool:
"""Check if there is a handler in the logging chain that will handle the
given logger's :meth:`effective level <~logging.Logger.getEffectiveLevel>`.
"""
@ -33,20 +37,20 @@ def has_level_handler(logger):
if not current.propagate:
break
current = current.parent
current = current.parent # type: ignore
return False
#: Log messages to :func:`~flask.logging.wsgi_errors_stream` with the format
#: ``[%(asctime)s] %(levelname)s in %(module)s: %(message)s``.
default_handler = logging.StreamHandler(wsgi_errors_stream)
default_handler = logging.StreamHandler(wsgi_errors_stream) # type: ignore
default_handler.setFormatter(
logging.Formatter("[%(asctime)s] %(levelname)s in %(module)s: %(message)s")
)
def create_logger(app):
def create_logger(app: "Flask") -> logging.Logger:
"""Get the Flask app's logger and configure it if needed.
The logger name will be the same as

View File

@ -2,8 +2,11 @@ import importlib.util
import os
import pkgutil
import sys
import typing as t
from collections import defaultdict
from functools import update_wrapper
from json import JSONDecoder
from json import JSONEncoder
from jinja2 import FileSystemLoader
from werkzeug.exceptions import default_exceptions
@ -14,17 +17,28 @@ from .globals import current_app
from .helpers import locked_cached_property
from .helpers import send_from_directory
from .templating import _default_template_ctx_processor
from .typing import AfterRequestCallable
from .typing import AppOrBlueprintKey
from .typing import BeforeRequestCallable
from .typing import ErrorHandlerCallable
from .typing import TeardownCallable
from .typing import TemplateContextProcessorCallable
from .typing import URLDefaultCallable
from .typing import URLValuePreprocessorCallable
if t.TYPE_CHECKING:
from .wrappers import Response
# a singleton sentinel value for parameter defaults
_sentinel = object()
def setupmethod(f):
def setupmethod(f: t.Callable) -> t.Callable:
"""Wraps a method so that it performs a check in debug mode if the
first request was already handled.
"""
def wrapper_func(self, *args, **kwargs):
def wrapper_func(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if self._is_setup_finished():
raise AssertionError(
"A setup function was called after the first request "
@ -60,24 +74,24 @@ class Scaffold:
"""
name: str
_static_folder = None
_static_url_path = None
_static_folder: t.Optional[str] = None
_static_url_path: t.Optional[str] = None
#: JSON encoder class used by :func:`flask.json.dumps`. If a
#: blueprint sets this, it will be used instead of the app's value.
json_encoder = None
json_encoder: t.Optional[t.Type[JSONEncoder]] = None
#: JSON decoder class used by :func:`flask.json.loads`. If a
#: blueprint sets this, it will be used instead of the app's value.
json_decoder = None
json_decoder: t.Optional[t.Type[JSONDecoder]] = None
def __init__(
self,
import_name,
static_folder=None,
static_url_path=None,
template_folder=None,
root_path=None,
import_name: str,
static_folder: t.Optional[str] = None,
static_url_path: t.Optional[str] = None,
template_folder: t.Optional[str] = None,
root_path: t.Optional[str] = None,
):
#: The name of the package or module that this object belongs
#: to. Do not change this once it is set by the constructor.
@ -110,7 +124,7 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.view_functions = {}
self.view_functions: t.Dict[str, t.Callable] = {}
#: A data structure of registered error handlers, in the format
#: ``{scope: {code: {class: handler}}}```. The ``scope`` key is
@ -125,7 +139,10 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.error_handler_spec = defaultdict(lambda: defaultdict(dict))
self.error_handler_spec: t.Dict[
AppOrBlueprintKey,
t.Dict[t.Optional[int], t.Dict[t.Type[Exception], ErrorHandlerCallable]],
] = defaultdict(lambda: defaultdict(dict))
#: A data structure of functions to call at the beginning of
#: each request, in the format ``{scope: [functions]}``. The
@ -137,7 +154,9 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.before_request_funcs = defaultdict(list)
self.before_request_funcs: t.Dict[
AppOrBlueprintKey, t.List[BeforeRequestCallable]
] = defaultdict(list)
#: A data structure of functions to call at the end of each
#: request, in the format ``{scope: [functions]}``. The
@ -149,7 +168,9 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.after_request_funcs = defaultdict(list)
self.after_request_funcs: t.Dict[
AppOrBlueprintKey, t.List[AfterRequestCallable]
] = defaultdict(list)
#: A data structure of functions to call at the end of each
#: request even if an exception is raised, in the format
@ -162,7 +183,9 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.teardown_request_funcs = defaultdict(list)
self.teardown_request_funcs: t.Dict[
AppOrBlueprintKey, t.List[TeardownCallable]
] = defaultdict(list)
#: A data structure of functions to call to pass extra context
#: values when rendering templates, in the format
@ -175,9 +198,9 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.template_context_processors = defaultdict(
list, {None: [_default_template_ctx_processor]}
)
self.template_context_processors: t.Dict[
AppOrBlueprintKey, t.List[TemplateContextProcessorCallable]
] = defaultdict(list, {None: [_default_template_ctx_processor]})
#: A data structure of functions to call to modify the keyword
#: arguments passed to the view function, in the format
@ -190,7 +213,10 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.url_value_preprocessors = defaultdict(list)
self.url_value_preprocessors: t.Dict[
AppOrBlueprintKey,
t.List[URLValuePreprocessorCallable],
] = defaultdict(list)
#: A data structure of functions to call to modify the keyword
#: arguments when generating URLs, in the format
@ -203,31 +229,35 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.url_default_functions = defaultdict(list)
self.url_default_functions: t.Dict[
AppOrBlueprintKey, t.List[URLDefaultCallable]
] = defaultdict(list)
def __repr__(self):
def __repr__(self) -> str:
return f"<{type(self).__name__} {self.name!r}>"
def _is_setup_finished(self):
def _is_setup_finished(self) -> bool:
raise NotImplementedError
@property
def static_folder(self):
def static_folder(self) -> t.Optional[str]:
"""The absolute path to the configured static folder. ``None``
if no static folder is set.
"""
if self._static_folder is not None:
return os.path.join(self.root_path, self._static_folder)
else:
return None
@static_folder.setter
def static_folder(self, value):
def static_folder(self, value: t.Optional[str]) -> None:
if value is not None:
value = os.fspath(value).rstrip(r"\/")
self._static_folder = value
@property
def has_static_folder(self):
def has_static_folder(self) -> bool:
"""``True`` if :attr:`static_folder` is set.
.. versionadded:: 0.5
@ -235,7 +265,7 @@ class Scaffold:
return self.static_folder is not None
@property
def static_url_path(self):
def static_url_path(self) -> t.Optional[str]:
"""The URL prefix that the static route will be accessible from.
If it was not configured during init, it is derived from
@ -248,14 +278,16 @@ class Scaffold:
basename = os.path.basename(self.static_folder)
return f"/{basename}".rstrip("/")
return None
@static_url_path.setter
def static_url_path(self, value):
def static_url_path(self, value: t.Optional[str]) -> None:
if value is not None:
value = value.rstrip("/")
self._static_url_path = value
def get_send_file_max_age(self, filename):
def get_send_file_max_age(self, filename: str) -> t.Optional[int]:
"""Used by :func:`send_file` to determine the ``max_age`` cache
value for a given file path if it wasn't passed.
@ -276,7 +308,7 @@ class Scaffold:
return int(value.total_seconds())
def send_static_file(self, filename):
def send_static_file(self, filename: str) -> "Response":
"""The view function used to serve files from
:attr:`static_folder`. A route is automatically registered for
this view at :attr:`static_url_path` if :attr:`static_folder` is
@ -290,10 +322,12 @@ class Scaffold:
# send_file only knows to call get_send_file_max_age on the app,
# call it here so it works for blueprints too.
max_age = self.get_send_file_max_age(filename)
return send_from_directory(self.static_folder, filename, max_age=max_age)
return send_from_directory(
t.cast(str, self.static_folder), filename, max_age=max_age
)
@locked_cached_property
def jinja_loader(self):
def jinja_loader(self) -> t.Optional[FileSystemLoader]:
"""The Jinja loader for this object's templates. By default this
is a class :class:`jinja2.loaders.FileSystemLoader` to
:attr:`template_folder` if it is set.
@ -302,8 +336,10 @@ class Scaffold:
"""
if self.template_folder is not None:
return FileSystemLoader(os.path.join(self.root_path, self.template_folder))
else:
return None
def open_resource(self, resource, mode="rb"):
def open_resource(self, resource: str, mode: str = "rb") -> t.IO[t.AnyStr]:
"""Open a resource file relative to :attr:`root_path` for
reading.
@ -326,48 +362,48 @@ class Scaffold:
return open(os.path.join(self.root_path, resource), mode)
def _method_route(self, method, rule, options):
def _method_route(self, method: str, rule: str, options: dict) -> t.Callable:
if "methods" in options:
raise TypeError("Use the 'route' decorator to use the 'methods' argument.")
return self.route(rule, methods=[method], **options)
def get(self, rule, **options):
def get(self, rule: str, **options: t.Any) -> t.Callable:
"""Shortcut for :meth:`route` with ``methods=["GET"]``.
.. versionadded:: 2.0
"""
return self._method_route("GET", rule, options)
def post(self, rule, **options):
def post(self, rule: str, **options: t.Any) -> t.Callable:
"""Shortcut for :meth:`route` with ``methods=["POST"]``.
.. versionadded:: 2.0
"""
return self._method_route("POST", rule, options)
def put(self, rule, **options):
def put(self, rule: str, **options: t.Any) -> t.Callable:
"""Shortcut for :meth:`route` with ``methods=["PUT"]``.
.. versionadded:: 2.0
"""
return self._method_route("PUT", rule, options)
def delete(self, rule, **options):
def delete(self, rule: str, **options: t.Any) -> t.Callable:
"""Shortcut for :meth:`route` with ``methods=["DELETE"]``.
.. versionadded:: 2.0
"""
return self._method_route("DELETE", rule, options)
def patch(self, rule, **options):
def patch(self, rule: str, **options: t.Any) -> t.Callable:
"""Shortcut for :meth:`route` with ``methods=["PATCH"]``.
.. versionadded:: 2.0
"""
return self._method_route("PATCH", rule, options)
def route(self, rule, **options):
def route(self, rule: str, **options: t.Any) -> t.Callable:
"""Decorate a view function to register it with the given URL
rule and options. Calls :meth:`add_url_rule`, which has more
details about the implementation.
@ -391,7 +427,7 @@ class Scaffold:
:class:`~werkzeug.routing.Rule` object.
"""
def decorator(f):
def decorator(f: t.Callable) -> t.Callable:
endpoint = options.pop("endpoint", None)
self.add_url_rule(rule, endpoint, f, **options)
return f
@ -401,12 +437,12 @@ class Scaffold:
@setupmethod
def add_url_rule(
self,
rule,
endpoint=None,
view_func=None,
provide_automatic_options=None,
**options,
):
rule: str,
endpoint: t.Optional[str] = None,
view_func: t.Optional[t.Callable] = None,
provide_automatic_options: t.Optional[bool] = None,
**options: t.Any,
) -> t.Callable:
"""Register a rule for routing incoming requests and building
URLs. The :meth:`route` decorator is a shortcut to call this
with the ``view_func`` argument. These are equivalent:
@ -466,7 +502,7 @@ class Scaffold:
"""
raise NotImplementedError
def endpoint(self, endpoint):
def endpoint(self, endpoint: str) -> t.Callable:
"""Decorate a view function to register it for the given
endpoint. Used if a rule is added without a ``view_func`` with
:meth:`add_url_rule`.
@ -490,7 +526,7 @@ class Scaffold:
return decorator
@setupmethod
def before_request(self, f):
def before_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable:
"""Register a function to run before each request.
For example, this can be used to open a database connection, or
@ -512,7 +548,7 @@ class Scaffold:
return f
@setupmethod
def after_request(self, f):
def after_request(self, f: AfterRequestCallable) -> AfterRequestCallable:
"""Register a function to run after each request to this object.
The function is called with the response object, and must return
@ -528,7 +564,7 @@ class Scaffold:
return f
@setupmethod
def teardown_request(self, f):
def teardown_request(self, f: TeardownCallable) -> TeardownCallable:
"""Register a function to be run at the end of each request,
regardless of whether there was an exception or not. These functions
are executed when the request context is popped, even if not an
@ -567,13 +603,17 @@ class Scaffold:
return f
@setupmethod
def context_processor(self, f):
def context_processor(
self, f: TemplateContextProcessorCallable
) -> TemplateContextProcessorCallable:
"""Registers a template context processor function."""
self.template_context_processors[None].append(f)
return f
@setupmethod
def url_value_preprocessor(self, f):
def url_value_preprocessor(
self, f: URLValuePreprocessorCallable
) -> URLValuePreprocessorCallable:
"""Register a URL value preprocessor function for all view
functions in the application. These functions will be called before the
:meth:`before_request` functions.
@ -590,7 +630,7 @@ class Scaffold:
return f
@setupmethod
def url_defaults(self, f):
def url_defaults(self, f: URLDefaultCallable) -> URLDefaultCallable:
"""Callback function for URL defaults for all view functions of the
application. It's called with the endpoint and values and should
update the values passed in place.
@ -599,7 +639,9 @@ class Scaffold:
return f
@setupmethod
def errorhandler(self, code_or_exception):
def errorhandler(
self, code_or_exception: t.Union[t.Type[Exception], int]
) -> t.Callable:
"""Register a function to handle errors by code or exception class.
A decorator that is used to register a function given an
@ -629,14 +671,18 @@ class Scaffold:
an arbitrary exception
"""
def decorator(f):
def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable:
self.register_error_handler(code_or_exception, f)
return f
return decorator
@setupmethod
def register_error_handler(self, code_or_exception, f):
def register_error_handler(
self,
code_or_exception: t.Union[t.Type[Exception], int],
f: ErrorHandlerCallable,
) -> None:
"""Alternative error attach function to the :meth:`errorhandler`
decorator that is more straightforward to use for non decorator
usage.
@ -662,7 +708,9 @@ class Scaffold:
self.error_handler_spec[None][code][exc_class] = self.ensure_sync(f)
@staticmethod
def _get_exc_class_and_code(exc_class_or_code):
def _get_exc_class_and_code(
exc_class_or_code: t.Union[t.Type[Exception], int]
) -> t.Tuple[t.Type[Exception], t.Optional[int]]:
"""Get the exception class being handled. For HTTP status codes
or ``HTTPException`` subclasses, return both the exception and
status code.
@ -670,6 +718,7 @@ class Scaffold:
:param exc_class_or_code: Any exception class, or an HTTP status
code as an integer.
"""
exc_class: t.Type[Exception]
if isinstance(exc_class_or_code, int):
exc_class = default_exceptions[exc_class_or_code]
else:
@ -684,11 +733,11 @@ class Scaffold:
else:
return exc_class, None
def ensure_sync(self, func):
def ensure_sync(self, func: t.Callable) -> t.Callable:
raise NotImplementedError()
def _endpoint_from_view_func(view_func):
def _endpoint_from_view_func(view_func: t.Callable) -> str:
"""Internal helper that returns the default endpoint for a given
function. This always is the function name.
"""
@ -696,7 +745,7 @@ def _endpoint_from_view_func(view_func):
return view_func.__name__
def get_root_path(import_name):
def get_root_path(import_name: str) -> str:
"""Find the root path of a package, or the path that contains a
module. If it cannot be found, returns the current working
directory.
@ -721,7 +770,7 @@ def get_root_path(import_name):
return os.getcwd()
if hasattr(loader, "get_filename"):
filepath = loader.get_filename(import_name)
filepath = loader.get_filename(import_name) # type: ignore
else:
# Fall back to imports.
__import__(import_name)
@ -822,7 +871,7 @@ def _find_package_path(root_mod_name):
return package_path
def find_package(import_name):
def find_package(import_name: str):
"""Find the prefix that a package is installed under, and the path
that it would be imported from.

View File

@ -1,4 +1,5 @@
import hashlib
import typing as t
import warnings
from collections.abc import MutableMapping
from datetime import datetime
@ -10,17 +11,21 @@ from werkzeug.datastructures import CallbackDict
from .helpers import is_ip
from .json.tag import TaggedJSONSerializer
if t.TYPE_CHECKING:
from .app import Flask
from .wrappers import Request, Response
class SessionMixin(MutableMapping):
"""Expands a basic dictionary with session attributes."""
@property
def permanent(self):
def permanent(self) -> bool:
"""This reflects the ``'_permanent'`` key in the dict."""
return self.get("_permanent", False)
@permanent.setter
def permanent(self, value):
def permanent(self, value: bool) -> None:
self["_permanent"] = bool(value)
#: Some implementations can detect whether a session is newly
@ -61,22 +66,22 @@ class SecureCookieSession(CallbackDict, SessionMixin):
#: different users.
accessed = False
def __init__(self, initial=None):
def on_update(self):
def __init__(self, initial: t.Any = None) -> None:
def on_update(self) -> None:
self.modified = True
self.accessed = True
super().__init__(initial, on_update)
def __getitem__(self, key):
def __getitem__(self, key: str) -> t.Any:
self.accessed = True
return super().__getitem__(key)
def get(self, key, default=None):
def get(self, key: str, default: t.Any = None) -> t.Any:
self.accessed = True
return super().get(key, default)
def setdefault(self, key, default=None):
def setdefault(self, key: str, default: t.Any = None) -> t.Any:
self.accessed = True
return super().setdefault(key, default)
@ -87,14 +92,14 @@ class NullSession(SecureCookieSession):
but fail on setting.
"""
def _fail(self, *args, **kwargs):
def _fail(self, *args: t.Any, **kwargs: t.Any) -> t.NoReturn:
raise RuntimeError(
"The session is unavailable because no secret "
"key was set. Set the secret_key on the "
"application to something unique and secret."
)
__setitem__ = __delitem__ = clear = pop = popitem = update = setdefault = _fail
__setitem__ = __delitem__ = clear = pop = popitem = update = setdefault = _fail # type: ignore # noqa: B950
del _fail
@ -141,7 +146,7 @@ class SessionInterface:
#: .. versionadded:: 0.10
pickle_based = False
def make_null_session(self, app):
def make_null_session(self, app: "Flask") -> NullSession:
"""Creates a null session which acts as a replacement object if the
real session support could not be loaded due to a configuration
error. This mainly aids the user experience because the job of the
@ -153,7 +158,7 @@ class SessionInterface:
"""
return self.null_session_class()
def is_null_session(self, obj):
def is_null_session(self, obj: object) -> bool:
"""Checks if a given object is a null session. Null sessions are
not asked to be saved.
@ -162,14 +167,14 @@ class SessionInterface:
"""
return isinstance(obj, self.null_session_class)
def get_cookie_name(self, app):
def get_cookie_name(self, app: "Flask") -> str:
"""Returns the name of the session cookie.
Uses ``app.session_cookie_name`` which is set to ``SESSION_COOKIE_NAME``
"""
return app.session_cookie_name
def get_cookie_domain(self, app):
def get_cookie_domain(self, app: "Flask") -> t.Optional[str]:
"""Returns the domain that should be set for the session cookie.
Uses ``SESSION_COOKIE_DOMAIN`` if it is configured, otherwise
@ -227,7 +232,7 @@ class SessionInterface:
app.config["SESSION_COOKIE_DOMAIN"] = rv
return rv
def get_cookie_path(self, app):
def get_cookie_path(self, app: "Flask") -> str:
"""Returns the path for which the cookie should be valid. The
default implementation uses the value from the ``SESSION_COOKIE_PATH``
config var if it's set, and falls back to ``APPLICATION_ROOT`` or
@ -235,27 +240,29 @@ class SessionInterface:
"""
return app.config["SESSION_COOKIE_PATH"] or app.config["APPLICATION_ROOT"]
def get_cookie_httponly(self, app):
def get_cookie_httponly(self, app: "Flask") -> bool:
"""Returns True if the session cookie should be httponly. This
currently just returns the value of the ``SESSION_COOKIE_HTTPONLY``
config var.
"""
return app.config["SESSION_COOKIE_HTTPONLY"]
def get_cookie_secure(self, app):
def get_cookie_secure(self, app: "Flask") -> bool:
"""Returns True if the cookie should be secure. This currently
just returns the value of the ``SESSION_COOKIE_SECURE`` setting.
"""
return app.config["SESSION_COOKIE_SECURE"]
def get_cookie_samesite(self, app):
def get_cookie_samesite(self, app: "Flask") -> str:
"""Return ``'Strict'`` or ``'Lax'`` if the cookie should use the
``SameSite`` attribute. This currently just returns the value of
the :data:`SESSION_COOKIE_SAMESITE` setting.
"""
return app.config["SESSION_COOKIE_SAMESITE"]
def get_expiration_time(self, app, session):
def get_expiration_time(
self, app: "Flask", session: SessionMixin
) -> t.Optional[datetime]:
"""A helper method that returns an expiration date for the session
or ``None`` if the session is linked to the browser session. The
default implementation returns now + the permanent session
@ -263,8 +270,9 @@ class SessionInterface:
"""
if session.permanent:
return datetime.utcnow() + app.permanent_session_lifetime
return None
def should_set_cookie(self, app, session):
def should_set_cookie(self, app: "Flask", session: SessionMixin) -> bool:
"""Used by session backends to determine if a ``Set-Cookie`` header
should be set for this session cookie for this response. If the session
has been modified, the cookie is set. If the session is permanent and
@ -280,7 +288,9 @@ class SessionInterface:
session.permanent and app.config["SESSION_REFRESH_EACH_REQUEST"]
)
def open_session(self, app, request):
def open_session(
self, app: "Flask", request: "Request"
) -> t.Optional[SessionMixin]:
"""This method has to be implemented and must either return ``None``
in case the loading failed because of a configuration error or an
instance of a session object which implements a dictionary like
@ -288,7 +298,9 @@ class SessionInterface:
"""
raise NotImplementedError()
def save_session(self, app, session, response):
def save_session(
self, app: "Flask", session: SessionMixin, response: "Response"
) -> None:
"""This is called for actual sessions returned by :meth:`open_session`
at the end of the request. This is still called during a request
context so if you absolutely need access to the request you can do
@ -319,7 +331,9 @@ class SecureCookieSessionInterface(SessionInterface):
serializer = session_json_serializer
session_class = SecureCookieSession
def get_signing_serializer(self, app):
def get_signing_serializer(
self, app: "Flask"
) -> t.Optional[URLSafeTimedSerializer]:
if not app.secret_key:
return None
signer_kwargs = dict(
@ -332,7 +346,9 @@ class SecureCookieSessionInterface(SessionInterface):
signer_kwargs=signer_kwargs,
)
def open_session(self, app, request):
def open_session(
self, app: "Flask", request: "Request"
) -> t.Optional[SecureCookieSession]:
s = self.get_signing_serializer(app)
if s is None:
return None
@ -346,7 +362,9 @@ class SecureCookieSessionInterface(SessionInterface):
except BadSignature:
return self.session_class()
def save_session(self, app, session, response):
def save_session(
self, app: "Flask", session: SessionMixin, response: "Response"
) -> None:
name = self.get_cookie_name(app)
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
@ -372,10 +390,10 @@ class SecureCookieSessionInterface(SessionInterface):
httponly = self.get_cookie_httponly(app)
expires = self.get_expiration_time(app, session)
val = self.get_signing_serializer(app).dumps(dict(session))
val = self.get_signing_serializer(app).dumps(dict(session)) # type: ignore
response.set_cookie(
name,
val,
val, # type: ignore
expires=expires,
httponly=httponly,
domain=domain,

View File

@ -1,3 +1,5 @@
import typing as t
try:
from blinker import Namespace
@ -5,8 +7,8 @@ try:
except ImportError:
signals_available = False
class Namespace:
def signal(self, name, doc=None):
class Namespace: # type: ignore
def signal(self, name: str, doc: t.Optional[str] = None) -> "_FakeSignal":
return _FakeSignal(name, doc)
class _FakeSignal:
@ -16,14 +18,14 @@ except ImportError:
will just ignore the arguments and do nothing instead.
"""
def __init__(self, name, doc=None):
def __init__(self, name: str, doc: t.Optional[str] = None) -> None:
self.name = name
self.__doc__ = doc
def send(self, *args, **kwargs):
def send(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
pass
def _fail(self, *args, **kwargs):
def _fail(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
raise RuntimeError(
"Signalling support is unavailable because the blinker"
" library is not installed."

View File

@ -1,5 +1,8 @@
import typing as t
from jinja2 import BaseLoader
from jinja2 import Environment as BaseEnvironment
from jinja2 import Template
from jinja2 import TemplateNotFound
from .globals import _app_ctx_stack
@ -7,8 +10,12 @@ from .globals import _request_ctx_stack
from .signals import before_render_template
from .signals import template_rendered
if t.TYPE_CHECKING:
from .app import Flask
from .scaffold import Scaffold
def _default_template_ctx_processor():
def _default_template_ctx_processor() -> t.Dict[str, t.Any]:
"""Default template context processor. Injects `request`,
`session` and `g`.
"""
@ -29,7 +36,7 @@ class Environment(BaseEnvironment):
name of the blueprint to referenced templates if necessary.
"""
def __init__(self, app, **options):
def __init__(self, app: "Flask", **options: t.Any) -> None:
if "loader" not in options:
options["loader"] = app.create_global_jinja_loader()
BaseEnvironment.__init__(self, **options)
@ -41,15 +48,19 @@ class DispatchingJinjaLoader(BaseLoader):
the blueprint folders.
"""
def __init__(self, app):
def __init__(self, app: "Flask") -> None:
self.app = app
def get_source(self, environment, template):
def get_source(
self, environment: Environment, template: str
) -> t.Tuple[str, t.Optional[str], t.Callable]:
if self.app.config["EXPLAIN_TEMPLATE_LOADING"]:
return self._get_source_explained(environment, template)
return self._get_source_fast(environment, template)
def _get_source_explained(self, environment, template):
def _get_source_explained(
self, environment: Environment, template: str
) -> t.Tuple[str, t.Optional[str], t.Callable]:
attempts = []
trv = None
@ -70,7 +81,9 @@ class DispatchingJinjaLoader(BaseLoader):
return trv
raise TemplateNotFound(template)
def _get_source_fast(self, environment, template):
def _get_source_fast(
self, environment: Environment, template: str
) -> t.Tuple[str, t.Optional[str], t.Callable]:
for _srcobj, loader in self._iter_loaders(template):
try:
return loader.get_source(environment, template)
@ -78,7 +91,9 @@ class DispatchingJinjaLoader(BaseLoader):
continue
raise TemplateNotFound(template)
def _iter_loaders(self, template):
def _iter_loaders(
self, template: str
) -> t.Generator[t.Tuple["Scaffold", BaseLoader], None, None]:
loader = self.app.jinja_loader
if loader is not None:
yield self.app, loader
@ -88,7 +103,7 @@ class DispatchingJinjaLoader(BaseLoader):
if loader is not None:
yield blueprint, loader
def list_templates(self):
def list_templates(self) -> t.List[str]:
result = set()
loader = self.app.jinja_loader
if loader is not None:
@ -103,7 +118,7 @@ class DispatchingJinjaLoader(BaseLoader):
return list(result)
def _render(template, context, app):
def _render(template: Template, context: dict, app: "Flask") -> str:
"""Renders the template and fires the signal"""
before_render_template.send(app, template=template, context=context)
@ -112,7 +127,9 @@ def _render(template, context, app):
return rv
def render_template(template_name_or_list, **context):
def render_template(
template_name_or_list: t.Union[str, t.List[str]], **context: t.Any
) -> str:
"""Renders a template from the template folder with the given
context.
@ -131,7 +148,7 @@ def render_template(template_name_or_list, **context):
)
def render_template_string(source, **context):
def render_template_string(source: str, **context: t.Any) -> str:
"""Renders a template from the given template source string
with the given context. Template variables will be autoescaped.

View File

@ -1,5 +1,7 @@
import typing as t
from contextlib import contextmanager
from copy import copy
from types import TracebackType
import werkzeug.test
from click.testing import CliRunner
@ -10,6 +12,11 @@ from werkzeug.wrappers import Request as BaseRequest
from . import _request_ctx_stack
from .cli import ScriptInfo
from .json import dumps as json_dumps
from .sessions import SessionMixin
if t.TYPE_CHECKING:
from .app import Flask
from .wrappers import Response
class EnvironBuilder(werkzeug.test.EnvironBuilder):
@ -36,14 +43,14 @@ class EnvironBuilder(werkzeug.test.EnvironBuilder):
def __init__(
self,
app,
path="/",
base_url=None,
subdomain=None,
url_scheme=None,
*args,
**kwargs,
):
app: "Flask",
path: str = "/",
base_url: t.Optional[str] = None,
subdomain: t.Optional[str] = None,
url_scheme: t.Optional[str] = None,
*args: t.Any,
**kwargs: t.Any,
) -> None:
assert not (base_url or subdomain or url_scheme) or (
base_url is not None
) != bool(
@ -74,7 +81,7 @@ class EnvironBuilder(werkzeug.test.EnvironBuilder):
self.app = app
super().__init__(path, base_url, *args, **kwargs)
def json_dumps(self, obj, **kwargs):
def json_dumps(self, obj: t.Any, **kwargs: t.Any) -> str: # type: ignore
"""Serialize ``obj`` to a JSON-formatted string.
The serialization will be configured according to the config associated
@ -99,9 +106,10 @@ class FlaskClient(Client):
Basic usage is outlined in the :doc:`/testing` chapter.
"""
application: "Flask"
preserve_context = False
def __init__(self, *args, **kwargs):
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.environ_base = {
"REMOTE_ADDR": "127.0.0.1",
@ -109,7 +117,9 @@ class FlaskClient(Client):
}
@contextmanager
def session_transaction(self, *args, **kwargs):
def session_transaction(
self, *args: t.Any, **kwargs: t.Any
) -> t.Generator[SessionMixin, None, None]:
"""When used in combination with a ``with`` statement this opens a
session transaction. This can be used to modify the session that
the test client uses. Once the ``with`` block is left the session is
@ -161,9 +171,14 @@ class FlaskClient(Client):
headers = resp.get_wsgi_headers(c.request.environ)
self.cookie_jar.extract_wsgi(c.request.environ, headers)
def open(
self, *args, as_tuple=False, buffered=False, follow_redirects=False, **kwargs
):
def open( # type: ignore
self,
*args: t.Any,
as_tuple: bool = False,
buffered: bool = False,
follow_redirects: bool = False,
**kwargs: t.Any,
) -> "Response":
# Same logic as super.open, but apply environ_base and preserve_context.
request = None
@ -198,20 +213,22 @@ class FlaskClient(Client):
finally:
builder.close()
return super().open(
return super().open( # type: ignore
request,
as_tuple=as_tuple,
buffered=buffered,
follow_redirects=follow_redirects,
)
def __enter__(self):
def __enter__(self) -> "FlaskClient":
if self.preserve_context:
raise RuntimeError("Cannot nest client invocations")
self.preserve_context = True
return self
def __exit__(self, exc_type, exc_value, tb):
def __exit__(
self, exc_type: type, exc_value: BaseException, tb: TracebackType
) -> None:
self.preserve_context = False
# Normally the request context is preserved until the next
@ -233,11 +250,13 @@ class FlaskCliRunner(CliRunner):
:meth:`~flask.Flask.test_cli_runner`. See :ref:`testing-cli`.
"""
def __init__(self, app, **kwargs):
def __init__(self, app: "Flask", **kwargs: t.Any) -> None:
self.app = app
super().__init__(**kwargs)
def invoke(self, cli=None, args=None, **kwargs):
def invoke( # type: ignore
self, cli: t.Any = None, args: t.Any = None, **kwargs: t.Any
) -> t.Any:
"""Invokes a CLI command in an isolated environment. See
:meth:`CliRunner.invoke <click.testing.CliRunner.invoke>` for
full method documentation. See :ref:`testing-cli` for examples.

46
src/flask/typing.py Normal file
View File

@ -0,0 +1,46 @@
import typing as t
if t.TYPE_CHECKING:
from werkzeug.datastructures import Headers # noqa: F401
from wsgiref.types import WSGIApplication # noqa: F401
from .wrappers import Response # noqa: F401
# The possible types that are directly convertible or are a Response object.
ResponseValue = t.Union[
"Response",
t.AnyStr,
t.Dict[str, t.Any], # any jsonify-able dict
t.Generator[t.AnyStr, None, None],
]
StatusCode = int
# the possible types for an individual HTTP header
HeaderName = str
HeaderValue = t.Union[str, t.List[str], t.Tuple[str, ...]]
# the possible types for HTTP headers
HeadersValue = t.Union[
"Headers", t.Dict[HeaderName, HeaderValue], t.List[t.Tuple[HeaderName, HeaderValue]]
]
# The possible types returned by a route function.
ResponseReturnValue = t.Union[
ResponseValue,
t.Tuple[ResponseValue, HeadersValue],
t.Tuple[ResponseValue, StatusCode],
t.Tuple[ResponseValue, StatusCode, HeadersValue],
"WSGIApplication",
]
AppOrBlueprintKey = t.Optional[str] # The App key is None, whereas blueprints are named
AfterRequestCallable = t.Callable[["Response"], "Response"]
BeforeRequestCallable = t.Callable[[], None]
ErrorHandlerCallable = t.Callable[[Exception], ResponseReturnValue]
TeardownCallable = t.Callable[[t.Optional[BaseException]], "Response"]
TemplateContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]]
TemplateFilterCallable = t.Callable[[t.Any], str]
TemplateGlobalCallable = t.Callable[[], t.Any]
TemplateTestCallable = t.Callable[[t.Any], bool]
URLDefaultCallable = t.Callable[[str, dict], None]
URLValuePreprocessorCallable = t.Callable[[t.Optional[str], t.Optional[dict]], None]

View File

@ -1,4 +1,7 @@
import typing as t
from .globals import request
from .typing import ResponseReturnValue
http_method_funcs = frozenset(
@ -39,10 +42,10 @@ class View:
"""
#: A list of methods this view can handle.
methods = None
methods: t.Optional[t.List[str]] = None
#: Setting this disables or force-enables the automatic options handling.
provide_automatic_options = None
provide_automatic_options: t.Optional[bool] = None
#: The canonical way to decorate class-based views is to decorate the
#: return value of as_view(). However since this moves parts of the
@ -53,9 +56,9 @@ class View:
#: view function is created the result is automatically decorated.
#:
#: .. versionadded:: 0.8
decorators = ()
decorators: t.List[t.Callable] = []
def dispatch_request(self):
def dispatch_request(self) -> ResponseReturnValue:
"""Subclasses have to override this method to implement the
actual view function code. This method is called with all
the arguments from the URL rule.
@ -63,7 +66,9 @@ class View:
raise NotImplementedError()
@classmethod
def as_view(cls, name, *class_args, **class_kwargs):
def as_view(
cls, name: str, *class_args: t.Any, **class_kwargs: t.Any
) -> t.Callable:
"""Converts the class into an actual view function that can be used
with the routing system. Internally this generates a function on the
fly which will instantiate the :class:`View` on each request and call
@ -73,8 +78,8 @@ class View:
constructor of the class.
"""
def view(*args, **kwargs):
self = view.view_class(*class_args, **class_kwargs)
def view(*args: t.Any, **kwargs: t.Any) -> ResponseReturnValue:
self = view.view_class(*class_args, **class_kwargs) # type: ignore
return self.dispatch_request(*args, **kwargs)
if cls.decorators:
@ -88,12 +93,12 @@ class View:
# view this thing came from, secondly it's also used for instantiating
# the view class so you can actually replace it with something else
# for testing purposes and debugging.
view.view_class = cls
view.view_class = cls # type: ignore
view.__name__ = name
view.__doc__ = cls.__doc__
view.__module__ = cls.__module__
view.methods = cls.methods
view.provide_automatic_options = cls.provide_automatic_options
view.methods = cls.methods # type: ignore
view.provide_automatic_options = cls.provide_automatic_options # type: ignore
return view
@ -140,7 +145,7 @@ class MethodView(View, metaclass=MethodViewType):
app.add_url_rule('/counter', view_func=CounterAPI.as_view('counter'))
"""
def dispatch_request(self, *args, **kwargs):
def dispatch_request(self, *args: t.Any, **kwargs: t.Any) -> ResponseReturnValue:
meth = getattr(self, request.method.lower(), None)
# If the request method is HEAD and we don't have a handler for it

View File

@ -1,3 +1,5 @@
import typing as t
from werkzeug.exceptions import BadRequest
from werkzeug.wrappers import Request as RequestBase
from werkzeug.wrappers import Response as ResponseBase
@ -5,6 +7,9 @@ from werkzeug.wrappers import Response as ResponseBase
from . import json
from .globals import current_app
if t.TYPE_CHECKING:
from werkzeug.routing import Rule
class Request(RequestBase):
"""The request object used by default in Flask. Remembers the
@ -31,26 +36,28 @@ class Request(RequestBase):
#: because the request was never internally bound.
#:
#: .. versionadded:: 0.6
url_rule = None
url_rule: t.Optional["Rule"] = None
#: A dict of view arguments that matched the request. If an exception
#: happened when matching, this will be ``None``.
view_args = None
view_args: t.Optional[t.Dict[str, t.Any]] = None
#: If matching the URL failed, this is the exception that will be
#: raised / was raised as part of the request handling. This is
#: usually a :exc:`~werkzeug.exceptions.NotFound` exception or
#: something similar.
routing_exception = None
routing_exception: t.Optional[Exception] = None
@property
def max_content_length(self):
def max_content_length(self) -> t.Optional[int]: # type: ignore
"""Read-only view of the ``MAX_CONTENT_LENGTH`` config key."""
if current_app:
return current_app.config["MAX_CONTENT_LENGTH"]
else:
return None
@property
def endpoint(self):
def endpoint(self) -> t.Optional[str]:
"""The endpoint that matched the request. This in combination with
:attr:`view_args` can be used to reconstruct the same or a
modified URL. If an exception happened when matching, this will
@ -58,14 +65,18 @@ class Request(RequestBase):
"""
if self.url_rule is not None:
return self.url_rule.endpoint
else:
return None
@property
def blueprint(self):
def blueprint(self) -> t.Optional[str]:
"""The name of the current blueprint"""
if self.url_rule and "." in self.url_rule.endpoint:
return self.url_rule.endpoint.rsplit(".", 1)[0]
else:
return None
def _load_form_data(self):
def _load_form_data(self) -> None:
RequestBase._load_form_data(self)
# In debug mode we're replacing the files multidict with an ad-hoc
@ -80,7 +91,7 @@ class Request(RequestBase):
attach_enctype_error_multidict(self)
def on_json_loading_failed(self, e):
def on_json_loading_failed(self, e: Exception) -> t.NoReturn:
if current_app and current_app.debug:
raise BadRequest(f"Failed to decode JSON object: {e}")
@ -110,7 +121,7 @@ class Response(ResponseBase):
json_module = json
@property
def max_cookie_size(self):
def max_cookie_size(self) -> int: # type: ignore
"""Read-only view of the :data:`MAX_COOKIE_SIZE` config key.
See :attr:`~werkzeug.wrappers.Response.max_cookie_size` in