This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 167669f075b Add test setup and example asset operation tests to task 
SDK integration tests (#58027)
167669f075b is described below

commit 167669f075b90215038e7b1e293131f26ccda5af
Author: Amogh Desai <[email protected]>
AuthorDate: Sat Nov 8 18:58:08 2025 +0530

    Add test setup and example asset operation tests to task SDK integration 
tests (#58027)
---
 task-sdk-tests/dags/test_asset_dag.py              |  45 +++
 task-sdk-tests/tests/task_sdk_tests/conftest.py    | 305 ++++++++++++++++-----
 .../tests/task_sdk_tests/test_asset_operations.py  |  65 +++++
 3 files changed, 352 insertions(+), 63 deletions(-)

diff --git a/task-sdk-tests/dags/test_asset_dag.py 
b/task-sdk-tests/dags/test_asset_dag.py
new file mode 100644
index 00000000000..4ae28268ffc
--- /dev/null
+++ b/task-sdk-tests/dags/test_asset_dag.py
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Test DAG for asset operations.
+
+This file contains:
+- asset_producer_dag: A DAG that produces an asset
+- asset_consumer_dag: A DAG that is triggered by the asset
+"""
+
+from __future__ import annotations
+
+from airflow.sdk import DAG, Asset, task
+
+test_asset = Asset(uri="test://asset1", name="test_asset")
+
+with DAG(
+    dag_id="asset_producer_dag",
+    description="DAG that produces an asset for testing",
+    schedule=None,
+    catchup=False,
+) as producer_dag:
+
+    @task(outlets=[test_asset])
+    def produce_asset():
+        """Task that produces the test asset."""
+        print("Producing test asset")
+        return "asset_produced"
+
+    produce_asset()
diff --git a/task-sdk-tests/tests/task_sdk_tests/conftest.py 
b/task-sdk-tests/tests/task_sdk_tests/conftest.py
index 0b3d3efbd1e..f5ba0087f8c 100644
--- a/task-sdk-tests/tests/task_sdk_tests/conftest.py
+++ b/task-sdk-tests/tests/task_sdk_tests/conftest.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 import os
 import subprocess
 import sys
+from collections.abc import Callable
 from pathlib import Path
 
 import pytest
@@ -182,9 +183,55 @@ def pytest_sessionstart(session):
         raise
 
 
[email protected](scope="session")
-def airflow_test_setup(docker_compose_setup):
-    """Fixed session-scoped fixture that matches UI behavior."""
+def setup_dag_and_get_client(
+    *,
+    dag_id: str,
+    headers: dict[str, str],
+    auth_token: str | None = None,
+    task_id_filter: str | Callable[[dict], bool] | None = None,
+    wait_for_task_state: str | None = None,
+    wait_for_dag_state: str | None = None,
+    wait_timeout: int = 60,
+    additional_metadata: dict | None = None,
+) -> dict:
+    """
+    Utility to set up a DAG run and create an SDK client.
+
+    This function handles the common pattern of:
+    1. Getting DAG status
+    2. Unpausing DAG (if needed)
+    3. Triggering a DAG run
+    4. Waiting for task instances or DAG run state
+    5. Acquiring the task instance ID for the triggered DAG run
+    6. Generating JWT token for that task instance
+    7. Creating a task SDK client with that JWT token
+
+    Args:
+        dag_id: The DAG ID to set up
+        headers: Headers for API requests (must include Authorization)
+        auth_token: Auth token string (extracted from headers if not provided)
+        task_id_filter: Task ID to filter for, or callable that takes TI dict 
and returns bool.
+                       If None, uses first available task instance.
+        wait_for_task_state: If provided, wait for task instance to reach this 
state.
+                             Mutually exclusive with wait_for_dag_state.
+        wait_for_dag_state: If provided, wait for DAG run to reach this state.
+                            Mutually exclusive with wait_for_task_state.
+        wait_timeout: Maximum number of attempts to wait (each attempt is ~2 
seconds)
+        additional_metadata: Additional metadata to include in return dict
+
+    Returns:
+        A dict which contains:
+        - dag_info: dict with dag_id, dag_run_id, logical_date
+        - task_instance_id: UUID string of task instance
+        - sdk_client: Authenticated SDK client
+        - core_api_headers: Headers for Core API requests
+        - auth_token: Auth token string
+        - Any additional metadata from additional_metadata
+
+    Raises:
+        TimeoutError: If waiting for the DAG run or task instance times out
+        RuntimeError: If task instance is not found or DAG run fails
+    """
     import time
 
     import requests
@@ -193,96 +240,202 @@ def airflow_test_setup(docker_compose_setup):
     from airflow.sdk.timezone import utcnow
     from task_sdk_tests.jwt_plugin import generate_jwt_token
 
-    time.sleep(15)
+    if wait_for_task_state and wait_for_dag_state:
+        raise ValueError("Cannot specify both wait_for_task_state and 
wait_for_dag_state")
 
-    # Step 1: Get auth token
-    auth_url = "http://localhost:8080/auth/token";
-    try:
-        auth_response = requests.get(auth_url, timeout=10)
-        auth_response.raise_for_status()
-        auth_token = auth_response.json()["access_token"]
-        console.print("[green]✅ Got auth token")
-    except Exception as e:
-        raise e
-
-    # Step 2: Check and unpause DAG
-    headers = {"Authorization": f"Bearer {auth_token}", "Content-Type": 
"application/json"}
-
-    console.print("[yellow]Checking DAG status...")
-    dag_response = requests.get("http://localhost:8080/api/v2/dags/test_dag";, 
headers=headers)
+    # Step 1: Get DAG status
+    console.print(f"[yellow]Checking {dag_id} status...")
+    dag_response = requests.get(f"http://localhost:8080/api/v2/dags/{dag_id}";, 
headers=headers)
     dag_response.raise_for_status()
     dag_data = dag_response.json()
 
+    # Step 2: Unpause DAG if needed
     if dag_data.get("is_paused", True):
-        console.print("[yellow]Unpausing DAG...")
+        console.print(f"[yellow]Unpausing {dag_id}...")
         unpause_response = requests.patch(
-            "http://localhost:8080/api/v2/dags/test_dag";, json={"is_paused": 
False}, headers=headers
+            f"http://localhost:8080/api/v2/dags/{dag_id}";, json={"is_paused": 
False}, headers=headers
         )
         unpause_response.raise_for_status()
-        console.print("[green]✅ DAG unpaused")
+        console.print(f"[green]✅ {dag_id} unpaused")
+
+    # Step 3: Trigger DAG run
     logical_date = utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")[:-3] + "Z"
     payload = {"conf": {}, "logical_date": logical_date}
 
+    console.print(f"[yellow]Triggering {dag_id}...")
     trigger_response = requests.post(
-        "http://localhost:8080/api/v2/dags/test_dag/dagRuns";, json=payload, 
headers=headers, timeout=30
+        f"http://localhost:8080/api/v2/dags/{dag_id}/dagRuns";, json=payload, 
headers=headers, timeout=30
     )
-
-    console.print(f"[blue]Trigger DAG Run response status: 
{trigger_response.status_code}")
-    console.print(f"[blue]Trigger DAG Run response: {trigger_response.text}")
-
     trigger_response.raise_for_status()
     dag_run_data = trigger_response.json()
     dag_run_id = dag_run_data["dag_run_id"]
 
-    console.print(f"[green]✅ DAG triggered: {dag_run_id}")
-
-    # Step 4: Get task instance for testing - wait for long_running_task to be 
RUNNING
-    console.print("[yellow]Waiting for long_running_task to be RUNNING...")
-    ti_id = None
-
-    for attempt in range(30):  # Increased to 30 attempts (60 seconds)
-        try:
-            ti_url = 
f"http://localhost:8080/api/v2/dags/test_dag/dagRuns/{dag_run_id}/taskInstances";
-            ti_response = requests.get(ti_url, headers=headers, timeout=10)
-            ti_response.raise_for_status()
-
-            task_instances = ti_response.json().get("task_instances", [])
-
-            # Look specifically for long_running_task that is in RUNNING state
-            for ti in task_instances:
-                if ti.get("task_id") == "long_running_task" and 
ti.get("state") == "running":
-                    ti_id = ti.get("id")
-                    if ti_id:
-                        console.print(f"[green]✅ Found running task: 
'{ti.get('task_id')}'")
-                        console.print(f"[green]    State: {ti.get('state')}")
-                        console.print(f"[green]    Instance ID: {ti_id}")
-                        break
-
-            if ti_id:
+    console.print(f"[green]✅ {dag_id} triggered: {dag_run_id}")
+
+    # Step 4: Wait for condition
+    if wait_for_dag_state:
+        # Wait for DAG run to reach specific state
+        console.print(f"[yellow]Waiting for {dag_id} to reach state 
'{wait_for_dag_state}'...")
+        final_state = None
+        for attempt in range(wait_timeout):
+            try:
+                dr_response = requests.get(
+                    
f"http://localhost:8080/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}";, 
headers=headers
+                )
+                dr_response.raise_for_status()
+                dr_data = dr_response.json()
+                state = dr_data.get("state")
+
+                if state == wait_for_dag_state:
+                    console.print(f"[green]✅ {dag_id} reached state 
'{wait_for_dag_state}'!")
+                    final_state = state
+                    break
+                if state in ["failed", "skipped"]:
+                    raise RuntimeError(f"{dag_id} ended in state: {state}")
+                console.print(
+                    f"[blue]Waiting for {dag_id} to reach 
'{wait_for_dag_state}' "
+                    f"(attempt {attempt + 1}/{wait_timeout}, current state: 
{state})"
+                )
+
+            except Exception as e:
+                console.print(f"[yellow]DAG run check failed: {e}")
+
+            time.sleep(2)
+
+        if final_state != wait_for_dag_state:
+            raise TimeoutError(f"{dag_id} did not reach state 
'{wait_for_dag_state}' within timeout period")
+
+    # Step 5: Get task instance ID
+    console.print(f"[yellow]Getting task instance ID from {dag_id}...")
+    ti_url = 
f"http://localhost:8080/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances";
+    ti_response = requests.get(ti_url, headers=headers, timeout=10)
+    ti_response.raise_for_status()
+
+    task_instances = ti_response.json().get("task_instances", [])
+
+    # Filter task instances
+    if wait_for_task_state:
+        # Wait for specific task to reach specific state
+        console.print(f"[yellow]Waiting for task instance to reach state 
'{wait_for_task_state}'...")
+        ti_id = None
+
+        for attempt in range(wait_timeout):
+            try:
+                ti_response = requests.get(ti_url, headers=headers, timeout=10)
+                ti_response.raise_for_status()
+                task_instances = ti_response.json().get("task_instances", [])
+
+                # Filter task instances
+                for ti in task_instances:
+                    if callable(task_id_filter):
+                        matches_filter = task_id_filter(ti)
+                    elif task_id_filter:
+                        matches_filter = ti.get("task_id") == task_id_filter
+                    else:
+                        matches_filter = True
+
+                    if matches_filter and ti.get("state") == 
wait_for_task_state:
+                        ti_id = ti.get("id")
+                        if ti_id:
+                            console.print(f"[green]✅ Found task in 
'{wait_for_task_state}' state")
+                            console.print(f"[green]    Task ID: 
{ti.get('task_id')}")
+                            console.print(f"[green]    Instance ID: {ti_id}")
+                            break
+
+                if ti_id:
+                    break
+                console.print(f"[blue]Waiting for task instance (attempt 
{attempt + 1}/{wait_timeout})")
+
+            except Exception as e:
+                console.print(f"[yellow]Task check failed: {e}")
+
+            time.sleep(2)
+
+        if not ti_id:
+            raise TimeoutError(
+                f"Task instance did not reach '{wait_for_task_state}' state 
within timeout period"
+            )
+    else:
+        ti_id = None
+        for ti in task_instances:
+            if callable(task_id_filter):
+                matches_filter = task_id_filter(ti)
+            elif task_id_filter:
+                matches_filter = ti.get("task_id") == task_id_filter
+            else:
+                matches_filter = True
+
+            if matches_filter:
+                ti_id = ti.get("id")
                 break
-            console.print(f"[blue]Waiting for long_running_task to start 
(attempt {attempt + 1}/30)")
 
-        except Exception as e:
-            console.print(f"[yellow]Task check failed: {e}")
+        if not ti_id:
+            raise RuntimeError(f"Could not find task instance ID from 
{dag_id}")
 
-        time.sleep(2)
+        console.print(f"[green]✅ Found task instance ID: {ti_id}")
 
-    if not ti_id:
-        console.print("[red]❌ long_running_task never reached RUNNING state. 
Final debug info:")
-        raise TimeoutError("long_running_task did not reach RUNNING state 
within timeout period")
-
-    # Step 5: Create SDK client
+    # Step 6: Generate JWT token and create SDK client
     jwt_token = generate_jwt_token(ti_id)
     sdk_client = Client(base_url=f"http://{TASK_SDK_HOST_PORT}/execution";, 
token=jwt_token)
 
-    return {
+    # Extract auth token from headers if not provided
+    if auth_token is None:
+        auth_token = headers.get("Authorization", "").replace("Bearer ", "")
+
+    # Build return dict
+    result = {
         "auth_token": auth_token,
-        "dag_info": {"dag_id": "test_dag", "dag_run_id": dag_run_id, 
"logical_date": logical_date},
+        "dag_info": {"dag_id": dag_id, "dag_run_id": dag_run_id, 
"logical_date": logical_date},
         "task_instance_id": ti_id,
         "sdk_client": sdk_client,
         "core_api_headers": headers,
     }
 
+    if additional_metadata:
+        result.update(additional_metadata)
+
+    return result
+
+
[email protected](scope="session")
+def airflow_ready(docker_compose_setup):
+    """Shared fixture that waits for Airflow to be ready and provides auth 
token to communicate with Airflow."""
+    import time
+
+    import requests
+
+    # Generous sleep for Airflow to be ready
+    time.sleep(15)
+
+    auth_url = "http://localhost:8080/auth/token";
+    try:
+        auth_response = requests.get(auth_url, timeout=10)
+        auth_response.raise_for_status()
+        auth_token = auth_response.json()["access_token"]
+        console.print("[green]✅ Got auth token")
+    except Exception as e:
+        raise e
+
+    headers = {"Authorization": f"Bearer {auth_token}", "Content-Type": 
"application/json"}
+
+    return {"auth_token": auth_token, "headers": headers}
+
+
[email protected](scope="session")
+def airflow_test_setup(docker_compose_setup, airflow_ready):
+    """Fixed session-scoped fixture that matches UI behavior."""
+    headers = airflow_ready["headers"]
+    auth_token = airflow_ready["auth_token"]
+
+    return setup_dag_and_get_client(
+        dag_id="test_dag",
+        headers=headers,
+        auth_token=auth_token,
+        task_id_filter="long_running_task",
+        wait_for_task_state="running",
+        wait_timeout=30,
+    )
+
 
 @pytest.fixture(scope="session")
 def task_sdk_api_version():
@@ -320,3 +473,29 @@ def sdk_client(airflow_test_setup):
 def core_api_headers(airflow_test_setup):
     """Get Core API headers from setup."""
     return airflow_test_setup["core_api_headers"]
+
+
[email protected](scope="session")
+def sdk_client_for_assets(asset_test_setup):
+    """Get SDK client for asset tests (doesn't require test_dag)."""
+    return asset_test_setup["sdk_client"]
+
+
[email protected](scope="session")
+def asset_test_setup(docker_compose_setup, airflow_ready):
+    """Setup assets for testing by triggering asset_producer_dag."""
+    headers = airflow_ready["headers"]
+    auth_token = airflow_ready["auth_token"]
+
+    return setup_dag_and_get_client(
+        dag_id="asset_producer_dag",
+        headers=headers,
+        auth_token=auth_token,
+        task_id_filter="produce_asset",
+        wait_for_dag_state="success",
+        wait_timeout=60,
+        additional_metadata={
+            "name": "test_asset",
+            "uri": "test://asset1/",
+        },
+    )
diff --git a/task-sdk-tests/tests/task_sdk_tests/test_asset_operations.py 
b/task-sdk-tests/tests/task_sdk_tests/test_asset_operations.py
new file mode 100644
index 00000000000..528b1993bc9
--- /dev/null
+++ b/task-sdk-tests/tests/task_sdk_tests/test_asset_operations.py
@@ -0,0 +1,65 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Integration tests for Asset operations.
+
+These tests validate the Execution API endpoints for Asset operations:
+- get(): Get asset by name
+"""
+
+from __future__ import annotations
+
+from airflow.sdk.api.datamodels._generated import AssetResponse
+from airflow.sdk.execution_time.comms import ErrorResponse
+from task_sdk_tests import console
+
+
+def test_asset_get_by_name(sdk_client_for_assets, asset_test_setup):
+    """Test getting asset by name."""
+    console.print("[yellow]Getting asset by name...")
+
+    response = sdk_client_for_assets.assets.get(name=asset_test_setup["name"])
+
+    console.print(" Asset Get By Name Response ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+    console.print(f"[bright_blue]Name:[/] {response.name}")
+    console.print(f"[bright_blue]URI:[/] {response.uri}")
+    console.print(f"[bright_blue]Group:[/] {response.group}")
+    console.print("=" * 72)
+
+    assert isinstance(response, AssetResponse)
+    assert response.name == asset_test_setup["name"]
+    assert response.uri == asset_test_setup["uri"]
+    console.print("[green]Asset get by name test passed!")
+
+
+def test_asset_get_by_name_not_found(sdk_client_for_assets):
+    """Test getting non-existent asset by name."""
+    console.print("[yellow]Getting non-existent asset by name...")
+
+    response = sdk_client_for_assets.assets.get(name="non_existent_asset_name")
+
+    console.print(" Asset Get (Not Found) Response ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+    console.print(f"[bright_blue]Error Type:[/] {response.error}")
+    console.print(f"[bright_blue]Detail:[/] {response.detail}")
+    console.print("=" * 72)
+
+    assert isinstance(response, ErrorResponse)
+    assert str(response.error).endswith("ASSET_NOT_FOUND")
+    console.print("[green]Asset get by name (not found) test passed!")

Reply via email to