This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c7de9768ff Adding Task SDK integration tests for Xcom operations 
(#58638)
2c7de9768ff is described below

commit 2c7de9768ffdbef673caff25e8f144ee6d8d35b6
Author: Henry Chen <[email protected]>
AuthorDate: Tue Dec 2 17:00:30 2025 +0800

    Adding Task SDK integration tests for Xcom operations (#58638)
    
    Co-authored-by: Amogh Desai <[email protected]>
---
 task-sdk-integration-tests/dags/test_dag.py        |  11 +-
 .../tests/task_sdk_tests/test_xcom_operations.py   | 235 ++++++++++++++++++---
 2 files changed, 218 insertions(+), 28 deletions(-)

diff --git a/task-sdk-integration-tests/dags/test_dag.py 
b/task-sdk-integration-tests/dags/test_dag.py
index 0d69d551de6..708f9434056 100644
--- a/task-sdk-integration-tests/dags/test_dag.py
+++ b/task-sdk-integration-tests/dags/test_dag.py
@@ -35,6 +35,14 @@ def return_tuple_task(ti=None):
     return 1, "test_value"
 
 
+@task(dag=dag)
+def mapped_task(value, ti=None):
+    """Mapped task that processes individual values for testing XCom sequence 
operations"""
+    print(f"Processing value: {value} with TI ID: {ti.id}, map_index: 
{ti.map_index}")
+    # Return a modified value for XCom testing
+    return f"processed_{value}"
+
+
 @task(dag=dag)
 def long_running_task(ti=None):
     """Long-running task that sleeps for 5 minutes to allow testing"""
@@ -49,6 +57,7 @@ def long_running_task(ti=None):
 
 get_ti_id = get_task_instance_id()
 tuple_task = return_tuple_task()
+mapped_instances = mapped_task.expand(value=["alpha", "beta", "gamma", 
"delta"])
 long_task = long_running_task()
 
-get_ti_id >> tuple_task >> long_task
+get_ti_id >> tuple_task >> mapped_instances >> long_task
diff --git 
a/task-sdk-integration-tests/tests/task_sdk_tests/test_xcom_operations.py 
b/task-sdk-integration-tests/tests/task_sdk_tests/test_xcom_operations.py
index 684ac070ed4..c6a5c03996e 100644
--- a/task-sdk-integration-tests/tests/task_sdk_tests/test_xcom_operations.py
+++ b/task-sdk-integration-tests/tests/task_sdk_tests/test_xcom_operations.py
@@ -27,8 +27,13 @@ from __future__ import annotations
 
 import pytest
 
-from airflow.sdk.api.datamodels._generated import XComResponse
-from airflow.sdk.execution_time.comms import OKResponse
+from airflow.sdk.api.datamodels._generated import (
+    XComResponse,
+    XComSequenceIndexResponse,
+    XComSequenceSliceResponse,
+)
+from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.execution_time.comms import ErrorResponse, OKResponse, 
XComCountResponse
 from task_sdk_tests import console
 
 
@@ -64,7 +69,6 @@ def test_get_xcom(sdk_client, dag_info):
     console.print("[green]✅ XCom get test passed!")
 
 
[email protected](reason="TODO: Implement XCom get (not found) test")
 def test_get_xcom_not_found(sdk_client, dag_info):
     """
     Test getting non-existent XCom value.
@@ -72,8 +76,26 @@ def test_get_xcom_not_found(sdk_client, dag_info):
     Expected: XComResponse with value=None or ErrorResponse
     Endpoint: GET /execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}
     """
-    console.print("[yellow]TODO: Implement test_get_xcom_not_found")
-    raise NotImplementedError("test_get_xcom_not_found not implemented")
+    missing_key = "non_existent_xcom_key_for_test"
+    console.print("[yellow]Getting non-existent XCom key...")
+
+    response = sdk_client.xcoms.get(
+        dag_id=dag_info["dag_id"],
+        run_id=dag_info["dag_run_id"],
+        task_id="get_task_instance_id",
+        key=missing_key,
+    )
+
+    console.print(" XCom Get (Not Found) Response ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+    console.print(f"[bright_blue]Key:[/] {response.key}")
+    console.print(f"[bright_blue]Value:[/] {response.value}")
+    console.print("=" * 72)
+
+    assert isinstance(response, XComResponse)
+    assert response.key == missing_key
+    assert response.value is None
+    console.print("[green]✅ XCom not-found test passed!")
 
 
 def test_set_xcom(sdk_client, dag_info):
@@ -129,7 +151,6 @@ def test_xcom_delete(sdk_client, dag_info):
 
     test_key = "test_xcom_key_delete"
 
-    # Set XCom first
     sdk_client.xcoms.set(
         dag_id=dag_info["dag_id"],
         run_id=dag_info["dag_run_id"],
@@ -138,7 +159,6 @@ def test_xcom_delete(sdk_client, dag_info):
         value="to_be_deleted",
     )
 
-    # Delete XCom
     delete_response = sdk_client.xcoms.delete(
         dag_id=dag_info["dag_id"],
         run_id=dag_info["dag_run_id"],
@@ -174,55 +194,198 @@ def test_xcom_delete(sdk_client, dag_info):
     console.print("[green]✅ XCom delete test passed!")
 
 
[email protected](reason="TODO: Implement XCom head test")
-def test_xcom_head(sdk_client, dag_info):
+def test_xcom_head_unmapped(sdk_client, dag_info):
+    """
+    Test getting count of unmapped XCom values.
+
+    Expected: XComCountResponse with len field in it (should be ideally equal 
to number of unmapped tasks, since we have None it might throw RuntimeError)
+    Endpoint: HEAD /execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}
+    """
+    console.print("[yellow]Testing XCom head for non-mapped task...")
+
+    response_single = sdk_client.xcoms.head(
+        dag_id=dag_info["dag_id"],
+        run_id=dag_info["dag_run_id"],
+        task_id="return_tuple_task",
+        key="return_value",
+    )
+
+    console.print(" XCom Head Response (Non-Mapped) ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] 
{type(response_single).__name__}")
+    console.print(f"[bright_blue]Count:[/] {response_single.len}")
+    console.print("=" * 72)
+
+    assert isinstance(response_single, XComCountResponse)
+    assert response_single.len == 1
+
+    console.print("[green]✅ XCom head non-mapped test passed!")
+
+
+def test_xcom_head_mapped(sdk_client, dag_info):
     """
     Test getting count of mapped XCom values.
 
     Expected: XComCountResponse with len field in it (should be ideally equal 
to number of mapped tasks, since we have None it might throw RuntimeError)
     Endpoint: HEAD /execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}
     """
-    console.print("[yellow]TODO: Implement test_xcom_head")
-    raise NotImplementedError("test_xcom_head not implemented")
+    console.print("[yellow]Testing XCom head for mapped task...")
 
+    response_mapped = sdk_client.xcoms.head(
+        dag_id=dag_info["dag_id"],
+        run_id=dag_info["dag_run_id"],
+        task_id="mapped_task",
+        key="return_value",
+    )
 
[email protected](reason="TODO: Implement XCom get_sequence_item test")
-def test_xcom_get_sequence_item(sdk_client, dag_info):
+    console.print(" XCom Head Response (Mapped) ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] 
{type(response_mapped).__name__}")
+    console.print(f"[bright_blue]Count:[/] {response_mapped.len}")
+    console.print("=" * 72)
+
+    assert isinstance(response_mapped, XComCountResponse)
+    assert response_mapped.len == 4
+    console.print("[green]✅ XCom head mapped test passed!")
+
+
[email protected](
+    ("offset", "expected_value"),
+    [
+        (0, "processed_alpha"),
+        (-1, "processed_delta"),
+        (2, "processed_gamma"),
+    ],
+)
+def test_xcom_get_sequence_item(sdk_client, dag_info, offset, expected_value):
     """
     Test getting XCom sequence item by offset.
 
     Expected: XComSequenceIndexResponse with value
     Endpoint: GET 
/execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}
     """
-    console.print("[yellow]TODO: Implement test_xcom_get_sequence_item")
-    raise NotImplementedError("test_xcom_get_sequence_item not implemented")
+    console.print("[yellow]Testing XCom sequence item access...")
+
+    response = sdk_client.xcoms.get_sequence_item(
+        dag_id=dag_info["dag_id"],
+        run_id=dag_info["dag_run_id"],
+        task_id="mapped_task",
+        key="return_value",
+        offset=offset,
+    )
 
+    console.print(f" XCom Sequence Item [offset={offset}] ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+    console.print(f"[bright_blue]Value:[/] {response.root}")
+    console.print("=" * 72)
 
[email protected](reason="TODO: Implement XCom get_sequence_item (not found) 
test")
-def test_xcom_get_sequence_item_not_found(sdk_client, dag_info):
+    assert isinstance(response, XComSequenceIndexResponse)
+    assert response.root == expected_value
+    console.print(f"[green]✅ XCom get_sequence_item test {offset} passed!")
+
+
+def test_xcom_get_sequence_item_not_found_offset(sdk_client, dag_info):
     """
-    Test getting non-existent XCom sequence item.
+    Test getting non-existent XCom sequence item due to out-of-range offset.
 
     Expected: ErrorResponse with XCOM_NOT_FOUND error
     Endpoint: GET 
/execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}
     """
-    console.print("[yellow]TODO: Implement 
test_xcom_get_sequence_item_not_found")
-    raise NotImplementedError("test_xcom_get_sequence_item_not_found not 
implemented")
+    console.print("[yellow]Testing XCom sequence item not found (offset)...")
+
+    response = sdk_client.xcoms.get_sequence_item(
+        dag_id=dag_info["dag_id"],
+        run_id=dag_info["dag_run_id"],
+        task_id="mapped_task",
+        key="return_value",
+        offset=10,
+    )
+
+    console.print(" XCom Sequence Item Not Found (offset) ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+    console.print(f"[bright_blue]Error:[/] {response.error}")
+    console.print(f"[bright_blue]Detail:[/] {response.detail}")
+    console.print("=" * 72)
+
+    assert isinstance(response, ErrorResponse)
+    assert response.error == ErrorType.XCOM_NOT_FOUND
+    assert response.detail["key"] == "return_value"
+    assert response.detail["offset"] == 10
+
+    console.print("[green]✅ XCom get_sequence_item_not_found (offset) test 
passed!")
+
+
+def test_xcom_get_sequence_item_not_found_wrong_key(sdk_client, dag_info):
+    """
+    Test getting non-existent XCom sequence item due to wrong key.
+
+    Expected: ErrorResponse with XCOM_NOT_FOUND error
+    Endpoint: GET 
/execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}
+    """
+    console.print("[yellow]Testing XCom sequence item not found (wrong 
key)...")
+
+    response_bad_key = sdk_client.xcoms.get_sequence_item(
+        dag_id=dag_info["dag_id"],
+        run_id=dag_info["dag_run_id"],
+        task_id="mapped_task",
+        key="non_existent_key",
+        offset=0,
+    )
 
+    console.print(" XCom Sequence Item Not Found (wrong key) ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] 
{type(response_bad_key).__name__}")
+    console.print(f"[bright_blue]Error:[/] {response_bad_key.error}")
+    console.print("=" * 72)
 
[email protected](reason="TODO: Implement XCom get_sequence_slice test")
-def test_xcom_get_sequence_slice(sdk_client, dag_info):
+    assert isinstance(response_bad_key, ErrorResponse)
+    assert response_bad_key.error == ErrorType.XCOM_NOT_FOUND
+
+    console.print("[green]✅ XCom get_sequence_item_not_found (wrong key) test 
passed!")
+
+
[email protected](
+    ("case_params", "expected_values"),
+    [
+        (
+            {"start": None, "stop": None, "step": None},
+            ["processed_alpha", "processed_beta", "processed_gamma", 
"processed_delta"],
+        ),
+        ({"start": 1, "stop": 3, "step": None}, ["processed_beta", 
"processed_gamma"]),
+        ({"start": 0, "stop": 4, "step": 2}, ["processed_alpha", 
"processed_gamma"]),
+    ],
+    ids=["full_slice", "slice_1_to_3", "slice_step_2"],
+)
+def test_xcom_get_sequence_slice(sdk_client, dag_info, case_params, 
expected_values):
     """
     Test getting XCom sequence slice.
 
     Expected: XComSequenceSliceResponse with list of values
     Endpoint: GET /execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice
     """
-    console.print("[yellow]TODO: Implement test_xcom_get_sequence_slice")
-    raise NotImplementedError("test_xcom_get_sequence_slice not implemented")
+    start = case_params["start"]
+    stop = case_params["stop"]
+    step = case_params["step"]
+    console.print(f"[yellow]Testing XCom sequence slice access (start={start}, 
stop={stop}, step={step})...")
+
+    response = sdk_client.xcoms.get_sequence_slice(
+        dag_id=dag_info["dag_id"],
+        run_id=dag_info["dag_run_id"],
+        task_id="mapped_task",
+        key="return_value",
+        start=start,
+        stop=stop,
+        step=step,
+    )
+
+    console.print(f" XCom Sequence Slice [{start}:{stop}:{step}] ".center(72, 
"="))
+    console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+    console.print(f"[bright_blue]Values:[/] {response.root}")
+    console.print("=" * 72)
+
+    assert isinstance(response, XComSequenceSliceResponse)
+    assert response.root == expected_values
+
+    console.print("[green]✅ XCom get_sequence_slice test passed!")
 
 
[email protected](reason="TODO: Implement XCom get_sequence_slice (not found) 
test")
 def test_xcom_get_sequence_slice_not_found(sdk_client, dag_info):
     """
     Test getting slice for non-existent XCom key.
@@ -230,5 +393,23 @@ def test_xcom_get_sequence_slice_not_found(sdk_client, 
dag_info):
     Expected: XComSequenceSliceResponse as empty list
     Endpoint: GET /execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice
     """
-    console.print("[yellow]TODO: Implement 
test_xcom_get_sequence_slice_not_found")
-    raise NotImplementedError("test_xcom_get_sequence_slice_not_found not 
implemented")
+    console.print("[yellow]Testing XCom sequence slice not found...")
+
+    response = sdk_client.xcoms.get_sequence_slice(
+        dag_id=dag_info["dag_id"],
+        run_id=dag_info["dag_run_id"],
+        task_id="mapped_task",
+        key="non_existent_key",
+        start=0,
+        stop=10,
+        step=None,
+    )
+
+    console.print(" XCom Sequence Slice (Not Found) ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+    console.print(f"[bright_blue]Values:[/] {getattr(response, 'root', None)}")
+    console.print("=" * 72)
+
+    assert isinstance(response, XComSequenceSliceResponse)
+    assert response.root == []
+    console.print("[green]✅ XCom get_sequence_slice_not_found test passed!")

Reply via email to