This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ff7e70056e5  AIP-72: Extending SET RTIF endpoint to accept all 
JSONable types (#44843)
ff7e70056e5 is described below

commit ff7e70056e5b4e6e61d20215eec00054415d4a77
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Dec 13 23:03:42 2024 +0530

     AIP-72: Extending SET RTIF endpoint to accept all JSONable types (#44843)
    
    
    An endpoint to set RTIF was added in #44359. This allowed only `dict[str, 
str]` entries to be passed down to the API which lead to issues when running 
tests with DAGs like:
    ```py
    from __future__ import annotations
    
    import sys
    import time
    from datetime import datetime
    
    from airflow import DAG
    from airflow.decorators import dag, task
    from airflow.operators.bash import BashOperator
    
    
    @dag(
        # every minute on the 30-second mark
        catchup=False,
        tags=[],
        schedule=None,
        start_date=datetime(2021, 1, 1),
    )
    def hello_dag():
        """
        ### TaskFlow API Tutorial Documentation
        This is a simple data pipeline example which demonstrates the use of
        the TaskFlow API using three simple tasks for Extract, Transform, and 
Load.
        Documentation that goes along with the Airflow TaskFlow API tutorial is
        located
        
[here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html)
        """
    
        @task()
        def hello():
            print("hello")
            time.sleep(3)
            print("goodbye")
            print("err mesg", file=sys.stderr)
    
        hello()
    
    
    hello_dag()
    ```
    
    The reason for this is that the arguments such as `op_args` and `op_kwargs` 
for PythonOperator can be non str. So that leads to a conclusion that we should 
accept `str` keys but `JsonAble` values.
    
    Some points to note for reviewers:
    1. Type we store in the table: 
https://github.com/apache/airflow/blob/1eb683be3a79c80927e9af1e89dabb5e78ce3136/airflow/models/renderedtifields.py#L76.
 Hence we should be able to accept any JsonAble types and store them, for non 
JsonAble ones like tuple and set, we should convert them and do it.
    
    
    ### What does this PR change?
    - Get rid of the `RTIFPayload` and consume the payload directly in the api 
handler.
    - Handling special case of `tuples` - they are json serialisable but we 
used to store them as lists when passed as tuples, because of usage of 
json.dumps(). It has been made like this now:
    ```
        def is_jsonable(x):
            try:
                json.dumps(x)
                if isinstance(x, tuple):
                    # Tuple is converted to list in json.dumps
                    # so while it is jsonable, it changes the type which might 
be a surprise
                    # for the user, so instead we return False here -- which 
will convert it to string
                    return False
    ```
    - Reusing `serialize_template_field` from `airflow.serialization.helpers` 
because copy pasting code will be expensive, hard to maintain. We will revisit 
it anyways when we port the logic of templating to TASK SDK. Discussion: 
https://github.com/apache/airflow/pull/44843/files#r1882834039
    - Added test cases with different scopes and different types to handle 
different cases of templated_fields well.
---
 .../execution_api/datamodels/taskinstance.py       |  6 +--
 .../execution_api/routes/task_instances.py         |  6 +--
 airflow/serialization/helpers.py                   |  5 ++
 task_sdk/src/airflow/sdk/execution_time/comms.py   |  2 +-
 .../src/airflow/sdk/execution_time/task_runner.py  | 24 +++++----
 task_sdk/tests/execution_time/test_task_runner.py  | 58 ++++++++++++++++++++++
 .../endpoints/test_task_instance_endpoint.py       |  6 +--
 .../core_api/routes/public/test_task_instances.py  |  6 +--
 .../execution_api/routes/test_task_instances.py    | 48 +++++++-----------
 tests/models/test_renderedtifields.py              | 36 +++++++++++---
 10 files changed, 137 insertions(+), 60 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index e0d8f371f09..bbc557d0124 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -21,7 +21,7 @@ import uuid
 from datetime import timedelta
 from typing import Annotated, Any, Literal, Union
 
-from pydantic import Discriminator, Field, RootModel, Tag, WithJsonSchema
+from pydantic import Discriminator, Field, Tag, WithJsonSchema
 
 from airflow.api_fastapi.common.types import UtcDateTime
 from airflow.api_fastapi.core_api.base import BaseModel
@@ -135,7 +135,3 @@ class TaskInstance(BaseModel):
     run_id: str
     try_number: int
     map_index: int | None = None
-
-
-"""Schema for setting RTIF for a task instance."""
-RTIFPayload = RootModel[dict[str, str]]
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 90bbe1c1d3e..e06798209c5 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -22,6 +22,7 @@ from typing import Annotated
 from uuid import UUID
 
 from fastapi import Body, HTTPException, status
+from pydantic import JsonValue
 from sqlalchemy import update
 from sqlalchemy.exc import NoResultFound, SQLAlchemyError
 from sqlalchemy.sql import select
@@ -29,7 +30,6 @@ from sqlalchemy.sql import select
 from airflow.api_fastapi.common.db.common import SessionDep
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
-    RTIFPayload,
     TIDeferredStatePayload,
     TIEnterRunningPayload,
     TIHeartbeatInfo,
@@ -237,7 +237,7 @@ def ti_heartbeat(
 )
 def ti_put_rtif(
     task_instance_id: UUID,
-    put_rtif_payload: RTIFPayload,
+    put_rtif_payload: Annotated[dict[str, JsonValue], Body()],
     session: SessionDep,
 ):
     """Add an RTIF entry for a task instance, sent by the worker."""
@@ -247,6 +247,6 @@ def ti_put_rtif(
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
         )
-    _update_rtif(task_instance, put_rtif_payload.model_dump(), session)
+    _update_rtif(task_instance, put_rtif_payload, session)
 
     return {"message": "Rendered task instance fields successfully set"}
diff --git a/airflow/serialization/helpers.py b/airflow/serialization/helpers.py
index 85bf3a1cc55..dc1aabbca98 100644
--- a/airflow/serialization/helpers.py
+++ b/airflow/serialization/helpers.py
@@ -36,6 +36,11 @@ def serialize_template_field(template_field: Any, name: str) 
-> str | dict | lis
     def is_jsonable(x):
         try:
             json.dumps(x)
+            if isinstance(x, tuple):
+                # Tuple is converted to list in json.dumps
+                # so while it is jsonable, it changes the type which might be 
a surprise
+                # for the user, so instead we return False here -- which will 
convert it to string
+                return False
         except (TypeError, OverflowError):
             return False
         else:
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py 
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index 34d6a9e3156..9e6093a092d 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -176,7 +176,7 @@ class SetRenderedFields(BaseModel):
     # We are using a BaseModel here compared to server using RootModel because 
we
     # have a discriminator running with "type", and RootModel doesn't support 
type
 
-    rendered_fields: dict[str, str | None]
+    rendered_fields: dict[str, JsonValue]
     type: Literal["SetRenderedFields"] = "SetRenderedFields"
 
 
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index c01677ce1a7..5aca25f590e 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Generic, TextIO, 
TypeVar
 
 import attrs
 import structlog
-from pydantic import BaseModel, ConfigDict, TypeAdapter
+from pydantic import BaseModel, ConfigDict, JsonValue, TypeAdapter
 
 from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
 from airflow.sdk.definitions.baseoperator import BaseOperator
@@ -196,22 +196,26 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
     # 1. Implementing the part where we pull in the logic to render fields and 
add that here
     # for all operators, we should do setattr(task, templated_field, 
rendered_templated_field)
     # task.templated_fields should give all the templated_fields and each of 
those fields should
-    # give the rendered values.
+    # give the rendered values. task.templated_fields should already be in a 
JSONable format and
+    # we should not have to handle that here.
 
     # 2. Once rendered, we call the `set_rtif` API to store the rtif in the 
metadata DB
-    templated_fields = ti.task.template_fields
-    payload = {}
-
-    for field in templated_fields:
-        if field not in payload:
-            payload[field] = getattr(ti.task, field)
 
     # so that we do not call the API unnecessarily
-    if payload:
-        SUPERVISOR_COMMS.send_request(log=log, 
msg=SetRenderedFields(rendered_fields=payload))
+    if rendered_fields := _get_rendered_fields(ti.task):
+        SUPERVISOR_COMMS.send_request(log=log, 
msg=SetRenderedFields(rendered_fields=rendered_fields))
     return ti, log
 
 
+def _get_rendered_fields(task: BaseOperator) -> dict[str, JsonValue]:
+    # TODO: Port one of the following to Task SDK
+    #   airflow.serialization.helpers.serialize_template_field or
+    #   airflow.models.renderedtifields.get_serialized_template_fields
+    from airflow.serialization.helpers import serialize_template_field
+
+    return {field: serialize_template_field(getattr(task, field), field) for 
field in task.template_fields}
+
+
 def run(ti: RuntimeTaskInstance, log: Logger):
     """Run the task in this process."""
     from airflow.exceptions import (
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index 517157e0a7a..c9755c252bb 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -260,3 +260,61 @@ def test_startup_basic_templated_dag(mocked_parse):
             ),
             log=mock.ANY,
         )
+
+
[email protected](
+    ["task_params", "expected_rendered_fields"],
+    [
+        pytest.param(
+            {"op_args": [], "op_kwargs": {}, "templates_dict": None},
+            {"op_args": [], "op_kwargs": {}, "templates_dict": None},
+            id="no_templates",
+        ),
+        pytest.param(
+            {
+                "op_args": ["arg1", "arg2", 1, 2, 3.75, {"key": "value"}],
+                "op_kwargs": {"key1": "value1", "key2": 99.0, "key3": 
{"nested_key": "nested_value"}},
+            },
+            {
+                "op_args": ["arg1", "arg2", 1, 2, 3.75, {"key": "value"}],
+                "op_kwargs": {"key1": "value1", "key2": 99.0, "key3": 
{"nested_key": "nested_value"}},
+            },
+            id="mixed_types",
+        ),
+        pytest.param(
+            {"my_tup": (1, 2), "my_set": {1, 2, 3}},
+            {"my_tup": "(1, 2)", "my_set": "{1, 2, 3}"},
+            id="tuples_and_sets",
+        ),
+    ],
+)
+def test_startup_dag_with_templated_fields(mocked_parse, task_params, 
expected_rendered_fields):
+    """Test startup of a DAG with various templated fields."""
+
+    class CustomOperator(BaseOperator):
+        template_fields = tuple(task_params.keys())
+
+        def __init__(self, *args, **kwargs):
+            super().__init__(*args, **kwargs)
+            for key, value in task_params.items():
+                setattr(self, key, value)
+
+    task = CustomOperator(task_id="templated_task")
+
+    what = StartupDetails(
+        ti=TaskInstance(id=uuid7(), task_id="templated_task", 
dag_id="basic_dag", run_id="c", try_number=1),
+        file="",
+        requests_fd=0,
+    )
+    mocked_parse(what, "basic_dag", task)
+
+    with mock.patch(
+        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+    ) as mock_supervisor_comms:
+        mock_supervisor_comms.get_message.return_value = what
+
+        startup()
+        mock_supervisor_comms.send_request.assert_called_once_with(
+            msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
+            log=mock.ANY,
+        )
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py 
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index f39cff8ae76..a4089c9785a 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -351,7 +351,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
             "try_number": 0,
             "unixname": getuser(),
             "dag_run_id": "TEST_DAG_RUN_ID",
-            "rendered_fields": {"op_args": [], "op_kwargs": {}, 
"templates_dict": None},
+            "rendered_fields": {"op_args": "()", "op_kwargs": {}, 
"templates_dict": None},
             "rendered_map_index": None,
             "trigger": None,
             "triggerer_job": None,
@@ -403,7 +403,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
                 "try_number": 0,
                 "unixname": getuser(),
                 "dag_run_id": "TEST_DAG_RUN_ID",
-                "rendered_fields": {"op_args": [], "op_kwargs": {}, 
"templates_dict": None},
+                "rendered_fields": {"op_args": "()", "op_kwargs": {}, 
"templates_dict": None},
                 "rendered_map_index": None,
                 "trigger": None,
                 "triggerer_job": None,
@@ -2371,7 +2371,7 @@ class TestSetTaskInstanceNote(TestTaskInstanceEndpoint):
                 "try_number": 0,
                 "unixname": getuser(),
                 "dag_run_id": "TEST_DAG_RUN_ID",
-                "rendered_fields": {"op_args": [], "op_kwargs": {}, 
"templates_dict": None},
+                "rendered_fields": {"op_args": "()", "op_kwargs": {}, 
"templates_dict": None},
                 "rendered_map_index": None,
                 "trigger": None,
                 "triggerer_job": None,
diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py 
b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
index 9b427253b29..7ce944a4d47 100644
--- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
+++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -344,7 +344,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
             "try_number": 0,
             "unixname": getuser(),
             "dag_run_id": "TEST_DAG_RUN_ID",
-            "rendered_fields": {"op_args": [], "op_kwargs": {}, 
"templates_dict": None},
+            "rendered_fields": {"op_args": "()", "op_kwargs": {}, 
"templates_dict": None},
             "rendered_map_index": None,
             "trigger": None,
             "triggerer_job": None,
@@ -444,7 +444,7 @@ class TestGetMappedTaskInstance(TestTaskInstanceEndpoint):
                 "try_number": 0,
                 "unixname": getuser(),
                 "dag_run_id": "TEST_DAG_RUN_ID",
-                "rendered_fields": {"op_args": [], "op_kwargs": {}, 
"templates_dict": None},
+                "rendered_fields": {"op_args": "()", "op_kwargs": {}, 
"templates_dict": None},
                 "rendered_map_index": None,
                 "trigger": None,
                 "triggerer_job": None,
@@ -3070,7 +3070,7 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
                 "try_number": 0,
                 "unixname": getuser(),
                 "dag_run_id": self.RUN_ID,
-                "rendered_fields": {"op_args": [], "op_kwargs": {}, 
"templates_dict": None},
+                "rendered_fields": {"op_args": "()", "op_kwargs": {}, 
"templates_dict": None},
                 "rendered_map_index": None,
                 "trigger": None,
                 "triggerer_job": None,
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py 
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index c13effee0bb..15e56bbc587 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -422,16 +422,31 @@ class TestTIPutRTIF:
         clear_db_runs()
         clear_rendered_ti_fields()
 
-    def test_ti_put_rtif_success(self, client, session, create_task_instance):
+    @pytest.mark.parametrize(
+        "payload",
+        [
+            # string value
+            {"field1": "string_value", "field2": "another_string"},
+            # dictionary value
+            {"field1": {"nested_key": "nested_value"}},
+            # string lists value
+            {"field1": ["123"], "field2": ["a", "b", "c"]},
+            # list of JSON values
+            {"field1": [1, "string", 3.14, True, None, {"nested": "dict"}]},
+            # nested dictionary with mixed types in lists
+            {
+                "field1": {"nested_dict": {"key1": 123, "key2": "value"}},
+                "field2": [3.14, {"sub_key": "sub_value"}, [1, 2]],
+            },
+        ],
+    )
+    def test_ti_put_rtif_success(self, client, session, create_task_instance, 
payload):
         ti = create_task_instance(
             task_id="test_ti_put_rtif_success",
             state=State.RUNNING,
             session=session,
         )
         session.commit()
-
-        payload = {"field1": "rendered_value1", "field2": "rendered_value2"}
-
         response = client.put(f"/execution/task-instances/{ti.id}/rtif", 
json=payload)
         assert response.status_code == 201
         assert response.json() == {"message": "Rendered task instance fields 
successfully set"}
@@ -461,28 +476,3 @@ class TestTIPutRTIF:
         response = client.put(f"/execution/task-instances/{random_id}/rtif", 
json=payload)
         assert response.status_code == 404
         assert response.json()["detail"] == "Not Found"
-
-    def test_ti_put_rtif_extra_fields(self, client, session, 
create_task_instance):
-        ti = create_task_instance(
-            task_id="test_ti_put_rtif_missing_ti",
-            state=State.RUNNING,
-            session=session,
-        )
-        session.commit()
-
-        payload = {
-            "field1": "rendered_value1",
-            "field2": "rendered_value2",
-            "invalid_key": {"field3": "rendered_value3"},
-        }
-
-        response = client.put(f"/execution/task-instances/{ti.id}/rtif", 
json=payload)
-        assert response.status_code == 422
-        assert response.json()["detail"] == [
-            {
-                "input": {"field3": "rendered_value3"},
-                "loc": ["body", "invalid_key"],
-                "msg": "Input should be a valid string",
-                "type": "string_type",
-            }
-        ]
diff --git a/tests/models/test_renderedtifields.py 
b/tests/models/test_renderedtifields.py
index 3f1b13cd1a3..ded755c4d01 100644
--- a/tests/models/test_renderedtifields.py
+++ b/tests/models/test_renderedtifields.py
@@ -19,6 +19,7 @@
 
 from __future__ import annotations
 
+import ast
 import os
 from collections import Counter
 from datetime import date, timedelta
@@ -100,8 +101,12 @@ class TestRenderedTaskInstanceFields:
             (None, None),
             ([], []),
             ({}, {}),
+            ((), "()"),
+            (set(), "set()"),
             ("test-string", "test-string"),
             ({"foo": "bar"}, {"foo": "bar"}),
+            (("foo", "bar"), "('foo', 'bar')"),
+            ({"foo", "bar"}, "{'foo', 'bar'}"),
             ("{{ task.task_id }}", "test"),
             (date(2018, 12, 6), "2018-12-06"),
             (datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"),
@@ -158,16 +163,35 @@ class TestRenderedTaskInstanceFields:
         assert ti.dag_id == rtif.dag_id
         assert ti.task_id == rtif.task_id
         assert ti.run_id == rtif.run_id
-        assert expected_rendered_field == 
rtif.rendered_fields.get("bash_command")
+        if type(templated_field) is set:
+            # the output order of a set is non-deterministic and can change 
per process.
+            # this validation can fail if that happens before stringification, 
so we convert to set and compare.
+            assert ast.literal_eval(expected_rendered_field) == 
ast.literal_eval(
+                rtif.rendered_fields.get("bash_command")
+            )
+        else:
+            assert expected_rendered_field == 
rtif.rendered_fields.get("bash_command")
 
         session.add(rtif)
         session.flush()
 
-        assert RTIF.get_templated_fields(ti=ti, session=session) == {
-            "bash_command": expected_rendered_field,
-            "env": None,
-            "cwd": None,
-        }
+        if type(templated_field) is set:
+            # the output order of a set is non-deterministic and can change 
per process.
+            # this validation can fail if that happens before stringification, 
so we convert to set and compare.
+            expected = RTIF.get_templated_fields(ti=ti, session=session)
+            expected["bash_command"] = 
ast.literal_eval(expected["bash_command"])
+            actual = {
+                "bash_command": ast.literal_eval(expected_rendered_field),
+                "env": None,
+                "cwd": None,
+            }
+            assert expected == actual
+        else:
+            assert RTIF.get_templated_fields(ti=ti, session=session) == {
+                "bash_command": expected_rendered_field,
+                "env": None,
+                "cwd": None,
+            }
         # Test the else part of get_templated_fields
         # i.e. for the TIs that are not stored in RTIF table
         # Fetching them will return None

Reply via email to