This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch recursive-metrics
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 5d6eff5f64535ba547dbc4ed8f48243032a62f56
Author: Beto Dealmeida <[email protected]>
AuthorDate: Tue Feb 11 16:53:23 2025 -0500

    feat: recursive metric definitions
---
 superset/jinja_context.py              | 50 +++++++++++++++++++++++++++-------
 tests/unit_tests/jinja_context_test.py | 23 ++++++++++++++++
 2 files changed, 63 insertions(+), 10 deletions(-)

diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index b0e29505a0..97e6edd42f 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -27,7 +27,8 @@ from typing import Any, Callable, cast, Optional, 
TYPE_CHECKING, TypedDict, Unio
 import dateutil
 from flask import current_app, g, has_request_context, request
 from flask_babel import gettext as _
-from jinja2 import DebugUndefined, Environment
+from jinja2 import DebugUndefined, Environment, nodes
+from jinja2.nodes import Call, Node
 from jinja2.sandbox import SandboxedEnvironment
 from sqlalchemy.engine.interfaces import Dialect
 from sqlalchemy.sql.expression import bindparam
@@ -888,6 +889,26 @@ def get_dataset_id_from_context(metric_key: str) -> int:
     raise SupersetTemplateException(exc_message)
 
 
+def has_metric_macro(template_string: str, env: Environment) -> bool:
+    """
+    Checks if a template string contains a metric macro.
+
+        >>> has_metric_macro("{{ metric('my_metric') }}")
+        True
+
+    """
+    ast = env.parse(template_string)
+
+    def visit_node(node: Node) -> bool:
+        return (
+            isinstance(node, Call)
+            and isinstance(node.node, nodes.Name)
+            and node.node.name == "metric"
+        ) or any(visit_node(child) for child in node.iter_child_nodes())
+
+    return visit_node(ast)
+
+
 def metric_macro(metric_key: str, dataset_id: Optional[int] = None) -> str:
     """
     Given a metric key, returns its syntax.
@@ -908,16 +929,25 @@ def metric_macro(metric_key: str, dataset_id: 
Optional[int] = None) -> str:
     dataset = DatasetDAO.find_by_id(dataset_id)
     if not dataset:
         raise DatasetNotFoundError(f"Dataset ID {dataset_id} not found.")
+
     metrics: dict[str, str] = {
         metric.metric_name: metric.expression for metric in dataset.metrics
     }
-    dataset_name = dataset.table_name
-    if metric := metrics.get(metric_key):
-        return metric
-    raise SupersetTemplateException(
-        _(
-            "Metric ``%(metric_name)s`` not found in %(dataset_name)s.",
-            metric_name=metric_key,
-            dataset_name=dataset_name,
+    if metric_key not in metrics:
+        raise SupersetTemplateException(
+            _(
+                "Metric ``%(metric_name)s`` not found in %(dataset_name)s.",
+                metric_name=metric_key,
+                dataset_name=dataset.table_name,
+            )
         )
-    )
+
+    definition = metrics[metric_key]
+
+    env = SandboxedEnvironment(undefined=DebugUndefined)
+    context = {"metric": partial(safe_proxy, metric_macro)}
+    while has_metric_macro(definition, env):
+        template = env.from_string(definition)
+        definition = template.render(context)
+
+    return definition
diff --git a/tests/unit_tests/jinja_context_test.py 
b/tests/unit_tests/jinja_context_test.py
index c17c066b9d..f09af6ab88 100644
--- a/tests/unit_tests/jinja_context_test.py
+++ b/tests/unit_tests/jinja_context_test.py
@@ -544,6 +544,29 @@ def test_metric_macro_with_dataset_id(mocker: 
MockerFixture) -> None:
     mock_get_form_data.assert_not_called()
 
 
+def test_metric_macro_recursive(mocker: MockerFixture) -> None:
+    """
+    Test the ``metric_macro`` when the definition is recursive.
+    """
+    mock_g = mocker.patch("superset.jinja_context.g")
+    mock_g.form_data = {"datasource": {"id": 1}}
+    mock_get_form_data = mocker.patch("superset.views.utils.get_form_data")
+    DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO")  # noqa: N806
+    DatasetDAO.find_by_id.return_value = SqlaTable(
+        table_name="test_dataset",
+        metrics=[
+            SqlMetric(metric_name="a", expression="COUNT(*)"),
+            SqlMetric(metric_name="b", expression="{{ metric('a') }}"),
+            SqlMetric(metric_name="c", expression="{{ metric('b') }}"),
+        ],
+        database=Database(database_name="my_database", 
sqlalchemy_uri="sqlite://"),
+        schema="my_schema",
+        sql=None,
+    )
+    assert metric_macro("c", 1) == "COUNT(*)"
+    mock_get_form_data.assert_not_called()
+
+
 def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> 
None:
     """
     Test the ``metric_macro`` when passing a dataset ID and an invalid key.

Reply via email to