This is an automated email from the ASF dual-hosted git repository. michaelsmolina pushed a commit to branch 5.0 in repository https://gitbox.apache.org/repos/asf/superset.git
commit 654062db27cad5c5a7034e6d09012f9b87a331fe Author: Beto Dealmeida <[email protected]> AuthorDate: Mon Feb 24 18:08:50 2025 -0500 fix: ensure metric_macro expands templates (#32344) --- superset/jinja_context.py | 45 ++++++-- tests/unit_tests/jinja_context_test.py | 194 ++++++++++++++++++++++++++++++--- 2 files changed, 210 insertions(+), 29 deletions(-) diff --git a/superset/jinja_context.py b/superset/jinja_context.py index b0e29505a0..4b02a69a75 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -601,7 +601,12 @@ class BaseTemplateProcessor: kwargs.update(self._context) context = validate_template_context(self.engine, kwargs) - return template.render(context) + try: + return template.render(context) + except RecursionError as ex: + raise SupersetTemplateException( + "Infinite recursion detected in template" + ) from ex class JinjaTemplateProcessor(BaseTemplateProcessor): @@ -658,11 +663,18 @@ class JinjaTemplateProcessor(BaseTemplateProcessor): "filter_values": partial(safe_proxy, extra_cache.filter_values), "get_filters": partial(safe_proxy, extra_cache.get_filters), "dataset": partial(safe_proxy, dataset_macro_with_context), - "metric": partial(safe_proxy, metric_macro), "get_time_filter": partial(safe_proxy, extra_cache.get_time_filter), } ) + # The `metric` filter needs the full context, in order to expand other filters + self._context["metric"] = partial( + safe_proxy, + metric_macro, + self.env, + self._context, + ) + class NoOpTemplateProcessor(BaseTemplateProcessor): def process_template(self, sql: str, **kwargs: Any) -> str: @@ -888,7 +900,12 @@ def get_dataset_id_from_context(metric_key: str) -> int: raise SupersetTemplateException(exc_message) -def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str: +def metric_macro( + env: Environment, + context: dict[str, Any], + metric_key: str, + dataset_id: Optional[int] = None, +) -> str: """ Given a metric key, returns its syntax. @@ -911,13 +928,17 @@ def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str: metrics: dict[str, str] = { metric.metric_name: metric.expression for metric in dataset.metrics } - dataset_name = dataset.table_name - if metric := metrics.get(metric_key): - return metric - raise SupersetTemplateException( - _( - "Metric ``%(metric_name)s`` not found in %(dataset_name)s.", - metric_name=metric_key, - dataset_name=dataset_name, + if metric_key not in metrics: + raise SupersetTemplateException( + _( + "Metric ``%(metric_name)s`` not found in %(dataset_name)s.", + metric_name=metric_key, + dataset_name=dataset.table_name, + ) ) - ) + + definition = metrics[metric_key] + template = env.from_string(definition) + definition = template.render(context) + + return definition diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index c17c066b9d..d2ec9c8345 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -21,6 +21,8 @@ from typing import Any import pytest from freezegun import freeze_time +from jinja2 import DebugUndefined +from jinja2.sandbox import SandboxedEnvironment from pytest_mock import MockerFixture from sqlalchemy.dialects import mysql from sqlalchemy.dialects.postgresql import dialect @@ -32,6 +34,7 @@ from superset.exceptions import SupersetTemplateException from superset.jinja_context import ( dataset_macro, ExtraCache, + get_template_processor, metric_macro, safe_proxy, TimeFilter, @@ -540,10 +543,156 @@ def test_metric_macro_with_dataset_id(mocker: MockerFixture) -> None: schema="my_schema", sql=None, ) - assert metric_macro("count", 1) == "COUNT(*)" + env = SandboxedEnvironment(undefined=DebugUndefined) + assert metric_macro(env, {}, "count", 1) == "COUNT(*)" mock_get_form_data.assert_not_called() +def test_metric_macro_recursive(mocker: MockerFixture) -> None: + """ + Test the ``metric_macro`` when the definition is recursive. + """ + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + dataset = SqlaTable( + id=1, + metrics=[ + SqlMetric(metric_name="a", expression="COUNT(*)"), + SqlMetric(metric_name="b", expression="{{ metric('a') }}"), + SqlMetric(metric_name="c", expression="{{ metric('b') }}"), + ], + table_name="test_dataset", + database=database, + schema="my_schema", + sql=None, + ) + + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = dataset + + processor = get_template_processor(database=database) + assert processor.process_template("{{ metric('c', 1) }}") == "COUNT(*)" + + +def test_metric_macro_expansion(mocker: MockerFixture) -> None: + """ + Test that the ``metric_macro`` expands other macros. + """ + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + dataset = SqlaTable( + id=1, + metrics=[ + SqlMetric(metric_name="a", expression="{{ current_user_id() }}"), + SqlMetric(metric_name="b", expression="{{ metric('a') }}"), + SqlMetric(metric_name="c", expression="{{ metric('b') }}"), + ], + table_name="test_dataset", + database=database, + schema="my_schema", + sql=None, + ) + + mocker.patch("superset.jinja_context.get_user_id", return_value=42) + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = dataset + + processor = get_template_processor(database=database) + assert processor.process_template("{{ metric('c') }}") == "42" + + +def test_metric_macro_recursive_compound(mocker: MockerFixture) -> None: + """ + Test the ``metric_macro`` when the definition is compound. + """ + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + dataset = SqlaTable( + id=1, + metrics=[ + SqlMetric(metric_name="a", expression="SUM(*)"), + SqlMetric(metric_name="b", expression="COUNT(*)"), + SqlMetric( + metric_name="c", + expression="{{ metric('a') }} / {{ metric('b') }}", + ), + ], + table_name="test_dataset", + database=database, + schema="my_schema", + sql=None, + ) + + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = dataset + + processor = get_template_processor(database=database) + assert processor.process_template("{{ metric('c') }}") == "SUM(*) / COUNT(*)" + + +def test_metric_macro_recursive_cyclic(mocker: MockerFixture) -> None: + """ + Test the ``metric_macro`` when the definition is cyclic. + + In this case it should stop, and not go into an infinite loop. + """ + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + dataset = SqlaTable( + id=1, + metrics=[ + SqlMetric(metric_name="a", expression="{{ metric('c') }}"), + SqlMetric(metric_name="b", expression="{{ metric('a') }}"), + SqlMetric(metric_name="c", expression="{{ metric('b') }}"), + ], + table_name="test_dataset", + database=database, + schema="my_schema", + sql=None, + ) + + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = dataset + + processor = get_template_processor(database=database) + with pytest.raises(SupersetTemplateException) as excinfo: + processor.process_template("{{ metric('c') }}") + assert str(excinfo.value) == "Infinite recursion detected in template" + + +def test_metric_macro_recursive_infinite(mocker: MockerFixture) -> None: + """ + Test the ``metric_macro`` when the definition is cyclic. + + In this case it should stop, and not go into an infinite loop. + """ + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + dataset = SqlaTable( + id=1, + metrics=[ + SqlMetric(metric_name="a", expression="{{ metric('a') }}"), + ], + table_name="test_dataset", + database=database, + schema="my_schema", + sql=None, + ) + + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = dataset + + processor = get_template_processor(database=database) + with pytest.raises(SupersetTemplateException) as excinfo: + processor.process_template("{{ metric('a') }}") + assert str(excinfo.value) == "Infinite recursion detected in template" + + def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> None: """ Test the ``metric_macro`` when passing a dataset ID and an invalid key. @@ -559,8 +708,9 @@ def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> None schema="my_schema", sql=None, ) + env = SandboxedEnvironment(undefined=DebugUndefined) with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("blah", 1) + metric_macro(env, {}, "blah", 1) assert str(excinfo.value) == "Metric ``blah`` not found in test_dataset." mock_get_form_data.assert_not_called() @@ -572,8 +722,9 @@ def test_metric_macro_invalid_dataset_id(mocker: MockerFixture) -> None: mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 DatasetDAO.find_by_id.return_value = None + env = SandboxedEnvironment(undefined=DebugUndefined) with pytest.raises(DatasetNotFoundError) as excinfo: - metric_macro("macro_key", 100) + metric_macro(env, {}, "macro_key", 100) assert str(excinfo.value) == "Dataset ID 100 not found." mock_get_form_data.assert_not_called() @@ -586,9 +737,10 @@ def test_metric_macro_no_dataset_id_no_context(mocker: MockerFixture) -> None: DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 mock_g = mocker.patch("superset.jinja_context.g") mock_g.form_data = {} + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -605,6 +757,8 @@ def test_metric_macro_no_dataset_id_with_context_missing_info( DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 mock_g = mocker.patch("superset.jinja_context.g") mock_g.form_data = {"queries": []} + + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -623,7 +777,7 @@ def test_metric_macro_no_dataset_id_with_context_missing_info( } ): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -651,6 +805,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -666,7 +821,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id( ) } ): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" # Getting data from g's form_data mock_g.form_data = { @@ -679,7 +834,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id( ], } with app.test_request_context(): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" def test_metric_macro_no_dataset_id_with_context_datasource_id_none( @@ -693,6 +848,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -709,7 +865,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none( } ): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -726,7 +882,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none( } with app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -758,6 +914,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -773,7 +930,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id( ) } ): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" # Getting data from g's form_data mock_g.form_data = { @@ -786,7 +943,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id( ], } with app.test_request_context(): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" def test_metric_macro_no_dataset_id_with_context_slice_id_none( @@ -800,6 +957,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -816,7 +974,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none( } ): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -833,7 +991,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none( } with app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -852,6 +1010,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -868,7 +1027,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart( } ): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -885,7 +1044,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart( } with app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -913,6 +1072,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -924,7 +1084,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data( ) } ): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" # Getting data from g's form_data mock_g.form_data = { @@ -932,7 +1092,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data( } with app.test_request_context(): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" @pytest.mark.parametrize(
