This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3904206b69 Fix Python-based decorators templating (#36103)
3904206b69 is described below

commit 3904206b69428525db31ff7813daa0322f7b83e8
Author: Jarek Potiuk <[email protected]>
AuthorDate: Thu Dec 7 10:19:54 2023 +0100

    Fix Python-based decorators templating (#36103)
    
    Templating of Python-based decorators has been broken since
    implementation. The decorators used template_fields definition
    as defined originally in PythonOperator rather than the ones from
    virtualenv because template fields were redefined in
    _PythonDecoratedOperator class and they took precedence (MRU).
    
    This PR add explicit copying of template_fields from the operators
    that they are decorating.
    
    Fixes: #36102
---
 airflow/decorators/branch_external_python.py |  1 +
 airflow/decorators/branch_python.py          |  1 +
 airflow/decorators/branch_virtualenv.py      |  1 +
 airflow/decorators/external_python.py        |  1 +
 airflow/decorators/python_virtualenv.py      |  1 +
 airflow/decorators/short_circuit.py          |  1 +
 airflow/models/abstractoperator.py           |  1 -
 tests/decorators/test_branch_virtualenv.py   |  9 ++++++---
 tests/decorators/test_external_python.py     | 14 ++++++++++++++
 tests/decorators/test_python_virtualenv.py   | 26 ++++++++++++++++++++++++++
 10 files changed, 52 insertions(+), 4 deletions(-)

diff --git a/airflow/decorators/branch_external_python.py 
b/airflow/decorators/branch_external_python.py
index 8e945541c5..2902a47c67 100644
--- a/airflow/decorators/branch_external_python.py
+++ b/airflow/decorators/branch_external_python.py
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
 class _BranchExternalPythonDecoratedOperator(_PythonDecoratedOperator, 
BranchExternalPythonOperator):
     """Wraps a Python callable and captures args/kwargs when called for 
execution."""
 
+    template_fields = BranchExternalPythonOperator.template_fields
     custom_operator_name: str = "@task.branch_external_python"
 
 
diff --git a/airflow/decorators/branch_python.py 
b/airflow/decorators/branch_python.py
index 3ac11f0efa..31750ef657 100644
--- a/airflow/decorators/branch_python.py
+++ b/airflow/decorators/branch_python.py
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
 class _BranchPythonDecoratedOperator(_PythonDecoratedOperator, 
BranchPythonOperator):
     """Wraps a Python callable and captures args/kwargs when called for 
execution."""
 
+    template_fields = BranchPythonOperator.template_fields
     custom_operator_name: str = "@task.branch"
 
 
diff --git a/airflow/decorators/branch_virtualenv.py 
b/airflow/decorators/branch_virtualenv.py
index 3e4c3fcaf1..c96638ee20 100644
--- a/airflow/decorators/branch_virtualenv.py
+++ b/airflow/decorators/branch_virtualenv.py
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
 class _BranchPythonVirtualenvDecoratedOperator(_PythonDecoratedOperator, 
BranchPythonVirtualenvOperator):
     """Wraps a Python callable and captures args/kwargs when called for 
execution."""
 
+    template_fields = BranchPythonVirtualenvOperator.template_fields
     custom_operator_name: str = "@task.branch_virtualenv"
 
 
diff --git a/airflow/decorators/external_python.py 
b/airflow/decorators/external_python.py
index 1e39ed561b..2d8e2603f9 100644
--- a/airflow/decorators/external_python.py
+++ b/airflow/decorators/external_python.py
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
 class _PythonExternalDecoratedOperator(_PythonDecoratedOperator, 
ExternalPythonOperator):
     """Wraps a Python callable and captures args/kwargs when called for 
execution."""
 
+    template_fields = ExternalPythonOperator.template_fields
     custom_operator_name: str = "@task.external_python"
 
 
diff --git a/airflow/decorators/python_virtualenv.py 
b/airflow/decorators/python_virtualenv.py
index 2eb8678779..d0eb93a0d7 100644
--- a/airflow/decorators/python_virtualenv.py
+++ b/airflow/decorators/python_virtualenv.py
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
 class _PythonVirtualenvDecoratedOperator(_PythonDecoratedOperator, 
PythonVirtualenvOperator):
     """Wraps a Python callable and captures args/kwargs when called for 
execution."""
 
+    template_fields = PythonVirtualenvOperator.template_fields
     custom_operator_name: str = "@task.virtualenv"
 
 
diff --git a/airflow/decorators/short_circuit.py 
b/airflow/decorators/short_circuit.py
index 210a0e0453..c964ed6bb7 100644
--- a/airflow/decorators/short_circuit.py
+++ b/airflow/decorators/short_circuit.py
@@ -29,6 +29,7 @@ if TYPE_CHECKING:
 class _ShortCircuitDecoratedOperator(_PythonDecoratedOperator, 
ShortCircuitOperator):
     """Wraps a Python callable and captures args/kwargs when called for 
execution."""
 
+    template_fields = ShortCircuitOperator.template_fields
     custom_operator_name: str = "@task.short_circuit"
 
 
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index df0e6cb349..f5a266f4b1 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -679,7 +679,6 @@ class AbstractOperator(Templater, DAGNode):
                     f"{attr_name!r} is configured as a template field "
                     f"but {parent.task_type} does not have this attribute."
                 )
-
             try:
                 if not value:
                     continue
diff --git a/tests/decorators/test_branch_virtualenv.py 
b/tests/decorators/test_branch_virtualenv.py
index 2b5f9bb95e..57db52f167 100644
--- a/tests/decorators/test_branch_virtualenv.py
+++ b/tests/decorators/test_branch_virtualenv.py
@@ -31,7 +31,10 @@ class Test_BranchPythonVirtualenvDecoratedOperator:
     # possibilities. So we are increasing the timeout for this test to 3x of 
the default timeout
     @pytest.mark.execution_timeout(180)
     @pytest.mark.parametrize("branch_task_name", ["task_1", "task_2"])
-    def test_branch_one(self, dag_maker, branch_task_name):
+    def test_branch_one(self, dag_maker, branch_task_name, tmp_path):
+        requirements_file = tmp_path / "requirements.txt"
+        requirements_file.write_text("funcsigs==0.4")
+
         @task
         def dummy_f():
             pass
@@ -57,14 +60,14 @@ class Test_BranchPythonVirtualenvDecoratedOperator:
 
         else:
 
-            @task.branch_virtualenv(task_id="branching", 
requirements=["funcsigs"])
+            @task.branch_virtualenv(task_id="branching", 
requirements="requirements.txt")
             def branch_operator():
                 import funcsigs
 
                 print(f"We successfully imported funcsigs version 
{funcsigs.__version__}")
                 return "task_2"
 
-        with dag_maker():
+        with dag_maker(template_searchpath=tmp_path.as_posix()):
             branchoperator = branch_operator()
             df = dummy_f()
             task_1 = task_1()
diff --git a/tests/decorators/test_external_python.py 
b/tests/decorators/test_external_python.py
index cdd8c6cd49..27d8b0ed10 100644
--- a/tests/decorators/test_external_python.py
+++ b/tests/decorators/test_external_python.py
@@ -74,6 +74,20 @@ class TestExternalPythonDecorator:
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
+    def test_with_templated_python(self, dag_maker, venv_python_with_dill):
+        # add template that produces empty string when rendered
+        templated_python_with_dill = venv_python_with_dill.as_posix() + "{{ '' 
}}"
+
+        @task.external_python(python=templated_python_with_dill, use_dill=True)
+        def f():
+            """Import dill to double-check it is installed ."""
+            import dill  # noqa: F401
+
+        with dag_maker():
+            ret = f()
+
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
     def test_no_dill_installed_raises_exception_when_use_dill(self, dag_maker, 
venv_python):
         @task.external_python(python=venv_python, use_dill=True)
         def f():
diff --git a/tests/decorators/test_python_virtualenv.py 
b/tests/decorators/test_python_virtualenv.py
index fc604ac464..a069aee8b1 100644
--- a/tests/decorators/test_python_virtualenv.py
+++ b/tests/decorators/test_python_virtualenv.py
@@ -103,6 +103,32 @@ class TestPythonVirtualenvDecorator:
 
         ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
+    def test_with_requirements_file(self, dag_maker, tmp_path):
+        requirements_file = tmp_path / "requirements.txt"
+        requirements_file.write_text("funcsigs==0.4\nattrs==23.1.0")
+
+        @task.virtualenv(
+            system_site_packages=False,
+            requirements="requirements.txt",
+            python_version=PYTHON_VERSION,
+            use_dill=True,
+        )
+        def f():
+            import funcsigs
+
+            if funcsigs.__version__ != "0.4":
+                raise Exception
+
+            import attrs
+
+            if attrs.__version__ != "23.1.0":
+                raise Exception
+
+        with dag_maker(template_searchpath=tmp_path.as_posix()):
+            ret = f()
+
+        ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
     def test_unpinned_requirements(self, dag_maker):
         @task.virtualenv(
             system_site_packages=False,

Reply via email to