This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e75522b063 fix: PythonVirtualenvOperator crashes if any 
python_callable function is defined in the same source as DAG (#37165)
e75522b063 is described below

commit e75522b0636fb5115d73da43da244d0c3832794f
Author: Kalyan <[email protected]>
AuthorDate: Wed Feb 7 20:46:09 2024 +0530

    fix: PythonVirtualenvOperator crashes if any python_callable function is 
defined in the same source as DAG (#37165)
    
    
    
    ---------
    
    Signed-off-by: kalyanr <[email protected]>
---
 airflow/models/dagbag.py                      | 12 +++++++-----
 airflow/operators/python.py                   | 23 +++++++++++++++--------
 airflow/utils/file.py                         | 12 ++++++++++++
 airflow/utils/python_virtualenv_script.jinja2 | 19 +++++++++++++++++--
 4 files changed, 51 insertions(+), 15 deletions(-)

diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index ca81af8105..ce9bf5587b 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import hashlib
 import importlib
 import importlib.machinery
 import importlib.util
@@ -48,7 +47,12 @@ from airflow.stats import Stats
 from airflow.utils import timezone
 from airflow.utils.dag_cycle_tester import check_cycle
 from airflow.utils.docs import get_docs_url
-from airflow.utils.file import correct_maybe_zipped, list_py_file_paths, 
might_contain_dag
+from airflow.utils.file import (
+    correct_maybe_zipped,
+    get_unique_dag_module_name,
+    list_py_file_paths,
+    might_contain_dag,
+)
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
 from airflow.utils.session import NEW_SESSION, provide_session
@@ -326,9 +330,7 @@ class DagBag(LoggingMixin):
             return []
 
         self.log.debug("Importing %s", filepath)
-        path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
-        org_mod_name = Path(filepath).stem
-        mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
+        mod_name = get_unique_dag_module_name(filepath)
 
         if mod_name in sys.modules:
             del sys.modules[mod_name]
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 1c5c9d3f69..1b1453cc5e 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -52,6 +52,7 @@ from airflow.models.variable import Variable
 from airflow.operators.branch import BranchMixIn
 from airflow.utils import hashlib_wrapper
 from airflow.utils.context import context_copy_partial, context_merge
+from airflow.utils.file import get_unique_dag_module_name
 from airflow.utils.operator_helpers import KeywordParameters
 from airflow.utils.process_utils import execute_in_subprocess
 from airflow.utils.python_virtualenv import prepare_virtualenv, 
write_python_script
@@ -437,15 +438,21 @@ class _BasePythonVirtualenvOperator(PythonOperator, 
metaclass=ABCMeta):
 
             self._write_args(input_path)
             self._write_string_args(string_args_path)
+
+            jinja_context = {
+                "op_args": self.op_args,
+                "op_kwargs": op_kwargs,
+                "expect_airflow": self.expect_airflow,
+                "pickling_library": self.pickling_library.__name__,
+                "python_callable": self.python_callable.__name__,
+                "python_callable_source": self.get_python_source(),
+            }
+
+            if inspect.getfile(self.python_callable) == self.dag.fileloc:
+                jinja_context["modified_dag_module_name"] = 
get_unique_dag_module_name(self.dag.fileloc)
+
             write_python_script(
-                jinja_context={
-                    "op_args": self.op_args,
-                    "op_kwargs": op_kwargs,
-                    "expect_airflow": self.expect_airflow,
-                    "pickling_library": self.pickling_library.__name__,
-                    "python_callable": self.python_callable.__name__,
-                    "python_callable_source": self.get_python_source(),
-                },
+                jinja_context=jinja_context,
                 filename=os.fspath(script_path),
                 
render_template_as_native_obj=self.dag.render_template_as_native_obj,
             )
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
index 013d9ea36a..7e15eeb2f8 100644
--- a/airflow/utils/file.py
+++ b/airflow/utils/file.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import ast
+import hashlib
 import logging
 import os
 import zipfile
@@ -33,6 +34,8 @@ from airflow.exceptions import RemovedInAirflow3Warning
 
 log = logging.getLogger(__name__)
 
+MODIFIED_DAG_MODULE_NAME = "unusual_prefix_{path_hash}_{module_name}"
+
 
 class _IgnoreRule(Protocol):
     """Interface for ignore rules for structural subtyping."""
@@ -379,3 +382,12 @@ def iter_airflow_imports(file_path: str) -> Generator[str, 
None, None]:
     for m in _find_imported_modules(parsed):
         if m.startswith("airflow."):
             yield m
+
+
+def get_unique_dag_module_name(file_path: str) -> str:
+    """Returns a unique module name in the format unusual_prefix_{sha1 of 
module's file path}_{original module name}."""
+    if isinstance(file_path, str):
+        path_hash = hashlib.sha1(file_path.encode("utf-8")).hexdigest()
+        org_mod_name = Path(file_path).stem
+        return MODIFIED_DAG_MODULE_NAME.format(path_hash=path_hash, 
module_name=org_mod_name)
+    raise ValueError("file_path should be a string to generate unique module 
name")
diff --git a/airflow/utils/python_virtualenv_script.jinja2 
b/airflow/utils/python_virtualenv_script.jinja2
index 7bbf6a9531..4199a47130 100644
--- a/airflow/utils/python_virtualenv_script.jinja2
+++ b/airflow/utils/python_virtualenv_script.jinja2
@@ -34,6 +34,22 @@ if sys.version_info >= (3,6):
         pass
 {% endif %}
 
+# Script
+{{ python_callable_source }}
+
+# monkey patching for the cases when python_callable is part of the dag module.
+{% if modified_dag_module_name is defined %}
+
+import types
+
+{{ modified_dag_module_name }}  = types.ModuleType("{{ 
modified_dag_module_name }}")
+
+{{ modified_dag_module_name }}.{{ python_callable }} = {{ python_callable }}
+
+sys.modules["{{modified_dag_module_name}}"] = {{modified_dag_module_name}}
+
+{% endif%}
+
 {% if op_args or op_kwargs %}
 with open(sys.argv[1], "rb") as file:
     arg_dict = {{ pickling_library }}.load(file)
@@ -47,8 +63,7 @@ with open(sys.argv[3], "r") as file:
     virtualenv_string_args = list(map(lambda x: x.strip(), list(file)))
 {% endif %}
 
-# Script
-{{ python_callable_source }}
+
 try:
     res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])
 except Exception as e:

Reply via email to