From 166a2a6207027bff07fdfc5590ce04f9b37e9e8f Mon Sep 17 00:00:00 2001 From: Matthias Paulsen Date: Tue, 10 Aug 2021 00:23:57 +0200 Subject: [PATCH 1/2] Fix callback order for nested blueprints Handlers registered via url_value_preprocessor, before_request, context_processor, and url_defaults are called in downward order: First on the app and last on the current blueprint. Handlers registered via after_request and teardown_request are called in upward order: First on the current blueprint and last on the app. --- CHANGES.rst | 3 ++ src/flask/app.py | 48 +++++++++++------------- tests/test_blueprints.py | 80 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 26 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index eeba61ab..1b4551ab 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -21,6 +21,9 @@ Unreleased :issue:`4096` - The CLI loader handles ``**kwargs`` in a ``create_app`` function. :issue:`4170` +- Fix the order of ``before_request`` and other callbacks that trigger + before the view returns. They are called from the app down to the + closest nested blueprint. :issue:`4229` Version 2.0.1 diff --git a/src/flask/app.py b/src/flask/app.py index 5179305a..a9bdba65 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -745,12 +745,12 @@ class Flask(Scaffold): :param context: the context as a dictionary that is updated in place to add extra variables. """ - funcs: t.Iterable[ - TemplateContextProcessorCallable - ] = self.template_context_processors[None] + funcs: t.Iterable[TemplateContextProcessorCallable] = [] + if None in self.template_context_processors: + funcs = chain(funcs, self.template_context_processors[None]) reqctx = _request_ctx_stack.top if reqctx is not None: - for bp in request.blueprints: + for bp in reversed(request.blueprints): if bp in self.template_context_processors: funcs = chain(funcs, self.template_context_processors[bp]) orig_ctx = context.copy() @@ -1806,7 +1806,9 @@ class Flask(Scaffold): # This is called by url_for, which can be called outside a # request, can't use request.blueprints. bps = _split_blueprint_path(endpoint.rpartition(".")[0]) - bp_funcs = chain.from_iterable(self.url_default_functions[bp] for bp in bps) + bp_funcs = chain.from_iterable( + self.url_default_functions[bp] for bp in reversed(bps) + ) funcs = chain(funcs, bp_funcs) for func in funcs: @@ -1846,19 +1848,17 @@ class Flask(Scaffold): further request handling is stopped. """ - funcs: t.Iterable[URLValuePreprocessorCallable] = self.url_value_preprocessors[ - None - ] - for bp in request.blueprints: - if bp in self.url_value_preprocessors: - funcs = chain(funcs, self.url_value_preprocessors[bp]) + funcs: t.Iterable[URLValuePreprocessorCallable] = [] + for name in chain([None], reversed(request.blueprints)): + if name in self.url_value_preprocessors: + funcs = chain(funcs, self.url_value_preprocessors[name]) for func in funcs: func(request.endpoint, request.view_args) - funcs: t.Iterable[BeforeRequestCallable] = self.before_request_funcs[None] - for bp in request.blueprints: - if bp in self.before_request_funcs: - funcs = chain(funcs, self.before_request_funcs[bp]) + funcs: t.Iterable[BeforeRequestCallable] = [] + for name in chain([None], reversed(request.blueprints)): + if name in self.before_request_funcs: + funcs = chain(funcs, self.before_request_funcs[name]) for func in funcs: rv = self.ensure_sync(func)() if rv is not None: @@ -1881,11 +1881,9 @@ class Flask(Scaffold): """ ctx = _request_ctx_stack.top funcs: t.Iterable[AfterRequestCallable] = ctx._after_request_functions - for bp in request.blueprints: - if bp in self.after_request_funcs: - funcs = chain(funcs, reversed(self.after_request_funcs[bp])) - if None in self.after_request_funcs: - funcs = chain(funcs, reversed(self.after_request_funcs[None])) + for name in chain(request.blueprints, [None]): + if name in self.after_request_funcs: + funcs = chain(funcs, reversed(self.after_request_funcs[name])) for handler in funcs: response = self.ensure_sync(handler)(response) if not self.session_interface.is_null_session(ctx.session): @@ -1917,12 +1915,10 @@ class Flask(Scaffold): """ if exc is _sentinel: exc = sys.exc_info()[1] - funcs: t.Iterable[TeardownCallable] = reversed( - self.teardown_request_funcs[None] - ) - for bp in request.blueprints: - if bp in self.teardown_request_funcs: - funcs = chain(funcs, reversed(self.teardown_request_funcs[bp])) + funcs: t.Iterable[TeardownCallable] = [] + for name in chain(request.blueprints, [None]): + if name in self.teardown_request_funcs: + funcs = chain(funcs, reversed(self.teardown_request_funcs[name])) for func in funcs: self.ensure_sync(func)(exc) request_tearing_down.send(self, exc=exc) diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index a124c612..e02cd4be 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -837,6 +837,86 @@ def test_nested_blueprint(app, client): assert client.get("/parent/child/grandchild/no").data == b"Grandchild no" +def test_nested_callback_order(app, client): + parent = flask.Blueprint("parent", __name__) + child = flask.Blueprint("child", __name__) + + @app.before_request + def app_before1(): + flask.g.setdefault("seen", []).append("app_1") + + @app.teardown_request + def app_teardown1(e=None): + assert flask.g.seen.pop() == "app_1" + + @app.before_request + def app_before2(): + flask.g.setdefault("seen", []).append("app_2") + + @app.teardown_request + def app_teardown2(e=None): + assert flask.g.seen.pop() == "app_2" + + @app.context_processor + def app_ctx(): + return dict(key="app") + + @parent.before_request + def parent_before1(): + flask.g.setdefault("seen", []).append("parent_1") + + @parent.teardown_request + def parent_teardown1(e=None): + assert flask.g.seen.pop() == "parent_1" + + @parent.before_request + def parent_before2(): + flask.g.setdefault("seen", []).append("parent_2") + + @parent.teardown_request + def parent_teardown2(e=None): + assert flask.g.seen.pop() == "parent_2" + + @parent.context_processor + def parent_ctx(): + return dict(key="parent") + + @child.before_request + def child_before1(): + flask.g.setdefault("seen", []).append("child_1") + + @child.teardown_request + def child_teardown1(e=None): + assert flask.g.seen.pop() == "child_1" + + @child.before_request + def child_before2(): + flask.g.setdefault("seen", []).append("child_2") + + @child.teardown_request + def child_teardown2(e=None): + assert flask.g.seen.pop() == "child_2" + + @child.context_processor + def child_ctx(): + return dict(key="child") + + @child.route("/a") + def a(): + return ", ".join(flask.g.seen) + + @child.route("/b") + def b(): + return flask.render_template_string("{{ key }}") + + parent.register_blueprint(child) + app.register_blueprint(parent) + assert ( + client.get("/a").data == b"app_1, app_2, parent_1, parent_2, child_1, child_2" + ) + assert client.get("/b").data == b"child" + + @pytest.mark.parametrize( "parent_init, child_init, parent_registration, child_registration", [ From 3f6cdbd8b35f2b887dddbdac5ef6461833b6dd3b Mon Sep 17 00:00:00 2001 From: David Lord Date: Sun, 3 Oct 2021 20:19:33 -0700 Subject: [PATCH 2/2] use similar code for all callback-applying methods avoid building nested chain iterables avoid triggering defaultdict when looking up registries apply functions as they are looked up --- src/flask/app.py | 103 ++++++++++++++++++++++++----------------------- 1 file changed, 52 insertions(+), 51 deletions(-) diff --git a/src/flask/app.py b/src/flask/app.py index a9bdba65..23b99e2c 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -58,17 +58,12 @@ from .signals import request_started from .signals import request_tearing_down from .templating import DispatchingJinjaLoader from .templating import Environment -from .typing import AfterRequestCallable from .typing import BeforeFirstRequestCallable -from .typing import BeforeRequestCallable from .typing import ResponseReturnValue from .typing import TeardownCallable -from .typing import TemplateContextProcessorCallable from .typing import TemplateFilterCallable from .typing import TemplateGlobalCallable from .typing import TemplateTestCallable -from .typing import URLDefaultCallable -from .typing import URLValuePreprocessorCallable from .wrappers import Request from .wrappers import Response @@ -745,20 +740,21 @@ class Flask(Scaffold): :param context: the context as a dictionary that is updated in place to add extra variables. """ - funcs: t.Iterable[TemplateContextProcessorCallable] = [] - if None in self.template_context_processors: - funcs = chain(funcs, self.template_context_processors[None]) - reqctx = _request_ctx_stack.top - if reqctx is not None: - for bp in reversed(request.blueprints): - if bp in self.template_context_processors: - funcs = chain(funcs, self.template_context_processors[bp]) + names: t.Iterable[t.Optional[str]] = (None,) + + # A template may be rendered outside a request context. + if request: + names = chain(names, reversed(request.blueprints)) + + # The values passed to render_template take precedence. Keep a + # copy to re-apply after all context functions. orig_ctx = context.copy() - for func in funcs: - context.update(func()) - # make sure the original values win. This makes it possible to - # easier add new variables in context processors without breaking - # existing views. + + for name in names: + if name in self.template_context_processors: + for func in self.template_context_processors[name]: + context.update(func()) + context.update(orig_ctx) def make_shell_context(self) -> dict: @@ -1278,9 +1274,10 @@ class Flask(Scaffold): class, or ``None`` if a suitable handler is not found. """ exc_class, code = self._get_exc_class_and_code(type(e)) + names = (*request.blueprints, None) - for c in [code, None] if code is not None else [None]: - for name in chain(request.blueprints, [None]): + for c in (code, None) if code is not None else (None,): + for name in names: handler_map = self.error_handler_spec[name][c] if not handler_map: @@ -1800,19 +1797,19 @@ class Flask(Scaffold): .. versionadded:: 0.7 """ - funcs: t.Iterable[URLDefaultCallable] = self.url_default_functions[None] + names: t.Iterable[t.Optional[str]] = (None,) + # url_for may be called outside a request context, parse the + # passed endpoint instead of using request.blueprints. if "." in endpoint: - # This is called by url_for, which can be called outside a - # request, can't use request.blueprints. - bps = _split_blueprint_path(endpoint.rpartition(".")[0]) - bp_funcs = chain.from_iterable( - self.url_default_functions[bp] for bp in reversed(bps) + names = chain( + names, reversed(_split_blueprint_path(endpoint.rpartition(".")[0])) ) - funcs = chain(funcs, bp_funcs) - for func in funcs: - func(endpoint, values) + for name in names: + if name in self.url_default_functions: + for func in self.url_default_functions[name]: + func(endpoint, values) def handle_url_build_error( self, error: Exception, endpoint: str, values: dict @@ -1847,22 +1844,20 @@ class Flask(Scaffold): value is handled as if it was the return value from the view, and further request handling is stopped. """ + names = (None, *reversed(request.blueprints)) - funcs: t.Iterable[URLValuePreprocessorCallable] = [] - for name in chain([None], reversed(request.blueprints)): + for name in names: if name in self.url_value_preprocessors: - funcs = chain(funcs, self.url_value_preprocessors[name]) - for func in funcs: - func(request.endpoint, request.view_args) + for url_func in self.url_value_preprocessors[name]: + url_func(request.endpoint, request.view_args) - funcs: t.Iterable[BeforeRequestCallable] = [] - for name in chain([None], reversed(request.blueprints)): + for name in names: if name in self.before_request_funcs: - funcs = chain(funcs, self.before_request_funcs[name]) - for func in funcs: - rv = self.ensure_sync(func)() - if rv is not None: - return rv + for before_func in self.before_request_funcs[name]: + rv = self.ensure_sync(before_func)() + + if rv is not None: + return rv return None @@ -1880,14 +1875,18 @@ class Flask(Scaffold): instance of :attr:`response_class`. """ ctx = _request_ctx_stack.top - funcs: t.Iterable[AfterRequestCallable] = ctx._after_request_functions - for name in chain(request.blueprints, [None]): + + for func in ctx._after_request_functions: + response = self.ensure_sync(func)(response) + + for name in chain(request.blueprints, (None,)): if name in self.after_request_funcs: - funcs = chain(funcs, reversed(self.after_request_funcs[name])) - for handler in funcs: - response = self.ensure_sync(handler)(response) + for func in reversed(self.after_request_funcs[name]): + response = self.ensure_sync(func)(response) + if not self.session_interface.is_null_session(ctx.session): self.session_interface.save_session(self, ctx.session, response) + return response def do_teardown_request( @@ -1915,12 +1914,12 @@ class Flask(Scaffold): """ if exc is _sentinel: exc = sys.exc_info()[1] - funcs: t.Iterable[TeardownCallable] = [] - for name in chain(request.blueprints, [None]): + + for name in chain(request.blueprints, (None,)): if name in self.teardown_request_funcs: - funcs = chain(funcs, reversed(self.teardown_request_funcs[name])) - for func in funcs: - self.ensure_sync(func)(exc) + for func in reversed(self.teardown_request_funcs[name]): + self.ensure_sync(func)(exc) + request_tearing_down.send(self, exc=exc) def do_teardown_appcontext( @@ -1942,8 +1941,10 @@ class Flask(Scaffold): """ if exc is _sentinel: exc = sys.exc_info()[1] + for func in reversed(self.teardown_appcontext_funcs): self.ensure_sync(func)(exc) + appcontext_tearing_down.send(self, exc=exc) def app_context(self) -> AppContext: