This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch metric-macro-expansion
in repository https://gitbox.apache.org/repos/asf/superset.git

commit d47ff5ea676311e0e8d0a9d878af8b2a4dbf8a5a
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri Feb 21 08:04:57 2025 -0600

    fix: ensure metric_macro expands templates
---
 superset/jinja_context.py              |  56 ++++--------
 tests/unit_tests/jinja_context_test.py | 161 ++++++++++++++++++++++++---------
 2 files changed, 136 insertions(+), 81 deletions(-)

diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index 713bee777c..7bb7b58dbf 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -602,7 +602,10 @@ class BaseTemplateProcessor:
         kwargs.update(self._context)
 
         context = validate_template_context(self.engine, kwargs)
-        return template.render(context)
+        try:
+            return template.render(context)
+        except RecursionError as ex:
+            raise SupersetTemplateException("Cyclic filters detected") from ex
 
 
 class JinjaTemplateProcessor(BaseTemplateProcessor):
@@ -659,11 +662,18 @@ class JinjaTemplateProcessor(BaseTemplateProcessor):
                 "filter_values": partial(safe_proxy, 
extra_cache.filter_values),
                 "get_filters": partial(safe_proxy, extra_cache.get_filters),
                 "dataset": partial(safe_proxy, dataset_macro_with_context),
-                "metric": partial(safe_proxy, metric_macro),
                 "get_time_filter": partial(safe_proxy, 
extra_cache.get_time_filter),
             }
         )
 
+        # The `metric` filter needs the full context, in order to expand other 
filters
+        self._context["metric"] = partial(
+            safe_proxy,
+            metric_macro,
+            self.env,
+            self._context,
+        )
+
 
 class NoOpTemplateProcessor(BaseTemplateProcessor):
     def process_template(self, sql: str, **kwargs: Any) -> str:
@@ -889,27 +899,12 @@ def get_dataset_id_from_context(metric_key: str) -> int:
     raise SupersetTemplateException(exc_message)
 
 
-def has_metric_macro(template_string: str, env: Environment) -> bool:
-    """
-    Checks if a template string contains a metric macro.
-
-        >>> has_metric_macro("{{ metric('my_metric') }}")
-        True
-
-    """
-    ast = env.parse(template_string)
-
-    def visit_node(node: Node) -> bool:
-        return (
-            isinstance(node, Call)
-            and isinstance(node.node, nodes.Name)
-            and node.node.name == "metric"
-        ) or any(visit_node(child) for child in node.iter_child_nodes())
-
-    return visit_node(ast)
-
-
-def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str:
+def metric_macro(
+    env: Environment,
+    context: dict[str, Any],
+    metric_key: str,
+    dataset_id: Optional[int] = None,
+) -> str:
     """
     Given a metric key, returns its syntax.
 
@@ -943,18 +938,7 @@ def metric_macro(metric_key: str, dataset_id: 
Optional[int] = None) -> str:
         )
 
     definition = metrics[metric_key]
-
-    env = SandboxedEnvironment(undefined=DebugUndefined)
-    context = {"metric": partial(safe_proxy, metric_macro)}
-    while has_metric_macro(definition, env):
-        old_definition = definition
-        template = env.from_string(definition)
-        try:
-            definition = template.render(context)
-        except RecursionError as ex:
-            raise SupersetTemplateException("Cyclic metric macro detected") 
from ex
-
-        if definition == old_definition:
-            break
+    template = env.from_string(definition)
+    definition = template.render(context)
 
     return definition
diff --git a/tests/unit_tests/jinja_context_test.py 
b/tests/unit_tests/jinja_context_test.py
index faba808128..99a35b0aa2 100644
--- a/tests/unit_tests/jinja_context_test.py
+++ b/tests/unit_tests/jinja_context_test.py
@@ -21,6 +21,8 @@ from typing import Any
 
 import pytest
 from freezegun import freeze_time
+from jinja2 import DebugUndefined
+from jinja2.sandbox import SandboxedEnvironment
 from pytest_mock import MockerFixture
 from sqlalchemy.dialects import mysql
 from sqlalchemy.dialects.postgresql import dialect
@@ -32,6 +34,7 @@ from superset.exceptions import SupersetTemplateException
 from superset.jinja_context import (
     dataset_macro,
     ExtraCache,
+    get_template_processor,
     metric_macro,
     safe_proxy,
     TimeFilter,
@@ -540,7 +543,8 @@ def test_metric_macro_with_dataset_id(mocker: 
MockerFixture) -> None:
         schema="my_schema",
         sql=None,
     )
-    assert metric_macro("count", 1) == "COUNT(*)"
+    env = SandboxedEnvironment(undefined=DebugUndefined)
+    assert metric_macro(env, {}, "count", 1) == "COUNT(*)"
     mock_get_form_data.assert_not_called()
 
 
@@ -548,32 +552,65 @@ def test_metric_macro_recursive(mocker: MockerFixture) -> 
None:
     """
     Test the ``metric_macro`` when the definition is recursive.
     """
+    database = Database(id=1, database_name="my_database", 
sqlalchemy_uri="sqlite://")
+    dataset = SqlaTable(
+        id=1,
+        metrics=[
+            SqlMetric(metric_name="a", expression="COUNT(*)"),
+            SqlMetric(metric_name="b", expression="{{ metric('a') }}"),
+            SqlMetric(metric_name="c", expression="{{ metric('b') }}"),
+        ],
+        table_name="test_dataset",
+        database=database,
+        schema="my_schema",
+        sql=None,
+    )
+
+    mocker.patch("superset.jinja_context.get_user_id", return_value=42)
     mock_g = mocker.patch("superset.jinja_context.g")
     mock_g.form_data = {"datasource": {"id": 1}}
     DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
-    DatasetDAO.find_by_id.return_value = SqlaTable(
-        table_name="test_dataset",
+    DatasetDAO.find_by_id.return_value = dataset
+
+    processor = get_template_processor(database=database)
+    assert processor.process_template("{{ metric('c', 1) }}") == "COUNT(*)"
+
+
+def test_metric_macro_expansion(mocker: MockerFixture) -> None:
+    """
+    Test that the ``metric_macro`` expands other macros.
+    """
+    database = Database(id=1, database_name="my_database", 
sqlalchemy_uri="sqlite://")
+    dataset = SqlaTable(
+        id=1,
         metrics=[
-            SqlMetric(metric_name="a", expression="COUNT(*)"),
+            SqlMetric(metric_name="a", expression="{{ current_user_id() }}"),
             SqlMetric(metric_name="b", expression="{{ metric('a') }}"),
             SqlMetric(metric_name="c", expression="{{ metric('b') }}"),
         ],
-        database=Database(database_name="my_database", 
sqlalchemy_uri="sqlite://"),
+        table_name="test_dataset",
+        database=database,
         schema="my_schema",
         sql=None,
     )
-    assert metric_macro("c", 1) == "COUNT(*)"
+
+    mocker.patch("superset.jinja_context.get_user_id", return_value=42)
+    mock_g = mocker.patch("superset.jinja_context.g")
+    mock_g.form_data = {"datasource": {"id": 1}}
+    DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
+    DatasetDAO.find_by_id.return_value = dataset
+
+    processor = get_template_processor(database=database)
+    assert processor.process_template("{{ metric('c') }}") == "42"
 
 
 def test_metric_macro_recursive_compound(mocker: MockerFixture) -> None:
     """
     Test the ``metric_macro`` when the definition is compound.
     """
-    mock_g = mocker.patch("superset.jinja_context.g")
-    mock_g.form_data = {"datasource": {"id": 1}}
-    DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
-    DatasetDAO.find_by_id.return_value = SqlaTable(
-        table_name="test_dataset",
+    database = Database(id=1, database_name="my_database", 
sqlalchemy_uri="sqlite://")
+    dataset = SqlaTable(
+        id=1,
         metrics=[
             SqlMetric(metric_name="a", expression="SUM(*)"),
             SqlMetric(metric_name="b", expression="COUNT(*)"),
@@ -582,11 +619,20 @@ def test_metric_macro_recursive_compound(mocker: 
MockerFixture) -> None:
                 expression="{{ metric('a') }} / {{ metric('b') }}",
             ),
         ],
-        database=Database(database_name="my_database", 
sqlalchemy_uri="sqlite://"),
+        table_name="test_dataset",
+        database=database,
         schema="my_schema",
         sql=None,
     )
-    assert metric_macro("c", 1) == "SUM(*) / COUNT(*)"
+
+    mocker.patch("superset.jinja_context.get_user_id", return_value=42)
+    mock_g = mocker.patch("superset.jinja_context.g")
+    mock_g.form_data = {"datasource": {"id": 1}}
+    DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
+    DatasetDAO.find_by_id.return_value = dataset
+
+    processor = get_template_processor(database=database)
+    assert processor.process_template("{{ metric('c') }}") == "SUM(*) / 
COUNT(*)"
 
 
 def test_metric_macro_recursive_cyclic(mocker: MockerFixture) -> None:
@@ -595,23 +641,30 @@ def test_metric_macro_recursive_cyclic(mocker: 
MockerFixture) -> None:
 
     In this case it should stop, and not go into an infinite loop.
     """
