This is an automated email from the ASF dual-hosted git repository.

gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 07945cbcc7f Add @task.analytics Decorator (#62648)
07945cbcc7f is described below

commit 07945cbcc7fe17d76cafefcd900b30485bb8a944
Author: GPK <[email protected]>
AuthorDate: Sat Feb 28 22:46:48 2026 +0000

    Add @task.analytics Decorator (#62648)
    
    * Add @task.analytics decorator
    
    * Fix publish-docs to respect not ready packages when requested to build
---
 dev/breeze/src/airflow_breeze/utils/packages.py    |   5 +-
 providers/common/sql/docs/operators.rst            |  11 ++
 providers/common/sql/provider.yaml                 |   2 +
 .../providers/common/sql/decorators/analytics.py   | 137 +++++++++++++++++++++
 .../common/sql/example_dags/example_analytics.py   |  11 +-
 .../providers/common/sql/get_provider_info.py      |   6 +-
 .../unit/common/sql/decorators/test_analytics.py   | 122 ++++++++++++++++++
 7 files changed, 291 insertions(+), 3 deletions(-)

diff --git a/dev/breeze/src/airflow_breeze/utils/packages.py 
b/dev/breeze/src/airflow_breeze/utils/packages.py
index 1c7a3e0dd92..aba09920103 100644
--- a/dev/breeze/src/airflow_breeze/utils/packages.py
+++ b/dev/breeze/src/airflow_breeze/utils/packages.py
@@ -419,7 +419,10 @@ def find_matching_long_package_names(
     removed_packages: list[str] = [
         f"apache-airflow-providers-{provider.replace('.', '-')}" for provider 
in get_removed_provider_ids()
     ]
-    all_packages_including_removed: list[str] = available_doc_packages + 
removed_packages
+    not_ready_packages: list[str] = [
+        f"apache-airflow-providers-{provider.replace('.', '-')}" for provider 
in get_not_ready_provider_ids()
+    ]
+    all_packages_including_removed: list[str] = available_doc_packages + 
removed_packages + not_ready_packages
     invalid_filters = [
         f
         for f in processed_package_filters
diff --git a/providers/common/sql/docs/operators.rst 
b/providers/common/sql/docs/operators.rst
index 97832e74854..0354c491559 100644
--- a/providers/common/sql/docs/operators.rst
+++ b/providers/common/sql/docs/operators.rst
@@ -317,3 +317,14 @@ Local File System Storage
     :dedent: 4
     :start-after: [START howto_analytics_operator_with_local]
     :end-before: [END howto_analytics_operator_with_local]
+
+Analytics TaskFlow Decorator
+----------------------------
+
+The ``@task.analytics`` decorator lets you write a function that returns the
+analytics sql queries:
+
+.. exampleinclude:: 
/../../sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
+    :language: python
+    :start-after: [START howto_analytics_decorator]
+    :end-before: [END howto_analytics_decorator]
diff --git a/providers/common/sql/provider.yaml 
b/providers/common/sql/provider.yaml
index aca5da531af..8afdd4bfc58 100644
--- a/providers/common/sql/provider.yaml
+++ b/providers/common/sql/provider.yaml
@@ -129,3 +129,5 @@ sensors:
 task-decorators:
   - class-name: airflow.providers.common.sql.decorators.sql.sql_task
     name: sql
+  - class-name: 
airflow.providers.common.sql.decorators.analytics.analytics_task
+    name: analytics
diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/decorators/analytics.py 
b/providers/common/sql/src/airflow/providers/common/sql/decorators/analytics.py
new file mode 100644
index 00000000000..29fb66d872a
--- /dev/null
+++ 
b/providers/common/sql/src/airflow/providers/common/sql/decorators/analytics.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from collections.abc import Callable, Collection, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar
+
+from airflow.providers.common.compat.sdk import (
+    AIRFLOW_V_3_0_PLUS,
+    DecoratedOperator,
+    TaskDecorator,
+    context_merge,
+    task_decorator_factory,
+)
+from airflow.providers.common.sql.operators.analytics import AnalyticsOperator
+from airflow.utils.operator_helpers import determine_kwargs
+
+if AIRFLOW_V_3_0_PLUS:
+    from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
+else:
+    from airflow.utils.types import NOTSET as SET_DURING_EXECUTION  # type: 
ignore[attr-defined,no-redef]
+
+
+if TYPE_CHECKING:
+    from airflow.providers.common.compat.sdk import Context
+
+
+class _AnalyticsDecoratedOperator(DecoratedOperator, AnalyticsOperator):
+    """
+    Wraps a Python callable and uses the callable return value as the SQL 
commands to be executed.
+
+    :param python_callable: A reference to an object that is callable.
+    :param op_kwargs: A dictionary of keyword arguments that will get unpacked 
(templated).
+    :param op_args: A list of positional arguments that will get unpacked 
(templated).
+    """
+
+    template_fields: Sequence[str] = (
+        *DecoratedOperator.template_fields,
+        *AnalyticsOperator.template_fields,
+    )
+    template_fields_renderers: ClassVar[dict[str, str]] = {
+        **DecoratedOperator.template_fields_renderers,
+        **AnalyticsOperator.template_fields_renderers,
+    }
+
+    overwrite_rtif_after_execution: bool = True
+
+    custom_operator_name: str = "@task.analytics"
+
+    def __init__(
+        self,
+        python_callable: Callable,
+        op_args: Collection[Any] | None = None,
+        op_kwargs: Mapping[str, Any] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            python_callable=python_callable,
+            op_args=op_args,
+            op_kwargs=op_kwargs,
+            queries=SET_DURING_EXECUTION,
+            **kwargs,
+        )
+
+    @property
+    def xcom_push(self) -> bool:
+        """Compatibility property for BaseDecorator that expects xcom_push 
attribute."""
+        return self.do_xcom_push
+
+    @xcom_push.setter
+    def xcom_push(self, value: bool) -> None:
+        """Compatibility setter for BaseDecorator that expects xcom_push 
attribute."""
+        self.do_xcom_push = value
+
+    def execute(self, context: Context) -> Any:
+        """
+        Build the SQL and execute the generated query (or queries).
+
+        :param context: Airflow context.
+        :return: Any
+        """
+        context_merge(context, self.op_kwargs)
+        kwargs = determine_kwargs(self.python_callable, self.op_args, context)
+
+        # Set the queries using the Python callable
+        result = self.python_callable(*self.op_args, **kwargs)
+
+        # Only non-empty strings and non-empty lists of non-empty strings are 
acceptable return types
+        if (
+            not isinstance(result, (str, list))
+            or (isinstance(result, str) and not result.strip())
+            or (
+                isinstance(result, list)
+                and (not result or not all(isinstance(s, str) and s.strip() 
for s in result))
+            )
+        ):
+            raise TypeError(
+                "The returned value from the @task.analytics callable must be 
a non-empty string "
+                "or a non-empty list of non-empty strings."
+            )
+
+        # AnalyticsOperator expects queries as a list of strings
+        self.queries = [result] if isinstance(result, str) else result
+
+        self.render_template_fields(context)
+
+        return AnalyticsOperator.execute(self, context)
+
+
+def analytics_task(python_callable=None, **kwargs) -> TaskDecorator:
+    """
+    Wrap a Python function into a AnalyticsOperator.
+
+    :param python_callable: Function to decorate.
+
+    :meta private:
+    """
+    return task_decorator_factory(
+        python_callable=python_callable,
+        decorated_operator_class=_AnalyticsDecoratedOperator,
+        **kwargs,
+    )
diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
 
b/providers/common/sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
index 4871757b3bb..bee7829df7e 100644
--- 
a/providers/common/sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
+++ 
b/providers/common/sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
@@ -20,7 +20,7 @@ import datetime
 
 from airflow.providers.common.sql.config import DataSourceConfig
 from airflow.providers.common.sql.operators.analytics import AnalyticsOperator
-from airflow.sdk import DAG
+from airflow.sdk import DAG, task
 
 datasource_config_s3 = DataSourceConfig(
     conn_id="aws_default", table_name="users_data", uri="s3://bucket/path/", 
format="parquet"
@@ -56,3 +56,12 @@ with DAG(
     )
     analytics_with_s3 >> analytics_with_local
     # [END howto_analytics_operator_with_local]
+
+    # [START howto_analytics_decorator]
+    @task.analytics(datasource_configs=[datasource_config_s3])
+    def get_user_summary_queries():
+        return ["SELECT * FROM users_data LIMIT 10", "SELECT count(*) FROM 
users_data"]
+
+    # [END howto_analytics_decorator]
+
+    analytics_with_local >> get_user_summary_queries()
diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py 
b/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py
index ca5892917f6..6ee5011fc55 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py
@@ -71,6 +71,10 @@ def get_provider_info():
             {"integration-name": "Common SQL", "python-modules": 
["airflow.providers.common.sql.sensors.sql"]}
         ],
         "task-decorators": [
-            {"class-name": 
"airflow.providers.common.sql.decorators.sql.sql_task", "name": "sql"}
+            {"class-name": 
"airflow.providers.common.sql.decorators.sql.sql_task", "name": "sql"},
+            {
+                "class-name": 
"airflow.providers.common.sql.decorators.analytics.analytics_task",
+                "name": "analytics",
+            },
         ],
     }
diff --git 
a/providers/common/sql/tests/unit/common/sql/decorators/test_analytics.py 
b/providers/common/sql/tests/unit/common/sql/decorators/test_analytics.py
new file mode 100644
index 00000000000..36e964bbece
--- /dev/null
+++ b/providers/common/sql/tests/unit/common/sql/decorators/test_analytics.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.sql.config import DataSourceConfig
+from airflow.providers.common.sql.decorators.analytics import 
_AnalyticsDecoratedOperator
+
+DATASOURCE_CONFIGS = [
+    DataSourceConfig(conn_id="", table_name="users_data", 
uri="file:///path/to/", format="parquet")
+]
+
+
+class TestAnalyticsDecoratedOperator:
+    def test_custom_operator_name(self):
+        assert _AnalyticsDecoratedOperator.custom_operator_name == 
"@task.analytics"
+
+    @patch(
+        
"airflow.providers.common.sql.operators.analytics.AnalyticsOperator.execute",
+        autospec=True,
+    )
+    def test_execute_calls_callable_and_sets_queries_from_list(self, 
mock_execute):
+        """The callable return value (list) becomes self.queries."""
+        mock_execute.return_value = "mocked output"
+
+        def get_user_queries():
+            return ["SELECT * FROM users_data", "SELECT count(*) FROM 
users_data"]
+
+        op = _AnalyticsDecoratedOperator(
+            task_id="test",
+            python_callable=get_user_queries,
+            datasource_configs=DATASOURCE_CONFIGS,
+        )
+        result = op.execute(context={})
+
+        assert result == "mocked output"
+        assert op.queries == ["SELECT * FROM users_data", "SELECT count(*) 
FROM users_data"]
+        mock_execute.assert_called_once()
+
+    @patch(
+        
"airflow.providers.common.sql.operators.analytics.AnalyticsOperator.execute",
+        autospec=True,
+    )
+    def test_execute_wraps_single_string_into_list(self, mock_execute):
+        """A single string return value is wrapped into a list for 
self.queries."""
+        mock_execute.return_value = "mocked output"
+
+        def get_single_query():
+            return "SELECT 1"
+
+        op = _AnalyticsDecoratedOperator(
+            task_id="test",
+            python_callable=get_single_query,
+            datasource_configs=DATASOURCE_CONFIGS,
+        )
+        op.execute(context={})
+
+        assert op.queries == ["SELECT 1"]
+
+    @pytest.mark.parametrize(
+        "return_value",
+        [42, "", "   ", None, [], [""], ["SELECT 1", ""], ["SELECT 1", "   "], 
[42]],
+        ids=[
+            "non-string",
+            "empty-string",
+            "whitespace-string",
+            "none",
+            "empty-list",
+            "list-with-empty-string",
+            "list-with-one-valid-one-empty",
+            "list-with-one-valid-one-whitespace",
+            "list-with-non-string",
+        ],
+    )
+    def test_execute_raises_on_invalid_return_value(self, return_value):
+        """TypeError when the callable returns an invalid value."""
+        op = _AnalyticsDecoratedOperator(
+            task_id="test",
+            python_callable=lambda: return_value,
+            datasource_configs=DATASOURCE_CONFIGS,
+        )
+        with pytest.raises(TypeError, match="non-empty string"):
+            op.execute(context={})
+
+    @patch(
+        
"airflow.providers.common.sql.operators.analytics.AnalyticsOperator.execute",
+        autospec=True,
+    )
+    def test_execute_merges_op_kwargs_into_callable(self, mock_execute):
+        """op_kwargs are forwarded to the callable to build queries."""
+        mock_execute.return_value = "mocked output"
+
+        def get_queries_for_table(table_name):
+            return [f"SELECT * FROM {table_name}", f"SELECT count(*) FROM 
{table_name}"]
+
+        op = _AnalyticsDecoratedOperator(
+            task_id="test",
+            python_callable=get_queries_for_table,
+            datasource_configs=DATASOURCE_CONFIGS,
+            op_kwargs={"table_name": "orders"},
+        )
+        op.execute(context={"task_instance": MagicMock()})
+
+        assert op.queries == ["SELECT * FROM orders", "SELECT count(*) FROM 
orders"]

Reply via email to