This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0010bf1269 Make _get_template_context an RPC call (#38567)
0010bf1269 is described below

commit 0010bf126909a7385b731de80668b91af7cc74e5
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Apr 2 09:25:49 2024 -0700

    Make _get_template_context an RPC call (#38567)
    
    Provide way of serializing the template context over RPC
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py |  2 ++
 airflow/models/taskinstance.py                     | 23 +++++++++++++++++++++-
 airflow/serialization/enums.py                     |  1 +
 airflow/serialization/serialized_objects.py        | 16 ++++++++++++++-
 4 files changed, 40 insertions(+), 2 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py 
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 5074504b8d..243fcfa284 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable
 from flask import Response
 
 from airflow.jobs.job import Job, most_recent_job
+from airflow.models.taskinstance import _get_template_context
 from airflow.serialization.serialized_objects import BaseSerialization
 from airflow.utils.session import create_session
 
@@ -48,6 +49,7 @@ def _initialize_map() -> dict[str, Callable]:
     from airflow.utils.log.file_task_handler import FileTaskHandler
 
     functions: list[Callable] = [
+        _get_template_context,
         DagFileProcessor.update_import_errors,
         DagFileProcessor.manage_slas,
         DagFileProcessorManager.deactivate_stale_dags,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 14fc0fc8f7..e7fdc5bec1 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -597,6 +597,7 @@ def _clear_next_method_args(*, task_instance: TaskInstance 
| TaskInstancePydanti
     task_instance.next_kwargs = None
 
 
+@internal_api_call
 def _get_template_context(
     *,
     task_instance: TaskInstance | TaskInstancePydantic,
@@ -623,10 +624,30 @@ def _get_template_context(
 
     task = task_instance.task
     if TYPE_CHECKING:
+        assert task_instance.task
         assert task
         assert task.dag
-    dag: DAG = task.dag
+    try:
+        dag: DAG = task.dag
+    except AirflowException:
+        from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
 
+        if isinstance(task_instance, TaskInstancePydantic):
+            ti = session.scalar(
+                select(TaskInstance).where(
+                    TaskInstance.task_id == task_instance.task_id,
+                    TaskInstance.dag_id == task_instance.dag_id,
+                    TaskInstance.run_id == task_instance.run_id,
+                    TaskInstance.map_index == task_instance.map_index,
+                )
+            )
+            dag = ti.dag_model.serialized_dag.dag
+            if hasattr(task_instance.task, "_dag"):  # BaseOperator
+                task_instance.task._dag = dag
+            else:  # MappedOperator
+                task_instance.task.dag = dag
+        else:
+            raise
     dag_run = task_instance.get_dagrun(session)
     data_interval = dag.get_run_data_interval(dag_run)
 
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index 2a4387eeb4..9b7cdbcc73 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -61,4 +61,5 @@ class DagAttributeTypes(str, Enum):
     DATA_SET = "data_set"
     LOG_TEMPLATE = "log_template"
     CONNECTION = "connection"
+    TASK_CONTEXT = "task_context"
     ARG_NOT_SET = "arg_not_set"
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 16a5c9e481..98d3d3a654 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -67,6 +67,7 @@ from airflow.task.priority_strategy import (
     airflow_priority_weight_strategies_classes,
 )
 from airflow.utils.code_utils import get_python_source
+from airflow.utils.context import Context
 from airflow.utils.docs import get_docs_url
 from airflow.utils.module_loading import import_string, qualname
 from airflow.utils.operator_resources import Resources
@@ -602,6 +603,12 @@ class BaseSerialization:
             )
         elif isinstance(var, Connection):
             return cls._encode(var.to_dict(validate=True), 
type_=DAT.CONNECTION)
+        elif var.__class__ == Context:
+            d = {}
+            for k, v in var._context.items():
+                obj = cls.serialize(v, strict=strict, 
use_pydantic_models=use_pydantic_models)
+                d[str(k)] = obj
+            return cls._encode(d, type_=DAT.TASK_CONTEXT)
         elif use_pydantic_models and _ENABLE_AIP_44:
 
             def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> 
dict[str, Any]:
@@ -648,7 +655,14 @@ class BaseSerialization:
             raise ValueError(f"The encoded_var should be dict and is 
{type(encoded_var)}")
         var = encoded_var[Encoding.VAR]
         type_ = encoded_var[Encoding.TYPE]
-
+        if type_ == DAT.TASK_CONTEXT:
+            d = {}
+            for k, v in var.items():
+                if k == "task":  # todo: add TaskPydantic so we don't need 
this?
+                    continue
+                d[k] = cls.deserialize(v, use_pydantic_models=True)
+            d["task"] = d["task_instance"].task  # todo: add TaskPydantic so 
we don't need this?
+            return Context(**d)
         if type_ == DAT.DICT:
             return {k: cls.deserialize(v, use_pydantic_models) for k, v in 
var.items()}
         elif type_ == DAT.DAG:

Reply via email to