This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 41c8e58dee Support serialization to Pydantic models in Internal API 
(#30282)
41c8e58dee is described below

commit 41c8e58deec2895b0a04879fcde5444b170e679e
Author: mhenc <[email protected]>
AuthorDate: Wed Apr 5 10:54:00 2023 +0200

    Support serialization to Pydantic models in Internal API (#30282)
    
    * Support serialization to Pydantic models in Internal API.
    
    * Added BaseJobPydantic support and more tests
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py |  4 +-
 airflow/api_internal/internal_api_call.py          |  4 +-
 airflow/serialization/enums.py                     |  4 ++
 airflow/serialization/serialized_objects.py        | 74 +++++++++++++++++-----
 .../endpoints/test_rpc_api_endpoint.py             | 19 ++++++
 tests/api_internal/test_internal_api_call.py       | 45 +++++++++++++
 tests/serialization/test_serialized_objects.py     | 18 ++++++
 7 files changed, 147 insertions(+), 21 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py 
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 5456d6b9a0..5e3931c10d 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -83,7 +83,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
     try:
         if body.get("params"):
             params_json = json.loads(str(body.get("params")))
-            params = BaseSerialization.deserialize(params_json)
+            params = BaseSerialization.deserialize(params_json, 
use_pydantic_models=True)
     except Exception as err:
         log.error("Error deserializing parameters.")
         log.error(err)
@@ -92,7 +92,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
     log.debug("Calling method %.", {method_name})
     try:
         output = handler(**params)
-        output_json = BaseSerialization.serialize(output)
+        output_json = BaseSerialization.serialize(output, 
use_pydantic_models=True)
         log.debug("Returning response")
         return Response(
             response=json.dumps(output_json or "{}"), headers={"Content-Type": 
"application/json"}
diff --git a/airflow/api_internal/internal_api_call.py 
b/airflow/api_internal/internal_api_call.py
index 4179188baa..e8627ec68b 100644
--- a/airflow/api_internal/internal_api_call.py
+++ b/airflow/api_internal/internal_api_call.py
@@ -117,9 +117,9 @@ def internal_api_call(func: Callable[PS, RT]) -> 
Callable[PS, RT]:
         if "cls" in arguments_dict:  # used by @classmethod
             del arguments_dict["cls"]
 
-        args_json = json.dumps(BaseSerialization.serialize(arguments_dict))
+        args_json = json.dumps(BaseSerialization.serialize(arguments_dict, 
use_pydantic_models=True))
         method_name = f"{func.__module__}.{func.__qualname__}"
         result = make_jsonrpc_request(method_name, args_json)
-        return BaseSerialization.deserialize(json.loads(result))
+        return BaseSerialization.deserialize(json.loads(result), 
use_pydantic_models=True)
 
     return wrapper
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index f233261613..1f8dce26dd 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -51,3 +51,7 @@ class DagAttributeTypes(str, Enum):
     XCOM_REF = "xcomref"
     DATASET = "dataset"
     SIMPLE_TASK_INSTANCE = "simple_task_instance"
+    BASE_JOB = "base_job"
+    TASK_INSTANCE = "task_instance"
+    DAG_RUN = "dag_run"
+    DATA_SET = "data_set"
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 2d753b0f48..b3d783ad69 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -38,14 +38,20 @@ from airflow.compat.functools import cache
 from airflow.configuration import conf
 from airflow.datasets import Dataset
 from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, 
SerializationError
+from airflow.jobs.base_job import BaseJob
+from airflow.jobs.pydantic.base_job import BaseJobPydantic
 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
 from airflow.models.connection import Connection
 from airflow.models.dag import DAG, create_timetable
+from airflow.models.dagrun import DagRun
 from airflow.models.expandinput import EXPAND_INPUT_EMPTY, ExpandInput, 
create_expand_input, get_map_type_key
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.operator import Operator
 from airflow.models.param import Param, ParamsDict
-from airflow.models.taskinstance import SimpleTaskInstance
+from airflow.models.pydantic.dag_run import DagRunPydantic
+from airflow.models.pydantic.dataset import DatasetPydantic
+from airflow.models.pydantic.taskinstance import TaskInstancePydantic
+from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
 from airflow.models.taskmixin import DAGNode
 from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, 
serialize_xcom_arg
 from airflow.providers_manager import ProvidersManager
@@ -384,7 +390,7 @@ class BaseSerialization:
 
     @classmethod
     def serialize(
-        cls, var: Any, *, strict: bool = False
+        cls, var: Any, *, strict: bool = False, use_pydantic_models: bool = 
False
     ) -> Any:  # Unfortunately there is no support for recursive types in mypy
         """Helper function of depth first search for serialization.
 
@@ -405,10 +411,14 @@ class BaseSerialization:
             return var
         elif isinstance(var, dict):
             return cls._encode(
-                {str(k): cls.serialize(v, strict=strict) for k, v in 
var.items()}, type_=DAT.DICT
+                {
+                    str(k): cls.serialize(v, strict=strict, 
use_pydantic_models=use_pydantic_models)
+                    for k, v in var.items()
+                },
+                type_=DAT.DICT,
             )
         elif isinstance(var, list):
-            return [cls.serialize(v, strict=strict) for v in var]
+            return [cls.serialize(v, strict=strict, 
use_pydantic_models=use_pydantic_models) for v in var]
         elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and 
isinstance(var, k8s.V1Pod):
             json_pod = PodGenerator.serialize_pod(var)
             return cls._encode(json_pod, type_=DAT.POD)
@@ -433,12 +443,23 @@ class BaseSerialization:
         elif isinstance(var, set):
             # FIXME: casts set to list in customized serialization in future.
             try:
-                return cls._encode(sorted(cls.serialize(v, strict=strict) for 
v in var), type_=DAT.SET)
+                return cls._encode(
+                    sorted(
+                        cls.serialize(v, strict=strict, 
use_pydantic_models=use_pydantic_models) for v in var
+                    ),
+                    type_=DAT.SET,
+                )
             except TypeError:
-                return cls._encode([cls.serialize(v, strict=strict) for v in 
var], type_=DAT.SET)
+                return cls._encode(
+                    [cls.serialize(v, strict=strict, 
use_pydantic_models=use_pydantic_models) for v in var],
+                    type_=DAT.SET,
+                )
         elif isinstance(var, tuple):
             # FIXME: casts tuple to list in customized serialization in future.
-            return cls._encode([cls.serialize(v, strict=strict) for v in var], 
type_=DAT.TUPLE)
+            return cls._encode(
+                [cls.serialize(v, strict=strict, 
use_pydantic_models=use_pydantic_models) for v in var],
+                type_=DAT.TUPLE,
+            )
         elif isinstance(var, TaskGroup):
             return TaskGroupSerialization.serialize_task_group(var)
         elif isinstance(var, Param):
@@ -448,7 +469,18 @@ class BaseSerialization:
         elif isinstance(var, Dataset):
             return cls._encode(dict(uri=var.uri, extra=var.extra), 
type_=DAT.DATASET)
         elif isinstance(var, SimpleTaskInstance):
-            return cls._encode(cls.serialize(var.__dict__, strict=strict), 
type_=DAT.SIMPLE_TASK_INSTANCE)
+            return cls._encode(
+                cls.serialize(var.__dict__, strict=strict, 
use_pydantic_models=use_pydantic_models),
+                type_=DAT.SIMPLE_TASK_INSTANCE,
+            )
+        elif use_pydantic_models and isinstance(var, BaseJob):
+            return cls._encode(BaseJobPydantic.from_orm(var).dict(), 
type_=DAT.BASE_JOB)
+        elif use_pydantic_models and isinstance(var, TaskInstance):
+            return cls._encode(TaskInstancePydantic.from_orm(var).dict(), 
type_=DAT.TASK_INSTANCE)
+        elif use_pydantic_models and isinstance(var, DagRun):
+            return cls._encode(DagRunPydantic.from_orm(var).dict(), 
type_=DAT.DAG_RUN)
+        elif use_pydantic_models and isinstance(var, Dataset):
+            return cls._encode(DatasetPydantic.from_orm(var).dict(), 
type_=DAT.DATA_SET)
         else:
             log.debug("Cast type %s to str in serialization.", type(var))
             if strict:
@@ -456,7 +488,7 @@ class BaseSerialization:
             return str(var)
 
     @classmethod
-    def deserialize(cls, encoded_var: Any) -> Any:
+    def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
         """Helper function of depth first search for deserialization.
 
         :meta private:
@@ -465,7 +497,7 @@ class BaseSerialization:
         if cls._is_primitive(encoded_var):
             return encoded_var
         elif isinstance(encoded_var, list):
-            return [cls.deserialize(v) for v in encoded_var]
+            return [cls.deserialize(v, use_pydantic_models) for v in 
encoded_var]
 
         if not isinstance(encoded_var, dict):
             raise ValueError(f"The encoded_var should be dict and is 
{type(encoded_var)}")
@@ -473,7 +505,7 @@ class BaseSerialization:
         type_ = encoded_var[Encoding.TYPE]
 
         if type_ == DAT.DICT:
-            return {k: cls.deserialize(v) for k, v in var.items()}
+            return {k: cls.deserialize(v, use_pydantic_models) for k, v in 
var.items()}
         elif type_ == DAT.DAG:
             return SerializedDAG.deserialize_dag(var)
         elif type_ == DAT.OP:
@@ -492,9 +524,9 @@ class BaseSerialization:
         elif type_ == DAT.RELATIVEDELTA:
             return decode_relativedelta(var)
         elif type_ == DAT.SET:
-            return {cls.deserialize(v) for v in var}
+            return {cls.deserialize(v, use_pydantic_models) for v in var}
         elif type_ == DAT.TUPLE:
-            return tuple(cls.deserialize(v) for v in var)
+            return tuple(cls.deserialize(v, use_pydantic_models) for v in var)
         elif type_ == DAT.PARAM:
             return cls._deserialize_param(var)
         elif type_ == DAT.XCOM_REF:
@@ -503,6 +535,14 @@ class BaseSerialization:
             return Dataset(**var)
         elif type_ == DAT.SIMPLE_TASK_INSTANCE:
             return SimpleTaskInstance(**cls.deserialize(var))
+        elif use_pydantic_models and type_ == DAT.BASE_JOB:
+            return BaseJobPydantic.parse_obj(var)
+        elif use_pydantic_models and type_ == DAT.TASK_INSTANCE:
+            return TaskInstancePydantic.parse_obj(var)
+        elif use_pydantic_models and type_ == DAT.DAG_RUN:
+            return DagRunPydantic.parse_obj(var)
+        elif use_pydantic_models and type_ == DAT.DATA_SET:
+            return DatasetPydantic.parse_obj(var)
         else:
             raise TypeError(f"Invalid type {type_!s} in deserialization.")
 
@@ -1114,13 +1154,13 @@ class SerializedBaseOperator(BaseOperator, 
BaseSerialization):
         return serialize_operator_extra_links
 
     @classmethod
-    def serialize(cls, var: Any, *, strict: bool = False) -> Any:
+    def serialize(cls, var: Any, *, strict: bool = False, use_pydantic_models: 
bool = False) -> Any:
         # the wonders of multiple inheritance BaseOperator defines an instance 
method
-        return BaseSerialization.serialize(var=var, strict=strict)
+        return BaseSerialization.serialize(var=var, strict=strict, 
use_pydantic_models=use_pydantic_models)
 
     @classmethod
-    def deserialize(cls, encoded_var: Any) -> Any:
-        return BaseSerialization.deserialize(encoded_var=encoded_var)
+    def deserialize(cls, encoded_var: Any, use_pydantic_models: bool = False) 
-> Any:
+        return BaseSerialization.deserialize(encoded_var=encoded_var, 
use_pydantic_models=use_pydantic_models)
 
 
 class SerializedDAG(DAG, BaseSerialization):
diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py 
b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
index a0ec216147..9e09376dd6 100644
--- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py
+++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
@@ -23,7 +23,11 @@ from unittest import mock
 import pytest
 from flask import Flask
 
+from airflow.models.pydantic.taskinstance import TaskInstancePydantic
+from airflow.models.taskinstance import TaskInstance
+from airflow.operators.empty import EmptyOperator
 from airflow.serialization.serialized_objects import BaseSerialization
+from airflow.utils.state import State
 from airflow.www import app
 from tests.test_utils.config import conf_vars
 from tests.test_utils.decorators import dont_initialize_flask_app_submodules
@@ -110,6 +114,21 @@ class TestRpcApiEndpoint:
 
         expected_mock.assert_called_once_with(**method_params)
 
+    def test_method_with_pydantic_serialized_object(self):
+        ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", 
state=State.RUNNING)
+        mock_test_method.return_value = ti
+
+        response = self.client.post(
+            "/internal_api/v1/rpcapi",
+            headers={"Content-Type": "application/json"},
+            data=json.dumps({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, 
"params": ""}),
+        )
+        assert response.status_code == 200
+        print(response.data)
+        response_data = 
BaseSerialization.deserialize(json.loads(response.data), 
use_pydantic_models=True)
+        expected_data = TaskInstancePydantic.from_orm(ti)
+        assert response_data == expected_data
+
     def test_method_with_exception(self):
         mock_test_method.side_effect = ValueError("Error!!!")
         data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}
diff --git a/tests/api_internal/test_internal_api_call.py 
b/tests/api_internal/test_internal_api_call.py
index c96b2bde32..e7cd488e66 100644
--- a/tests/api_internal/test_internal_api_call.py
+++ b/tests/api_internal/test_internal_api_call.py
@@ -25,7 +25,11 @@ import pytest
 import requests
 
 from airflow.api_internal.internal_api_call import InternalApiConfig, 
internal_api_call
+from airflow.models.pydantic.taskinstance import TaskInstancePydantic
+from airflow.models.taskinstance import TaskInstance
+from airflow.operators.empty import EmptyOperator
 from airflow.serialization.serialized_objects import BaseSerialization
+from airflow.utils.state import State
 from tests.test_utils.config import conf_vars
 
 
@@ -81,6 +85,14 @@ class TestInternalApiCall:
     def fake_class_method_with_params(cls, dag_id: str, session) -> str:
         return f"local-classmethod-call-with-params-{dag_id}"
 
+    @staticmethod
+    @internal_api_call
+    def fake_class_method_with_serialized_params(
+        ti: TaskInstance | TaskInstancePydantic,
+        session,
+    ) -> str:
+        return f"local-classmethod-call-with-serialized-{ti.task_id}"
+
     @conf_vars(
         {
             ("core", "database_access_isolation"): "false",
@@ -200,3 +212,36 @@ class TestInternalApiCall:
             data=expected_data,
             headers={"Content-Type": "application/json"},
         )
+
+    @conf_vars(
+        {
+            ("core", "database_access_isolation"): "true",
+            ("core", "internal_api_url"): "http://localhost:8888";,
+        }
+    )
+    @mock.patch("airflow.api_internal.internal_api_call.requests")
+    def test_remote_call_with_serialized_model(self, mock_requests):
+        response = requests.Response()
+        response.status_code = 200
+
+        response._content = 
json.dumps(BaseSerialization.serialize("remote-call"))
+
+        mock_requests.post.return_value = response
+        ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", 
state=State.RUNNING)
+
+        result = 
TestInternalApiCall.fake_class_method_with_serialized_params(ti, 
session="session")
+
+        assert result == "remote-call"
+        expected_data = json.dumps(
+            {
+                "jsonrpc": "2.0",
+                "method": 
"tests.api_internal.test_internal_api_call.TestInternalApiCall."
+                "fake_class_method_with_serialized_params",
+                "params": json.dumps(BaseSerialization.serialize({"ti": ti}, 
use_pydantic_models=True)),
+            }
+        )
+        mock_requests.post.assert_called_once_with(
+            url="http://localhost:8888/internal_api/v1/rpcapi";,
+            data=expected_data,
+            headers={"Content-Type": "application/json"},
+        )
diff --git a/tests/serialization/test_serialized_objects.py 
b/tests/serialization/test_serialized_objects.py
index 3298fb6cba..7ccf16bb6a 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -20,6 +20,10 @@ from __future__ import annotations
 import pytest
 
 from airflow.exceptions import SerializationError
+from airflow.models.pydantic.taskinstance import TaskInstancePydantic
+from airflow.models.taskinstance import TaskInstance
+from airflow.operators.empty import EmptyOperator
+from airflow.utils.state import State
 from tests import REPO_ROOT
 
 
@@ -76,3 +80,17 @@ def test_strict_mode():
     BaseSerialization.serialize(obj)  # does not raise
     with pytest.raises(SerializationError, match="Encountered unexpected 
type"):
         BaseSerialization.serialize(obj, strict=True)  # now raises
+
+
+def test_use_pydantic_models():
+    """If use_pydantic_models=True the TaskInstance object should be 
serialized to TaskInstancePydantic."""
+
+    from airflow.serialization.serialized_objects import BaseSerialization
+
+    ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", 
state=State.RUNNING)
+    obj = [[ti]]  # nested to verify recursive behavior
+
+    serialized = BaseSerialization.serialize(obj, use_pydantic_models=True)  # 
does not raise
+    deserialized = BaseSerialization.deserialize(serialized, 
use_pydantic_models=True)  # does not raise
+
+    assert isinstance(deserialized[0][0], TaskInstancePydantic)

Reply via email to