This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 92389cf090 Handle json encoding of V1Pod in task callback (#27609)
92389cf090 is described below

commit 92389cf090f336073337517f2460c2914a9f0d4b
Author: Daniel Standish <[email protected]>
AuthorDate: Wed Nov 16 07:43:43 2022 -0800

    Handle json encoding of V1Pod in task callback (#27609)
---
 airflow/callbacks/callback_requests.py         | 14 ++---
 airflow/exceptions.py                          |  2 +-
 airflow/models/taskinstance.py                 | 10 ++++
 airflow/serialization/enums.py                 |  1 +
 airflow/serialization/serialized_objects.py    | 23 ++++++--
 tests/__init__.py                              |  5 ++
 tests/callbacks/test_callback_requests.py      | 36 ++++++++++++
 tests/serialization/test_serialized_objects.py | 78 ++++++++++++++++++++++++++
 8 files changed, 155 insertions(+), 14 deletions(-)

diff --git a/airflow/callbacks/callback_requests.py 
b/airflow/callbacks/callback_requests.py
index d8c36cb753..f0c33e79f8 100644
--- a/airflow/callbacks/callback_requests.py
+++ b/airflow/callbacks/callback_requests.py
@@ -84,17 +84,17 @@ class TaskCallbackRequest(CallbackRequest):
         self.is_failure_callback = is_failure_callback
 
     def to_json(self) -> str:
-        dict_obj = self.__dict__.copy()
-        dict_obj["simple_task_instance"] = self.simple_task_instance.as_dict()
-        return json.dumps(dict_obj)
+        from airflow.serialization.serialized_objects import BaseSerialization
+
+        val = BaseSerialization.serialize(self.__dict__, strict=True)
+        return json.dumps(val)
 
     @classmethod
     def from_json(cls, json_str: str):
-        from airflow.models.taskinstance import SimpleTaskInstance
+        from airflow.serialization.serialized_objects import BaseSerialization
 
-        kwargs = json.loads(json_str)
-        simple_ti = 
SimpleTaskInstance.from_dict(obj_dict=kwargs.pop("simple_task_instance"))
-        return cls(simple_task_instance=simple_ti, **kwargs)
+        val = json.loads(json_str)
+        return cls(**BaseSerialization.deserialize(val))
 
 
 class DagCallbackRequest(CallbackRequest):
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 5815e47647..e3c4e333b9 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -225,7 +225,7 @@ class TaskAlreadyInTaskGroup(AirflowException):
 
 
 class SerializationError(AirflowException):
-    """A problem occurred when trying to serialize a DAG."""
+    """A problem occurred when trying to serialize something."""
 
 
 class ParamValidationError(AirflowException):
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c1957384ef..3f9a587a73 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2574,6 +2574,11 @@ class SimpleTaskInstance:
         return NotImplemented
 
     def as_dict(self):
+        warnings.warn(
+            "This method is deprecated. Use BaseSerialization.serialize.",
+            RemovedInAirflow3Warning,
+            stacklevel=2,
+        )
         new_dict = dict(self.__dict__)
         for key in new_dict:
             if key in ["start_date", "end_date"]:
@@ -2604,6 +2609,11 @@ class SimpleTaskInstance:
 
     @classmethod
     def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance:
+        warnings.warn(
+            "This method is deprecated. Use BaseSerialization.deserialize.",
+            RemovedInAirflow3Warning,
+            stacklevel=2,
+        )
         ti_key = TaskInstanceKey(*obj_dict.pop("key"))
         start_date = None
         end_date = None
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index e798d6646e..f233261613 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -50,3 +50,4 @@ class DagAttributeTypes(str, Enum):
     PARAM = "param"
     XCOM_REF = "xcomref"
     DATASET = "dataset"
+    SIMPLE_TASK_INSTANCE = "simple_task_instance"
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 32b6533c41..d41d70a4e7 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -44,6 +44,7 @@ from airflow.models.expandinput import EXPAND_INPUT_EMPTY, 
ExpandInput, create_e
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.operator import Operator
 from airflow.models.param import Param, ParamsDict
+from airflow.models.taskinstance import SimpleTaskInstance
 from airflow.models.taskmixin import DAGNode
 from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, 
serialize_xcom_arg
 from airflow.providers_manager import ProvidersManager
@@ -381,7 +382,9 @@ class BaseSerialization:
         return serialized_object
 
     @classmethod
-    def serialize(cls, var: Any) -> Any:  # Unfortunately there is no support 
for recursive types in mypy
+    def serialize(
+        cls, var: Any, *, strict: bool = False
+    ) -> Any:  # Unfortunately there is no support for recursive types in mypy
         """Helper function of depth first search for serialization.
 
         The serialization protocol is:
@@ -400,9 +403,11 @@ class BaseSerialization:
                 return var.value
             return var
         elif isinstance(var, dict):
-            return cls._encode({str(k): cls.serialize(v) for k, v in 
var.items()}, type_=DAT.DICT)
+            return cls._encode(
+                {str(k): cls.serialize(v, strict=strict) for k, v in 
var.items()}, type_=DAT.DICT
+            )
         elif isinstance(var, list):
-            return [cls.serialize(v) for v in var]
+            return [cls.serialize(v, strict=strict) for v in var]
         elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and 
isinstance(var, k8s.V1Pod):
             json_pod = PodGenerator.serialize_pod(var)
             return cls._encode(json_pod, type_=DAT.POD)
@@ -427,12 +432,12 @@ class BaseSerialization:
         elif isinstance(var, set):
             # FIXME: casts set to list in customized serialization in future.
             try:
-                return cls._encode(sorted(cls.serialize(v) for v in var), 
type_=DAT.SET)
+                return cls._encode(sorted(cls.serialize(v, strict=strict) for 
v in var), type_=DAT.SET)
             except TypeError:
-                return cls._encode([cls.serialize(v) for v in var], 
type_=DAT.SET)
+                return cls._encode([cls.serialize(v, strict=strict) for v in 
var], type_=DAT.SET)
         elif isinstance(var, tuple):
             # FIXME: casts tuple to list in customized serialization in future.
-            return cls._encode([cls.serialize(v) for v in var], 
type_=DAT.TUPLE)
+            return cls._encode([cls.serialize(v, strict=strict) for v in var], 
type_=DAT.TUPLE)
         elif isinstance(var, TaskGroup):
             return TaskGroupSerialization.serialize_task_group(var)
         elif isinstance(var, Param):
@@ -441,8 +446,12 @@ class BaseSerialization:
             return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
         elif isinstance(var, Dataset):
             return cls._encode(dict(uri=var.uri, extra=var.extra), 
type_=DAT.DATASET)
+        elif isinstance(var, SimpleTaskInstance):
+            return cls._encode(cls.serialize(var.__dict__, strict=strict), 
type_=DAT.SIMPLE_TASK_INSTANCE)
         else:
             log.debug("Cast type %s to str in serialization.", type(var))
+            if strict:
+                raise SerializationError("Encountered unexpected type")
             return str(var)
 
     @classmethod
@@ -491,6 +500,8 @@ class BaseSerialization:
             return _XComRef(var)  # Delay deserializing XComArg objects until 
we have the entire DAG.
         elif type_ == DAT.DATASET:
             return Dataset(**var)
+        elif type_ == DAT.SIMPLE_TASK_INSTANCE:
+            return SimpleTaskInstance(**cls.deserialize(var))
         else:
             raise TypeError(f"Invalid type {type_!s} in deserialization.")
 
diff --git a/tests/__init__.py b/tests/__init__.py
index 217e5db960..6d87031140 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -15,3 +15,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from __future__ import annotations
+
+from pathlib import Path
+
+REPO_ROOT = Path(__file__).parent.parent
diff --git a/tests/callbacks/test_callback_requests.py 
b/tests/callbacks/test_callback_requests.py
index b5c7ceefee..e97d0fc8ae 100644
--- a/tests/callbacks/test_callback_requests.py
+++ b/tests/callbacks/test_callback_requests.py
@@ -92,3 +92,39 @@ class TestCallbackRequest:
         json_str = input.to_json()
         result = TaskCallbackRequest.from_json(json_str)
         assert input == result
+
+    def test_simple_ti_roundtrip_exec_config_pod(self):
+        """A callback request including a TI with an exec config with a V1Pod 
should safely roundtrip."""
+        from kubernetes.client import models as k8s
+
+        from airflow.callbacks.callback_requests import TaskCallbackRequest
+        from airflow.models import TaskInstance
+        from airflow.models.taskinstance import SimpleTaskInstance
+        from airflow.operators.bash import BashOperator
+
+        test_pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="hello", 
namespace="ns"))
+        op = BashOperator(task_id="hi", executor_config={"pod_override": 
test_pod}, bash_command="hi")
+        ti = TaskInstance(task=op)
+        s = SimpleTaskInstance.from_ti(ti)
+        data = TaskCallbackRequest("hi", s).to_json()
+        actual = 
TaskCallbackRequest.from_json(data).simple_task_instance.executor_config["pod_override"]
+        assert actual == test_pod
+
+    def test_simple_ti_roundtrip_dates(self):
+        """A callback request including a TI with an exec config with a V1Pod 
should safely roundtrip."""
+        from unittest.mock import MagicMock
+
+        from airflow.callbacks.callback_requests import TaskCallbackRequest
+        from airflow.models import TaskInstance
+        from airflow.models.taskinstance import SimpleTaskInstance
+        from airflow.operators.bash import BashOperator
+
+        op = BashOperator(task_id="hi", bash_command="hi")
+        ti = TaskInstance(task=op)
+        ti.set_state("SUCCESS", session=MagicMock())
+        start_date = ti.start_date
+        end_date = ti.end_date
+        s = SimpleTaskInstance.from_ti(ti)
+        data = TaskCallbackRequest("hi", s).to_json()
+        assert 
TaskCallbackRequest.from_json(data).simple_task_instance.start_date == 
start_date
+        assert 
TaskCallbackRequest.from_json(data).simple_task_instance.end_date == end_date
diff --git a/tests/serialization/test_serialized_objects.py 
b/tests/serialization/test_serialized_objects.py
new file mode 100644
index 0000000000..3298fb6cba
--- /dev/null
+++ b/tests/serialization/test_serialized_objects.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.exceptions import SerializationError
+from tests import REPO_ROOT
+
+
+def test_recursive_serialize_calls_must_forward_kwargs():
+    """Any time we recurse cls.serialize, we must forward all kwargs."""
+    import ast
+
+    valid_recursive_call_count = 0
+    file = REPO_ROOT / "airflow/serialization/serialized_objects.py"
+    content = file.read_text()
+    tree = ast.parse(content)
+
+    class_def = None
+    for stmt in ast.walk(tree):
+        if not isinstance(stmt, ast.ClassDef):
+            continue
+        if stmt.name == "BaseSerialization":
+            class_def = stmt
+
+    method_def = None
+    for elem in ast.walk(class_def):
+        if isinstance(elem, ast.FunctionDef):
+            if elem.name == "serialize":
+                method_def = elem
+                break
+    kwonly_args = [x.arg for x in method_def.args.kwonlyargs]
+
+    for elem in ast.walk(method_def):
+        if isinstance(elem, ast.Call):
+            if getattr(elem.func, "attr", "") == "serialize":
+                kwargs = {y.arg: y.value for y in elem.keywords}
+                for name in kwonly_args:
+                    if name not in kwargs or getattr(kwargs[name], "id", "") 
!= name:
+                        ref = f"{file}:{elem.lineno}"
+                        message = (
+                            f"Error at {ref}; recursive calls to 
`cls.serialize` "
+                            f"must forward the `{name}` argument"
+                        )
+                        raise Exception(message)
+                    valid_recursive_call_count += 1
+    print(f"validated calls: {valid_recursive_call_count}")
+    assert valid_recursive_call_count > 0
+
+
+def test_strict_mode():
+    """If strict=True, serialization should fail when object is not JSON 
serializable."""
+
+    class Test:
+        a = 1
+
+    from airflow.serialization.serialized_objects import BaseSerialization
+
+    obj = [[[Test()]]]  # nested to verify recursive behavior
+    BaseSerialization.serialize(obj)  # does not raise
+    with pytest.raises(SerializationError, match="Encountered unexpected 
type"):
+        BaseSerialization.serialize(obj, strict=True)  # now raises

Reply via email to