amoghrajesh commented on code in PR #58027:
URL: https://github.com/apache/airflow/pull/58027#discussion_r2504815052


##########
task-sdk-tests/tests/task_sdk_tests/conftest.py:
##########
@@ -320,3 +333,111 @@ def sdk_client(airflow_test_setup):
 def core_api_headers(airflow_test_setup):
     """Get Core API headers from setup."""
     return airflow_test_setup["core_api_headers"]
+
+
[email protected](scope="session")
+def sdk_client_for_assets(asset_test_setup):
+    """Get SDK client for asset tests (doesn't require test_dag)."""
+    return asset_test_setup["sdk_client"]
+
+
[email protected](scope="session")
+def asset_test_setup(docker_compose_setup, airflow_ready):
+    """Setup assets for testing by triggering asset_producer_dag."""
+    import time
+
+    import requests
+
+    from airflow.sdk.api.client import Client
+    from airflow.sdk.timezone import utcnow
+    from task_sdk_tests.jwt_plugin import generate_jwt_token
+
+    headers = airflow_ready["headers"]
+
+    # For asset test setup, we need to trigger asset_producer_dag to create 
test asset
+    console.print("[yellow]Checking asset_producer_dag status...")
+    dag_response = 
requests.get("http://localhost:8080/api/v2/dags/asset_producer_dag";, 
headers=headers)
+    dag_response.raise_for_status()
+    dag_data = dag_response.json()
+
+    if dag_data.get("is_paused", True):
+        console.print("[yellow]Unpausing asset_producer_dag...")
+        unpause_response = requests.patch(
+            "http://localhost:8080/api/v2/dags/asset_producer_dag";, 
json={"is_paused": False}, headers=headers
+        )
+        unpause_response.raise_for_status()
+        console.print("[green]asset_producer_dag unpaused")
+
+    logical_date = utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")[:-3] + "Z"
+    payload = {"conf": {}, "logical_date": logical_date}
+
+    console.print("[yellow]Triggering asset_producer_dag...")
+    trigger_response = requests.post(
+        "http://localhost:8080/api/v2/dags/asset_producer_dag/dagRuns";,
+        json=payload,
+        headers=headers,
+        timeout=30,
+    )
+
+    trigger_response.raise_for_status()
+    dag_run_data = trigger_response.json()
+    dag_run_id = dag_run_data["dag_run_id"]
+
+    console.print(f"[green]asset_producer_dag triggered: {dag_run_id}")
+
+    console.print("[yellow]Waiting for asset_producer_dag to complete...")
+    final_state = None
+    for attempt in range(60):
+        try:
+            dr_response = requests.get(
+                
f"http://localhost:8080/api/v2/dags/asset_producer_dag/dagRuns/{dag_run_id}";, 
headers=headers
+            )
+            dr_response.raise_for_status()
+            dr_data = dr_response.json()
+            state = dr_data.get("state")
+
+            if state == "success":
+                console.print("[green]asset_producer_dag completed 
successfully!")
+                final_state = state
+                break
+            if state in ["failed", "skipped"]:
+                raise RuntimeError(f"asset_producer_dag ended in state: 
{state}")
+            console.print(
+                f"[blue]Waiting for asset_producer_dag to complete (attempt 
{attempt + 1}/60, state: {state})"
+            )
+
+        except Exception as e:
+            console.print(f"[yellow]DAG run check failed: {e}")
+
+        time.sleep(2)
+
+    if final_state != "success":
+        raise TimeoutError("asset_producer_dag did not complete successfully 
within timeout period")
+
+    # Get task instance for testing - wait for long_running_task to be RUNNING
+    console.print("[yellow]Getting task instance ID from 
asset_producer_dag...")
+    ti_url = 
f"http://localhost:8080/api/v2/dags/asset_producer_dag/dagRuns/{dag_run_id}/taskInstances";
+    ti_response = requests.get(ti_url, headers=headers, timeout=10)
+    ti_response.raise_for_status()
+
+    task_instances = ti_response.json().get("task_instances", [])
+    ti_id = None
+    for ti in task_instances:
+        if ti.get("task_id") == "produce_asset":
+            ti_id = ti.get("id")
+            break
+
+    if not ti_id:
+        raise RuntimeError("Could not find task instance ID from 
asset_producer_dag")
+
+    console.print(f"[green]Found task instance ID: {ti_id}")
+
+    jwt_token = generate_jwt_token(ti_id)
+    sdk_client = Client(base_url=f"http://{TASK_SDK_HOST_PORT}/execution";, 
token=jwt_token)
+

Review Comment:
   Sure, not a bad idea at all.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to