This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0010bf1269 Make _get_template_context an RPC call (#38567)
0010bf1269 is described below
commit 0010bf126909a7385b731de80668b91af7cc74e5
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Apr 2 09:25:49 2024 -0700
Make _get_template_context an RPC call (#38567)
Provide way of serializing the template context over RPC
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 2 ++
airflow/models/taskinstance.py | 23 +++++++++++++++++++++-
airflow/serialization/enums.py | 1 +
airflow/serialization/serialized_objects.py | 16 ++++++++++++++-
4 files changed, 40 insertions(+), 2 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 5074504b8d..243fcfa284 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable
from flask import Response
from airflow.jobs.job import Job, most_recent_job
+from airflow.models.taskinstance import _get_template_context
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.session import create_session
@@ -48,6 +49,7 @@ def _initialize_map() -> dict[str, Callable]:
from airflow.utils.log.file_task_handler import FileTaskHandler
functions: list[Callable] = [
+ _get_template_context,
DagFileProcessor.update_import_errors,
DagFileProcessor.manage_slas,
DagFileProcessorManager.deactivate_stale_dags,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 14fc0fc8f7..e7fdc5bec1 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -597,6 +597,7 @@ def _clear_next_method_args(*, task_instance: TaskInstance
| TaskInstancePydanti
task_instance.next_kwargs = None
+@internal_api_call
def _get_template_context(
*,
task_instance: TaskInstance | TaskInstancePydantic,
@@ -623,10 +624,30 @@ def _get_template_context(
task = task_instance.task
if TYPE_CHECKING:
+ assert task_instance.task
assert task
assert task.dag
- dag: DAG = task.dag
+ try:
+ dag: DAG = task.dag
+ except AirflowException:
+ from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
+ if isinstance(task_instance, TaskInstancePydantic):
+ ti = session.scalar(
+ select(TaskInstance).where(
+ TaskInstance.task_id == task_instance.task_id,
+ TaskInstance.dag_id == task_instance.dag_id,
+ TaskInstance.run_id == task_instance.run_id,
+ TaskInstance.map_index == task_instance.map_index,
+ )
+ )
+ dag = ti.dag_model.serialized_dag.dag
+ if hasattr(task_instance.task, "_dag"): # BaseOperator
+ task_instance.task._dag = dag
+ else: # MappedOperator
+ task_instance.task.dag = dag
+ else:
+ raise
dag_run = task_instance.get_dagrun(session)
data_interval = dag.get_run_data_interval(dag_run)
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index 2a4387eeb4..9b7cdbcc73 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -61,4 +61,5 @@ class DagAttributeTypes(str, Enum):
DATA_SET = "data_set"
LOG_TEMPLATE = "log_template"
CONNECTION = "connection"
+ TASK_CONTEXT = "task_context"
ARG_NOT_SET = "arg_not_set"
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 16a5c9e481..98d3d3a654 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -67,6 +67,7 @@ from airflow.task.priority_strategy import (
airflow_priority_weight_strategies_classes,
)
from airflow.utils.code_utils import get_python_source
+from airflow.utils.context import Context
from airflow.utils.docs import get_docs_url
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.operator_resources import Resources
@@ -602,6 +603,12 @@ class BaseSerialization:
)
elif isinstance(var, Connection):
return cls._encode(var.to_dict(validate=True),
type_=DAT.CONNECTION)
+ elif var.__class__ == Context:
+ d = {}
+ for k, v in var._context.items():
+ obj = cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models)
+ d[str(k)] = obj
+ return cls._encode(d, type_=DAT.TASK_CONTEXT)
elif use_pydantic_models and _ENABLE_AIP_44:
def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) ->
dict[str, Any]:
@@ -648,7 +655,14 @@ class BaseSerialization:
raise ValueError(f"The encoded_var should be dict and is
{type(encoded_var)}")
var = encoded_var[Encoding.VAR]
type_ = encoded_var[Encoding.TYPE]
-
+ if type_ == DAT.TASK_CONTEXT:
+ d = {}
+ for k, v in var.items():
+ if k == "task": # todo: add TaskPydantic so we don't need
this?
+ continue
+ d[k] = cls.deserialize(v, use_pydantic_models=True)
+ d["task"] = d["task_instance"].task # todo: add TaskPydantic so
we don't need this?
+ return Context(**d)
if type_ == DAT.DICT:
return {k: cls.deserialize(v, use_pydantic_models) for k, v in
var.items()}
elif type_ == DAT.DAG: