This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 92389cf090 Handle json encoding of V1Pod in task callback (#27609)
92389cf090 is described below
commit 92389cf090f336073337517f2460c2914a9f0d4b
Author: Daniel Standish <[email protected]>
AuthorDate: Wed Nov 16 07:43:43 2022 -0800
Handle json encoding of V1Pod in task callback (#27609)
---
airflow/callbacks/callback_requests.py | 14 ++---
airflow/exceptions.py | 2 +-
airflow/models/taskinstance.py | 10 ++++
airflow/serialization/enums.py | 1 +
airflow/serialization/serialized_objects.py | 23 ++++++--
tests/__init__.py | 5 ++
tests/callbacks/test_callback_requests.py | 36 ++++++++++++
tests/serialization/test_serialized_objects.py | 78 ++++++++++++++++++++++++++
8 files changed, 155 insertions(+), 14 deletions(-)
diff --git a/airflow/callbacks/callback_requests.py
b/airflow/callbacks/callback_requests.py
index d8c36cb753..f0c33e79f8 100644
--- a/airflow/callbacks/callback_requests.py
+++ b/airflow/callbacks/callback_requests.py
@@ -84,17 +84,17 @@ class TaskCallbackRequest(CallbackRequest):
self.is_failure_callback = is_failure_callback
def to_json(self) -> str:
- dict_obj = self.__dict__.copy()
- dict_obj["simple_task_instance"] = self.simple_task_instance.as_dict()
- return json.dumps(dict_obj)
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ val = BaseSerialization.serialize(self.__dict__, strict=True)
+ return json.dumps(val)
@classmethod
def from_json(cls, json_str: str):
- from airflow.models.taskinstance import SimpleTaskInstance
+ from airflow.serialization.serialized_objects import BaseSerialization
- kwargs = json.loads(json_str)
- simple_ti =
SimpleTaskInstance.from_dict(obj_dict=kwargs.pop("simple_task_instance"))
- return cls(simple_task_instance=simple_ti, **kwargs)
+ val = json.loads(json_str)
+ return cls(**BaseSerialization.deserialize(val))
class DagCallbackRequest(CallbackRequest):
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 5815e47647..e3c4e333b9 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -225,7 +225,7 @@ class TaskAlreadyInTaskGroup(AirflowException):
class SerializationError(AirflowException):
- """A problem occurred when trying to serialize a DAG."""
+ """A problem occurred when trying to serialize something."""
class ParamValidationError(AirflowException):
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c1957384ef..3f9a587a73 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2574,6 +2574,11 @@ class SimpleTaskInstance:
return NotImplemented
def as_dict(self):
+ warnings.warn(
+ "This method is deprecated. Use BaseSerialization.serialize.",
+ RemovedInAirflow3Warning,
+ stacklevel=2,
+ )
new_dict = dict(self.__dict__)
for key in new_dict:
if key in ["start_date", "end_date"]:
@@ -2604,6 +2609,11 @@ class SimpleTaskInstance:
@classmethod
def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance:
+ warnings.warn(
+ "This method is deprecated. Use BaseSerialization.deserialize.",
+ RemovedInAirflow3Warning,
+ stacklevel=2,
+ )
ti_key = TaskInstanceKey(*obj_dict.pop("key"))
start_date = None
end_date = None
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index e798d6646e..f233261613 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -50,3 +50,4 @@ class DagAttributeTypes(str, Enum):
PARAM = "param"
XCOM_REF = "xcomref"
DATASET = "dataset"
+ SIMPLE_TASK_INSTANCE = "simple_task_instance"
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 32b6533c41..d41d70a4e7 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -44,6 +44,7 @@ from airflow.models.expandinput import EXPAND_INPUT_EMPTY,
ExpandInput, create_e
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.param import Param, ParamsDict
+from airflow.models.taskinstance import SimpleTaskInstance
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg,
serialize_xcom_arg
from airflow.providers_manager import ProvidersManager
@@ -381,7 +382,9 @@ class BaseSerialization:
return serialized_object
@classmethod
- def serialize(cls, var: Any) -> Any: # Unfortunately there is no support
for recursive types in mypy
+ def serialize(
+ cls, var: Any, *, strict: bool = False
+ ) -> Any: # Unfortunately there is no support for recursive types in mypy
"""Helper function of depth first search for serialization.
The serialization protocol is:
@@ -400,9 +403,11 @@ class BaseSerialization:
return var.value
return var
elif isinstance(var, dict):
- return cls._encode({str(k): cls.serialize(v) for k, v in
var.items()}, type_=DAT.DICT)
+ return cls._encode(
+ {str(k): cls.serialize(v, strict=strict) for k, v in
var.items()}, type_=DAT.DICT
+ )
elif isinstance(var, list):
- return [cls.serialize(v) for v in var]
+ return [cls.serialize(v, strict=strict) for v in var]
elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and
isinstance(var, k8s.V1Pod):
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
@@ -427,12 +432,12 @@ class BaseSerialization:
elif isinstance(var, set):
# FIXME: casts set to list in customized serialization in future.
try:
- return cls._encode(sorted(cls.serialize(v) for v in var),
type_=DAT.SET)
+ return cls._encode(sorted(cls.serialize(v, strict=strict) for
v in var), type_=DAT.SET)
except TypeError:
- return cls._encode([cls.serialize(v) for v in var],
type_=DAT.SET)
+ return cls._encode([cls.serialize(v, strict=strict) for v in
var], type_=DAT.SET)
elif isinstance(var, tuple):
# FIXME: casts tuple to list in customized serialization in future.
- return cls._encode([cls.serialize(v) for v in var],
type_=DAT.TUPLE)
+ return cls._encode([cls.serialize(v, strict=strict) for v in var],
type_=DAT.TUPLE)
elif isinstance(var, TaskGroup):
return TaskGroupSerialization.serialize_task_group(var)
elif isinstance(var, Param):
@@ -441,8 +446,12 @@ class BaseSerialization:
return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
elif isinstance(var, Dataset):
return cls._encode(dict(uri=var.uri, extra=var.extra),
type_=DAT.DATASET)
+ elif isinstance(var, SimpleTaskInstance):
+ return cls._encode(cls.serialize(var.__dict__, strict=strict),
type_=DAT.SIMPLE_TASK_INSTANCE)
else:
log.debug("Cast type %s to str in serialization.", type(var))
+ if strict:
+ raise SerializationError("Encountered unexpected type")
return str(var)
@classmethod
@@ -491,6 +500,8 @@ class BaseSerialization:
return _XComRef(var) # Delay deserializing XComArg objects until
we have the entire DAG.
elif type_ == DAT.DATASET:
return Dataset(**var)
+ elif type_ == DAT.SIMPLE_TASK_INSTANCE:
+ return SimpleTaskInstance(**cls.deserialize(var))
else:
raise TypeError(f"Invalid type {type_!s} in deserialization.")
diff --git a/tests/__init__.py b/tests/__init__.py
index 217e5db960..6d87031140 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -15,3 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
+from pathlib import Path
+
+REPO_ROOT = Path(__file__).parent.parent
diff --git a/tests/callbacks/test_callback_requests.py
b/tests/callbacks/test_callback_requests.py
index b5c7ceefee..e97d0fc8ae 100644
--- a/tests/callbacks/test_callback_requests.py
+++ b/tests/callbacks/test_callback_requests.py
@@ -92,3 +92,39 @@ class TestCallbackRequest:
json_str = input.to_json()
result = TaskCallbackRequest.from_json(json_str)
assert input == result
+
+ def test_simple_ti_roundtrip_exec_config_pod(self):
+ """A callback request including a TI with an exec config with a V1Pod
should safely roundtrip."""
+ from kubernetes.client import models as k8s
+
+ from airflow.callbacks.callback_requests import TaskCallbackRequest
+ from airflow.models import TaskInstance
+ from airflow.models.taskinstance import SimpleTaskInstance
+ from airflow.operators.bash import BashOperator
+
+ test_pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="hello",
namespace="ns"))
+ op = BashOperator(task_id="hi", executor_config={"pod_override":
test_pod}, bash_command="hi")
+ ti = TaskInstance(task=op)
+ s = SimpleTaskInstance.from_ti(ti)
+ data = TaskCallbackRequest("hi", s).to_json()
+ actual =
TaskCallbackRequest.from_json(data).simple_task_instance.executor_config["pod_override"]
+ assert actual == test_pod
+
+ def test_simple_ti_roundtrip_dates(self):
+ """A callback request including a TI with an exec config with a V1Pod
should safely roundtrip."""
+ from unittest.mock import MagicMock
+
+ from airflow.callbacks.callback_requests import TaskCallbackRequest
+ from airflow.models import TaskInstance
+ from airflow.models.taskinstance import SimpleTaskInstance
+ from airflow.operators.bash import BashOperator
+
+ op = BashOperator(task_id="hi", bash_command="hi")
+ ti = TaskInstance(task=op)
+ ti.set_state("SUCCESS", session=MagicMock())
+ start_date = ti.start_date
+ end_date = ti.end_date
+ s = SimpleTaskInstance.from_ti(ti)
+ data = TaskCallbackRequest("hi", s).to_json()
+ assert
TaskCallbackRequest.from_json(data).simple_task_instance.start_date ==
start_date
+ assert
TaskCallbackRequest.from_json(data).simple_task_instance.end_date == end_date
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
new file mode 100644
index 0000000000..3298fb6cba
--- /dev/null
+++ b/tests/serialization/test_serialized_objects.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.exceptions import SerializationError
+from tests import REPO_ROOT
+
+
+def test_recursive_serialize_calls_must_forward_kwargs():
+ """Any time we recurse cls.serialize, we must forward all kwargs."""
+ import ast
+
+ valid_recursive_call_count = 0
+ file = REPO_ROOT / "airflow/serialization/serialized_objects.py"
+ content = file.read_text()
+ tree = ast.parse(content)
+
+ class_def = None
+ for stmt in ast.walk(tree):
+ if not isinstance(stmt, ast.ClassDef):
+ continue
+ if stmt.name == "BaseSerialization":
+ class_def = stmt
+
+ method_def = None
+ for elem in ast.walk(class_def):
+ if isinstance(elem, ast.FunctionDef):
+ if elem.name == "serialize":
+ method_def = elem
+ break
+ kwonly_args = [x.arg for x in method_def.args.kwonlyargs]
+
+ for elem in ast.walk(method_def):
+ if isinstance(elem, ast.Call):
+ if getattr(elem.func, "attr", "") == "serialize":
+ kwargs = {y.arg: y.value for y in elem.keywords}
+ for name in kwonly_args:
+ if name not in kwargs or getattr(kwargs[name], "id", "")
!= name:
+ ref = f"{file}:{elem.lineno}"
+ message = (
+ f"Error at {ref}; recursive calls to
`cls.serialize` "
+ f"must forward the `{name}` argument"
+ )
+ raise Exception(message)
+ valid_recursive_call_count += 1
+ print(f"validated calls: {valid_recursive_call_count}")
+ assert valid_recursive_call_count > 0
+
+
+def test_strict_mode():
+ """If strict=True, serialization should fail when object is not JSON
serializable."""
+
+ class Test:
+ a = 1
+
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ obj = [[[Test()]]] # nested to verify recursive behavior
+ BaseSerialization.serialize(obj) # does not raise
+ with pytest.raises(SerializationError, match="Encountered unexpected
type"):
+ BaseSerialization.serialize(obj, strict=True) # now raises