This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit a69b668e1c4777070fd31483eab07b46d14b3553
Author: Kaxil Naik <[email protected]>
AuthorDate: Wed Oct 23 00:13:59 2024 +0100

    Use DAG Context from Task SDK [skip ci]
    
     [skip ci]
---
 airflow/models/dag.py                                  | 2 ++
 airflow/models/dagbag.py                               | 9 +++++----
 task_sdk/src/airflow/sdk/definitions/contextmanager.py | 5 +++--
 3 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 61e260f2fbc..230087602ca 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2427,10 +2427,12 @@ if STATICA_HACK:  # pragma: no cover
 class DagContext(airflow.sdk.definitions.contextmanager.DagContext, 
share_parent_context=True):
     """:meta private:"""  # noqa: D400
 
+    # TODO: Method is not used anywhere. Remove them if not needed.
     @classmethod
     def push_context_managed_dag(cls, dag: DAG):
         cls.push(dag)
 
+    # TODO: Method is not used anywhere. Remove them if not needed.
     @classmethod
     def pop_context_managed_dag(cls) -> DAG | None:
         return cast(DAG, cls.pop())
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 4ad82325e9e..c9ad8edaa40 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -280,7 +280,7 @@ class DagBag(LoggingMixin):
 
     def process_file(self, filepath, only_if_updated=True, safe_mode=True):
         """Given a path to a python module or zip file, import the module and 
look for dag objects within."""
-        from airflow.models.dag import DagContext
+        from airflow.sdk.definitions.contextmanager import DagContext
 
         # if the source file no longer exists in the DB or in the filesystem,
         # return an empty list
@@ -328,7 +328,7 @@ class DagBag(LoggingMixin):
         return found_dags
 
     def _load_modules_from_file(self, filepath, safe_mode):
-        from airflow.models.dag import DagContext
+        from airflow.sdk.definitions.contextmanager import DagContext
 
         if not might_contain_dag(filepath, safe_mode):
             # Don't want to spam user with skip messages
@@ -384,7 +384,7 @@ class DagBag(LoggingMixin):
             return parse(mod_name, filepath)
 
     def _load_modules_from_zip(self, filepath, safe_mode):
-        from airflow.models.dag import DagContext
+        from airflow.sdk.definitions.contextmanager import DagContext
 
         mods = []
         with zipfile.ZipFile(filepath) as current_zip_file:
@@ -433,7 +433,8 @@ class DagBag(LoggingMixin):
         return mods
 
     def _process_modules(self, filepath, mods, file_last_changed_on_disk):
-        from airflow.models.dag import DAG, DagContext  # Avoid circular import
+        from airflow.models.dag import DAG  # Avoid circular import
+        from airflow.sdk.definitions.contextmanager import DagContext
 
         top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if 
isinstance(o, DAG)}
 
diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py 
b/task_sdk/src/airflow/sdk/definitions/contextmanager.py
index 69bef998420..a22662a0619 100644
--- a/task_sdk/src/airflow/sdk/definitions/contextmanager.py
+++ b/task_sdk/src/airflow/sdk/definitions/contextmanager.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import sys
 from collections import deque
 from types import ModuleType
 from typing import Any, Generic, TypeVar
@@ -102,8 +103,8 @@ class DagContext(ContextStack[DAG]):
         dag = super().pop()
         # In a few cases around serialization we explicitly push None in to 
the stack
         if cls.current_autoregister_module_name is not None and dag and 
getattr(dag, "auto_register", True):
-            # mod = sys.modules[cls.current_autoregister_module_name]
-            cls.autoregistered_dags.add((dag, None))
+            mod = sys.modules[cls.current_autoregister_module_name]
+            cls.autoregistered_dags.add((dag, mod))
         return dag
 
 

Reply via email to