This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch metric-macro-expansion in repository https://gitbox.apache.org/repos/asf/superset.git
commit d47ff5ea676311e0e8d0a9d878af8b2a4dbf8a5a Author: Beto Dealmeida <[email protected]> AuthorDate: Fri Feb 21 08:04:57 2025 -0600 fix: ensure metric_macro expands templates --- superset/jinja_context.py | 56 ++++-------- tests/unit_tests/jinja_context_test.py | 161 ++++++++++++++++++++++++--------- 2 files changed, 136 insertions(+), 81 deletions(-) diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 713bee777c..7bb7b58dbf 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -602,7 +602,10 @@ class BaseTemplateProcessor: kwargs.update(self._context) context = validate_template_context(self.engine, kwargs) - return template.render(context) + try: + return template.render(context) + except RecursionError as ex: + raise SupersetTemplateException("Cyclic filters detected") from ex class JinjaTemplateProcessor(BaseTemplateProcessor): @@ -659,11 +662,18 @@ class JinjaTemplateProcessor(BaseTemplateProcessor): "filter_values": partial(safe_proxy, extra_cache.filter_values), "get_filters": partial(safe_proxy, extra_cache.get_filters), "dataset": partial(safe_proxy, dataset_macro_with_context), - "metric": partial(safe_proxy, metric_macro), "get_time_filter": partial(safe_proxy, extra_cache.get_time_filter), } ) + # The `metric` filter needs the full context, in order to expand other filters + self._context["metric"] = partial( + safe_proxy, + metric_macro, + self.env, + self._context, + ) + class NoOpTemplateProcessor(BaseTemplateProcessor): def process_template(self, sql: str, **kwargs: Any) -> str: @@ -889,27 +899,12 @@ def get_dataset_id_from_context(metric_key: str) -> int: raise SupersetTemplateException(exc_message) -def has_metric_macro(template_string: str, env: Environment) -> bool: - """ - Checks if a template string contains a metric macro. - - >>> has_metric_macro("{{ metric('my_metric') }}") - True - - """ - ast = env.parse(template_string) - - def visit_node(node: Node) -> bool: - return ( - isinstance(node, Call) - and isinstance(node.node, nodes.Name) - and node.node.name == "metric" - ) or any(visit_node(child) for child in node.iter_child_nodes()) - - return visit_node(ast) - - -def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str: +def metric_macro( + env: Environment, + context: dict[str, Any], + metric_key: str, + dataset_id: Optional[int] = None, +) -> str: """ Given a metric key, returns its syntax. @@ -943,18 +938,7 @@ def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str: ) definition = metrics[metric_key] - - env = SandboxedEnvironment(undefined=DebugUndefined) - context = {"metric": partial(safe_proxy, metric_macro)} - while has_metric_macro(definition, env): - old_definition = definition - template = env.from_string(definition) - try: - definition = template.render(context) - except RecursionError as ex: - raise SupersetTemplateException("Cyclic metric macro detected") from ex - - if definition == old_definition: - break + template = env.from_string(definition) + definition = template.render(context) return definition diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index faba808128..99a35b0aa2 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -21,6 +21,8 @@ from typing import Any import pytest from freezegun import freeze_time +from jinja2 import DebugUndefined +from jinja2.sandbox import SandboxedEnvironment from pytest_mock import MockerFixture from sqlalchemy.dialects import mysql from sqlalchemy.dialects.postgresql import dialect @@ -32,6 +34,7 @@ from superset.exceptions import SupersetTemplateException from superset.jinja_context import ( dataset_macro, ExtraCache, + get_template_processor, metric_macro, safe_proxy, TimeFilter, @@ -540,7 +543,8 @@ def test_metric_macro_with_dataset_id(mocker: MockerFixture) -> None: schema="my_schema", sql=None, ) - assert metric_macro("count", 1) == "COUNT(*)" + env = SandboxedEnvironment(undefined=DebugUndefined) + assert metric_macro(env, {}, "count", 1) == "COUNT(*)" mock_get_form_data.assert_not_called() @@ -548,32 +552,65 @@ def test_metric_macro_recursive(mocker: MockerFixture) -> None: """ Test the ``metric_macro`` when the definition is recursive. """ + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + dataset = SqlaTable( + id=1, + metrics=[ + SqlMetric(metric_name="a", expression="COUNT(*)"), + SqlMetric(metric_name="b", expression="{{ metric('a') }}"), + SqlMetric(metric_name="c", expression="{{ metric('b') }}"), + ], + table_name="test_dataset", + database=database, + schema="my_schema", + sql=None, + ) + + mocker.patch("superset.jinja_context.get_user_id", return_value=42) mock_g = mocker.patch("superset.jinja_context.g") mock_g.form_data = {"datasource": {"id": 1}} DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 - DatasetDAO.find_by_id.return_value = SqlaTable( - table_name="test_dataset", + DatasetDAO.find_by_id.return_value = dataset + + processor = get_template_processor(database=database) + assert processor.process_template("{{ metric('c', 1) }}") == "COUNT(*)" + + +def test_metric_macro_expansion(mocker: MockerFixture) -> None: + """ + Test that the ``metric_macro`` expands other macros. + """ + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + dataset = SqlaTable( + id=1, metrics=[ - SqlMetric(metric_name="a", expression="COUNT(*)"), + SqlMetric(metric_name="a", expression="{{ current_user_id() }}"), SqlMetric(metric_name="b", expression="{{ metric('a') }}"), SqlMetric(metric_name="c", expression="{{ metric('b') }}"), ], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + table_name="test_dataset", + database=database, schema="my_schema", sql=None, ) - assert metric_macro("c", 1) == "COUNT(*)" + + mocker.patch("superset.jinja_context.get_user_id", return_value=42) + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = dataset + + processor = get_template_processor(database=database) + assert processor.process_template("{{ metric('c') }}") == "42" def test_metric_macro_recursive_compound(mocker: MockerFixture) -> None: """ Test the ``metric_macro`` when the definition is compound. """ - mock_g = mocker.patch("superset.jinja_context.g") - mock_g.form_data = {"datasource": {"id": 1}} - DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 - DatasetDAO.find_by_id.return_value = SqlaTable( - table_name="test_dataset", + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + dataset = SqlaTable( + id=1, metrics=[ SqlMetric(metric_name="a", expression="SUM(*)"), SqlMetric(metric_name="b", expression="COUNT(*)"), @@ -582,11 +619,20 @@ def test_metric_macro_recursive_compound(mocker: MockerFixture) -> None: expression="{{ metric('a') }} / {{ metric('b') }}", ), ], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + table_name="test_dataset", + database=database, schema="my_schema", sql=None, ) - assert metric_macro("c", 1) == "SUM(*) / COUNT(*)" + + mocker.patch("superset.jinja_context.get_user_id", return_value=42) + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = dataset + + processor = get_template_processor(database=database) + assert processor.process_template("{{ metric('c') }}") == "SUM(*) / COUNT(*)" def test_metric_macro_recursive_cyclic(mocker: MockerFixture) -> None: @@ -595,23 +641,30 @@ def test_metric_macro_recursive_cyclic(mocker: MockerFixture) -> None: In this case it should stop, and not go into an infinite loop. """ - mock_g = mocker.patch("superset.jinja_context.g") - mock_g.form_data = {"datasource": {"id": 1}} - DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 - DatasetDAO.find_by_id.return_value = SqlaTable( - table_name="test_dataset", + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + dataset = SqlaTable( + id=1, metrics=[ SqlMetric(metric_name="a", expression="{{ metric('c') }}"), SqlMetric(metric_name="b", expression="{{ metric('a') }}"), SqlMetric(metric_name="c", expression="{{ metric('b') }}"), ], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + table_name="test_dataset", + database=database, schema="my_schema", sql=None, ) + + mocker.patch("superset.jinja_context.get_user_id", return_value=42) + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = dataset + + processor = get_template_processor(database=database) with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("c", 1) - assert str(excinfo.value) == "Cyclic metric macro detected" + processor.process_template("{{ metric('c') }}") + assert str(excinfo.value) == "Cyclic filters detected" def test_metric_macro_recursive_infinite(mocker: MockerFixture) -> None: @@ -620,21 +673,28 @@ def test_metric_macro_recursive_infinite(mocker: MockerFixture) -> None: In this case it should stop, and not go into an infinite loop. """ - mock_g = mocker.patch("superset.jinja_context.g") - mock_g.form_data = {"datasource": {"id": 1}} - DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 - DatasetDAO.find_by_id.return_value = SqlaTable( - table_name="test_dataset", + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + dataset = SqlaTable( + id=1, metrics=[ SqlMetric(metric_name="a", expression="{{ metric('a') }}"), ], - database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + table_name="test_dataset", + database=database, schema="my_schema", sql=None, ) + + mocker.patch("superset.jinja_context.get_user_id", return_value=42) + mock_g = mocker.patch("superset.jinja_context.g") + mock_g.form_data = {"datasource": {"id": 1}} + DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 + DatasetDAO.find_by_id.return_value = dataset + + processor = get_template_processor(database=database) with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("a", 1) - assert str(excinfo.value) == "Cyclic metric macro detected" + processor.process_template("{{ metric('a') }}") + assert str(excinfo.value) == "Cyclic filters detected" def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> None: @@ -652,8 +712,9 @@ def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> None schema="my_schema", sql=None, ) + env = SandboxedEnvironment(undefined=DebugUndefined) with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("blah", 1) + metric_macro(env, {}, "blah", 1) assert str(excinfo.value) == "Metric ``blah`` not found in test_dataset." mock_get_form_data.assert_not_called() @@ -665,8 +726,9 @@ def test_metric_macro_invalid_dataset_id(mocker: MockerFixture) -> None: mock_get_form_data = mocker.patch("superset.views.utils.get_form_data") DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 DatasetDAO.find_by_id.return_value = None + env = SandboxedEnvironment(undefined=DebugUndefined) with pytest.raises(DatasetNotFoundError) as excinfo: - metric_macro("macro_key", 100) + metric_macro(env, {}, "macro_key", 100) assert str(excinfo.value) == "Dataset ID 100 not found." mock_get_form_data.assert_not_called() @@ -679,9 +741,10 @@ def test_metric_macro_no_dataset_id_no_context(mocker: MockerFixture) -> None: DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 mock_g = mocker.patch("superset.jinja_context.g") mock_g.form_data = {} + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -698,6 +761,8 @@ def test_metric_macro_no_dataset_id_with_context_missing_info( DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806 mock_g = mocker.patch("superset.jinja_context.g") mock_g.form_data = {"queries": []} + + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -716,7 +781,7 @@ def test_metric_macro_no_dataset_id_with_context_missing_info( } ): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -744,6 +809,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -759,7 +825,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id( ) } ): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" # Getting data from g's form_data mock_g.form_data = { @@ -772,7 +838,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id( ], } with app.test_request_context(): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" def test_metric_macro_no_dataset_id_with_context_datasource_id_none( @@ -786,6 +852,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -802,7 +869,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none( } ): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -819,7 +886,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none( } with app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -851,6 +918,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -866,7 +934,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id( ) } ): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" # Getting data from g's form_data mock_g.form_data = { @@ -879,7 +947,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id( ], } with app.test_request_context(): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" def test_metric_macro_no_dataset_id_with_context_slice_id_none( @@ -893,6 +961,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -909,7 +978,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none( } ): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -926,7 +995,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none( } with app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -945,6 +1014,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -961,7 +1031,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart( } ): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -978,7 +1048,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart( } with app.test_request_context(): with pytest.raises(SupersetTemplateException) as excinfo: - metric_macro("macro_key") + metric_macro(env, {}, "macro_key") assert str(excinfo.value) == ( "Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501 ) @@ -1006,6 +1076,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data( mock_g.form_data = {} # Getting the data from the request context + env = SandboxedEnvironment(undefined=DebugUndefined) with app.test_request_context( data={ "form_data": json.dumps( @@ -1017,7 +1088,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data( ) } ): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" # Getting data from g's form_data mock_g.form_data = { @@ -1025,7 +1096,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data( } with app.test_request_context(): - assert metric_macro("macro_key") == "COUNT(*)" + assert metric_macro(env, {}, "macro_key") == "COUNT(*)" @pytest.mark.parametrize(
