diff --git a/CHANGES.rst b/CHANGES.rst index be263970..d8dd974f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -11,6 +11,7 @@ Unreleased - Fixed the issue where typing requires template global decorators to accept functions with no arguments. :issue:`4098` - Support View and MethodView instances with async handlers. :issue:`4112` +- Enhance typing of ``app.errorhandler`` decorator. :issue:`4095` Version 2.0.1 diff --git a/src/flask/app.py b/src/flask/app.py index f6656458..22cc9abc 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -61,7 +61,6 @@ from .templating import Environment from .typing import AfterRequestCallable from .typing import BeforeFirstRequestCallable from .typing import BeforeRequestCallable -from .typing import ErrorHandlerCallable from .typing import ResponseReturnValue from .typing import TeardownCallable from .typing import TemplateContextProcessorCallable @@ -78,6 +77,7 @@ if t.TYPE_CHECKING: from .blueprints import Blueprint from .testing import FlaskClient from .testing import FlaskCliRunner + from .typing import ErrorHandlerCallable if sys.version_info >= (3, 8): iscoroutinefunction = inspect.iscoroutinefunction @@ -1268,7 +1268,9 @@ class Flask(Scaffold): self.shell_context_processors.append(f) return f - def _find_error_handler(self, e: Exception) -> t.Optional[ErrorHandlerCallable]: + def _find_error_handler( + self, e: Exception + ) -> t.Optional["ErrorHandlerCallable[Exception]"]: """Return a registered error handler for an exception in this order: blueprint handler for a specific code, app handler for a specific code, blueprint handler for an exception class, app handler for an exception diff --git a/src/flask/blueprints.py b/src/flask/blueprints.py index 883fc2ff..8241420b 100644 --- a/src/flask/blueprints.py +++ b/src/flask/blueprints.py @@ -8,7 +8,6 @@ from .scaffold import Scaffold from .typing import AfterRequestCallable from .typing import BeforeFirstRequestCallable from .typing import BeforeRequestCallable -from .typing import ErrorHandlerCallable from .typing import TeardownCallable from .typing import TemplateContextProcessorCallable from .typing import TemplateFilterCallable @@ -19,6 +18,7 @@ from .typing import URLValuePreprocessorCallable if t.TYPE_CHECKING: from .app import Flask + from .typing import ErrorHandlerCallable DeferredSetupFunction = t.Callable[["BlueprintSetupState"], t.Callable] @@ -581,7 +581,9 @@ class Blueprint(Scaffold): handler is used for all requests, even if outside of the blueprint. """ - def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable: + def decorator( + f: "ErrorHandlerCallable[Exception]", + ) -> "ErrorHandlerCallable[Exception]": self.record_once(lambda s: s.app.errorhandler(code)(f)) return f diff --git a/src/flask/scaffold.py b/src/flask/scaffold.py index 239bc46a..e80e1915 100644 --- a/src/flask/scaffold.py +++ b/src/flask/scaffold.py @@ -21,7 +21,7 @@ from .templating import _default_template_ctx_processor from .typing import AfterRequestCallable from .typing import AppOrBlueprintKey from .typing import BeforeRequestCallable -from .typing import ErrorHandlerCallable +from .typing import GenericException from .typing import TeardownCallable from .typing import TemplateContextProcessorCallable from .typing import URLDefaultCallable @@ -29,6 +29,7 @@ from .typing import URLValuePreprocessorCallable if t.TYPE_CHECKING: from .wrappers import Response + from .typing import ErrorHandlerCallable # a singleton sentinel value for parameter defaults _sentinel = object() @@ -144,7 +145,10 @@ class Scaffold: #: directly and its format may change at any time. self.error_handler_spec: t.Dict[ AppOrBlueprintKey, - t.Dict[t.Optional[int], t.Dict[t.Type[Exception], ErrorHandlerCallable]], + t.Dict[ + t.Optional[int], + t.Dict[t.Type[Exception], "ErrorHandlerCallable[Exception]"], + ], ] = defaultdict(lambda: defaultdict(dict)) #: A data structure of functions to call at the beginning of @@ -643,8 +647,11 @@ class Scaffold: @setupmethod def errorhandler( - self, code_or_exception: t.Union[t.Type[Exception], int] - ) -> t.Callable[[ErrorHandlerCallable], ErrorHandlerCallable]: + self, code_or_exception: t.Union[t.Type[GenericException], int] + ) -> t.Callable[ + ["ErrorHandlerCallable[GenericException]"], + "ErrorHandlerCallable[GenericException]", + ]: """Register a function to handle errors by code or exception class. A decorator that is used to register a function given an @@ -674,7 +681,9 @@ class Scaffold: an arbitrary exception """ - def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable: + def decorator( + f: "ErrorHandlerCallable[GenericException]", + ) -> "ErrorHandlerCallable[GenericException]": self.register_error_handler(code_or_exception, f) return f @@ -683,8 +692,8 @@ class Scaffold: @setupmethod def register_error_handler( self, - code_or_exception: t.Union[t.Type[Exception], int], - f: ErrorHandlerCallable, + code_or_exception: t.Union[t.Type[GenericException], int], + f: "ErrorHandlerCallable[GenericException]", ) -> None: """Alternative error attach function to the :meth:`errorhandler` decorator that is more straightforward to use for non decorator @@ -708,7 +717,9 @@ class Scaffold: " instead." ) - self.error_handler_spec[None][code][exc_class] = f + self.error_handler_spec[None][code][exc_class] = t.cast( + "ErrorHandlerCallable[Exception]", f + ) @staticmethod def _get_exc_class_and_code( diff --git a/src/flask/typing.py b/src/flask/typing.py index b1a6cbdc..f1c84670 100644 --- a/src/flask/typing.py +++ b/src/flask/typing.py @@ -33,11 +33,12 @@ ResponseReturnValue = t.Union[ "WSGIApplication", ] +GenericException = t.TypeVar("GenericException", bound=Exception, contravariant=True) + AppOrBlueprintKey = t.Optional[str] # The App key is None, whereas blueprints are named AfterRequestCallable = t.Callable[["Response"], "Response"] BeforeFirstRequestCallable = t.Callable[[], None] BeforeRequestCallable = t.Callable[[], t.Optional[ResponseReturnValue]] -ErrorHandlerCallable = t.Callable[[Exception], ResponseReturnValue] TeardownCallable = t.Callable[[t.Optional[BaseException]], None] TemplateContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]] TemplateFilterCallable = t.Callable[..., t.Any] @@ -45,3 +46,11 @@ TemplateGlobalCallable = t.Callable[..., t.Any] TemplateTestCallable = t.Callable[..., bool] URLDefaultCallable = t.Callable[[str, dict], None] URLValuePreprocessorCallable = t.Callable[[t.Optional[str], t.Optional[dict]], None] + + +if t.TYPE_CHECKING: + import typing_extensions as te + + class ErrorHandlerCallable(te.Protocol[GenericException]): + def __call__(self, error: GenericException) -> ResponseReturnValue: + ...