This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ff7e70056e5 AIP-72: Extending SET RTIF endpoint to accept all
JSONable types (#44843)
ff7e70056e5 is described below
commit ff7e70056e5b4e6e61d20215eec00054415d4a77
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Dec 13 23:03:42 2024 +0530
AIP-72: Extending SET RTIF endpoint to accept all JSONable types (#44843)
An endpoint to set RTIF was added in #44359. This allowed only `dict[str,
str]` entries to be passed down to the API which lead to issues when running
tests with DAGs like:
```py
from __future__ import annotations
import sys
import time
from datetime import datetime
from airflow import DAG
from airflow.decorators import dag, task
from airflow.operators.bash import BashOperator
@dag(
# every minute on the 30-second mark
catchup=False,
tags=[],
schedule=None,
start_date=datetime(2021, 1, 1),
)
def hello_dag():
"""
### TaskFlow API Tutorial Documentation
This is a simple data pipeline example which demonstrates the use of
the TaskFlow API using three simple tasks for Extract, Transform, and
Load.
Documentation that goes along with the Airflow TaskFlow API tutorial is
located
[here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html)
"""
@task()
def hello():
print("hello")
time.sleep(3)
print("goodbye")
print("err mesg", file=sys.stderr)
hello()
hello_dag()
```
The reason for this is that the arguments such as `op_args` and `op_kwargs`
for PythonOperator can be non str. So that leads to a conclusion that we should
accept `str` keys but `JsonAble` values.
Some points to note for reviewers:
1. Type we store in the table:
https://github.com/apache/airflow/blob/1eb683be3a79c80927e9af1e89dabb5e78ce3136/airflow/models/renderedtifields.py#L76.
Hence we should be able to accept any JsonAble types and store them, for non
JsonAble ones like tuple and set, we should convert them and do it.
### What does this PR change?
- Get rid of the `RTIFPayload` and consume the payload directly in the api
handler.
- Handling special case of `tuples` - they are json serialisable but we
used to store them as lists when passed as tuples, because of usage of
json.dumps(). It has been made like this now:
```
def is_jsonable(x):
try:
json.dumps(x)
if isinstance(x, tuple):
# Tuple is converted to list in json.dumps
# so while it is jsonable, it changes the type which might
be a surprise
# for the user, so instead we return False here -- which
will convert it to string
return False
```
- Reusing `serialize_template_field` from `airflow.serialization.helpers`
because copy pasting code will be expensive, hard to maintain. We will revisit
it anyways when we port the logic of templating to TASK SDK. Discussion:
https://github.com/apache/airflow/pull/44843/files#r1882834039
- Added test cases with different scopes and different types to handle
different cases of templated_fields well.
---
.../execution_api/datamodels/taskinstance.py | 6 +--
.../execution_api/routes/task_instances.py | 6 +--
airflow/serialization/helpers.py | 5 ++
task_sdk/src/airflow/sdk/execution_time/comms.py | 2 +-
.../src/airflow/sdk/execution_time/task_runner.py | 24 +++++----
task_sdk/tests/execution_time/test_task_runner.py | 58 ++++++++++++++++++++++
.../endpoints/test_task_instance_endpoint.py | 6 +--
.../core_api/routes/public/test_task_instances.py | 6 +--
.../execution_api/routes/test_task_instances.py | 48 +++++++-----------
tests/models/test_renderedtifields.py | 36 +++++++++++---
10 files changed, 137 insertions(+), 60 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index e0d8f371f09..bbc557d0124 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -21,7 +21,7 @@ import uuid
from datetime import timedelta
from typing import Annotated, Any, Literal, Union
-from pydantic import Discriminator, Field, RootModel, Tag, WithJsonSchema
+from pydantic import Discriminator, Field, Tag, WithJsonSchema
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.core_api.base import BaseModel
@@ -135,7 +135,3 @@ class TaskInstance(BaseModel):
run_id: str
try_number: int
map_index: int | None = None
-
-
-"""Schema for setting RTIF for a task instance."""
-RTIFPayload = RootModel[dict[str, str]]
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 90bbe1c1d3e..e06798209c5 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -22,6 +22,7 @@ from typing import Annotated
from uuid import UUID
from fastapi import Body, HTTPException, status
+from pydantic import JsonValue
from sqlalchemy import update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.sql import select
@@ -29,7 +30,6 @@ from sqlalchemy.sql import select
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
- RTIFPayload,
TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
@@ -237,7 +237,7 @@ def ti_heartbeat(
)
def ti_put_rtif(
task_instance_id: UUID,
- put_rtif_payload: RTIFPayload,
+ put_rtif_payload: Annotated[dict[str, JsonValue], Body()],
session: SessionDep,
):
"""Add an RTIF entry for a task instance, sent by the worker."""
@@ -247,6 +247,6 @@ def ti_put_rtif(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
)
- _update_rtif(task_instance, put_rtif_payload.model_dump(), session)
+ _update_rtif(task_instance, put_rtif_payload, session)
return {"message": "Rendered task instance fields successfully set"}
diff --git a/airflow/serialization/helpers.py b/airflow/serialization/helpers.py
index 85bf3a1cc55..dc1aabbca98 100644
--- a/airflow/serialization/helpers.py
+++ b/airflow/serialization/helpers.py
@@ -36,6 +36,11 @@ def serialize_template_field(template_field: Any, name: str)
-> str | dict | lis
def is_jsonable(x):
try:
json.dumps(x)
+ if isinstance(x, tuple):
+ # Tuple is converted to list in json.dumps
+ # so while it is jsonable, it changes the type which might be
a surprise
+ # for the user, so instead we return False here -- which will
convert it to string
+ return False
except (TypeError, OverflowError):
return False
else:
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index 34d6a9e3156..9e6093a092d 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -176,7 +176,7 @@ class SetRenderedFields(BaseModel):
# We are using a BaseModel here compared to server using RootModel because
we
# have a discriminator running with "type", and RootModel doesn't support
type
- rendered_fields: dict[str, str | None]
+ rendered_fields: dict[str, JsonValue]
type: Literal["SetRenderedFields"] = "SetRenderedFields"
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index c01677ce1a7..5aca25f590e 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Generic, TextIO,
TypeVar
import attrs
import structlog
-from pydantic import BaseModel, ConfigDict, TypeAdapter
+from pydantic import BaseModel, ConfigDict, JsonValue, TypeAdapter
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.baseoperator import BaseOperator
@@ -196,22 +196,26 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
# 1. Implementing the part where we pull in the logic to render fields and
add that here
# for all operators, we should do setattr(task, templated_field,
rendered_templated_field)
# task.templated_fields should give all the templated_fields and each of
those fields should
- # give the rendered values.
+ # give the rendered values. task.templated_fields should already be in a
JSONable format and
+ # we should not have to handle that here.
# 2. Once rendered, we call the `set_rtif` API to store the rtif in the
metadata DB
- templated_fields = ti.task.template_fields
- payload = {}
-
- for field in templated_fields:
- if field not in payload:
- payload[field] = getattr(ti.task, field)
# so that we do not call the API unnecessarily
- if payload:
- SUPERVISOR_COMMS.send_request(log=log,
msg=SetRenderedFields(rendered_fields=payload))
+ if rendered_fields := _get_rendered_fields(ti.task):
+ SUPERVISOR_COMMS.send_request(log=log,
msg=SetRenderedFields(rendered_fields=rendered_fields))
return ti, log
+def _get_rendered_fields(task: BaseOperator) -> dict[str, JsonValue]:
+ # TODO: Port one of the following to Task SDK
+ # airflow.serialization.helpers.serialize_template_field or
+ # airflow.models.renderedtifields.get_serialized_template_fields
+ from airflow.serialization.helpers import serialize_template_field
+
+ return {field: serialize_template_field(getattr(task, field), field) for
field in task.template_fields}
+
+
def run(ti: RuntimeTaskInstance, log: Logger):
"""Run the task in this process."""
from airflow.exceptions import (
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index 517157e0a7a..c9755c252bb 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -260,3 +260,61 @@ def test_startup_basic_templated_dag(mocked_parse):
),
log=mock.ANY,
)
+
+
[email protected](
+ ["task_params", "expected_rendered_fields"],
+ [
+ pytest.param(
+ {"op_args": [], "op_kwargs": {}, "templates_dict": None},
+ {"op_args": [], "op_kwargs": {}, "templates_dict": None},
+ id="no_templates",
+ ),
+ pytest.param(
+ {
+ "op_args": ["arg1", "arg2", 1, 2, 3.75, {"key": "value"}],
+ "op_kwargs": {"key1": "value1", "key2": 99.0, "key3":
{"nested_key": "nested_value"}},
+ },
+ {
+ "op_args": ["arg1", "arg2", 1, 2, 3.75, {"key": "value"}],
+ "op_kwargs": {"key1": "value1", "key2": 99.0, "key3":
{"nested_key": "nested_value"}},
+ },
+ id="mixed_types",
+ ),
+ pytest.param(
+ {"my_tup": (1, 2), "my_set": {1, 2, 3}},
+ {"my_tup": "(1, 2)", "my_set": "{1, 2, 3}"},
+ id="tuples_and_sets",
+ ),
+ ],
+)
+def test_startup_dag_with_templated_fields(mocked_parse, task_params,
expected_rendered_fields):
+ """Test startup of a DAG with various templated fields."""
+
+ class CustomOperator(BaseOperator):
+ template_fields = tuple(task_params.keys())
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ for key, value in task_params.items():
+ setattr(self, key, value)
+
+ task = CustomOperator(task_id="templated_task")
+
+ what = StartupDetails(
+ ti=TaskInstance(id=uuid7(), task_id="templated_task",
dag_id="basic_dag", run_id="c", try_number=1),
+ file="",
+ requests_fd=0,
+ )
+ mocked_parse(what, "basic_dag", task)
+
+ with mock.patch(
+ "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+ ) as mock_supervisor_comms:
+ mock_supervisor_comms.get_message.return_value = what
+
+ startup()
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
+ log=mock.ANY,
+ )
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index f39cff8ae76..a4089c9785a 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -351,7 +351,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"try_number": 0,
"unixname": getuser(),
"dag_run_id": "TEST_DAG_RUN_ID",
- "rendered_fields": {"op_args": [], "op_kwargs": {},
"templates_dict": None},
+ "rendered_fields": {"op_args": "()", "op_kwargs": {},
"templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
@@ -403,7 +403,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"try_number": 0,
"unixname": getuser(),
"dag_run_id": "TEST_DAG_RUN_ID",
- "rendered_fields": {"op_args": [], "op_kwargs": {},
"templates_dict": None},
+ "rendered_fields": {"op_args": "()", "op_kwargs": {},
"templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
@@ -2371,7 +2371,7 @@ class TestSetTaskInstanceNote(TestTaskInstanceEndpoint):
"try_number": 0,
"unixname": getuser(),
"dag_run_id": "TEST_DAG_RUN_ID",
- "rendered_fields": {"op_args": [], "op_kwargs": {},
"templates_dict": None},
+ "rendered_fields": {"op_args": "()", "op_kwargs": {},
"templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
index 9b427253b29..7ce944a4d47 100644
--- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
+++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -344,7 +344,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
"try_number": 0,
"unixname": getuser(),
"dag_run_id": "TEST_DAG_RUN_ID",
- "rendered_fields": {"op_args": [], "op_kwargs": {},
"templates_dict": None},
+ "rendered_fields": {"op_args": "()", "op_kwargs": {},
"templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
@@ -444,7 +444,7 @@ class TestGetMappedTaskInstance(TestTaskInstanceEndpoint):
"try_number": 0,
"unixname": getuser(),
"dag_run_id": "TEST_DAG_RUN_ID",
- "rendered_fields": {"op_args": [], "op_kwargs": {},
"templates_dict": None},
+ "rendered_fields": {"op_args": "()", "op_kwargs": {},
"templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
@@ -3070,7 +3070,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
"try_number": 0,
"unixname": getuser(),
"dag_run_id": self.RUN_ID,
- "rendered_fields": {"op_args": [], "op_kwargs": {},
"templates_dict": None},
+ "rendered_fields": {"op_args": "()", "op_kwargs": {},
"templates_dict": None},
"rendered_map_index": None,
"trigger": None,
"triggerer_job": None,
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index c13effee0bb..15e56bbc587 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -422,16 +422,31 @@ class TestTIPutRTIF:
clear_db_runs()
clear_rendered_ti_fields()
- def test_ti_put_rtif_success(self, client, session, create_task_instance):
+ @pytest.mark.parametrize(
+ "payload",
+ [
+ # string value
+ {"field1": "string_value", "field2": "another_string"},
+ # dictionary value
+ {"field1": {"nested_key": "nested_value"}},
+ # string lists value
+ {"field1": ["123"], "field2": ["a", "b", "c"]},
+ # list of JSON values
+ {"field1": [1, "string", 3.14, True, None, {"nested": "dict"}]},
+ # nested dictionary with mixed types in lists
+ {
+ "field1": {"nested_dict": {"key1": 123, "key2": "value"}},
+ "field2": [3.14, {"sub_key": "sub_value"}, [1, 2]],
+ },
+ ],
+ )
+ def test_ti_put_rtif_success(self, client, session, create_task_instance,
payload):
ti = create_task_instance(
task_id="test_ti_put_rtif_success",
state=State.RUNNING,
session=session,
)
session.commit()
-
- payload = {"field1": "rendered_value1", "field2": "rendered_value2"}
-
response = client.put(f"/execution/task-instances/{ti.id}/rtif",
json=payload)
assert response.status_code == 201
assert response.json() == {"message": "Rendered task instance fields
successfully set"}
@@ -461,28 +476,3 @@ class TestTIPutRTIF:
response = client.put(f"/execution/task-instances/{random_id}/rtif",
json=payload)
assert response.status_code == 404
assert response.json()["detail"] == "Not Found"
-
- def test_ti_put_rtif_extra_fields(self, client, session,
create_task_instance):
- ti = create_task_instance(
- task_id="test_ti_put_rtif_missing_ti",
- state=State.RUNNING,
- session=session,
- )
- session.commit()
-
- payload = {
- "field1": "rendered_value1",
- "field2": "rendered_value2",
- "invalid_key": {"field3": "rendered_value3"},
- }
-
- response = client.put(f"/execution/task-instances/{ti.id}/rtif",
json=payload)
- assert response.status_code == 422
- assert response.json()["detail"] == [
- {
- "input": {"field3": "rendered_value3"},
- "loc": ["body", "invalid_key"],
- "msg": "Input should be a valid string",
- "type": "string_type",
- }
- ]
diff --git a/tests/models/test_renderedtifields.py
b/tests/models/test_renderedtifields.py
index 3f1b13cd1a3..ded755c4d01 100644
--- a/tests/models/test_renderedtifields.py
+++ b/tests/models/test_renderedtifields.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+import ast
import os
from collections import Counter
from datetime import date, timedelta
@@ -100,8 +101,12 @@ class TestRenderedTaskInstanceFields:
(None, None),
([], []),
({}, {}),
+ ((), "()"),
+ (set(), "set()"),
("test-string", "test-string"),
({"foo": "bar"}, {"foo": "bar"}),
+ (("foo", "bar"), "('foo', 'bar')"),
+ ({"foo", "bar"}, "{'foo', 'bar'}"),
("{{ task.task_id }}", "test"),
(date(2018, 12, 6), "2018-12-06"),
(datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"),
@@ -158,16 +163,35 @@ class TestRenderedTaskInstanceFields:
assert ti.dag_id == rtif.dag_id
assert ti.task_id == rtif.task_id
assert ti.run_id == rtif.run_id
- assert expected_rendered_field ==
rtif.rendered_fields.get("bash_command")
+ if type(templated_field) is set:
+ # the output order of a set is non-deterministic and can change
per process.
+ # this validation can fail if that happens before stringification,
so we convert to set and compare.
+ assert ast.literal_eval(expected_rendered_field) ==
ast.literal_eval(
+ rtif.rendered_fields.get("bash_command")
+ )
+ else:
+ assert expected_rendered_field ==
rtif.rendered_fields.get("bash_command")
session.add(rtif)
session.flush()
- assert RTIF.get_templated_fields(ti=ti, session=session) == {
- "bash_command": expected_rendered_field,
- "env": None,
- "cwd": None,
- }
+ if type(templated_field) is set:
+ # the output order of a set is non-deterministic and can change
per process.
+ # this validation can fail if that happens before stringification,
so we convert to set and compare.
+ expected = RTIF.get_templated_fields(ti=ti, session=session)
+ expected["bash_command"] =
ast.literal_eval(expected["bash_command"])
+ actual = {
+ "bash_command": ast.literal_eval(expected_rendered_field),
+ "env": None,
+ "cwd": None,
+ }
+ assert expected == actual
+ else:
+ assert RTIF.get_templated_fields(ti=ti, session=session) == {
+ "bash_command": expected_rendered_field,
+ "env": None,
+ "cwd": None,
+ }
# Test the else part of get_templated_fields
# i.e. for the TIs that are not stored in RTIF table
# Fetching them will return None