This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a69b668e1c4777070fd31483eab07b46d14b3553 Author: Kaxil Naik <[email protected]> AuthorDate: Wed Oct 23 00:13:59 2024 +0100 Use DAG Context from Task SDK [skip ci] [skip ci] --- airflow/models/dag.py | 2 ++ airflow/models/dagbag.py | 9 +++++---- task_sdk/src/airflow/sdk/definitions/contextmanager.py | 5 +++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 61e260f2fbc..230087602ca 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2427,10 +2427,12 @@ if STATICA_HACK: # pragma: no cover class DagContext(airflow.sdk.definitions.contextmanager.DagContext, share_parent_context=True): """:meta private:""" # noqa: D400 + # TODO: Method is not used anywhere. Remove them if not needed. @classmethod def push_context_managed_dag(cls, dag: DAG): cls.push(dag) + # TODO: Method is not used anywhere. Remove them if not needed. @classmethod def pop_context_managed_dag(cls) -> DAG | None: return cast(DAG, cls.pop()) diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 4ad82325e9e..c9ad8edaa40 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -280,7 +280,7 @@ class DagBag(LoggingMixin): def process_file(self, filepath, only_if_updated=True, safe_mode=True): """Given a path to a python module or zip file, import the module and look for dag objects within.""" - from airflow.models.dag import DagContext + from airflow.sdk.definitions.contextmanager import DagContext # if the source file no longer exists in the DB or in the filesystem, # return an empty list @@ -328,7 +328,7 @@ class DagBag(LoggingMixin): return found_dags def _load_modules_from_file(self, filepath, safe_mode): - from airflow.models.dag import DagContext + from airflow.sdk.definitions.contextmanager import DagContext if not might_contain_dag(filepath, safe_mode): # Don't want to spam user with skip messages @@ -384,7 +384,7 @@ class DagBag(LoggingMixin): return parse(mod_name, filepath) def _load_modules_from_zip(self, filepath, safe_mode): - from airflow.models.dag import DagContext + from airflow.sdk.definitions.contextmanager import DagContext mods = [] with zipfile.ZipFile(filepath) as current_zip_file: @@ -433,7 +433,8 @@ class DagBag(LoggingMixin): return mods def _process_modules(self, filepath, mods, file_last_changed_on_disk): - from airflow.models.dag import DAG, DagContext # Avoid circular import + from airflow.models.dag import DAG # Avoid circular import + from airflow.sdk.definitions.contextmanager import DagContext top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)} diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py b/task_sdk/src/airflow/sdk/definitions/contextmanager.py index 69bef998420..a22662a0619 100644 --- a/task_sdk/src/airflow/sdk/definitions/contextmanager.py +++ b/task_sdk/src/airflow/sdk/definitions/contextmanager.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import sys from collections import deque from types import ModuleType from typing import Any, Generic, TypeVar @@ -102,8 +103,8 @@ class DagContext(ContextStack[DAG]): dag = super().pop() # In a few cases around serialization we explicitly push None in to the stack if cls.current_autoregister_module_name is not None and dag and getattr(dag, "auto_register", True): - # mod = sys.modules[cls.current_autoregister_module_name] - cls.autoregistered_dags.add((dag, None)) + mod = sys.modules[cls.current_autoregister_module_name] + cls.autoregistered_dags.add((dag, mod)) return dag
