This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 167669f075b Add test setup and example asset operation tests to task
SDK integration tests (#58027)
167669f075b is described below
commit 167669f075b90215038e7b1e293131f26ccda5af
Author: Amogh Desai <[email protected]>
AuthorDate: Sat Nov 8 18:58:08 2025 +0530
Add test setup and example asset operation tests to task SDK integration
tests (#58027)
---
task-sdk-tests/dags/test_asset_dag.py | 45 +++
task-sdk-tests/tests/task_sdk_tests/conftest.py | 305 ++++++++++++++++-----
.../tests/task_sdk_tests/test_asset_operations.py | 65 +++++
3 files changed, 352 insertions(+), 63 deletions(-)
diff --git a/task-sdk-tests/dags/test_asset_dag.py
b/task-sdk-tests/dags/test_asset_dag.py
new file mode 100644
index 00000000000..4ae28268ffc
--- /dev/null
+++ b/task-sdk-tests/dags/test_asset_dag.py
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Test DAG for asset operations.
+
+This file contains:
+- asset_producer_dag: A DAG that produces an asset
+- asset_consumer_dag: A DAG that is triggered by the asset
+"""
+
+from __future__ import annotations
+
+from airflow.sdk import DAG, Asset, task
+
+test_asset = Asset(uri="test://asset1", name="test_asset")
+
+with DAG(
+ dag_id="asset_producer_dag",
+ description="DAG that produces an asset for testing",
+ schedule=None,
+ catchup=False,
+) as producer_dag:
+
+ @task(outlets=[test_asset])
+ def produce_asset():
+ """Task that produces the test asset."""
+ print("Producing test asset")
+ return "asset_produced"
+
+ produce_asset()
diff --git a/task-sdk-tests/tests/task_sdk_tests/conftest.py
b/task-sdk-tests/tests/task_sdk_tests/conftest.py
index 0b3d3efbd1e..f5ba0087f8c 100644
--- a/task-sdk-tests/tests/task_sdk_tests/conftest.py
+++ b/task-sdk-tests/tests/task_sdk_tests/conftest.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import os
import subprocess
import sys
+from collections.abc import Callable
from pathlib import Path
import pytest
@@ -182,9 +183,55 @@ def pytest_sessionstart(session):
raise
[email protected](scope="session")
-def airflow_test_setup(docker_compose_setup):
- """Fixed session-scoped fixture that matches UI behavior."""
+def setup_dag_and_get_client(
+ *,
+ dag_id: str,
+ headers: dict[str, str],
+ auth_token: str | None = None,
+ task_id_filter: str | Callable[[dict], bool] | None = None,
+ wait_for_task_state: str | None = None,
+ wait_for_dag_state: str | None = None,
+ wait_timeout: int = 60,
+ additional_metadata: dict | None = None,
+) -> dict:
+ """
+ Utility to set up a DAG run and create an SDK client.
+
+ This function handles the common pattern of:
+ 1. Getting DAG status
+ 2. Unpausing DAG (if needed)
+ 3. Triggering a DAG run
+ 4. Waiting for task instances or DAG run state
+ 5. Acquiring the task instance ID for the triggered DAG run
+ 6. Generating JWT token for that task instance
+ 7. Creating a task SDK client with that JWT token
+
+ Args:
+ dag_id: The DAG ID to set up
+ headers: Headers for API requests (must include Authorization)
+ auth_token: Auth token string (extracted from headers if not provided)
+ task_id_filter: Task ID to filter for, or callable that takes TI dict
and returns bool.
+ If None, uses first available task instance.
+ wait_for_task_state: If provided, wait for task instance to reach this
state.
+ Mutually exclusive with wait_for_dag_state.
+ wait_for_dag_state: If provided, wait for DAG run to reach this state.
+ Mutually exclusive with wait_for_task_state.
+ wait_timeout: Maximum number of attempts to wait (each attempt is ~2
seconds)
+ additional_metadata: Additional metadata to include in return dict
+
+ Returns:
+ A dict which contains:
+ - dag_info: dict with dag_id, dag_run_id, logical_date
+ - task_instance_id: UUID string of task instance
+ - sdk_client: Authenticated SDK client
+ - core_api_headers: Headers for Core API requests
+ - auth_token: Auth token string
+ - Any additional metadata from additional_metadata
+
+ Raises:
+ TimeoutError: If waiting for the DAG run or task instance times out
+ RuntimeError: If task instance is not found or DAG run fails
+ """
import time
import requests
@@ -193,96 +240,202 @@ def airflow_test_setup(docker_compose_setup):
from airflow.sdk.timezone import utcnow
from task_sdk_tests.jwt_plugin import generate_jwt_token
- time.sleep(15)
+ if wait_for_task_state and wait_for_dag_state:
+ raise ValueError("Cannot specify both wait_for_task_state and
wait_for_dag_state")
- # Step 1: Get auth token
- auth_url = "http://localhost:8080/auth/token"
- try:
- auth_response = requests.get(auth_url, timeout=10)
- auth_response.raise_for_status()
- auth_token = auth_response.json()["access_token"]
- console.print("[green]✅ Got auth token")
- except Exception as e:
- raise e
-
- # Step 2: Check and unpause DAG
- headers = {"Authorization": f"Bearer {auth_token}", "Content-Type":
"application/json"}
-
- console.print("[yellow]Checking DAG status...")
- dag_response = requests.get("http://localhost:8080/api/v2/dags/test_dag",
headers=headers)
+ # Step 1: Get DAG status
+ console.print(f"[yellow]Checking {dag_id} status...")
+ dag_response = requests.get(f"http://localhost:8080/api/v2/dags/{dag_id}",
headers=headers)
dag_response.raise_for_status()
dag_data = dag_response.json()
+ # Step 2: Unpause DAG if needed
if dag_data.get("is_paused", True):
- console.print("[yellow]Unpausing DAG...")
+ console.print(f"[yellow]Unpausing {dag_id}...")
unpause_response = requests.patch(
- "http://localhost:8080/api/v2/dags/test_dag", json={"is_paused":
False}, headers=headers
+ f"http://localhost:8080/api/v2/dags/{dag_id}", json={"is_paused":
False}, headers=headers
)
unpause_response.raise_for_status()
- console.print("[green]✅ DAG unpaused")
+ console.print(f"[green]✅ {dag_id} unpaused")
+
+ # Step 3: Trigger DAG run
logical_date = utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")[:-3] + "Z"
payload = {"conf": {}, "logical_date": logical_date}
+ console.print(f"[yellow]Triggering {dag_id}...")
trigger_response = requests.post(
- "http://localhost:8080/api/v2/dags/test_dag/dagRuns", json=payload,
headers=headers, timeout=30
+ f"http://localhost:8080/api/v2/dags/{dag_id}/dagRuns", json=payload,
headers=headers, timeout=30
)
-
- console.print(f"[blue]Trigger DAG Run response status:
{trigger_response.status_code}")
- console.print(f"[blue]Trigger DAG Run response: {trigger_response.text}")
-
trigger_response.raise_for_status()
dag_run_data = trigger_response.json()
dag_run_id = dag_run_data["dag_run_id"]
- console.print(f"[green]✅ DAG triggered: {dag_run_id}")
-
- # Step 4: Get task instance for testing - wait for long_running_task to be
RUNNING
- console.print("[yellow]Waiting for long_running_task to be RUNNING...")
- ti_id = None
-
- for attempt in range(30): # Increased to 30 attempts (60 seconds)
- try:
- ti_url =
f"http://localhost:8080/api/v2/dags/test_dag/dagRuns/{dag_run_id}/taskInstances"
- ti_response = requests.get(ti_url, headers=headers, timeout=10)
- ti_response.raise_for_status()
-
- task_instances = ti_response.json().get("task_instances", [])
-
- # Look specifically for long_running_task that is in RUNNING state
- for ti in task_instances:
- if ti.get("task_id") == "long_running_task" and
ti.get("state") == "running":
- ti_id = ti.get("id")
- if ti_id:
- console.print(f"[green]✅ Found running task:
'{ti.get('task_id')}'")
- console.print(f"[green] State: {ti.get('state')}")
- console.print(f"[green] Instance ID: {ti_id}")
- break
-
- if ti_id:
+ console.print(f"[green]✅ {dag_id} triggered: {dag_run_id}")
+
+ # Step 4: Wait for condition
+ if wait_for_dag_state:
+ # Wait for DAG run to reach specific state
+ console.print(f"[yellow]Waiting for {dag_id} to reach state
'{wait_for_dag_state}'...")
+ final_state = None
+ for attempt in range(wait_timeout):
+ try:
+ dr_response = requests.get(
+
f"http://localhost:8080/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}",
headers=headers
+ )
+ dr_response.raise_for_status()
+ dr_data = dr_response.json()
+ state = dr_data.get("state")
+
+ if state == wait_for_dag_state:
+ console.print(f"[green]✅ {dag_id} reached state
'{wait_for_dag_state}'!")
+ final_state = state
+ break
+ if state in ["failed", "skipped"]:
+ raise RuntimeError(f"{dag_id} ended in state: {state}")
+ console.print(
+ f"[blue]Waiting for {dag_id} to reach
'{wait_for_dag_state}' "
+ f"(attempt {attempt + 1}/{wait_timeout}, current state:
{state})"
+ )
+
+ except Exception as e:
+ console.print(f"[yellow]DAG run check failed: {e}")
+
+ time.sleep(2)
+
+ if final_state != wait_for_dag_state:
+ raise TimeoutError(f"{dag_id} did not reach state
'{wait_for_dag_state}' within timeout period")
+
+ # Step 5: Get task instance ID
+ console.print(f"[yellow]Getting task instance ID from {dag_id}...")
+ ti_url =
f"http://localhost:8080/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances"
+ ti_response = requests.get(ti_url, headers=headers, timeout=10)
+ ti_response.raise_for_status()
+
+ task_instances = ti_response.json().get("task_instances", [])
+
+ # Filter task instances
+ if wait_for_task_state:
+ # Wait for specific task to reach specific state
+ console.print(f"[yellow]Waiting for task instance to reach state
'{wait_for_task_state}'...")
+ ti_id = None
+
+ for attempt in range(wait_timeout):
+ try:
+ ti_response = requests.get(ti_url, headers=headers, timeout=10)
+ ti_response.raise_for_status()
+ task_instances = ti_response.json().get("task_instances", [])
+
+ # Filter task instances
+ for ti in task_instances:
+ if callable(task_id_filter):
+ matches_filter = task_id_filter(ti)
+ elif task_id_filter:
+ matches_filter = ti.get("task_id") == task_id_filter
+ else:
+ matches_filter = True
+
+ if matches_filter and ti.get("state") ==
wait_for_task_state:
+ ti_id = ti.get("id")
+ if ti_id:
+ console.print(f"[green]✅ Found task in
'{wait_for_task_state}' state")
+ console.print(f"[green] Task ID:
{ti.get('task_id')}")
+ console.print(f"[green] Instance ID: {ti_id}")
+ break
+
+ if ti_id:
+ break
+ console.print(f"[blue]Waiting for task instance (attempt
{attempt + 1}/{wait_timeout})")
+
+ except Exception as e:
+ console.print(f"[yellow]Task check failed: {e}")
+
+ time.sleep(2)
+
+ if not ti_id:
+ raise TimeoutError(
+ f"Task instance did not reach '{wait_for_task_state}' state
within timeout period"
+ )
+ else:
+ ti_id = None
+ for ti in task_instances:
+ if callable(task_id_filter):
+ matches_filter = task_id_filter(ti)
+ elif task_id_filter:
+ matches_filter = ti.get("task_id") == task_id_filter
+ else:
+ matches_filter = True
+
+ if matches_filter:
+ ti_id = ti.get("id")
break
- console.print(f"[blue]Waiting for long_running_task to start
(attempt {attempt + 1}/30)")
- except Exception as e:
- console.print(f"[yellow]Task check failed: {e}")
+ if not ti_id:
+ raise RuntimeError(f"Could not find task instance ID from
{dag_id}")
- time.sleep(2)
+ console.print(f"[green]✅ Found task instance ID: {ti_id}")
- if not ti_id:
- console.print("[red]❌ long_running_task never reached RUNNING state.
Final debug info:")
- raise TimeoutError("long_running_task did not reach RUNNING state
within timeout period")
-
- # Step 5: Create SDK client
+ # Step 6: Generate JWT token and create SDK client
jwt_token = generate_jwt_token(ti_id)
sdk_client = Client(base_url=f"http://{TASK_SDK_HOST_PORT}/execution",
token=jwt_token)
- return {
+ # Extract auth token from headers if not provided
+ if auth_token is None:
+ auth_token = headers.get("Authorization", "").replace("Bearer ", "")
+
+ # Build return dict
+ result = {
"auth_token": auth_token,
- "dag_info": {"dag_id": "test_dag", "dag_run_id": dag_run_id,
"logical_date": logical_date},
+ "dag_info": {"dag_id": dag_id, "dag_run_id": dag_run_id,
"logical_date": logical_date},
"task_instance_id": ti_id,
"sdk_client": sdk_client,
"core_api_headers": headers,
}
+ if additional_metadata:
+ result.update(additional_metadata)
+
+ return result
+
+
[email protected](scope="session")
+def airflow_ready(docker_compose_setup):
+ """Shared fixture that waits for Airflow to be ready and provides auth
token to communicate with Airflow."""
+ import time
+
+ import requests
+
+ # Generous sleep for Airflow to be ready
+ time.sleep(15)
+
+ auth_url = "http://localhost:8080/auth/token"
+ try:
+ auth_response = requests.get(auth_url, timeout=10)
+ auth_response.raise_for_status()
+ auth_token = auth_response.json()["access_token"]
+ console.print("[green]✅ Got auth token")
+ except Exception as e:
+ raise e
+
+ headers = {"Authorization": f"Bearer {auth_token}", "Content-Type":
"application/json"}
+
+ return {"auth_token": auth_token, "headers": headers}
+
+
[email protected](scope="session")
+def airflow_test_setup(docker_compose_setup, airflow_ready):
+ """Fixed session-scoped fixture that matches UI behavior."""
+ headers = airflow_ready["headers"]
+ auth_token = airflow_ready["auth_token"]
+
+ return setup_dag_and_get_client(
+ dag_id="test_dag",
+ headers=headers,
+ auth_token=auth_token,
+ task_id_filter="long_running_task",
+ wait_for_task_state="running",
+ wait_timeout=30,
+ )
+
@pytest.fixture(scope="session")
def task_sdk_api_version():
@@ -320,3 +473,29 @@ def sdk_client(airflow_test_setup):
def core_api_headers(airflow_test_setup):
"""Get Core API headers from setup."""
return airflow_test_setup["core_api_headers"]
+
+
[email protected](scope="session")
+def sdk_client_for_assets(asset_test_setup):
+ """Get SDK client for asset tests (doesn't require test_dag)."""
+ return asset_test_setup["sdk_client"]
+
+
[email protected](scope="session")
+def asset_test_setup(docker_compose_setup, airflow_ready):
+ """Setup assets for testing by triggering asset_producer_dag."""
+ headers = airflow_ready["headers"]
+ auth_token = airflow_ready["auth_token"]
+
+ return setup_dag_and_get_client(
+ dag_id="asset_producer_dag",
+ headers=headers,
+ auth_token=auth_token,
+ task_id_filter="produce_asset",
+ wait_for_dag_state="success",
+ wait_timeout=60,
+ additional_metadata={
+ "name": "test_asset",
+ "uri": "test://asset1/",
+ },
+ )
diff --git a/task-sdk-tests/tests/task_sdk_tests/test_asset_operations.py
b/task-sdk-tests/tests/task_sdk_tests/test_asset_operations.py
new file mode 100644
index 00000000000..528b1993bc9
--- /dev/null
+++ b/task-sdk-tests/tests/task_sdk_tests/test_asset_operations.py
@@ -0,0 +1,65 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Integration tests for Asset operations.
+
+These tests validate the Execution API endpoints for Asset operations:
+- get(): Get asset by name
+"""
+
+from __future__ import annotations
+
+from airflow.sdk.api.datamodels._generated import AssetResponse
+from airflow.sdk.execution_time.comms import ErrorResponse
+from task_sdk_tests import console
+
+
+def test_asset_get_by_name(sdk_client_for_assets, asset_test_setup):
+ """Test getting asset by name."""
+ console.print("[yellow]Getting asset by name...")
+
+ response = sdk_client_for_assets.assets.get(name=asset_test_setup["name"])
+
+ console.print(" Asset Get By Name Response ".center(72, "="))
+ console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+ console.print(f"[bright_blue]Name:[/] {response.name}")
+ console.print(f"[bright_blue]URI:[/] {response.uri}")
+ console.print(f"[bright_blue]Group:[/] {response.group}")
+ console.print("=" * 72)
+
+ assert isinstance(response, AssetResponse)
+ assert response.name == asset_test_setup["name"]
+ assert response.uri == asset_test_setup["uri"]
+ console.print("[green]Asset get by name test passed!")
+
+
+def test_asset_get_by_name_not_found(sdk_client_for_assets):
+ """Test getting non-existent asset by name."""
+ console.print("[yellow]Getting non-existent asset by name...")
+
+ response = sdk_client_for_assets.assets.get(name="non_existent_asset_name")
+
+ console.print(" Asset Get (Not Found) Response ".center(72, "="))
+ console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+ console.print(f"[bright_blue]Error Type:[/] {response.error}")
+ console.print(f"[bright_blue]Detail:[/] {response.detail}")
+ console.print("=" * 72)
+
+ assert isinstance(response, ErrorResponse)
+ assert str(response.error).endswith("ASSET_NOT_FOUND")
+ console.print("[green]Asset get by name (not found) test passed!")