-    mock_g = mocker.patch("superset.jinja_context.g")
-    mock_g.form_data = {"datasource": {"id": 1}}
-    DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
-    DatasetDAO.find_by_id.return_value = SqlaTable(
-        table_name="test_dataset",
+    database = Database(id=1, database_name="my_database", 
sqlalchemy_uri="sqlite://")
+    dataset = SqlaTable(
+        id=1,
         metrics=[
             SqlMetric(metric_name="a", expression="{{ metric('c') }}"),
             SqlMetric(metric_name="b", expression="{{ metric('a') }}"),
             SqlMetric(metric_name="c", expression="{{ metric('b') }}"),
         ],
-        database=Database(database_name="my_database", 
sqlalchemy_uri="sqlite://"),
+        table_name="test_dataset",
+        database=database,
         schema="my_schema",
         sql=None,
     )
+
+    mocker.patch("superset.jinja_context.get_user_id", return_value=42)
+    mock_g = mocker.patch("superset.jinja_context.g")
+    mock_g.form_data = {"datasource": {"id": 1}}
+    DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
+    DatasetDAO.find_by_id.return_value = dataset
+
+    processor = get_template_processor(database=database)
     with pytest.raises(SupersetTemplateException) as excinfo:
-        metric_macro("c", 1)
-    assert str(excinfo.value) == "Cyclic metric macro detected"
+        processor.process_template("{{ metric('c') }}")
+    assert str(excinfo.value) == "Cyclic filters detected"
 
 
 def test_metric_macro_recursive_infinite(mocker: MockerFixture) -> None:
@@ -620,21 +673,28 @@ def test_metric_macro_recursive_infinite(mocker: 
MockerFixture) -> None:
 
     In this case it should stop, and not go into an infinite loop.
     """
-    mock_g = mocker.patch("superset.jinja_context.g")
-    mock_g.form_data = {"datasource": {"id": 1}}
-    DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
-    DatasetDAO.find_by_id.return_value = SqlaTable(
-        table_name="test_dataset",
+    database = Database(id=1, database_name="my_database", 
sqlalchemy_uri="sqlite://")
+    dataset = SqlaTable(
+        id=1,
         metrics=[
             SqlMetric(metric_name="a", expression="{{ metric('a') }}"),
         ],
-        database=Database(database_name="my_database", 
sqlalchemy_uri="sqlite://"),
+        table_name="test_dataset",
+        database=database,
         schema="my_schema",
         sql=None,
     )
+
+    mocker.patch("superset.jinja_context.get_user_id", return_value=42)
+    mock_g = mocker.patch("superset.jinja_context.g")
+    mock_g.form_data = {"datasource": {"id": 1}}
+    DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
+    DatasetDAO.find_by_id.return_value = dataset
+
+    processor = get_template_processor(database=database)
     with pytest.raises(SupersetTemplateException) as excinfo:
-        metric_macro("a", 1)
-    assert str(excinfo.value) == "Cyclic metric macro detected"
+        processor.process_template("{{ metric('a') }}")
+    assert str(excinfo.value) == "Cyclic filters detected"
 
 
 def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> 
None:
@@ -652,8 +712,9 @@ def test_metric_macro_with_dataset_id_invalid_key(mocker: 
MockerFixture) -> None
         schema="my_schema",
         sql=None,
     )
+    env = SandboxedEnvironment(undefined=DebugUndefined)
     with pytest.raises(SupersetTemplateException) as excinfo:
-        metric_macro("blah", 1)
+        metric_macro(env, {}, "blah", 1)
     assert str(excinfo.value) == "Metric ``blah`` not found in test_dataset."
     mock_get_form_data.assert_not_called()
 
@@ -665,8 +726,9 @@ def test_metric_macro_invalid_dataset_id(mocker: 
MockerFixture) -> None:
     mock_get_form_data = mocker.patch("superset.views.utils.get_form_data")
     DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
     DatasetDAO.find_by_id.return_value = None
+    env = SandboxedEnvironment(undefined=DebugUndefined)
     with pytest.raises(DatasetNotFoundError) as excinfo:
-        metric_macro("macro_key", 100)
+        metric_macro(env, {}, "macro_key", 100)
     assert str(excinfo.value) == "Dataset ID 100 not found."
     mock_get_form_data.assert_not_called()
 
@@ -679,9 +741,10 @@ def test_metric_macro_no_dataset_id_no_context(mocker: 
MockerFixture) -> None:
     DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
     mock_g = mocker.patch("superset.jinja_context.g")
     mock_g.form_data = {}
+    env = SandboxedEnvironment(undefined=DebugUndefined)
     with app.test_request_context():
         with pytest.raises(SupersetTemplateException) as excinfo:
-            metric_macro("macro_key")
+            metric_macro(env, {}, "macro_key")
         assert str(excinfo.value) == (
             "Please specify the Dataset ID for the ``macro_key`` metric in the 
Jinja macro."  # noqa: E501
         )
@@ -698,6 +761,8 @@ def 
test_metric_macro_no_dataset_id_with_context_missing_info(
     DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
     mock_g = mocker.patch("superset.jinja_context.g")
     mock_g.form_data = {"queries": []}
+
+    env = SandboxedEnvironment(undefined=DebugUndefined)
     with app.test_request_context(
         data={
             "form_data": json.dumps(
@@ -716,7 +781,7 @@ def 
test_metric_macro_no_dataset_id_with_context_missing_info(
         }
     ):
         with pytest.raises(SupersetTemplateException) as excinfo:
-            metric_macro("macro_key")
+            metric_macro(env, {}, "macro_key")
         assert str(excinfo.value) == (
             "Please specify the Dataset ID for the ``macro_key`` metric in the 
Jinja macro."  # noqa: E501
         )
@@ -744,6 +809,7 @@ def 
test_metric_macro_no_dataset_id_with_context_datasource_id(
     mock_g.form_data = {}
 
     # Getting the data from the request context
+    env = SandboxedEnvironment(undefined=DebugUndefined)
     with app.test_request_context(
         data={
             "form_data": json.dumps(
@@ -759,7 +825,7 @@ def 
test_metric_macro_no_dataset_id_with_context_datasource_id(
             )
         }
     ):
-        assert metric_macro("macro_key") == "COUNT(*)"
+        assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
 
     # Getting data from g's form_data
     mock_g.form_data = {
@@ -772,7 +838,7 @@ def 
test_metric_macro_no_dataset_id_with_context_datasource_id(
         ],
     }
     with app.test_request_context():
-        assert metric_macro("macro_key") == "COUNT(*)"
+        assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
 
 
 def test_metric_macro_no_dataset_id_with_context_datasource_id_none(
@@ -786,6 +852,7 @@ def 
test_metric_macro_no_dataset_id_with_context_datasource_id_none(
     mock_g.form_data = {}
 
     # Getting the data from the request context
+    env = SandboxedEnvironment(undefined=DebugUndefined)
     with app.test_request_context(
         data={
             "form_data": json.dumps(
@@ -802,7 +869,7 @@ def 
test_metric_macro_no_dataset_id_with_context_datasource_id_none(
         }
     ):
         with pytest.raises(SupersetTemplateException) as excinfo:
-            metric_macro("macro_key")
+            metric_macro(env, {}, "macro_key")
         assert str(excinfo.value) == (
             "Please specify the Dataset ID for the ``macro_key`` metric in the 
Jinja macro."  # noqa: E501
         )
@@ -819,7 +886,7 @@ def 
test_metric_macro_no_dataset_id_with_context_datasource_id_none(
     }
     with app.test_request_context():
         with pytest.raises(SupersetTemplateException) as excinfo:
-            metric_macro("macro_key")
+            metric_macro(env, {}, "macro_key")
         assert str(excinfo.value) == (
             "Please specify the Dataset ID for the ``macro_key`` metric in the 
Jinja macro."  # noqa: E501
         )
@@ -851,6 +918,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id(
     mock_g.form_data = {}
 
     # Getting the data from the request context
+    env = SandboxedEnvironment(undefined=DebugUndefined)
     with app.test_request_context(
         data={
             "form_data": json.dumps(
@@ -866,7 +934,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id(
             )
         }
     ):
-        assert metric_macro("macro_key") == "COUNT(*)"
+        assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
 
     # Getting data from g's form_data
     mock_g.form_data = {
@@ -879,7 +947,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id(
         ],
     }
     with app.test_request_context():
-        assert metric_macro("macro_key") == "COUNT(*)"
+        assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
 
 
 def test_metric_macro_no_dataset_id_with_context_slice_id_none(
@@ -893,6 +961,7 @@ def 
test_metric_macro_no_dataset_id_with_context_slice_id_none(
     mock_g.form_data = {}
 
     # Getting the data from the request context
+    env = SandboxedEnvironment(undefined=DebugUndefined)
     with app.test_request_context(
         data={
             "form_data": json.dumps(
@@ -909,7 +978,7 @@ def 
test_metric_macro_no_dataset_id_with_context_slice_id_none(
         }
     ):
         with pytest.raises(SupersetTemplateException) as excinfo:
-            metric_macro("macro_key")
+            metric_macro(env, {}, "macro_key")
         assert str(excinfo.value) == (
             "Please specify the Dataset ID for the ``macro_key`` metric in the 
Jinja macro."  # noqa: E501
         )
@@ -926,7 +995,7 @@ def 
test_metric_macro_no_dataset_id_with_context_slice_id_none(
     }
     with app.test_request_context():
         with pytest.raises(SupersetTemplateException) as excinfo:
-            metric_macro("macro_key")
+            metric_macro(env, {}, "macro_key")
         assert str(excinfo.value) == (
             "Please specify the Dataset ID for the ``macro_key`` metric in the 
Jinja macro."  # noqa: E501
         )
@@ -945,6 +1014,7 @@ def 
test_metric_macro_no_dataset_id_with_context_deleted_chart(
     mock_g.form_data = {}
 
     # Getting the data from the request context
+    env = SandboxedEnvironment(undefined=DebugUndefined)
     with app.test_request_context(
         data={
             "form_data": json.dumps(
@@ -961,7 +1031,7 @@ def 
test_metric_macro_no_dataset_id_with_context_deleted_chart(
         }
     ):
         with pytest.raises(SupersetTemplateException) as excinfo:
-            metric_macro("macro_key")
+            metric_macro(env, {}, "macro_key")
         assert str(excinfo.value) == (
             "Please specify the Dataset ID for the ``macro_key`` metric in the 
Jinja macro."  # noqa: E501
         )
@@ -978,7 +1048,7 @@ def 
test_metric_macro_no_dataset_id_with_context_deleted_chart(
     }
     with app.test_request_context():
         with pytest.raises(SupersetTemplateException) as excinfo:
-            metric_macro("macro_key")
+            metric_macro(env, {}, "macro_key")
         assert str(excinfo.value) == (
             "Please specify the Dataset ID for the ``macro_key`` metric in the 
Jinja macro."  # noqa: E501
         )
@@ -1006,6 +1076,7 @@ def 
test_metric_macro_no_dataset_id_available_in_request_form_data(
     mock_g.form_data = {}
 
     # Getting the data from the request context
+    env = SandboxedEnvironment(undefined=DebugUndefined)
     with app.test_request_context(
         data={
             "form_data": json.dumps(
@@ -1017,7 +1088,7 @@ def 
test_metric_macro_no_dataset_id_available_in_request_form_data(
             )
         }
     ):
-        assert metric_macro("macro_key") == "COUNT(*)"
+        assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
 
     # Getting data from g's form_data
     mock_g.form_data = {
@@ -1025,7 +1096,7 @@ def 
test_metric_macro_no_dataset_id_available_in_request_form_data(
     }
 
     with app.test_request_context():
-        assert metric_macro("macro_key") == "COUNT(*)"
+        assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
 
 
 @pytest.mark.parametrize(

Reply via email to