This is an automated email from the ASF dual-hosted git repository.
gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 07945cbcc7f Add @task.analytics Decorator (#62648)
07945cbcc7f is described below
commit 07945cbcc7fe17d76cafefcd900b30485bb8a944
Author: GPK <[email protected]>
AuthorDate: Sat Feb 28 22:46:48 2026 +0000
Add @task.analytics Decorator (#62648)
* Add @task.analytics decorator
* Fix publish-docs to respect not ready packages when requested to build
---
dev/breeze/src/airflow_breeze/utils/packages.py | 5 +-
providers/common/sql/docs/operators.rst | 11 ++
providers/common/sql/provider.yaml | 2 +
.../providers/common/sql/decorators/analytics.py | 137 +++++++++++++++++++++
.../common/sql/example_dags/example_analytics.py | 11 +-
.../providers/common/sql/get_provider_info.py | 6 +-
.../unit/common/sql/decorators/test_analytics.py | 122 ++++++++++++++++++
7 files changed, 291 insertions(+), 3 deletions(-)
diff --git a/dev/breeze/src/airflow_breeze/utils/packages.py
b/dev/breeze/src/airflow_breeze/utils/packages.py
index 1c7a3e0dd92..aba09920103 100644
--- a/dev/breeze/src/airflow_breeze/utils/packages.py
+++ b/dev/breeze/src/airflow_breeze/utils/packages.py
@@ -419,7 +419,10 @@ def find_matching_long_package_names(
removed_packages: list[str] = [
f"apache-airflow-providers-{provider.replace('.', '-')}" for provider
in get_removed_provider_ids()
]
- all_packages_including_removed: list[str] = available_doc_packages +
removed_packages
+ not_ready_packages: list[str] = [
+ f"apache-airflow-providers-{provider.replace('.', '-')}" for provider
in get_not_ready_provider_ids()
+ ]
+ all_packages_including_removed: list[str] = available_doc_packages +
removed_packages + not_ready_packages
invalid_filters = [
f
for f in processed_package_filters
diff --git a/providers/common/sql/docs/operators.rst
b/providers/common/sql/docs/operators.rst
index 97832e74854..0354c491559 100644
--- a/providers/common/sql/docs/operators.rst
+++ b/providers/common/sql/docs/operators.rst
@@ -317,3 +317,14 @@ Local File System Storage
:dedent: 4
:start-after: [START howto_analytics_operator_with_local]
:end-before: [END howto_analytics_operator_with_local]
+
+Analytics TaskFlow Decorator
+----------------------------
+
+The ``@task.analytics`` decorator lets you write a function that returns the
+analytics sql queries:
+
+.. exampleinclude::
/../../sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
+ :language: python
+ :start-after: [START howto_analytics_decorator]
+ :end-before: [END howto_analytics_decorator]
diff --git a/providers/common/sql/provider.yaml
b/providers/common/sql/provider.yaml
index aca5da531af..8afdd4bfc58 100644
--- a/providers/common/sql/provider.yaml
+++ b/providers/common/sql/provider.yaml
@@ -129,3 +129,5 @@ sensors:
task-decorators:
- class-name: airflow.providers.common.sql.decorators.sql.sql_task
name: sql
+ - class-name:
airflow.providers.common.sql.decorators.analytics.analytics_task
+ name: analytics
diff --git
a/providers/common/sql/src/airflow/providers/common/sql/decorators/analytics.py
b/providers/common/sql/src/airflow/providers/common/sql/decorators/analytics.py
new file mode 100644
index 00000000000..29fb66d872a
--- /dev/null
+++
b/providers/common/sql/src/airflow/providers/common/sql/decorators/analytics.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from collections.abc import Callable, Collection, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar
+
+from airflow.providers.common.compat.sdk import (
+ AIRFLOW_V_3_0_PLUS,
+ DecoratedOperator,
+ TaskDecorator,
+ context_merge,
+ task_decorator_factory,
+)
+from airflow.providers.common.sql.operators.analytics import AnalyticsOperator
+from airflow.utils.operator_helpers import determine_kwargs
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
+else:
+ from airflow.utils.types import NOTSET as SET_DURING_EXECUTION # type:
ignore[attr-defined,no-redef]
+
+
+if TYPE_CHECKING:
+ from airflow.providers.common.compat.sdk import Context
+
+
+class _AnalyticsDecoratedOperator(DecoratedOperator, AnalyticsOperator):
+ """
+ Wraps a Python callable and uses the callable return value as the SQL
commands to be executed.
+
+ :param python_callable: A reference to an object that is callable.
+ :param op_kwargs: A dictionary of keyword arguments that will get unpacked
(templated).
+ :param op_args: A list of positional arguments that will get unpacked
(templated).
+ """
+
+ template_fields: Sequence[str] = (
+ *DecoratedOperator.template_fields,
+ *AnalyticsOperator.template_fields,
+ )
+ template_fields_renderers: ClassVar[dict[str, str]] = {
+ **DecoratedOperator.template_fields_renderers,
+ **AnalyticsOperator.template_fields_renderers,
+ }
+
+ overwrite_rtif_after_execution: bool = True
+
+ custom_operator_name: str = "@task.analytics"
+
+ def __init__(
+ self,
+ python_callable: Callable,
+ op_args: Collection[Any] | None = None,
+ op_kwargs: Mapping[str, Any] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ python_callable=python_callable,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ queries=SET_DURING_EXECUTION,
+ **kwargs,
+ )
+
+ @property
+ def xcom_push(self) -> bool:
+ """Compatibility property for BaseDecorator that expects xcom_push
attribute."""
+ return self.do_xcom_push
+
+ @xcom_push.setter
+ def xcom_push(self, value: bool) -> None:
+ """Compatibility setter for BaseDecorator that expects xcom_push
attribute."""
+ self.do_xcom_push = value
+
+ def execute(self, context: Context) -> Any:
+ """
+ Build the SQL and execute the generated query (or queries).
+
+ :param context: Airflow context.
+ :return: Any
+ """
+ context_merge(context, self.op_kwargs)
+ kwargs = determine_kwargs(self.python_callable, self.op_args, context)
+
+ # Set the queries using the Python callable
+ result = self.python_callable(*self.op_args, **kwargs)
+
+ # Only non-empty strings and non-empty lists of non-empty strings are
acceptable return types
+ if (
+ not isinstance(result, (str, list))
+ or (isinstance(result, str) and not result.strip())
+ or (
+ isinstance(result, list)
+ and (not result or not all(isinstance(s, str) and s.strip()
for s in result))
+ )
+ ):
+ raise TypeError(
+ "The returned value from the @task.analytics callable must be
a non-empty string "
+ "or a non-empty list of non-empty strings."
+ )
+
+ # AnalyticsOperator expects queries as a list of strings
+ self.queries = [result] if isinstance(result, str) else result
+
+ self.render_template_fields(context)
+
+ return AnalyticsOperator.execute(self, context)
+
+
+def analytics_task(python_callable=None, **kwargs) -> TaskDecorator:
+ """
+ Wrap a Python function into a AnalyticsOperator.
+
+ :param python_callable: Function to decorate.
+
+ :meta private:
+ """
+ return task_decorator_factory(
+ python_callable=python_callable,
+ decorated_operator_class=_AnalyticsDecoratedOperator,
+ **kwargs,
+ )
diff --git
a/providers/common/sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
b/providers/common/sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
index 4871757b3bb..bee7829df7e 100644
---
a/providers/common/sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
+++
b/providers/common/sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
@@ -20,7 +20,7 @@ import datetime
from airflow.providers.common.sql.config import DataSourceConfig
from airflow.providers.common.sql.operators.analytics import AnalyticsOperator
-from airflow.sdk import DAG
+from airflow.sdk import DAG, task
datasource_config_s3 = DataSourceConfig(
conn_id="aws_default", table_name="users_data", uri="s3://bucket/path/",
format="parquet"
@@ -56,3 +56,12 @@ with DAG(
)
analytics_with_s3 >> analytics_with_local
# [END howto_analytics_operator_with_local]
+
+ # [START howto_analytics_decorator]
+ @task.analytics(datasource_configs=[datasource_config_s3])
+ def get_user_summary_queries():
+ return ["SELECT * FROM users_data LIMIT 10", "SELECT count(*) FROM
users_data"]
+
+ # [END howto_analytics_decorator]
+
+ analytics_with_local >> get_user_summary_queries()
diff --git
a/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py
b/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py
index ca5892917f6..6ee5011fc55 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py
@@ -71,6 +71,10 @@ def get_provider_info():
{"integration-name": "Common SQL", "python-modules":
["airflow.providers.common.sql.sensors.sql"]}
],
"task-decorators": [
- {"class-name":
"airflow.providers.common.sql.decorators.sql.sql_task", "name": "sql"}
+ {"class-name":
"airflow.providers.common.sql.decorators.sql.sql_task", "name": "sql"},
+ {
+ "class-name":
"airflow.providers.common.sql.decorators.analytics.analytics_task",
+ "name": "analytics",
+ },
],
}
diff --git
a/providers/common/sql/tests/unit/common/sql/decorators/test_analytics.py
b/providers/common/sql/tests/unit/common/sql/decorators/test_analytics.py
new file mode 100644
index 00000000000..36e964bbece
--- /dev/null
+++ b/providers/common/sql/tests/unit/common/sql/decorators/test_analytics.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.sql.config import DataSourceConfig
+from airflow.providers.common.sql.decorators.analytics import
_AnalyticsDecoratedOperator
+
+DATASOURCE_CONFIGS = [
+ DataSourceConfig(conn_id="", table_name="users_data",
uri="file:///path/to/", format="parquet")
+]
+
+
+class TestAnalyticsDecoratedOperator:
+ def test_custom_operator_name(self):
+ assert _AnalyticsDecoratedOperator.custom_operator_name ==
"@task.analytics"
+
+ @patch(
+
"airflow.providers.common.sql.operators.analytics.AnalyticsOperator.execute",
+ autospec=True,
+ )
+ def test_execute_calls_callable_and_sets_queries_from_list(self,
mock_execute):
+ """The callable return value (list) becomes self.queries."""
+ mock_execute.return_value = "mocked output"
+
+ def get_user_queries():
+ return ["SELECT * FROM users_data", "SELECT count(*) FROM
users_data"]
+
+ op = _AnalyticsDecoratedOperator(
+ task_id="test",
+ python_callable=get_user_queries,
+ datasource_configs=DATASOURCE_CONFIGS,
+ )
+ result = op.execute(context={})
+
+ assert result == "mocked output"
+ assert op.queries == ["SELECT * FROM users_data", "SELECT count(*)
FROM users_data"]
+ mock_execute.assert_called_once()
+
+ @patch(
+
"airflow.providers.common.sql.operators.analytics.AnalyticsOperator.execute",
+ autospec=True,
+ )
+ def test_execute_wraps_single_string_into_list(self, mock_execute):
+ """A single string return value is wrapped into a list for
self.queries."""
+ mock_execute.return_value = "mocked output"
+
+ def get_single_query():
+ return "SELECT 1"
+
+ op = _AnalyticsDecoratedOperator(
+ task_id="test",
+ python_callable=get_single_query,
+ datasource_configs=DATASOURCE_CONFIGS,
+ )
+ op.execute(context={})
+
+ assert op.queries == ["SELECT 1"]
+
+ @pytest.mark.parametrize(
+ "return_value",
+ [42, "", " ", None, [], [""], ["SELECT 1", ""], ["SELECT 1", " "],
[42]],
+ ids=[
+ "non-string",
+ "empty-string",
+ "whitespace-string",
+ "none",
+ "empty-list",
+ "list-with-empty-string",
+ "list-with-one-valid-one-empty",
+ "list-with-one-valid-one-whitespace",
+ "list-with-non-string",
+ ],
+ )
+ def test_execute_raises_on_invalid_return_value(self, return_value):
+ """TypeError when the callable returns an invalid value."""
+ op = _AnalyticsDecoratedOperator(
+ task_id="test",
+ python_callable=lambda: return_value,
+ datasource_configs=DATASOURCE_CONFIGS,
+ )
+ with pytest.raises(TypeError, match="non-empty string"):
+ op.execute(context={})
+
+ @patch(
+
"airflow.providers.common.sql.operators.analytics.AnalyticsOperator.execute",
+ autospec=True,
+ )
+ def test_execute_merges_op_kwargs_into_callable(self, mock_execute):
+ """op_kwargs are forwarded to the callable to build queries."""
+ mock_execute.return_value = "mocked output"
+
+ def get_queries_for_table(table_name):
+ return [f"SELECT * FROM {table_name}", f"SELECT count(*) FROM
{table_name}"]
+
+ op = _AnalyticsDecoratedOperator(
+ task_id="test",
+ python_callable=get_queries_for_table,
+ datasource_configs=DATASOURCE_CONFIGS,
+ op_kwargs={"table_name": "orders"},
+ )
+ op.execute(context={"task_instance": MagicMock()})
+
+ assert op.queries == ["SELECT * FROM orders", "SELECT count(*) FROM
orders"]