This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new df584d3c8f3 Handle 404 from assets/by-XXX in runtime (#47852)
df584d3c8f3 is described below

commit df584d3c8f397dbfec33f383292a9ca7c05dcc07
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Mar 18 14:18:28 2025 +0800

    Handle 404 from assets/by-XXX in runtime (#47852)
    
    With this correct 404 handling, now a failure to fetch an asset results
    in a cleaner AirflowRuntimeError with some relevant information.
---
 task-sdk/src/airflow/sdk/api/client.py             | 21 ++++++-
 task-sdk/src/airflow/sdk/exceptions.py             |  5 ++
 .../src/airflow/sdk/execution_time/supervisor.py   | 15 +++--
 task-sdk/tests/task_sdk/api/test_client.py         | 64 +++++++++++++++++++++-
 4 files changed, 97 insertions(+), 8 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index d5c02b1f7b6..ae22a02603c 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -354,15 +354,30 @@ class AssetOperations:
     def __init__(self, client: Client):
         self.client = client
 
-    def get(self, name: str | None = None, uri: str | None = None) -> 
AssetResponse:
+    def get(self, name: str | None = None, uri: str | None = None) -> 
AssetResponse | ErrorResponse:
         """Get Asset value from the API server."""
         if name:
-            resp = self.client.get("assets/by-name", params={"name": name})
+            endpoint = "assets/by-name"
+            params = {"name": name}
         elif uri:
-            resp = self.client.get("assets/by-uri", params={"uri": uri})
+            endpoint = "assets/by-uri"
+            params = {"uri": uri}
         else:
             raise ValueError("Either `name` or `uri` must be provided")
 
+        try:
+            resp = self.client.get(endpoint, params=params)
+        except ServerResponseError as e:
+            if e.response.status_code == HTTPStatus.NOT_FOUND:
+                log.error(
+                    "Asset not found",
+                    params=params,
+                    detail=e.detail,
+                    status_code=e.response.status_code,
+                )
+                return ErrorResponse(error=ErrorType.ASSET_NOT_FOUIND, 
detail=params)
+            raise
+
         return AssetResponse.model_validate_json(resp.read())
 
 
diff --git a/task-sdk/src/airflow/sdk/exceptions.py 
b/task-sdk/src/airflow/sdk/exceptions.py
index 4dd4ff5910a..6d034227868 100644
--- a/task-sdk/src/airflow/sdk/exceptions.py
+++ b/task-sdk/src/airflow/sdk/exceptions.py
@@ -25,15 +25,20 @@ if TYPE_CHECKING:
 
 
 class AirflowRuntimeError(Exception):
+    """Generic Airflow arror raised by runtime functions."""
+
     def __init__(self, error: ErrorResponse):
         self.error = error
         super().__init__(f"{error.error.value}: {error.detail}")
 
 
 class ErrorType(enum.Enum):
+    """Error types used in the API client."""
+
     CONNECTION_NOT_FOUND = "CONNECTION_NOT_FOUND"
     VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND"
     XCOM_NOT_FOUND = "XCOM_NOT_FOUND"
+    ASSET_NOT_FOUIND = "ASSET_NOT_FOUND"
     GENERIC_ERROR = "GENERIC_ERROR"
 
 
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index b7706e7735e..d986f57f3c0 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -54,6 +54,7 @@ from pydantic import TypeAdapter
 
 from airflow.sdk.api.client import Client, ServerResponseError
 from airflow.sdk.api.datamodels._generated import (
+    AssetResponse,
     ConnectionResponse,
     IntermediateTIState,
     TaskInstance,
@@ -909,12 +910,18 @@ class ActivitySubprocess(WatchedSubprocess):
             self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
         elif isinstance(msg, GetAssetByName):
             asset_resp = self.client.assets.get(name=msg.name)
-            asset_result = AssetResult.from_asset_response(asset_resp)
-            resp = asset_result.model_dump_json(exclude_unset=True).encode()
+            if isinstance(asset_resp, AssetResponse):
+                asset_result = AssetResult.from_asset_response(asset_resp)
+                resp = 
asset_result.model_dump_json(exclude_unset=True).encode()
+            else:
+                resp = asset_resp.model_dump_json().encode()
         elif isinstance(msg, GetAssetByUri):
             asset_resp = self.client.assets.get(uri=msg.uri)
-            asset_result = AssetResult.from_asset_response(asset_resp)
-            resp = asset_result.model_dump_json(exclude_unset=True).encode()
+            if isinstance(asset_resp, AssetResponse):
+                asset_result = AssetResult.from_asset_response(asset_resp)
+                resp = 
asset_result.model_dump_json(exclude_unset=True).encode()
+            else:
+                resp = asset_resp.model_dump_json().encode()
         elif isinstance(msg, GetAssetEventByAsset):
             asset_event_resp = self.client.asset_events.get(uri=msg.uri, 
name=msg.name)
             asset_event_result = 
AssetEventsResult.from_asset_events_response(asset_event_resp)
diff --git a/task-sdk/tests/task_sdk/api/test_client.py 
b/task-sdk/tests/task_sdk/api/test_client.py
index bcfe92d481c..dd1d66ce6dc 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -26,7 +26,12 @@ import uuid6
 from task_sdk import make_client, make_client_w_dry_run, 
make_client_w_responses
 
 from airflow.sdk.api.client import RemoteValidationError, ServerResponseError
-from airflow.sdk.api.datamodels._generated import ConnectionResponse, 
VariableResponse, XComResponse
+from airflow.sdk.api.datamodels._generated import (
+    AssetResponse,
+    ConnectionResponse,
+    VariableResponse,
+    XComResponse,
+)
 from airflow.sdk.exceptions import ErrorType
 from airflow.sdk.execution_time.comms import DeferTask, ErrorResponse, 
RescheduleTask
 from airflow.utils import timezone
@@ -649,3 +654,60 @@ class TestConnectionOperations:
 
         assert isinstance(result, ErrorResponse)
         assert result.error == ErrorType.CONNECTION_NOT_FOUND
+
+
+class TestAssetOperations:
+    @pytest.mark.parametrize(
+        "request_params",
+        [
+            ({"name": "this_asset"}),
+            ({"uri": "s3://bucket/key"}),
+        ],
+    )
+    def test_by_name_get_success(self, request_params):
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path in ("/assets/by-name", "/assets/by-uri"):
+                return httpx.Response(
+                    status_code=200,
+                    json={
+                        "name": "this_asset",
+                        "uri": "s3://bucket/key",
+                        "group": "asset",
+                        "extra": {"foo": "bar"},
+                    },
+                )
+            return httpx.Response(status_code=400, json={"detail": "Bad 
Request"})
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.assets.get(**request_params)
+
+        assert isinstance(result, AssetResponse)
+        assert result.name == "this_asset"
+        assert result.uri == "s3://bucket/key"
+
+    @pytest.mark.parametrize(
+        "request_params",
+        [
+            ({"name": "this_asset"}),
+            ({"uri": "s3://bucket/key"}),
+        ],
+    )
+    def test_by_name_get_404_not_found(self, request_params):
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path in ("/assets/by-name", "/assets/by-uri"):
+                return httpx.Response(
+                    status_code=404,
+                    json={
+                        "detail": {
+                            "message": "Asset with name non_existent not 
found",
+                            "reason": "not_found",
+                        }
+                    },
+                )
+            return httpx.Response(status_code=400, json={"detail": "Bad 
Request"})
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.assets.get(**request_params)
+
+        assert isinstance(result, ErrorResponse)
+        assert result.error == ErrorType.ASSET_NOT_FOUIND

Reply via email to