This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 41c8e58dee Support serialization to Pydantic models in Internal API
(#30282)
41c8e58dee is described below
commit 41c8e58deec2895b0a04879fcde5444b170e679e
Author: mhenc <[email protected]>
AuthorDate: Wed Apr 5 10:54:00 2023 +0200
Support serialization to Pydantic models in Internal API (#30282)
* Support serialization to Pydantic models in Internal API.
* Added BaseJobPydantic support and more tests
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 4 +-
airflow/api_internal/internal_api_call.py | 4 +-
airflow/serialization/enums.py | 4 ++
airflow/serialization/serialized_objects.py | 74 +++++++++++++++++-----
.../endpoints/test_rpc_api_endpoint.py | 19 ++++++
tests/api_internal/test_internal_api_call.py | 45 +++++++++++++
tests/serialization/test_serialized_objects.py | 18 ++++++
7 files changed, 147 insertions(+), 21 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 5456d6b9a0..5e3931c10d 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -83,7 +83,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
try:
if body.get("params"):
params_json = json.loads(str(body.get("params")))
- params = BaseSerialization.deserialize(params_json)
+ params = BaseSerialization.deserialize(params_json,
use_pydantic_models=True)
except Exception as err:
log.error("Error deserializing parameters.")
log.error(err)
@@ -92,7 +92,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
log.debug("Calling method %.", {method_name})
try:
output = handler(**params)
- output_json = BaseSerialization.serialize(output)
+ output_json = BaseSerialization.serialize(output,
use_pydantic_models=True)
log.debug("Returning response")
return Response(
response=json.dumps(output_json or "{}"), headers={"Content-Type":
"application/json"}
diff --git a/airflow/api_internal/internal_api_call.py
b/airflow/api_internal/internal_api_call.py
index 4179188baa..e8627ec68b 100644
--- a/airflow/api_internal/internal_api_call.py
+++ b/airflow/api_internal/internal_api_call.py
@@ -117,9 +117,9 @@ def internal_api_call(func: Callable[PS, RT]) ->
Callable[PS, RT]:
if "cls" in arguments_dict: # used by @classmethod
del arguments_dict["cls"]
- args_json = json.dumps(BaseSerialization.serialize(arguments_dict))
+ args_json = json.dumps(BaseSerialization.serialize(arguments_dict,
use_pydantic_models=True))
method_name = f"{func.__module__}.{func.__qualname__}"
result = make_jsonrpc_request(method_name, args_json)
- return BaseSerialization.deserialize(json.loads(result))
+ return BaseSerialization.deserialize(json.loads(result),
use_pydantic_models=True)
return wrapper
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index f233261613..1f8dce26dd 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -51,3 +51,7 @@ class DagAttributeTypes(str, Enum):
XCOM_REF = "xcomref"
DATASET = "dataset"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
+ BASE_JOB = "base_job"
+ TASK_INSTANCE = "task_instance"
+ DAG_RUN = "dag_run"
+ DATA_SET = "data_set"
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 2d753b0f48..b3d783ad69 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -38,14 +38,20 @@ from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning,
SerializationError
+from airflow.jobs.base_job import BaseJob
+from airflow.jobs.pydantic.base_job import BaseJobPydantic
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.connection import Connection
from airflow.models.dag import DAG, create_timetable
+from airflow.models.dagrun import DagRun
from airflow.models.expandinput import EXPAND_INPUT_EMPTY, ExpandInput,
create_expand_input, get_map_type_key
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.param import Param, ParamsDict
-from airflow.models.taskinstance import SimpleTaskInstance
+from airflow.models.pydantic.dag_run import DagRunPydantic
+from airflow.models.pydantic.dataset import DatasetPydantic
+from airflow.models.pydantic.taskinstance import TaskInstancePydantic
+from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg,
serialize_xcom_arg
from airflow.providers_manager import ProvidersManager
@@ -384,7 +390,7 @@ class BaseSerialization:
@classmethod
def serialize(
- cls, var: Any, *, strict: bool = False
+ cls, var: Any, *, strict: bool = False, use_pydantic_models: bool =
False
) -> Any: # Unfortunately there is no support for recursive types in mypy
"""Helper function of depth first search for serialization.
@@ -405,10 +411,14 @@ class BaseSerialization:
return var
elif isinstance(var, dict):
return cls._encode(
- {str(k): cls.serialize(v, strict=strict) for k, v in
var.items()}, type_=DAT.DICT
+ {
+ str(k): cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models)
+ for k, v in var.items()
+ },
+ type_=DAT.DICT,
)
elif isinstance(var, list):
- return [cls.serialize(v, strict=strict) for v in var]
+ return [cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models) for v in var]
elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and
isinstance(var, k8s.V1Pod):
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
@@ -433,12 +443,23 @@ class BaseSerialization:
elif isinstance(var, set):
# FIXME: casts set to list in customized serialization in future.
try:
- return cls._encode(sorted(cls.serialize(v, strict=strict) for
v in var), type_=DAT.SET)
+ return cls._encode(
+ sorted(
+ cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models) for v in var
+ ),
+ type_=DAT.SET,
+ )
except TypeError:
- return cls._encode([cls.serialize(v, strict=strict) for v in
var], type_=DAT.SET)
+ return cls._encode(
+ [cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models) for v in var],
+ type_=DAT.SET,
+ )
elif isinstance(var, tuple):
# FIXME: casts tuple to list in customized serialization in future.
- return cls._encode([cls.serialize(v, strict=strict) for v in var],
type_=DAT.TUPLE)
+ return cls._encode(
+ [cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models) for v in var],
+ type_=DAT.TUPLE,
+ )
elif isinstance(var, TaskGroup):
return TaskGroupSerialization.serialize_task_group(var)
elif isinstance(var, Param):
@@ -448,7 +469,18 @@ class BaseSerialization:
elif isinstance(var, Dataset):
return cls._encode(dict(uri=var.uri, extra=var.extra),
type_=DAT.DATASET)
elif isinstance(var, SimpleTaskInstance):
- return cls._encode(cls.serialize(var.__dict__, strict=strict),
type_=DAT.SIMPLE_TASK_INSTANCE)
+ return cls._encode(
+ cls.serialize(var.__dict__, strict=strict,
use_pydantic_models=use_pydantic_models),
+ type_=DAT.SIMPLE_TASK_INSTANCE,
+ )
+ elif use_pydantic_models and isinstance(var, BaseJob):
+ return cls._encode(BaseJobPydantic.from_orm(var).dict(),
type_=DAT.BASE_JOB)
+ elif use_pydantic_models and isinstance(var, TaskInstance):
+ return cls._encode(TaskInstancePydantic.from_orm(var).dict(),
type_=DAT.TASK_INSTANCE)
+ elif use_pydantic_models and isinstance(var, DagRun):
+ return cls._encode(DagRunPydantic.from_orm(var).dict(),
type_=DAT.DAG_RUN)
+ elif use_pydantic_models and isinstance(var, Dataset):
+ return cls._encode(DatasetPydantic.from_orm(var).dict(),
type_=DAT.DATA_SET)
else:
log.debug("Cast type %s to str in serialization.", type(var))
if strict:
@@ -456,7 +488,7 @@ class BaseSerialization:
return str(var)
@classmethod
- def deserialize(cls, encoded_var: Any) -> Any:
+ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
"""Helper function of depth first search for deserialization.
:meta private:
@@ -465,7 +497,7 @@ class BaseSerialization:
if cls._is_primitive(encoded_var):
return encoded_var
elif isinstance(encoded_var, list):
- return [cls.deserialize(v) for v in encoded_var]
+ return [cls.deserialize(v, use_pydantic_models) for v in
encoded_var]
if not isinstance(encoded_var, dict):
raise ValueError(f"The encoded_var should be dict and is
{type(encoded_var)}")
@@ -473,7 +505,7 @@ class BaseSerialization:
type_ = encoded_var[Encoding.TYPE]
if type_ == DAT.DICT:
- return {k: cls.deserialize(v) for k, v in var.items()}
+ return {k: cls.deserialize(v, use_pydantic_models) for k, v in
var.items()}
elif type_ == DAT.DAG:
return SerializedDAG.deserialize_dag(var)
elif type_ == DAT.OP:
@@ -492,9 +524,9 @@ class BaseSerialization:
elif type_ == DAT.RELATIVEDELTA:
return decode_relativedelta(var)
elif type_ == DAT.SET:
- return {cls.deserialize(v) for v in var}
+ return {cls.deserialize(v, use_pydantic_models) for v in var}
elif type_ == DAT.TUPLE:
- return tuple(cls.deserialize(v) for v in var)
+ return tuple(cls.deserialize(v, use_pydantic_models) for v in var)
elif type_ == DAT.PARAM:
return cls._deserialize_param(var)
elif type_ == DAT.XCOM_REF:
@@ -503,6 +535,14 @@ class BaseSerialization:
return Dataset(**var)
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
+ elif use_pydantic_models and type_ == DAT.BASE_JOB:
+ return BaseJobPydantic.parse_obj(var)
+ elif use_pydantic_models and type_ == DAT.TASK_INSTANCE:
+ return TaskInstancePydantic.parse_obj(var)
+ elif use_pydantic_models and type_ == DAT.DAG_RUN:
+ return DagRunPydantic.parse_obj(var)
+ elif use_pydantic_models and type_ == DAT.DATA_SET:
+ return DatasetPydantic.parse_obj(var)
else:
raise TypeError(f"Invalid type {type_!s} in deserialization.")
@@ -1114,13 +1154,13 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
return serialize_operator_extra_links
@classmethod
- def serialize(cls, var: Any, *, strict: bool = False) -> Any:
+ def serialize(cls, var: Any, *, strict: bool = False, use_pydantic_models:
bool = False) -> Any:
# the wonders of multiple inheritance BaseOperator defines an instance
method
- return BaseSerialization.serialize(var=var, strict=strict)
+ return BaseSerialization.serialize(var=var, strict=strict,
use_pydantic_models=use_pydantic_models)
@classmethod
- def deserialize(cls, encoded_var: Any) -> Any:
- return BaseSerialization.deserialize(encoded_var=encoded_var)
+ def deserialize(cls, encoded_var: Any, use_pydantic_models: bool = False)
-> Any:
+ return BaseSerialization.deserialize(encoded_var=encoded_var,
use_pydantic_models=use_pydantic_models)
class SerializedDAG(DAG, BaseSerialization):
diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py
b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
index a0ec216147..9e09376dd6 100644
--- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py
+++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
@@ -23,7 +23,11 @@ from unittest import mock
import pytest
from flask import Flask
+from airflow.models.pydantic.taskinstance import TaskInstancePydantic
+from airflow.models.taskinstance import TaskInstance
+from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization
+from airflow.utils.state import State
from airflow.www import app
from tests.test_utils.config import conf_vars
from tests.test_utils.decorators import dont_initialize_flask_app_submodules
@@ -110,6 +114,21 @@ class TestRpcApiEndpoint:
expected_mock.assert_called_once_with(**method_params)
+ def test_method_with_pydantic_serialized_object(self):
+ ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id",
state=State.RUNNING)
+ mock_test_method.return_value = ti
+
+ response = self.client.post(
+ "/internal_api/v1/rpcapi",
+ headers={"Content-Type": "application/json"},
+ data=json.dumps({"jsonrpc": "2.0", "method": TEST_METHOD_NAME,
"params": ""}),
+ )
+ assert response.status_code == 200
+ print(response.data)
+ response_data =
BaseSerialization.deserialize(json.loads(response.data),
use_pydantic_models=True)
+ expected_data = TaskInstancePydantic.from_orm(ti)
+ assert response_data == expected_data
+
def test_method_with_exception(self):
mock_test_method.side_effect = ValueError("Error!!!")
data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}
diff --git a/tests/api_internal/test_internal_api_call.py
b/tests/api_internal/test_internal_api_call.py
index c96b2bde32..e7cd488e66 100644
--- a/tests/api_internal/test_internal_api_call.py
+++ b/tests/api_internal/test_internal_api_call.py
@@ -25,7 +25,11 @@ import pytest
import requests
from airflow.api_internal.internal_api_call import InternalApiConfig,
internal_api_call
+from airflow.models.pydantic.taskinstance import TaskInstancePydantic
+from airflow.models.taskinstance import TaskInstance
+from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization
+from airflow.utils.state import State
from tests.test_utils.config import conf_vars
@@ -81,6 +85,14 @@ class TestInternalApiCall:
def fake_class_method_with_params(cls, dag_id: str, session) -> str:
return f"local-classmethod-call-with-params-{dag_id}"
+ @staticmethod
+ @internal_api_call
+ def fake_class_method_with_serialized_params(
+ ti: TaskInstance | TaskInstancePydantic,
+ session,
+ ) -> str:
+ return f"local-classmethod-call-with-serialized-{ti.task_id}"
+
@conf_vars(
{
("core", "database_access_isolation"): "false",
@@ -200,3 +212,36 @@ class TestInternalApiCall:
data=expected_data,
headers={"Content-Type": "application/json"},
)
+
+ @conf_vars(
+ {
+ ("core", "database_access_isolation"): "true",
+ ("core", "internal_api_url"): "http://localhost:8888",
+ }
+ )
+ @mock.patch("airflow.api_internal.internal_api_call.requests")
+ def test_remote_call_with_serialized_model(self, mock_requests):
+ response = requests.Response()
+ response.status_code = 200
+
+ response._content =
json.dumps(BaseSerialization.serialize("remote-call"))
+
+ mock_requests.post.return_value = response
+ ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id",
state=State.RUNNING)
+
+ result =
TestInternalApiCall.fake_class_method_with_serialized_params(ti,
session="session")
+
+ assert result == "remote-call"
+ expected_data = json.dumps(
+ {
+ "jsonrpc": "2.0",
+ "method":
"tests.api_internal.test_internal_api_call.TestInternalApiCall."
+ "fake_class_method_with_serialized_params",
+ "params": json.dumps(BaseSerialization.serialize({"ti": ti},
use_pydantic_models=True)),
+ }
+ )
+ mock_requests.post.assert_called_once_with(
+ url="http://localhost:8888/internal_api/v1/rpcapi",
+ data=expected_data,
+ headers={"Content-Type": "application/json"},
+ )
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
index 3298fb6cba..7ccf16bb6a 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -20,6 +20,10 @@ from __future__ import annotations
import pytest
from airflow.exceptions import SerializationError
+from airflow.models.pydantic.taskinstance import TaskInstancePydantic
+from airflow.models.taskinstance import TaskInstance
+from airflow.operators.empty import EmptyOperator
+from airflow.utils.state import State
from tests import REPO_ROOT
@@ -76,3 +80,17 @@ def test_strict_mode():
BaseSerialization.serialize(obj) # does not raise
with pytest.raises(SerializationError, match="Encountered unexpected
type"):
BaseSerialization.serialize(obj, strict=True) # now raises
+
+
+def test_use_pydantic_models():
+ """If use_pydantic_models=True the TaskInstance object should be
serialized to TaskInstancePydantic."""
+
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id",
state=State.RUNNING)
+ obj = [[ti]] # nested to verify recursive behavior
+
+ serialized = BaseSerialization.serialize(obj, use_pydantic_models=True) #
does not raise
+ deserialized = BaseSerialization.deserialize(serialized,
use_pydantic_models=True) # does not raise
+
+ assert isinstance(deserialized[0][0], TaskInstancePydantic)