This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new df584d3c8f3 Handle 404 from assets/by-XXX in runtime (#47852)
df584d3c8f3 is described below
commit df584d3c8f397dbfec33f383292a9ca7c05dcc07
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Mar 18 14:18:28 2025 +0800
Handle 404 from assets/by-XXX in runtime (#47852)
With this correct 404 handling, now a failure to fetch an asset results
in a cleaner AirflowRuntimeError with some relevant information.
---
task-sdk/src/airflow/sdk/api/client.py | 21 ++++++-
task-sdk/src/airflow/sdk/exceptions.py | 5 ++
.../src/airflow/sdk/execution_time/supervisor.py | 15 +++--
task-sdk/tests/task_sdk/api/test_client.py | 64 +++++++++++++++++++++-
4 files changed, 97 insertions(+), 8 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index d5c02b1f7b6..ae22a02603c 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -354,15 +354,30 @@ class AssetOperations:
def __init__(self, client: Client):
self.client = client
- def get(self, name: str | None = None, uri: str | None = None) ->
AssetResponse:
+ def get(self, name: str | None = None, uri: str | None = None) ->
AssetResponse | ErrorResponse:
"""Get Asset value from the API server."""
if name:
- resp = self.client.get("assets/by-name", params={"name": name})
+ endpoint = "assets/by-name"
+ params = {"name": name}
elif uri:
- resp = self.client.get("assets/by-uri", params={"uri": uri})
+ endpoint = "assets/by-uri"
+ params = {"uri": uri}
else:
raise ValueError("Either `name` or `uri` must be provided")
+ try:
+ resp = self.client.get(endpoint, params=params)
+ except ServerResponseError as e:
+ if e.response.status_code == HTTPStatus.NOT_FOUND:
+ log.error(
+ "Asset not found",
+ params=params,
+ detail=e.detail,
+ status_code=e.response.status_code,
+ )
+ return ErrorResponse(error=ErrorType.ASSET_NOT_FOUIND,
detail=params)
+ raise
+
return AssetResponse.model_validate_json(resp.read())
diff --git a/task-sdk/src/airflow/sdk/exceptions.py
b/task-sdk/src/airflow/sdk/exceptions.py
index 4dd4ff5910a..6d034227868 100644
--- a/task-sdk/src/airflow/sdk/exceptions.py
+++ b/task-sdk/src/airflow/sdk/exceptions.py
@@ -25,15 +25,20 @@ if TYPE_CHECKING:
class AirflowRuntimeError(Exception):
+ """Generic Airflow arror raised by runtime functions."""
+
def __init__(self, error: ErrorResponse):
self.error = error
super().__init__(f"{error.error.value}: {error.detail}")
class ErrorType(enum.Enum):
+ """Error types used in the API client."""
+
CONNECTION_NOT_FOUND = "CONNECTION_NOT_FOUND"
VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND"
XCOM_NOT_FOUND = "XCOM_NOT_FOUND"
+ ASSET_NOT_FOUIND = "ASSET_NOT_FOUND"
GENERIC_ERROR = "GENERIC_ERROR"
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index b7706e7735e..d986f57f3c0 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -54,6 +54,7 @@ from pydantic import TypeAdapter
from airflow.sdk.api.client import Client, ServerResponseError
from airflow.sdk.api.datamodels._generated import (
+ AssetResponse,
ConnectionResponse,
IntermediateTIState,
TaskInstance,
@@ -909,12 +910,18 @@ class ActivitySubprocess(WatchedSubprocess):
self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
elif isinstance(msg, GetAssetByName):
asset_resp = self.client.assets.get(name=msg.name)
- asset_result = AssetResult.from_asset_response(asset_resp)
- resp = asset_result.model_dump_json(exclude_unset=True).encode()
+ if isinstance(asset_resp, AssetResponse):
+ asset_result = AssetResult.from_asset_response(asset_resp)
+ resp =
asset_result.model_dump_json(exclude_unset=True).encode()
+ else:
+ resp = asset_resp.model_dump_json().encode()
elif isinstance(msg, GetAssetByUri):
asset_resp = self.client.assets.get(uri=msg.uri)
- asset_result = AssetResult.from_asset_response(asset_resp)
- resp = asset_result.model_dump_json(exclude_unset=True).encode()
+ if isinstance(asset_resp, AssetResponse):
+ asset_result = AssetResult.from_asset_response(asset_resp)
+ resp =
asset_result.model_dump_json(exclude_unset=True).encode()
+ else:
+ resp = asset_resp.model_dump_json().encode()
elif isinstance(msg, GetAssetEventByAsset):
asset_event_resp = self.client.asset_events.get(uri=msg.uri,
name=msg.name)
asset_event_result =
AssetEventsResult.from_asset_events_response(asset_event_resp)
diff --git a/task-sdk/tests/task_sdk/api/test_client.py
b/task-sdk/tests/task_sdk/api/test_client.py
index bcfe92d481c..dd1d66ce6dc 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -26,7 +26,12 @@ import uuid6
from task_sdk import make_client, make_client_w_dry_run,
make_client_w_responses
from airflow.sdk.api.client import RemoteValidationError, ServerResponseError
-from airflow.sdk.api.datamodels._generated import ConnectionResponse,
VariableResponse, XComResponse
+from airflow.sdk.api.datamodels._generated import (
+ AssetResponse,
+ ConnectionResponse,
+ VariableResponse,
+ XComResponse,
+)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import DeferTask, ErrorResponse,
RescheduleTask
from airflow.utils import timezone
@@ -649,3 +654,60 @@ class TestConnectionOperations:
assert isinstance(result, ErrorResponse)
assert result.error == ErrorType.CONNECTION_NOT_FOUND
+
+
+class TestAssetOperations:
+ @pytest.mark.parametrize(
+ "request_params",
+ [
+ ({"name": "this_asset"}),
+ ({"uri": "s3://bucket/key"}),
+ ],
+ )
+ def test_by_name_get_success(self, request_params):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path in ("/assets/by-name", "/assets/by-uri"):
+ return httpx.Response(
+ status_code=200,
+ json={
+ "name": "this_asset",
+ "uri": "s3://bucket/key",
+ "group": "asset",
+ "extra": {"foo": "bar"},
+ },
+ )
+ return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.assets.get(**request_params)
+
+ assert isinstance(result, AssetResponse)
+ assert result.name == "this_asset"
+ assert result.uri == "s3://bucket/key"
+
+ @pytest.mark.parametrize(
+ "request_params",
+ [
+ ({"name": "this_asset"}),
+ ({"uri": "s3://bucket/key"}),
+ ],
+ )
+ def test_by_name_get_404_not_found(self, request_params):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path in ("/assets/by-name", "/assets/by-uri"):
+ return httpx.Response(
+ status_code=404,
+ json={
+ "detail": {
+ "message": "Asset with name non_existent not
found",
+ "reason": "not_found",
+ }
+ },
+ )
+ return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.assets.get(**request_params)
+
+ assert isinstance(result, ErrorResponse)
+ assert result.error == ErrorType.ASSET_NOT_FOUIND