This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2c7de9768ff Adding Task SDK integration tests for Xcom operations
(#58638)
2c7de9768ff is described below
commit 2c7de9768ffdbef673caff25e8f144ee6d8d35b6
Author: Henry Chen <[email protected]>
AuthorDate: Tue Dec 2 17:00:30 2025 +0800
Adding Task SDK integration tests for Xcom operations (#58638)
Co-authored-by: Amogh Desai <[email protected]>
---
task-sdk-integration-tests/dags/test_dag.py | 11 +-
.../tests/task_sdk_tests/test_xcom_operations.py | 235 ++++++++++++++++++---
2 files changed, 218 insertions(+), 28 deletions(-)
diff --git a/task-sdk-integration-tests/dags/test_dag.py
b/task-sdk-integration-tests/dags/test_dag.py
index 0d69d551de6..708f9434056 100644
--- a/task-sdk-integration-tests/dags/test_dag.py
+++ b/task-sdk-integration-tests/dags/test_dag.py
@@ -35,6 +35,14 @@ def return_tuple_task(ti=None):
return 1, "test_value"
+@task(dag=dag)
+def mapped_task(value, ti=None):
+ """Mapped task that processes individual values for testing XCom sequence
operations"""
+ print(f"Processing value: {value} with TI ID: {ti.id}, map_index:
{ti.map_index}")
+ # Return a modified value for XCom testing
+ return f"processed_{value}"
+
+
@task(dag=dag)
def long_running_task(ti=None):
"""Long-running task that sleeps for 5 minutes to allow testing"""
@@ -49,6 +57,7 @@ def long_running_task(ti=None):
get_ti_id = get_task_instance_id()
tuple_task = return_tuple_task()
+mapped_instances = mapped_task.expand(value=["alpha", "beta", "gamma",
"delta"])
long_task = long_running_task()
-get_ti_id >> tuple_task >> long_task
+get_ti_id >> tuple_task >> mapped_instances >> long_task
diff --git
a/task-sdk-integration-tests/tests/task_sdk_tests/test_xcom_operations.py
b/task-sdk-integration-tests/tests/task_sdk_tests/test_xcom_operations.py
index 684ac070ed4..c6a5c03996e 100644
--- a/task-sdk-integration-tests/tests/task_sdk_tests/test_xcom_operations.py
+++ b/task-sdk-integration-tests/tests/task_sdk_tests/test_xcom_operations.py
@@ -27,8 +27,13 @@ from __future__ import annotations
import pytest
-from airflow.sdk.api.datamodels._generated import XComResponse
-from airflow.sdk.execution_time.comms import OKResponse
+from airflow.sdk.api.datamodels._generated import (
+ XComResponse,
+ XComSequenceIndexResponse,
+ XComSequenceSliceResponse,
+)
+from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.execution_time.comms import ErrorResponse, OKResponse,
XComCountResponse
from task_sdk_tests import console
@@ -64,7 +69,6 @@ def test_get_xcom(sdk_client, dag_info):
console.print("[green]✅ XCom get test passed!")
[email protected](reason="TODO: Implement XCom get (not found) test")
def test_get_xcom_not_found(sdk_client, dag_info):
"""
Test getting non-existent XCom value.
@@ -72,8 +76,26 @@ def test_get_xcom_not_found(sdk_client, dag_info):
Expected: XComResponse with value=None or ErrorResponse
Endpoint: GET /execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}
"""
- console.print("[yellow]TODO: Implement test_get_xcom_not_found")
- raise NotImplementedError("test_get_xcom_not_found not implemented")
+ missing_key = "non_existent_xcom_key_for_test"
+ console.print("[yellow]Getting non-existent XCom key...")
+
+ response = sdk_client.xcoms.get(
+ dag_id=dag_info["dag_id"],
+ run_id=dag_info["dag_run_id"],
+ task_id="get_task_instance_id",
+ key=missing_key,
+ )
+
+ console.print(" XCom Get (Not Found) Response ".center(72, "="))
+ console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+ console.print(f"[bright_blue]Key:[/] {response.key}")
+ console.print(f"[bright_blue]Value:[/] {response.value}")
+ console.print("=" * 72)
+
+ assert isinstance(response, XComResponse)
+ assert response.key == missing_key
+ assert response.value is None
+ console.print("[green]✅ XCom not-found test passed!")
def test_set_xcom(sdk_client, dag_info):
@@ -129,7 +151,6 @@ def test_xcom_delete(sdk_client, dag_info):
test_key = "test_xcom_key_delete"
- # Set XCom first
sdk_client.xcoms.set(
dag_id=dag_info["dag_id"],
run_id=dag_info["dag_run_id"],
@@ -138,7 +159,6 @@ def test_xcom_delete(sdk_client, dag_info):
value="to_be_deleted",
)
- # Delete XCom
delete_response = sdk_client.xcoms.delete(
dag_id=dag_info["dag_id"],
run_id=dag_info["dag_run_id"],
@@ -174,55 +194,198 @@ def test_xcom_delete(sdk_client, dag_info):
console.print("[green]✅ XCom delete test passed!")
[email protected](reason="TODO: Implement XCom head test")
-def test_xcom_head(sdk_client, dag_info):
+def test_xcom_head_unmapped(sdk_client, dag_info):
+ """
+ Test getting count of unmapped XCom values.
+
+ Expected: XComCountResponse with len field in it (should be ideally equal
to number of unmapped tasks, since we have None it might throw RuntimeError)
+ Endpoint: HEAD /execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}
+ """
+ console.print("[yellow]Testing XCom head for non-mapped task...")
+
+ response_single = sdk_client.xcoms.head(
+ dag_id=dag_info["dag_id"],
+ run_id=dag_info["dag_run_id"],
+ task_id="return_tuple_task",
+ key="return_value",
+ )
+
+ console.print(" XCom Head Response (Non-Mapped) ".center(72, "="))
+ console.print(f"[bright_blue]Response Type:[/]
{type(response_single).__name__}")
+ console.print(f"[bright_blue]Count:[/] {response_single.len}")
+ console.print("=" * 72)
+
+ assert isinstance(response_single, XComCountResponse)
+ assert response_single.len == 1
+
+ console.print("[green]✅ XCom head non-mapped test passed!")
+
+
+def test_xcom_head_mapped(sdk_client, dag_info):
"""
Test getting count of mapped XCom values.
Expected: XComCountResponse with len field in it (should be ideally equal
to number of mapped tasks, since we have None it might throw RuntimeError)
Endpoint: HEAD /execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}
"""
- console.print("[yellow]TODO: Implement test_xcom_head")
- raise NotImplementedError("test_xcom_head not implemented")
+ console.print("[yellow]Testing XCom head for mapped task...")
+ response_mapped = sdk_client.xcoms.head(
+ dag_id=dag_info["dag_id"],
+ run_id=dag_info["dag_run_id"],
+ task_id="mapped_task",
+ key="return_value",
+ )
[email protected](reason="TODO: Implement XCom get_sequence_item test")
-def test_xcom_get_sequence_item(sdk_client, dag_info):
+ console.print(" XCom Head Response (Mapped) ".center(72, "="))
+ console.print(f"[bright_blue]Response Type:[/]
{type(response_mapped).__name__}")
+ console.print(f"[bright_blue]Count:[/] {response_mapped.len}")
+ console.print("=" * 72)
+
+ assert isinstance(response_mapped, XComCountResponse)
+ assert response_mapped.len == 4
+ console.print("[green]✅ XCom head mapped test passed!")
+
+
[email protected](
+ ("offset", "expected_value"),
+ [
+ (0, "processed_alpha"),
+ (-1, "processed_delta"),
+ (2, "processed_gamma"),
+ ],
+)
+def test_xcom_get_sequence_item(sdk_client, dag_info, offset, expected_value):
"""
Test getting XCom sequence item by offset.
Expected: XComSequenceIndexResponse with value
Endpoint: GET
/execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}
"""
- console.print("[yellow]TODO: Implement test_xcom_get_sequence_item")
- raise NotImplementedError("test_xcom_get_sequence_item not implemented")
+ console.print("[yellow]Testing XCom sequence item access...")
+
+ response = sdk_client.xcoms.get_sequence_item(
+ dag_id=dag_info["dag_id"],
+ run_id=dag_info["dag_run_id"],
+ task_id="mapped_task",
+ key="return_value",
+ offset=offset,
+ )
+ console.print(f" XCom Sequence Item [offset={offset}] ".center(72, "="))
+ console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+ console.print(f"[bright_blue]Value:[/] {response.root}")
+ console.print("=" * 72)
[email protected](reason="TODO: Implement XCom get_sequence_item (not found)
test")
-def test_xcom_get_sequence_item_not_found(sdk_client, dag_info):
+ assert isinstance(response, XComSequenceIndexResponse)
+ assert response.root == expected_value
+ console.print(f"[green]✅ XCom get_sequence_item test {offset} passed!")
+
+
+def test_xcom_get_sequence_item_not_found_offset(sdk_client, dag_info):
"""
- Test getting non-existent XCom sequence item.
+ Test getting non-existent XCom sequence item due to out-of-range offset.
Expected: ErrorResponse with XCOM_NOT_FOUND error
Endpoint: GET
/execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}
"""
- console.print("[yellow]TODO: Implement
test_xcom_get_sequence_item_not_found")
- raise NotImplementedError("test_xcom_get_sequence_item_not_found not
implemented")
+ console.print("[yellow]Testing XCom sequence item not found (offset)...")
+
+ response = sdk_client.xcoms.get_sequence_item(
+ dag_id=dag_info["dag_id"],
+ run_id=dag_info["dag_run_id"],
+ task_id="mapped_task",
+ key="return_value",
+ offset=10,
+ )
+
+ console.print(" XCom Sequence Item Not Found (offset) ".center(72, "="))
+ console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+ console.print(f"[bright_blue]Error:[/] {response.error}")
+ console.print(f"[bright_blue]Detail:[/] {response.detail}")
+ console.print("=" * 72)
+
+ assert isinstance(response, ErrorResponse)
+ assert response.error == ErrorType.XCOM_NOT_FOUND
+ assert response.detail["key"] == "return_value"
+ assert response.detail["offset"] == 10
+
+ console.print("[green]✅ XCom get_sequence_item_not_found (offset) test
passed!")
+
+
+def test_xcom_get_sequence_item_not_found_wrong_key(sdk_client, dag_info):
+ """
+ Test getting non-existent XCom sequence item due to wrong key.
+
+ Expected: ErrorResponse with XCOM_NOT_FOUND error
+ Endpoint: GET
/execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}
+ """
+ console.print("[yellow]Testing XCom sequence item not found (wrong
key)...")
+
+ response_bad_key = sdk_client.xcoms.get_sequence_item(
+ dag_id=dag_info["dag_id"],
+ run_id=dag_info["dag_run_id"],
+ task_id="mapped_task",
+ key="non_existent_key",
+ offset=0,
+ )
+ console.print(" XCom Sequence Item Not Found (wrong key) ".center(72, "="))
+ console.print(f"[bright_blue]Response Type:[/]
{type(response_bad_key).__name__}")
+ console.print(f"[bright_blue]Error:[/] {response_bad_key.error}")
+ console.print("=" * 72)
[email protected](reason="TODO: Implement XCom get_sequence_slice test")
-def test_xcom_get_sequence_slice(sdk_client, dag_info):
+ assert isinstance(response_bad_key, ErrorResponse)
+ assert response_bad_key.error == ErrorType.XCOM_NOT_FOUND
+
+ console.print("[green]✅ XCom get_sequence_item_not_found (wrong key) test
passed!")
+
+
[email protected](
+ ("case_params", "expected_values"),
+ [
+ (
+ {"start": None, "stop": None, "step": None},
+ ["processed_alpha", "processed_beta", "processed_gamma",
"processed_delta"],
+ ),
+ ({"start": 1, "stop": 3, "step": None}, ["processed_beta",
"processed_gamma"]),
+ ({"start": 0, "stop": 4, "step": 2}, ["processed_alpha",
"processed_gamma"]),
+ ],
+ ids=["full_slice", "slice_1_to_3", "slice_step_2"],
+)
+def test_xcom_get_sequence_slice(sdk_client, dag_info, case_params,
expected_values):
"""
Test getting XCom sequence slice.
Expected: XComSequenceSliceResponse with list of values
Endpoint: GET /execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice
"""
- console.print("[yellow]TODO: Implement test_xcom_get_sequence_slice")
- raise NotImplementedError("test_xcom_get_sequence_slice not implemented")
+ start = case_params["start"]
+ stop = case_params["stop"]
+ step = case_params["step"]
+ console.print(f"[yellow]Testing XCom sequence slice access (start={start},
stop={stop}, step={step})...")
+
+ response = sdk_client.xcoms.get_sequence_slice(
+ dag_id=dag_info["dag_id"],
+ run_id=dag_info["dag_run_id"],
+ task_id="mapped_task",
+ key="return_value",
+ start=start,
+ stop=stop,
+ step=step,
+ )
+
+ console.print(f" XCom Sequence Slice [{start}:{stop}:{step}] ".center(72,
"="))
+ console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+ console.print(f"[bright_blue]Values:[/] {response.root}")
+ console.print("=" * 72)
+
+ assert isinstance(response, XComSequenceSliceResponse)
+ assert response.root == expected_values
+
+ console.print("[green]✅ XCom get_sequence_slice test passed!")
[email protected](reason="TODO: Implement XCom get_sequence_slice (not found)
test")
def test_xcom_get_sequence_slice_not_found(sdk_client, dag_info):
"""
Test getting slice for non-existent XCom key.
@@ -230,5 +393,23 @@ def test_xcom_get_sequence_slice_not_found(sdk_client,
dag_info):
Expected: XComSequenceSliceResponse as empty list
Endpoint: GET /execution/xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice
"""
- console.print("[yellow]TODO: Implement
test_xcom_get_sequence_slice_not_found")
- raise NotImplementedError("test_xcom_get_sequence_slice_not_found not
implemented")
+ console.print("[yellow]Testing XCom sequence slice not found...")
+
+ response = sdk_client.xcoms.get_sequence_slice(
+ dag_id=dag_info["dag_id"],
+ run_id=dag_info["dag_run_id"],
+ task_id="mapped_task",
+ key="non_existent_key",
+ start=0,
+ stop=10,
+ step=None,
+ )
+
+ console.print(" XCom Sequence Slice (Not Found) ".center(72, "="))
+ console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+ console.print(f"[bright_blue]Values:[/] {getattr(response, 'root', None)}")
+ console.print("=" * 72)
+
+ assert isinstance(response, XComSequenceSliceResponse)
+ assert response.root == []
+ console.print("[green]✅ XCom get_sequence_slice_not_found test passed!")