This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8f1a53801a4 Support `@task.bash` with Task SDK (#48060)
8f1a53801a4 is described below
commit 8f1a53801a4da94fb81f65c11dcccf74601e1859
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Mar 22 00:28:55 2025 +0530
Support `@task.bash` with Task SDK (#48060)
---
.pre-commit-config.yaml | 2 +-
airflow-core/src/airflow/decorators/bash.py | 8 +++-
airflow-core/tests/unit/decorators/test_bash.py | 52 ++++++++++++----------
.../airflow/providers/standard/operators/bash.py | 52 +++-------------------
.../tests/unit/standard/operators/test_bash.py | 2 -
.../check_base_operator_partial_arguments.py | 1 +
.../src/airflow/sdk/definitions/_internal/types.py | 12 +++++
.../src/airflow/sdk/definitions/baseoperator.py | 5 +++
.../src/airflow/sdk/execution_time/task_runner.py | 10 +++++
.../task_sdk/execution_time/test_task_runner.py | 21 +++++++++
10 files changed, 90 insertions(+), 75 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 84dbf743439..36cec5f5a4a 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1356,7 +1356,7 @@ repos:
name: Check templated fields mapped in operators/sensors
language: python
entry: ./scripts/ci/pre_commit/check_template_fields.py
- files: ^(providers/.*/)?airflow/.*/(sensors|operators)/.*\.py$
+ files: ^(providers/.*/)?airflow-core/.*/(sensors|operators)/.*\.py$
additional_dependencies: [ 'rich>=12.4.4' ]
require_serial: true
- id: update-migration-references
diff --git a/airflow-core/src/airflow/decorators/bash.py
b/airflow-core/src/airflow/decorators/bash.py
index 996ac5ffe05..a82575ce3ac 100644
--- a/airflow-core/src/airflow/decorators/bash.py
+++ b/airflow-core/src/airflow/decorators/bash.py
@@ -23,9 +23,9 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar
from airflow.decorators.base import DecoratedOperator, TaskDecorator,
task_decorator_factory
from airflow.providers.standard.operators.bash import BashOperator
+from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
from airflow.utils.context import context_merge
from airflow.utils.operator_helpers import determine_kwargs
-from airflow.utils.types import NOTSET
if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
@@ -49,6 +49,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
}
custom_operator_name: str = "@task.bash"
+ overwrite_rtif_after_execution: bool = True
def __init__(
self,
@@ -69,7 +70,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
python_callable=python_callable,
op_args=op_args,
op_kwargs=op_kwargs,
- bash_command=NOTSET,
+ bash_command=SET_DURING_EXECUTION,
multiple_outputs=False,
**kwargs,
)
@@ -83,6 +84,9 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
if not isinstance(self.bash_command, str) or self.bash_command.strip()
== "":
raise TypeError("The returned value from the TaskFlow callable
must be a non-empty string.")
+ self._is_inline_cmd =
self._is_inline_command(bash_command=self.bash_command)
+ context["ti"].render_templates() # type: ignore[attr-defined]
+
return super().execute(context)
diff --git a/airflow-core/tests/unit/decorators/test_bash.py
b/airflow-core/tests/unit/decorators/test_bash.py
index 3dfbf2cc4e9..619326a5632 100644
--- a/airflow-core/tests/unit/decorators/test_bash.py
+++ b/airflow-core/tests/unit/decorators/test_bash.py
@@ -29,10 +29,11 @@ import pytest
from airflow.decorators import task
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.renderedtifields import RenderedTaskInstanceFields
+from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
from airflow.utils import timezone
-from airflow.utils.types import NOTSET
from tests_common.test_utils.db import clear_db_dags, clear_db_runs,
clear_rendered_ti_fields
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
if TYPE_CHECKING:
from airflow.models import TaskInstance
@@ -69,7 +70,10 @@ class TestBashDecorator:
@staticmethod
def validate_bash_command_rtif(ti, expected_command):
- assert
RenderedTaskInstanceFields.get_templated_fields(ti)["bash_command"] ==
expected_command
+ if AIRFLOW_V_3_0_PLUS:
+ assert ti.task.overwrite_rtif_after_execution
+ else:
+ assert
RenderedTaskInstanceFields.get_templated_fields(ti)["bash_command"] ==
expected_command
def test_bash_decorator_init(self):
"""Test the initialization of the @task.bash decorator."""
@@ -81,13 +85,13 @@ class TestBashDecorator:
bash_task = bash()
assert bash_task.operator.task_id == "bash"
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
assert bash_task.operator.env is None
assert bash_task.operator.append_env is False
assert bash_task.operator.output_encoding == "utf-8"
assert bash_task.operator.skip_on_exit_code == [99]
assert bash_task.operator.cwd is None
- assert bash_task.operator._init_bash_command_not_set is True
+ assert bash_task.operator._is_inline_cmd is None
@pytest.mark.parametrize(
argnames=["command", "expected_command", "expected_return_val"],
@@ -108,13 +112,12 @@ class TestBashDecorator:
bash_task = bash()
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
ti, return_val = self.execute_task(bash_task)
assert bash_task.operator.bash_command == expected_command
assert return_val == expected_return_val
-
self.validate_bash_command_rtif(ti, expected_command)
def test_op_args_kwargs(self):
@@ -127,7 +130,7 @@ class TestBashDecorator:
bash_task = bash("world", other_id="2")
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
ti, return_val = self.execute_task(bash_task)
@@ -152,7 +155,7 @@ class TestBashDecorator:
bash_task = bash("foo")
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
ti, return_val = self.execute_task(bash_task)
@@ -178,7 +181,7 @@ class TestBashDecorator:
bash_task = bash()
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
with mock.patch.dict("os.environ", {"AIRFLOW_HOME":
"path/to/airflow/home"}):
ti, return_val = self.execute_task(bash_task)
@@ -207,7 +210,7 @@ class TestBashDecorator:
bash_task = bash(exit_code)
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
with expected:
ti, return_val = self.execute_task(bash_task)
@@ -251,7 +254,7 @@ class TestBashDecorator:
bash_task = bash(exit_code)
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
with expected:
ti, return_val = self.execute_task(bash_task)
@@ -297,7 +300,7 @@ class TestBashDecorator:
with mock.patch.dict("os.environ", {"AIRFLOW_HOME":
"path/to/airflow/home"}):
bash_task = bash(f"{cmd_file} ")
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
ti, return_val = self.execute_task(bash_task)
@@ -319,7 +322,7 @@ class TestBashDecorator:
bash_task = bash()
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
ti, return_val = self.execute_task(bash_task)
@@ -339,7 +342,7 @@ class TestBashDecorator:
bash_task = bash()
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
dr = self.dag_maker.create_dagrun()
ti = dr.task_instances[0]
@@ -360,7 +363,7 @@ class TestBashDecorator:
bash_task = bash()
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
dr = self.dag_maker.create_dagrun()
ti = dr.task_instances[0]
@@ -378,7 +381,7 @@ class TestBashDecorator:
bash_task = bash()
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
dr = self.dag_maker.create_dagrun()
ti = dr.task_instances[0]
@@ -401,7 +404,7 @@ class TestBashDecorator:
):
bash_task = bash()
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
ti, _ = self.execute_task(bash_task)
@@ -409,12 +412,13 @@ class TestBashDecorator:
self.validate_bash_command_rtif(ti, "echo")
@pytest.mark.parametrize(
- "multiple_outputs", [False, pytest.param(None, id="none"),
pytest.param(NOTSET, id="not-set")]
+ "multiple_outputs",
+ [False, pytest.param(None, id="none"),
pytest.param(SET_DURING_EXECUTION, id="not-set")],
)
def test_multiple_outputs(self, multiple_outputs):
"""Verify setting `multiple_outputs` for a @task.bash-decorated
function is ignored."""
decorator_kwargs = {}
- if multiple_outputs is not NOTSET:
+ if multiple_outputs is not SET_DURING_EXECUTION:
decorator_kwargs["multiple_outputs"] = multiple_outputs
with self.dag:
@@ -428,7 +432,7 @@ class TestBashDecorator:
bash_task = bash()
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
ti, _ = self.execute_task(bash_task)
@@ -440,7 +444,9 @@ class TestBashDecorator:
argvalues=[
pytest.param(None, pytest.raises(TypeError),
id="return_none_typeerror"),
pytest.param(1, pytest.raises(TypeError),
id="return_int_typeerror"),
- pytest.param(NOTSET, pytest.raises(TypeError),
id="return_notset_typeerror"),
+ pytest.param(
+ SET_DURING_EXECUTION, pytest.raises(TypeError),
id="return_SET_DURING_EXECUTION_typeerror"
+ ),
pytest.param(True, pytest.raises(TypeError),
id="return_boolean_typeerror"),
pytest.param("", pytest.raises(TypeError),
id="return_empty_string_typerror"),
pytest.param(" ", pytest.raises(TypeError),
id="return_spaces_string_typerror"),
@@ -458,7 +464,7 @@ class TestBashDecorator:
bash_task = bash()
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
with expected:
ti, _ = self.execute_task(bash_task)
@@ -475,7 +481,7 @@ class TestBashDecorator:
bash_task = bash()
- assert bash_task.operator.bash_command == NOTSET
+ assert bash_task.operator.bash_command == SET_DURING_EXECUTION
dr = self.dag_maker.create_dagrun()
ti = dr.task_instances[0]
diff --git
a/providers/standard/src/airflow/providers/standard/operators/bash.py
b/providers/standard/src/airflow/providers/standard/operators/bash.py
index 02d53737588..02a3c03afbd 100644
--- a/providers/standard/src/airflow/providers/standard/operators/bash.py
+++ b/providers/standard/src/airflow/providers/standard/operators/bash.py
@@ -28,8 +28,6 @@ from airflow.exceptions import AirflowException,
AirflowSkipException
from airflow.models.baseoperator import BaseOperator
from airflow.providers.standard.hooks.subprocess import SubprocessHook,
SubprocessResult, working_directory
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
-from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.types import ArgNotSet
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.execution_time.context import context_to_airflow_vars
@@ -37,7 +35,7 @@ else:
from airflow.utils.operator_helpers import context_to_airflow_vars #
type: ignore[no-redef, attr-defined]
if TYPE_CHECKING:
- from sqlalchemy.orm import Session as SASession
+ from airflow.utils.types import ArgNotSet
try:
from airflow.sdk.definitions.context import Context
@@ -187,43 +185,15 @@ class BashOperator(BaseOperator):
self.cwd = cwd
self.append_env = append_env
self.output_processor = output_processor
-
- # When using the @task.bash decorator, the Bash command is not known
until the underlying Python
- # callable is executed and therefore set to NOTSET initially. This
flag is useful during execution to
- # determine whether the bash_command value needs to re-rendered.
- self._init_bash_command_not_set = isinstance(self.bash_command,
ArgNotSet)
-
- # Keep a copy of the original bash_command, without the Jinja template
rendered.
- # This is later used to determine if the bash_command is a script or
an inline string command.
- # We do this later, because the bash_command is not available in
__init__ when using @task.bash.
- self._unrendered_bash_command: str | ArgNotSet = bash_command
+ self._is_inline_cmd = None
+ if isinstance(bash_command, str):
+ self._is_inline_cmd =
self._is_inline_command(bash_command=bash_command)
@cached_property
def subprocess_hook(self):
"""Returns hook for running the bash command."""
return SubprocessHook()
- # TODO: This should be replaced with Task SDK API call
- @staticmethod
- @provide_session
- def refresh_bash_command(ti, session: SASession = NEW_SESSION) -> None:
- """
- Rewrite the underlying rendered bash_command value for a task instance
in the metadatabase.
-
- TaskInstance.get_rendered_template_fields() cannot be used because
this will retrieve the
- RenderedTaskInstanceFields from the metadatabase which doesn't have
the runtime-evaluated bash_command
- value.
-
- :meta private:
- """
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
-
- """Update rendered task instance fields for cases where runtime
evaluated, not templated."""
-
- rtif = RenderedTaskInstanceFields(ti)
- RenderedTaskInstanceFields.write(rtif, session=session)
- RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id,
session=session)
-
def get_env(self, context) -> dict:
"""Build the set of environment variables to be exposed for the bash
command."""
system_env = os.environ.copy()
@@ -252,19 +222,7 @@ class BashOperator(BaseOperator):
raise AirflowException(f"The cwd {self.cwd} must be a
directory")
env = self.get_env(context)
- # Because the bash_command value is evaluated at runtime using the
@task.bash decorator, the
- # RenderedTaskInstanceField data needs to be rewritten and the
bash_command value re-rendered -- the
- # latter because the returned command from the decorated callable
could contain a Jinja expression.
- # Both will ensure the correct Bash command is executed and that the
Rendered Template view in the UI
- # displays the executed command (otherwise it will display as an
ArgNotSet type).
- if self._init_bash_command_not_set:
- is_inline_command = self._is_inline_command(bash_command=cast(str,
self.bash_command))
- ti = context["ti"]
- self.refresh_bash_command(ti)
- else:
- is_inline_command = self._is_inline_command(bash_command=cast(str,
self._unrendered_bash_command))
-
- if is_inline_command:
+ if self._is_inline_cmd:
result = self._run_inline_command(bash_path=bash_path, env=env)
else:
result = self._run_rendered_script_file(bash_path=bash_path,
env=env)
diff --git a/providers/standard/tests/unit/standard/operators/test_bash.py
b/providers/standard/tests/unit/standard/operators/test_bash.py
index 59a0c8bbf23..fe33689d00e 100644
--- a/providers/standard/tests/unit/standard/operators/test_bash.py
+++ b/providers/standard/tests/unit/standard/operators/test_bash.py
@@ -60,8 +60,6 @@ class TestBashOperator:
assert op.output_encoding == "utf-8"
assert op.skip_on_exit_code == [99]
assert op.cwd is None
- assert op._init_bash_command_not_set is False
- assert op._unrendered_bash_command == "echo"
@pytest.mark.db_test
@pytest.mark.parametrize(
diff --git a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py
b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py
index e17c304a7ad..970d5d50aa7 100755
--- a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py
+++ b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py
@@ -43,6 +43,7 @@ IGNORED = {
"post_execute",
"pre_execute",
"multiple_outputs",
+ "overwrite_rtif_after_execution",
# Doesn't matter, not used anywhere.
"default_args",
# Deprecated and is aliased to max_active_tis_per_dag.
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/types.py
b/task-sdk/src/airflow/sdk/definitions/_internal/types.py
index 0e3a39cde20..b8bd9fbfc4b 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/types.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/types.py
@@ -49,6 +49,18 @@ NOTSET = ArgNotSet()
"""Sentinel value for argument default. See ``ArgNotSet``."""
+class SetDuringExecution(ArgNotSet):
+ """Sentinel type for annotations, useful when a value is dynamic and set
during Execution but not parsing."""
+
+ @staticmethod
+ def serialize() -> str:
+ return "DYNAMIC (set during execution)"
+
+
+SET_DURING_EXECUTION = SetDuringExecution()
+"""Sentinel value for argument default. See ``SetDuringExecution``."""
+
+
if TYPE_CHECKING:
import logging
diff --git a/task-sdk/src/airflow/sdk/definitions/baseoperator.py
b/task-sdk/src/airflow/sdk/definitions/baseoperator.py
index fd419cd6b57..4aa19108b68 100644
--- a/task-sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -899,6 +899,11 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
# Defines if the operator supports lineage without manual definitions
supports_lineage: bool = False
+ # If True, the Rendered Template fields will be overwritten in DB after
execution
+ # This is useful for Taskflow decorators that modify the template fields
during execution like
+ # @task.bash decorator.
+ overwrite_rtif_after_execution: bool = False
+
# If True then the class constructor was called
__instantiated: bool = False
# List of args as passed to `init()`, after apply_defaults() has been
updated. Used to "recreate" the task
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 5f23902df77..31d44f1116f 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -899,6 +899,16 @@ def finalize(
log.debug("Setting xcom for operator extra link", link=link,
xcom_key=xcom_key)
_xcom_push(ti, key=xcom_key, value=link)
+ if getattr(ti.task, "overwrite_rtif_after_execution", False):
+ log.debug("Overwriting Rendered template fields.")
+ if ti.task.template_fields:
+ SUPERVISOR_COMMS.send_request(
+ log=log,
+ msg=SetRenderedFields(
+ rendered_fields={field: getattr(ti.task, field) for field
in ti.task.template_fields}
+ ),
+ )
+
log.debug("Running finalizers", ti=ti)
if state == TerminalTIState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 96d7e2fb977..8f59eddbd11 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -1396,6 +1396,27 @@ class TestRuntimeTaskInstance:
log=mock.ANY,
)
+ def test_overwrite_rtif_after_execution_sets_rtif(self, create_runtime_ti,
mock_supervisor_comms):
+ """Test that the RTIF is overwritten after execution for certain
operators."""
+
+ class CustomOperator(BaseOperator):
+ overwrite_rtif_after_execution = True
+ template_fields = ["bash_command"]
+
+ def __init__(self, bash_command, *args, **kwargs):
+ self.bash_command = bash_command
+ super().__init__(*args, **kwargs)
+
+ task = CustomOperator(task_id="hello", bash_command="echo 'hi'")
+ runtime_ti = create_runtime_ti(task=task)
+
+ finalize(runtime_ti, log=mock.MagicMock(),
state=TerminalTIState.SUCCESS)
+
+ mock_supervisor_comms.send_request.assert_called_with(
+ msg=SetRenderedFields(rendered_fields={"bash_command": "echo
'hi'"}),
+ log=mock.ANY,
+ )
+
class TestXComAfterTaskExecution:
@pytest.mark.parametrize(