This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e75522b063 fix: PythonVirtualenvOperator crashes if any
python_callable function is defined in the same source as DAG (#37165)
e75522b063 is described below
commit e75522b0636fb5115d73da43da244d0c3832794f
Author: Kalyan <[email protected]>
AuthorDate: Wed Feb 7 20:46:09 2024 +0530
fix: PythonVirtualenvOperator crashes if any python_callable function is
defined in the same source as DAG (#37165)
---------
Signed-off-by: kalyanr <[email protected]>
---
airflow/models/dagbag.py | 12 +++++++-----
airflow/operators/python.py | 23 +++++++++++++++--------
airflow/utils/file.py | 12 ++++++++++++
airflow/utils/python_virtualenv_script.jinja2 | 19 +++++++++++++++++--
4 files changed, 51 insertions(+), 15 deletions(-)
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index ca81af8105..ce9bf5587b 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import hashlib
import importlib
import importlib.machinery
import importlib.util
@@ -48,7 +47,12 @@ from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.docs import get_docs_url
-from airflow.utils.file import correct_maybe_zipped, list_py_file_paths,
might_contain_dag
+from airflow.utils.file import (
+ correct_maybe_zipped,
+ get_unique_dag_module_name,
+ list_py_file_paths,
+ might_contain_dag,
+)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.session import NEW_SESSION, provide_session
@@ -326,9 +330,7 @@ class DagBag(LoggingMixin):
return []
self.log.debug("Importing %s", filepath)
- path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
- org_mod_name = Path(filepath).stem
- mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
+ mod_name = get_unique_dag_module_name(filepath)
if mod_name in sys.modules:
del sys.modules[mod_name]
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 1c5c9d3f69..1b1453cc5e 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -52,6 +52,7 @@ from airflow.models.variable import Variable
from airflow.operators.branch import BranchMixIn
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
+from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv,
write_python_script
@@ -437,15 +438,21 @@ class _BasePythonVirtualenvOperator(PythonOperator,
metaclass=ABCMeta):
self._write_args(input_path)
self._write_string_args(string_args_path)
+
+ jinja_context = {
+ "op_args": self.op_args,
+ "op_kwargs": op_kwargs,
+ "expect_airflow": self.expect_airflow,
+ "pickling_library": self.pickling_library.__name__,
+ "python_callable": self.python_callable.__name__,
+ "python_callable_source": self.get_python_source(),
+ }
+
+ if inspect.getfile(self.python_callable) == self.dag.fileloc:
+ jinja_context["modified_dag_module_name"] =
get_unique_dag_module_name(self.dag.fileloc)
+
write_python_script(
- jinja_context={
- "op_args": self.op_args,
- "op_kwargs": op_kwargs,
- "expect_airflow": self.expect_airflow,
- "pickling_library": self.pickling_library.__name__,
- "python_callable": self.python_callable.__name__,
- "python_callable_source": self.get_python_source(),
- },
+ jinja_context=jinja_context,
filename=os.fspath(script_path),
render_template_as_native_obj=self.dag.render_template_as_native_obj,
)
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
index 013d9ea36a..7e15eeb2f8 100644
--- a/airflow/utils/file.py
+++ b/airflow/utils/file.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import ast
+import hashlib
import logging
import os
import zipfile
@@ -33,6 +34,8 @@ from airflow.exceptions import RemovedInAirflow3Warning
log = logging.getLogger(__name__)
+MODIFIED_DAG_MODULE_NAME = "unusual_prefix_{path_hash}_{module_name}"
+
class _IgnoreRule(Protocol):
"""Interface for ignore rules for structural subtyping."""
@@ -379,3 +382,12 @@ def iter_airflow_imports(file_path: str) -> Generator[str,
None, None]:
for m in _find_imported_modules(parsed):
if m.startswith("airflow."):
yield m
+
+
+def get_unique_dag_module_name(file_path: str) -> str:
+ """Returns a unique module name in the format unusual_prefix_{sha1 of
module's file path}_{original module name}."""
+ if isinstance(file_path, str):
+ path_hash = hashlib.sha1(file_path.encode("utf-8")).hexdigest()
+ org_mod_name = Path(file_path).stem
+ return MODIFIED_DAG_MODULE_NAME.format(path_hash=path_hash,
module_name=org_mod_name)
+ raise ValueError("file_path should be a string to generate unique module
name")
diff --git a/airflow/utils/python_virtualenv_script.jinja2
b/airflow/utils/python_virtualenv_script.jinja2
index 7bbf6a9531..4199a47130 100644
--- a/airflow/utils/python_virtualenv_script.jinja2
+++ b/airflow/utils/python_virtualenv_script.jinja2
@@ -34,6 +34,22 @@ if sys.version_info >= (3,6):
pass
{% endif %}
+# Script
+{{ python_callable_source }}
+
+# monkey patching for the cases when python_callable is part of the dag module.
+{% if modified_dag_module_name is defined %}
+
+import types
+
+{{ modified_dag_module_name }} = types.ModuleType("{{
modified_dag_module_name }}")
+
+{{ modified_dag_module_name }}.{{ python_callable }} = {{ python_callable }}
+
+sys.modules["{{modified_dag_module_name}}"] = {{modified_dag_module_name}}
+
+{% endif%}
+
{% if op_args or op_kwargs %}
with open(sys.argv[1], "rb") as file:
arg_dict = {{ pickling_library }}.load(file)
@@ -47,8 +63,7 @@ with open(sys.argv[3], "r") as file:
virtualenv_string_args = list(map(lambda x: x.strip(), list(file)))
{% endif %}
-# Script
-{{ python_callable_source }}
+
try:
res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])
except Exception as e: