This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0af5d923d9 Make _get_ti compatible with RPC (#38570)
0af5d923d9 is described below
commit 0af5d923d99591576b3758ab3c694d02dbe152bf
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Apr 9 16:34:03 2024 -0700
Make _get_ti compatible with RPC (#38570)
This is for AIP-44. I had to pull out the "db access" parts from `_get_ti`
and move them to RPC function `_get_ti_db_access`. To make that work, I also
had to ensure that "task" objects (a.k.a. instances of AbstractOperator) can
properly be roundtripped with BaseSerialization.serialize. Up to now they
could not be, and they were "manually" serialized as part of SerializedDAG.
This changes a bit the way we serialize task objects and so we had to handle
backcompat and update a fair [...]
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 2 +
airflow/cli/commands/task_command.py | 39 ++++-
airflow/models/taskinstance.py | 1 -
airflow/serialization/serialized_objects.py | 14 +-
tests/providers/amazon/aws/links/test_base_aws.py | 2 +-
.../amazon/aws/operators/test_emr_serverless.py | 16 +-
.../google/cloud/operators/test_bigquery.py | 8 +-
.../google/cloud/operators/test_dataproc.py | 14 +-
tests/serialization/test_dag_serialization.py | 184 +++++++++++----------
9 files changed, 163 insertions(+), 117 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 97449810bd..c428e8e481 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -37,6 +37,7 @@ log = logging.getLogger(__name__)
@functools.lru_cache
def _initialize_map() -> dict[str, Callable]:
+ from airflow.cli.commands.task_command import _get_ti_db_access
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.dag_processing.processor import DagFileProcessor
from airflow.models import Trigger, Variable, XCom
@@ -51,6 +52,7 @@ def _initialize_map() -> dict[str, Callable]:
functions: list[Callable] = [
_get_template_context,
_update_rtif,
+ _get_ti_db_access,
DagFileProcessor.update_import_errors,
DagFileProcessor.manage_slas,
DagFileProcessorManager.deactivate_stale_dags,
diff --git a/airflow/cli/commands/task_command.py
b/airflow/cli/commands/task_command.py
index 05adb0abda..23a1ed460e 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -34,7 +34,7 @@ from pendulum.parsing.exceptions import ParserError
from sqlalchemy import select
from airflow import settings
-from airflow.api_internal.internal_api_call import InternalApiConfig
+from airflow.api_internal.internal_api_call import InternalApiConfig,
internal_api_call
from airflow.cli.simple_table import AirflowConsole
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagRunNotFound, TaskDeferred,
TaskInstanceNotFound
@@ -156,8 +156,10 @@ def _get_dag_run(
raise ValueError(f"unknown create_if_necessary value:
{create_if_necessary!r}")
+@internal_api_call
@provide_session
-def _get_ti(
+def _get_ti_db_access(
+ dag: DAG,
task: Operator,
map_index: int,
*,
@@ -167,9 +169,9 @@ def _get_ti(
session: Session = NEW_SESSION,
) -> tuple[TaskInstance | TaskInstancePydantic, bool]:
"""Get the task instance through DagRun.run_id, if that fails, get the TI
the old way."""
- dag = task.dag
- if dag is None:
- raise ValueError("Cannot get task instance for a task not assigned to
a DAG")
+ if task.dag_id != dag.dag_id:
+ raise ValueError(f"Provided task '{task.task_id}' is not assigned to
provided dag {dag.dag_id}.")
+
if not exec_date_or_run_id and not create_if_necessary:
raise ValueError("Must provide `exec_date_or_run_id` if not
`create_if_necessary`.")
if needs_expansion(task):
@@ -201,6 +203,33 @@ def _get_ti(
return ti, dr_created
+def _get_ti(
+ task: Operator,
+ map_index: int,
+ *,
+ exec_date_or_run_id: str | None = None,
+ pool: str | None = None,
+ create_if_necessary: CreateIfNecessary = False,
+):
+ dag = task.dag
+ if dag is None:
+ raise ValueError("Cannot get task instance for a task not assigned to
a DAG")
+
+ ti, dr_created = _get_ti_db_access(
+ dag=dag,
+ task=task,
+ map_index=map_index,
+ exec_date_or_run_id=exec_date_or_run_id,
+ pool=pool,
+ create_if_necessary=create_if_necessary,
+ )
+ # setting ti.task is necessary for AIP-44 since the task object does not
serialize perfectly
+ # if we update the serialization logic for Operator to also serialize the
dag object on it,
+ # then this would not be necessary;
+ ti.task = task
+ return ti, dr_created
+
+
def _run_task_by_selected_method(
args, dag: DAG, ti: TaskInstance | TaskInstancePydantic
) -> None | TaskReturnCode:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index d52a71c5b2..a55ea0fe77 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -776,7 +776,6 @@ def _get_template_context(
nonlocal dag_run
if dag_run not in session:
dag_run = session.merge(dag_run, load=False)
-
dataset_events = dag_run.consumed_dataset_events
triggering_events: dict[str, list[DatasetEvent |
DatasetEventPydantic]] = defaultdict(list)
for event in dataset_events:
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 9c86bd205d..9cc1d41931 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -539,9 +539,9 @@ class BaseSerialization:
elif isinstance(var, Resources):
return var.to_dict()
elif isinstance(var, MappedOperator):
- return SerializedBaseOperator.serialize_mapped_operator(var)
+ return
cls._encode(SerializedBaseOperator.serialize_mapped_operator(var), type_=DAT.OP)
elif isinstance(var, BaseOperator):
- return SerializedBaseOperator.serialize_operator(var)
+ return cls._encode(SerializedBaseOperator.serialize_operator(var),
type_=DAT.OP)
elif isinstance(var, cls._datetime_types):
return cls._encode(var.timestamp(), type_=DAT.DATETIME)
elif isinstance(var, datetime.timedelta):
@@ -1476,9 +1476,15 @@ class SerializedDAG(DAG, BaseSerialization):
v = set(v)
elif k == "tasks":
SerializedBaseOperator._load_operator_extra_links =
cls._load_operator_extra_links
-
- v = {task["task_id"]:
SerializedBaseOperator.deserialize_operator(task) for task in v}
+ tasks = {}
+ for obj in v:
+ if obj.get(Encoding.TYPE) == DAT.OP:
+ deser =
SerializedBaseOperator.deserialize_operator(obj[Encoding.VAR])
+ tasks[deser.task_id] = deser
+ else: # todo: remove in Airflow 3.0 (backcompat for
pre-2.10)
+ tasks[obj["task_id"]] =
SerializedBaseOperator.deserialize_operator(obj)
k = "task_dict"
+ v = tasks
elif k == "timezone":
v = cls._deserialize_timezone(v)
elif k == "dagrun_timeout":
diff --git a/tests/providers/amazon/aws/links/test_base_aws.py
b/tests/providers/amazon/aws/links/test_base_aws.py
index 546ead164d..222be31458 100644
--- a/tests/providers/amazon/aws/links/test_base_aws.py
+++ b/tests/providers/amazon/aws/links/test_base_aws.py
@@ -203,7 +203,7 @@ class BaseAwsLinksTestCase:
"""Test: Operator links should exist for serialized DAG."""
self.create_op_and_ti(self.link_class, dag_id="test_link_serialize",
task_id=self.task_id)
serialized_dag = self.dag_maker.get_serialized_data()
- operator_extra_link =
serialized_dag["dag"]["tasks"][0]["_operator_extra_links"]
+ operator_extra_link =
serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"]
error_message = "Operator links should exist for serialized DAG"
assert operator_extra_link == [{self.full_qualname: {}}], error_message
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py
b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index a5527fb4a4..be42cf63ba 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -38,7 +38,7 @@ from airflow.providers.amazon.aws.operators.emr import (
EmrServerlessStopApplicationOperator,
)
from airflow.serialization.serialized_objects import (
- SerializedBaseOperator,
+ BaseSerialization,
)
from airflow.utils.types import NOTSET
@@ -1118,11 +1118,10 @@ class TestEmrServerlessStartJobOperator:
configuration_overrides=[s3_configuration_overrides,
cloudwatch_configuration_overrides],
)
- serialize = SerializedBaseOperator.serialize
- deserialize = SerializedBaseOperator.deserialize_operator
- deserialized_operator = deserialize(serialize(operator))
+ ser_operator = BaseSerialization.serialize(operator)
+ deser_operator = BaseSerialization.deserialize(ser_operator)
- assert deserialized_operator.operator_extra_links == [
+ assert deser_operator.operator_extra_links == [
EmrServerlessS3LogsLink(),
EmrServerlessCloudWatchLogsLink(),
]
@@ -1140,11 +1139,10 @@ class TestEmrServerlessStartJobOperator:
configuration_overrides=[s3_configuration_overrides,
cloudwatch_configuration_overrides],
)
- serialize = SerializedBaseOperator.serialize
- deserialize = SerializedBaseOperator.deserialize_operator
- deserialized_operator = deserialize(serialize(operator))
+ ser_operator = BaseSerialization.serialize(operator)
+ deser_operator = BaseSerialization.deserialize(ser_operator)
- assert deserialized_operator.operator_extra_links == [
+ assert deser_operator.operator_extra_links == [
EmrServerlessS3LogsLink(),
EmrServerlessCloudWatchLogsLink(),
EmrServerlessDashboardLink(),
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py
b/tests/providers/google/cloud/operators/test_bigquery.py
index ee57be6e65..15a6e8d971 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -729,7 +729,7 @@ class TestBigQueryOperator:
sql="SELECT * FROM test_table",
)
serialized_dag = dag_maker.get_serialized_data()
- assert "sql" in serialized_dag["dag"]["tasks"][0]
+ assert "sql" in serialized_dag["dag"]["tasks"][0]["__var"]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict[TASK_ID]
@@ -740,7 +740,7 @@ class TestBigQueryOperator:
#########################################################
# Check Serialized version of operator link
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ assert
serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"] == [
{"airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink": {}}
]
@@ -766,7 +766,7 @@ class TestBigQueryOperator:
sql=["SELECT * FROM test_table", "SELECT * FROM test_table2"],
)
serialized_dag = dag_maker.get_serialized_data()
- assert "sql" in serialized_dag["dag"]["tasks"][0]
+ assert "sql" in serialized_dag["dag"]["tasks"][0]["__var"]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict[TASK_ID]
@@ -777,7 +777,7 @@ class TestBigQueryOperator:
#########################################################
# Check Serialized version of operator link
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ assert
serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"] == [
{"airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink":
{"index": 0}},
{"airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink":
{"index": 1}},
]
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py
b/tests/providers/google/cloud/operators/test_dataproc.py
index a56e8fe130..7d986b5017 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -1065,7 +1065,7 @@ def test_create_cluster_operator_extra_links(dag_maker,
create_task_instance_of_
deserialized_task = deserialized_dag.task_dict[TASK_ID]
# Assert operator links for serialized DAG
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ assert serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"]
== [
{"airflow.providers.google.cloud.links.dataproc.DataprocClusterLink":
{}}
]
@@ -1167,7 +1167,7 @@ def test_scale_cluster_operator_extra_links(dag_maker,
create_task_instance_of_o
deserialized_task = deserialized_dag.task_dict[TASK_ID]
# Assert operator links for serialized DAG
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ assert serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"]
== [
{"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
]
@@ -1562,7 +1562,7 @@ def test_submit_job_operator_extra_links(mock_hook,
dag_maker, create_task_insta
deserialized_task = deserialized_dag.task_dict[TASK_ID]
# Assert operator links for serialized_dag
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ assert serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"]
== [
{"airflow.providers.google.cloud.links.dataproc.DataprocJobLink": {}}
]
@@ -1767,7 +1767,7 @@ def test_update_cluster_operator_extra_links(dag_maker,
create_task_instance_of_
deserialized_task = deserialized_dag.task_dict[TASK_ID]
# Assert operator links for serialized_dag
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ assert serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"]
== [
{"airflow.providers.google.cloud.links.dataproc.DataprocClusterLink":
{}}
]
@@ -1989,7 +1989,7 @@ def
test_instantiate_workflow_operator_extra_links(mock_hook, dag_maker, create_
deserialized_task = deserialized_dag.task_dict[TASK_ID]
# Assert operator links for serialized_dag
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ assert serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"]
== [
{"airflow.providers.google.cloud.links.dataproc.DataprocWorkflowLink":
{}}
]
@@ -2151,7 +2151,7 @@ def test_instantiate_inline_workflow_operator_extra_links(
deserialized_task = deserialized_dag.task_dict[TASK_ID]
# Assert operator links for serialized_dag
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ assert serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"]
== [
{"airflow.providers.google.cloud.links.dataproc.DataprocWorkflowLink":
{}}
]
@@ -2472,7 +2472,7 @@ def test_submit_spark_job_operator_extra_links(mock_hook,
dag_maker, create_task
deserialized_task = deserialized_dag.task_dict[TASK_ID]
# Assert operator links for serialized DAG
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ assert serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"]
== [
{"airflow.providers.google.cloud.links.dataproc.DataprocLink": {}}
]
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 992d237aac..39270f51a9 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -161,62 +161,68 @@ serialized_simple_dag_ground_truth = {
"_processor_dags_folder": f"{repo_root}/tests/dags",
"tasks": [
{
- "task_id": "bash_task",
- "owner": "airflow",
- "retries": 1,
- "retry_delay": 300.0,
- "max_retry_delay": 600.0,
- "sla": 100.0,
- "downstream_task_ids": [],
- "_is_empty": False,
- "ui_color": "#f0ede4",
- "ui_fgcolor": "#000",
- "template_ext": [".sh", ".bash"],
- "template_fields": ["bash_command", "env", "cwd"],
- "template_fields_renderers": {"bash_command": "bash", "env":
"json"},
- "bash_command": "echo {{ task.task_id }}",
- "_task_type": "BashOperator",
- "_task_module": "airflow.operators.bash",
- "pool": "default_pool",
- "is_setup": False,
- "is_teardown": False,
- "on_failure_fail_dagrun": False,
- "executor_config": {
- "__type": "dict",
- "__var": {
- "pod_override": {
- "__type": "k8s.V1Pod",
- "__var":
PodGenerator.serialize_pod(executor_config_pod),
- }
+ "__type": "operator",
+ "__var": {
+ "task_id": "bash_task",
+ "owner": "airflow",
+ "retries": 1,
+ "retry_delay": 300.0,
+ "max_retry_delay": 600.0,
+ "sla": 100.0,
+ "downstream_task_ids": [],
+ "_is_empty": False,
+ "ui_color": "#f0ede4",
+ "ui_fgcolor": "#000",
+ "template_ext": [".sh", ".bash"],
+ "template_fields": ["bash_command", "env", "cwd"],
+ "template_fields_renderers": {"bash_command": "bash",
"env": "json"},
+ "bash_command": "echo {{ task.task_id }}",
+ "_task_type": "BashOperator",
+ "_task_module": "airflow.operators.bash",
+ "pool": "default_pool",
+ "is_setup": False,
+ "is_teardown": False,
+ "on_failure_fail_dagrun": False,
+ "executor_config": {
+ "__type": "dict",
+ "__var": {
+ "pod_override": {
+ "__type": "k8s.V1Pod",
+ "__var":
PodGenerator.serialize_pod(executor_config_pod),
+ }
+ },
},
+ "doc_md": "### Task Tutorial Documentation",
+ "_log_config_logger_name": "airflow.task.operators",
+ "weight_rule": "downstream",
},
- "doc_md": "### Task Tutorial Documentation",
- "_log_config_logger_name": "airflow.task.operators",
- "weight_rule": "downstream",
},
{
- "task_id": "custom_task",
- "retries": 1,
- "retry_delay": 300.0,
- "max_retry_delay": 600.0,
- "sla": 100.0,
- "downstream_task_ids": [],
- "_is_empty": False,
- "_operator_extra_links":
[{"tests.test_utils.mock_operators.CustomOpLink": {}}],
- "ui_color": "#fff",
- "ui_fgcolor": "#000",
- "template_ext": [],
- "template_fields": ["bash_command"],
- "template_fields_renderers": {},
- "_task_type": "CustomOperator",
- "_operator_name": "@custom",
- "_task_module": "tests.test_utils.mock_operators",
- "pool": "default_pool",
- "is_setup": False,
- "is_teardown": False,
- "on_failure_fail_dagrun": False,
- "_log_config_logger_name": "airflow.task.operators",
- "weight_rule": "downstream",
+ "__type": "operator",
+ "__var": {
+ "task_id": "custom_task",
+ "retries": 1,
+ "retry_delay": 300.0,
+ "max_retry_delay": 600.0,
+ "sla": 100.0,
+ "downstream_task_ids": [],
+ "_is_empty": False,
+ "_operator_extra_links":
[{"tests.test_utils.mock_operators.CustomOpLink": {}}],
+ "ui_color": "#fff",
+ "ui_fgcolor": "#000",
+ "template_ext": [],
+ "template_fields": ["bash_command"],
+ "template_fields_renderers": {},
+ "_task_type": "CustomOperator",
+ "_operator_name": "@custom",
+ "_task_module": "tests.test_utils.mock_operators",
+ "pool": "default_pool",
+ "is_setup": False,
+ "is_teardown": False,
+ "on_failure_fail_dagrun": False,
+ "_log_config_logger_name": "airflow.task.operators",
+ "weight_rule": "downstream",
+ },
},
],
"schedule_interval": {"__type": "timedelta", "__var": 86400.0},
@@ -451,7 +457,7 @@ class TestStringifiedDAGs:
)
for task in actual["dag"]["tasks"]:
for k, v in task.items():
- print(task["task_id"], k, v)
+ print(task["__var"]["task_id"], k, v)
assert actual == expected
@pytest.mark.db_test
@@ -492,7 +498,11 @@ class TestStringifiedDAGs:
items should not matter but assertEqual would fail if the order of
items changes in the dag dictionary
"""
- dag_dict["dag"]["tasks"] = sorted(dag_dict["dag"]["tasks"],
key=sorted)
+ tasks = []
+ for task in sorted(dag_dict["dag"]["tasks"], key=lambda x:
x["__var"]["task_id"]):
+ task["__var"] = dict(sorted(task["__var"].items(), key=lambda
x: x[0]))
+ tasks.append(task)
+ dag_dict["dag"]["tasks"] = tasks
dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"]
= sorted(
dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"]
)
@@ -735,9 +745,9 @@ class TestStringifiedDAGs:
if not task_start_date or dag_start_date >= task_start_date:
# If dag.start_date > task.start_date ->
task.start_date=dag.start_date
# because of the logic in dag.add_task()
- assert "start_date" not in serialized_dag["dag"]["tasks"][0]
+ assert "start_date" not in
serialized_dag["dag"]["tasks"][0]["__var"]
else:
- assert "start_date" in serialized_dag["dag"]["tasks"][0]
+ assert "start_date" in serialized_dag["dag"]["tasks"][0]["__var"]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict["simple_task"]
@@ -773,9 +783,9 @@ class TestStringifiedDAGs:
if not task_end_date or dag_end_date <= task_end_date:
# If dag.end_date < task.end_date -> task.end_date=dag.end_date
# because of the logic in dag.add_task()
- assert "end_date" not in serialized_dag["dag"]["tasks"][0]
+ assert "end_date" not in serialized_dag["dag"]["tasks"][0]["__var"]
else:
- assert "end_date" in serialized_dag["dag"]["tasks"][0]
+ assert "end_date" in serialized_dag["dag"]["tasks"][0]["__var"]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict["simple_task"]
@@ -993,9 +1003,9 @@ class TestStringifiedDAGs:
serialized_dag = SerializedDAG.to_dict(dag)
if val:
- assert "params" in serialized_dag["dag"]["tasks"][0]
+ assert "params" in serialized_dag["dag"]["tasks"][0]["__var"]
else:
- assert "params" not in serialized_dag["dag"]["tasks"][0]
+ assert "params" not in serialized_dag["dag"]["tasks"][0]["__var"]
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
@@ -1049,7 +1059,7 @@ class TestStringifiedDAGs:
CustomOperator(task_id="simple_task", bash_command=bash_command)
serialized_dag = SerializedDAG.to_dict(dag)
- assert "bash_command" in serialized_dag["dag"]["tasks"][0]
+ assert "bash_command" in serialized_dag["dag"]["tasks"][0]["__var"]
dag = SerializedDAG.from_dict(serialized_dag)
simple_task = dag.task_dict["simple_task"]
@@ -1059,7 +1069,7 @@ class TestStringifiedDAGs:
# Verify Operator Links work with Serialized Operator
#########################################################
# Check Serialized version of operator link only contains the inbuilt
Op Link
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] ==
serialized_links
+ assert
serialized_dag["dag"]["tasks"][0]["__var"]["_operator_extra_links"] ==
serialized_links
# Test all the extra_links are set
assert simple_task.extra_links == sorted({*links, "airflow", "github",
"google"})
@@ -2162,9 +2172,9 @@ def test_operator_expand_serde():
bash_command=literal
)
- serialized = SerializedBaseOperator.serialize(real_op)
+ serialized = BaseSerialization.serialize(real_op)
- assert serialized == {
+ assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_task_module": "airflow.operators.bash",
@@ -2194,7 +2204,7 @@ def test_operator_expand_serde():
"_expand_input_attr": "expand_input",
}
- op = SerializedBaseOperator.deserialize_operator(serialized)
+ op = BaseSerialization.deserialize(serialized)
assert isinstance(op, MappedOperator)
assert op.deps is MappedOperator.deps_for(BaseOperator)
@@ -2220,8 +2230,8 @@ def test_operator_expand_xcomarg_serde():
task1 = BaseOperator(task_id="op1")
mapped =
MockOperator.partial(task_id="task_2").expand(arg2=XComArg(task1))
- serialized = SerializedBaseOperator.serialize(mapped)
- assert serialized == {
+ serialized = BaseSerialization.serialize(mapped)
+ assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_task_module": "tests.test_utils.mock_operators",
@@ -2246,7 +2256,7 @@ def test_operator_expand_xcomarg_serde():
"_expand_input_attr": "expand_input",
}
- op = SerializedBaseOperator.deserialize_operator(serialized)
+ op = BaseSerialization.deserialize(serialized)
assert op.deps is MappedOperator.deps_for(BaseOperator)
# The XComArg can't be deserialized before the DAG is.
@@ -2272,8 +2282,8 @@ def test_operator_expand_kwargs_literal_serde(strict):
strict=strict,
)
- serialized = SerializedBaseOperator.serialize(mapped)
- assert serialized == {
+ serialized = BaseSerialization.serialize(mapped)
+ assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_task_module": "tests.test_utils.mock_operators",
@@ -2301,7 +2311,7 @@ def test_operator_expand_kwargs_literal_serde(strict):
"_expand_input_attr": "expand_input",
}
- op = SerializedBaseOperator.deserialize_operator(serialized)
+ op = BaseSerialization.deserialize(serialized)
assert op.deps is MappedOperator.deps_for(BaseOperator)
assert op._disallow_kwargs_override == strict
@@ -2325,7 +2335,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict):
mapped =
MockOperator.partial(task_id="task_2").expand_kwargs(XComArg(task1),
strict=strict)
serialized = SerializedBaseOperator.serialize(mapped)
- assert serialized == {
+ assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_task_module": "tests.test_utils.mock_operators",
@@ -2347,7 +2357,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict):
"_expand_input_attr": "expand_input",
}
- op = SerializedBaseOperator.deserialize_operator(serialized)
+ op = BaseSerialization.deserialize(serialized)
assert op.deps is MappedOperator.deps_for(BaseOperator)
assert op._disallow_kwargs_override == strict
@@ -2367,9 +2377,11 @@ def test_operator_expand_deserialized_unmap():
normal = BashOperator(task_id="a", bash_command=[1, 2],
executor_config={"a": "b"})
mapped = BashOperator.partial(task_id="a", executor_config={"a":
"b"}).expand(bash_command=[1, 2])
- serialize = SerializedBaseOperator.serialize
- deserialize = SerializedBaseOperator.deserialize_operator
- assert deserialize(serialize(mapped)).unmap(None) ==
deserialize(serialize(normal))
+ ser_mapped = BaseSerialization.serialize(mapped)
+ deser_mapped = BaseSerialization.deserialize(ser_mapped)
+ ser_normal = BaseSerialization.serialize(normal)
+ deser_normal = BaseSerialization.deserialize(ser_normal)
+ assert deser_mapped.unmap(None) == deser_normal
@pytest.mark.db_test
@@ -2380,7 +2392,7 @@ def test_sensor_expand_deserialized_unmap():
serialize = SerializedBaseOperator.serialize
- deserialize = SerializedBaseOperator.deserialize_operator
+ deserialize = SerializedBaseOperator.deserialize
assert deserialize(serialize(mapped)).unmap(None) ==
deserialize(serialize(normal))
@@ -2395,8 +2407,8 @@ def test_task_resources_serde():
with DAG("test_task_resources", start_date=execution_date) as _:
task = EmptyOperator(task_id=task_id, resources={"cpus": 0.1, "ram":
2048})
- serialized = SerializedBaseOperator.serialize(task)
- assert serialized["resources"] == {
+ serialized = BaseSerialization.serialize(task)
+ assert serialized["__var"]["resources"] == {
"cpus": {"name": "CPU", "qty": 0.1, "units_str": "core(s)"},
"disk": {"name": "Disk", "qty": 512, "units_str": "MB"},
"gpus": {"name": "GPU", "qty": 0, "units_str": "gpu(s)"},
@@ -2421,8 +2433,8 @@ def test_taskflow_expand_serde():
original = dag.get_task("x")
- serialized = SerializedBaseOperator.serialize(original)
- assert serialized == {
+ serialized = BaseSerialization.serialize(original)
+ assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_task_module": "airflow.decorators.python",
@@ -2461,7 +2473,7 @@ def test_taskflow_expand_serde():
"_expand_input_attr": "op_kwargs_expand_input",
}
- deserialized = SerializedBaseOperator.deserialize_operator(serialized)
+ deserialized = BaseSerialization.deserialize(serialized)
assert isinstance(deserialized, MappedOperator)
assert deserialized.deps is MappedOperator.deps_for(BaseOperator)
assert deserialized.upstream_task_ids == set()
@@ -2516,8 +2528,8 @@ def test_taskflow_expand_kwargs_serde(strict):
original = dag.get_task("x")
- serialized = SerializedBaseOperator.serialize(original)
- assert serialized == {
+ serialized = BaseSerialization.serialize(original)
+ assert serialized["__var"] == {
"_is_empty": False,
"_is_mapped": True,
"_task_module": "airflow.decorators.python",
@@ -2553,7 +2565,7 @@ def test_taskflow_expand_kwargs_serde(strict):
"_expand_input_attr": "op_kwargs_expand_input",
}
- deserialized = SerializedBaseOperator.deserialize_operator(serialized)
+ deserialized = BaseSerialization.deserialize(serialized)
assert isinstance(deserialized, MappedOperator)
assert deserialized.deps is MappedOperator.deps_for(BaseOperator)
assert deserialized._disallow_kwargs_override == strict
@@ -2653,7 +2665,7 @@ def test_mapped_task_with_operator_extra_links_property():
with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
_DummyOperator.partial(task_id="task").expand(inputs=[1, 2, 3])
serialized_dag = SerializedBaseOperator.serialize(dag)
- assert serialized_dag[Encoding.VAR]["tasks"][0] == {
+ assert serialized_dag[Encoding.VAR]["tasks"][0]["__var"] == {
"task_id": "task",
"expand_input": {
"type": "dict-of-lists",