This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 418b701bbdc AIP-72: Add support for `outlet_events` in Task Context
(#45727)
418b701bbdc is described below
commit 418b701bbdcad58bb0b8bc6d9bd4a0fedca937f1
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Jan 18 01:51:29 2025 +0530
AIP-72: Add support for `outlet_events` in Task Context (#45727)
part of https://github.com/apache/airflow/issues/45717 and
https://github.com/apache/airflow/issues/45752
This PR adds support for `outlet_events` in Context dict within the Task
SDK by adding an endpoint on the API Server which is fetched when outlet_events
is accessed.
---
.../{routes/__init__.py => datamodels/asset.py} | 25 +++--
.../api_fastapi/execution_api/routes/__init__.py | 10 +-
airflow/api_fastapi/execution_api/routes/assets.py | 71 ++++++++++++
airflow/serialization/serialized_objects.py | 3 +-
airflow/utils/context.py | 104 +++--------------
task_sdk/src/airflow/sdk/api/client.py | 25 +++++
.../src/airflow/sdk/api/datamodels/_generated.py | 20 ++++
.../src/airflow/sdk/definitions/asset/__init__.py | 4 +-
task_sdk/src/airflow/sdk/execution_time/comms.py | 34 +++++-
task_sdk/src/airflow/sdk/execution_time/context.py | 123 ++++++++++++++++++++-
.../src/airflow/sdk/execution_time/supervisor.py | 11 ++
.../src/airflow/sdk/execution_time/task_runner.py | 4 +-
task_sdk/tests/execution_time/test_context.py | 103 ++++++++++++++++-
task_sdk/tests/execution_time/test_supervisor.py | 39 ++++++-
task_sdk/tests/execution_time/test_task_runner.py | 9 +-
.../execution_api/routes/test_assets.py | 110 ++++++++++++++++++
tests/serialization/test_serialized_objects.py | 3 +-
tests/utils/test_context.py | 102 -----------------
18 files changed, 584 insertions(+), 216 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/routes/__init__.py
b/airflow/api_fastapi/execution_api/datamodels/asset.py
similarity index 55%
copy from airflow/api_fastapi/execution_api/routes/__init__.py
copy to airflow/api_fastapi/execution_api/datamodels/asset.py
index 0383503f18b..6d3a53c3e4c 100644
--- a/airflow/api_fastapi/execution_api/routes/__init__.py
+++ b/airflow/api_fastapi/execution_api/datamodels/asset.py
@@ -14,14 +14,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from __future__ import annotations
-from airflow.api_fastapi.common.router import AirflowRouter
-from airflow.api_fastapi.execution_api.routes import connections, health,
task_instances, variables, xcoms
+from airflow.api_fastapi.core_api.base import BaseModel
+
+
+class AssetResponse(BaseModel):
+ """Asset schema for responses with fields that are needed for Runtime."""
+
+ name: str
+ uri: str
+ group: str
+ extra: dict | None = None
+
+
+class AssetAliasResponse(BaseModel):
+ """Asset alias schema with fields that are needed for Runtime."""
-execution_api_router = AirflowRouter()
-execution_api_router.include_router(connections.router, prefix="/connections",
tags=["Connections"])
-execution_api_router.include_router(health.router, tags=["Health"])
-execution_api_router.include_router(task_instances.router,
prefix="/task-instances", tags=["Task Instances"])
-execution_api_router.include_router(variables.router, prefix="/variables",
tags=["Variables"])
-execution_api_router.include_router(xcoms.router, prefix="/xcoms",
tags=["XComs"])
+ name: str
+ group: str
diff --git a/airflow/api_fastapi/execution_api/routes/__init__.py
b/airflow/api_fastapi/execution_api/routes/__init__.py
index 0383503f18b..793cd8fe084 100644
--- a/airflow/api_fastapi/execution_api/routes/__init__.py
+++ b/airflow/api_fastapi/execution_api/routes/__init__.py
@@ -17,9 +17,17 @@
from __future__ import annotations
from airflow.api_fastapi.common.router import AirflowRouter
-from airflow.api_fastapi.execution_api.routes import connections, health,
task_instances, variables, xcoms
+from airflow.api_fastapi.execution_api.routes import (
+ assets,
+ connections,
+ health,
+ task_instances,
+ variables,
+ xcoms,
+)
execution_api_router = AirflowRouter()
+execution_api_router.include_router(assets.router, prefix="/assets",
tags=["Assets"])
execution_api_router.include_router(connections.router, prefix="/connections",
tags=["Connections"])
execution_api_router.include_router(health.router, tags=["Health"])
execution_api_router.include_router(task_instances.router,
prefix="/task-instances", tags=["Task Instances"])
diff --git a/airflow/api_fastapi/execution_api/routes/assets.py
b/airflow/api_fastapi/execution_api/routes/assets.py
new file mode 100644
index 00000000000..213c599befb
--- /dev/null
+++ b/airflow/api_fastapi/execution_api/routes/assets.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Annotated
+
+from fastapi import HTTPException, Query, status
+from sqlalchemy import select
+
+from airflow.api_fastapi.common.db.common import SessionDep
+from airflow.api_fastapi.common.router import AirflowRouter
+from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
+from airflow.models.asset import AssetModel
+
+# TODO: Add dependency on JWT token
+router = AirflowRouter(
+ responses={
+ status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
+ status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
+ },
+)
+
+
[email protected]("/by-name")
+def get_asset_by_name(
+ name: Annotated[str, Query(description="The name of the Asset")],
+ session: SessionDep,
+) -> AssetResponse:
+ """Get an Airflow Asset by `name`."""
+ asset = session.scalar(select(AssetModel).where(AssetModel.name == name,
AssetModel.active.has()))
+ _raise_if_not_found(asset, f"Asset with name {name} not found")
+
+ return AssetResponse.model_validate(asset)
+
+
[email protected]("/by-uri")
+def get_asset_by_uri(
+ uri: Annotated[str, Query(description="The URI of the Asset")],
+ session: SessionDep,
+) -> AssetResponse:
+ """Get an Airflow Asset by `uri`."""
+ asset = session.scalar(select(AssetModel).where(AssetModel.uri == uri,
AssetModel.active.has()))
+ _raise_if_not_found(asset, f"Asset with URI {uri} not found")
+
+ return AssetResponse.model_validate(asset)
+
+
+def _raise_if_not_found(asset, msg):
+ if asset is None:
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ detail={
+ "reason": "not_found",
+ "message": msg,
+ },
+ )
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 11c293b531f..d828a9a5b6b 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -64,6 +64,7 @@ from airflow.sdk.definitions.asset import (
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as
TaskSDKBaseOperator
+from airflow.sdk.execution_time.context import AssetAliasEvent,
OutletEventAccessor
from airflow.serialization.dag_dependency import DagDependency
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
@@ -77,10 +78,8 @@ from airflow.task.priority_strategy import (
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import (
- AssetAliasEvent,
ConnectionAccessor,
Context,
- OutletEventAccessor,
OutletEventAccessors,
VariableAccessor,
)
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 1f453457e43..168243290fa 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -19,7 +19,6 @@
from __future__ import annotations
-import contextlib
from collections.abc import (
Container,
Iterator,
@@ -51,9 +50,9 @@ from airflow.sdk.definitions.asset import (
AssetRef,
AssetUniqueKey,
AssetUriRef,
- BaseAssetUniqueKey,
)
from airflow.sdk.definitions.context import Context
+from airflow.sdk.execution_time.context import OutletEventAccessors as
OutletEventAccessorsSDK
from airflow.utils.db import LazySelectSequence
from airflow.utils.session import create_session
from airflow.utils.types import NOTSET
@@ -156,104 +155,29 @@ class ConnectionAccessor:
return default_conn
[email protected]()
-class AssetAliasEvent:
- """
- Represeation of asset event to be triggered by an asset alias.
-
- :meta private:
- """
-
- source_alias_name: str
- dest_asset_key: AssetUniqueKey
- extra: dict[str, Any]
-
-
[email protected]()
-class OutletEventAccessor:
- """
- Wrapper to access an outlet asset event in template.
-
- :meta private:
- """
-
- key: BaseAssetUniqueKey
- extra: dict[str, Any] = attrs.Factory(dict)
- asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)
-
- def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
- """Add an AssetEvent to an existing Asset."""
- if not isinstance(self.key, AssetAliasUniqueKey):
- return
-
- asset_alias_name = self.key.name
- event = AssetAliasEvent(
- source_alias_name=asset_alias_name,
- dest_asset_key=AssetUniqueKey.from_asset(asset),
- extra=extra or {},
- )
- self.asset_alias_events.append(event)
-
-
-class OutletEventAccessors(Mapping[Union[Asset, AssetAlias],
OutletEventAccessor]):
+class OutletEventAccessors(OutletEventAccessorsSDK):
"""
Lazy mapping of outlet asset event accessors.
:meta private:
"""
- _asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {}
-
- def __init__(self) -> None:
- self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}
-
- def __str__(self) -> str:
- return f"OutletEventAccessors(_dict={self._dict})"
-
- def __iter__(self) -> Iterator[Asset | AssetAlias]:
- return (
- key.to_asset() if isinstance(key, AssetUniqueKey) else
key.to_asset_alias() for key in self._dict
- )
-
- def __len__(self) -> int:
- return len(self._dict)
-
- def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor:
- hashable_key: BaseAssetUniqueKey
- if isinstance(key, Asset):
- hashable_key = AssetUniqueKey.from_asset(key)
- elif isinstance(key, AssetAlias):
- hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
- elif isinstance(key, AssetRef):
- hashable_key = self._resolve_asset_ref(key)
- else:
- raise TypeError(f"Key should be either an asset or an asset alias,
not {type(key)}")
-
- if hashable_key not in self._dict:
- self._dict[hashable_key] = OutletEventAccessor(extra={},
key=hashable_key)
- return self._dict[hashable_key]
-
- def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
- with contextlib.suppress(KeyError):
- return self._asset_ref_cache[ref]
-
- refs_to_cache: list[AssetRef]
- with create_session() as session:
- if isinstance(ref, AssetNameRef):
+ @staticmethod
+ def _get_asset_from_db(name: str | None = None, uri: str | None = None) ->
Asset:
+ if name:
+ with create_session() as session:
asset = session.scalar(
- select(AssetModel).where(AssetModel.name == ref.name,
AssetModel.active.has())
+ select(AssetModel).where(AssetModel.name == name,
AssetModel.active.has())
)
- refs_to_cache = [ref, AssetUriRef(asset.uri)]
- elif isinstance(ref, AssetUriRef):
+ elif uri:
+ with create_session() as session:
asset = session.scalar(
- select(AssetModel).where(AssetModel.uri == ref.uri,
AssetModel.active.has())
+ select(AssetModel).where(AssetModel.uri == uri,
AssetModel.active.has())
)
- refs_to_cache = [ref, AssetNameRef(asset.name)]
- else:
- raise TypeError(f"Unimplemented asset ref: {type(ref)}")
- for ref in refs_to_cache:
- self._asset_ref_cache[ref] = unique_key =
AssetUniqueKey.from_asset(asset)
- return unique_key
+ else:
+ raise ValueError("Either name or uri must be provided")
+
+ return asset.to_public()
class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]):
diff --git a/task_sdk/src/airflow/sdk/api/client.py
b/task_sdk/src/airflow/sdk/api/client.py
index 5ee27059148..e73e5aebea6 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -34,6 +34,7 @@ from uuid6 import uuid7
from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
+ AssetResponse,
ConnectionResponse,
DagRunType,
TerminalTIState,
@@ -267,6 +268,24 @@ class XComOperations:
return {"ok": True}
+class AssetOperations:
+ __slots__ = ("client",)
+
+ def __init__(self, client: Client):
+ self.client = client
+
+ def get(self, name: str | None = None, uri: str | None = None) ->
AssetResponse:
+ """Get Asset value from the API server."""
+ if name:
+ resp = self.client.get("assets/by-name", params={"name": name})
+ elif uri:
+ resp = self.client.get("assets/by-uri", params={"uri": uri})
+ else:
+ raise ValueError("Either `name` or `uri` must be provided")
+
+ return AssetResponse.model_validate_json(resp.read())
+
+
class BearerAuth(httpx.Auth):
def __init__(self, token: str):
self.token: str = token
@@ -374,6 +393,12 @@ class Client(httpx.Client):
"""Operations related to XComs."""
return XComOperations(self)
+ @lru_cache() # type: ignore[misc]
+ @property
+ def assets(self) -> AssetOperations:
+ """Operations related to XComs."""
+ return AssetOperations(self)
+
# This is only used for parsing. ServerResponseError is raised instead
class _ErrorBody(BaseModel):
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index a8b478d07f0..f0a04da21c8 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -29,6 +29,15 @@ from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
+class AssetAliasResponse(BaseModel):
+ """
+ Asset alias schema with fields that are needed for Runtime.
+ """
+
+ name: Annotated[str, Field(title="Name")]
+ group: Annotated[str, Field(title="Group")]
+
+
class ConnectionResponse(BaseModel):
"""
Connection schema for responses with fields that are needed for Runtime.
@@ -187,6 +196,17 @@ class TaskInstance(BaseModel):
hostname: Annotated[str | None, Field(title="Hostname")] = None
+class AssetResponse(BaseModel):
+ """
+ Asset schema for responses with fields that are needed for Runtime.
+ """
+
+ name: Annotated[str, Field(title="Name")]
+ uri: Annotated[str, Field(title="Uri")]
+ group: Annotated[str, Field(title="Group")]
+ extra: Annotated[dict[str, Any] | None, Field(title="Extra")] = None
+
+
class DagRun(BaseModel):
"""
Schema for DagRun model with minimal required fields needed for Runtime.
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index 5b0cbb4a784..ea89f1b6817 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -488,14 +488,14 @@ class AssetRef(BaseAsset, AttrsInstance):
)
[email protected]()
[email protected](hash=True)
class AssetNameRef(AssetRef):
"""Name reference to an asset."""
name: str
[email protected]()
[email protected](hash=True)
class AssetUriRef(AssetRef):
"""URI reference to an asset."""
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index b6874d47f09..f8aaab65af4 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -50,6 +50,7 @@ from fastapi import Body
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from airflow.sdk.api.datamodels._generated import (
+ AssetResponse,
BundleInfo,
ConnectionResponse,
TaskInstance,
@@ -79,6 +80,25 @@ class StartupDetails(BaseModel):
type: Literal["StartupDetails"] = "StartupDetails"
+class AssetResult(AssetResponse):
+ """Response to ReadXCom request."""
+
+ type: Literal["AssetResult"] = "AssetResult"
+
+ @classmethod
+ def from_asset_response(cls, asset_response: AssetResponse) -> AssetResult:
+ """
+ Get AssetResult from AssetResponse.
+
+ AssetResponse is autogenerated from the API schema, so we need to
convert it to AssetResult
+ for communication between the Supervisor and the task process.
+ """
+ # Exclude defaults to avoid sending unnecessary data
+ # Pass the type as AssetResult explicitly so we can then call
model_dump_json with exclude_unset=True
+ # to avoid sending unset fields (which are defaults in our case).
+ return cls(**asset_response.model_dump(exclude_defaults=True),
type="AssetResult")
+
+
class XComResult(XComResponse):
"""Response to ReadXCom request."""
@@ -133,7 +153,7 @@ class ErrorResponse(BaseModel):
ToTask = Annotated[
- Union[StartupDetails, XComResult, ConnectionResult, VariableResult,
ErrorResponse],
+ Union[StartupDetails, XComResult, ConnectionResult, VariableResult,
ErrorResponse, AssetResult],
Field(discriminator="type"),
]
@@ -231,12 +251,24 @@ class SetRenderedFields(BaseModel):
type: Literal["SetRenderedFields"] = "SetRenderedFields"
+class GetAssetByName(BaseModel):
+ name: str
+ type: Literal["GetAssetByName"] = "GetAssetByName"
+
+
+class GetAssetByUri(BaseModel):
+ uri: str
+ type: Literal["GetAssetByUri"] = "GetAssetByUri"
+
+
ToSupervisor = Annotated[
Union[
TaskState,
GetXCom,
GetConnection,
GetVariable,
+ GetAssetByName,
+ GetAssetByUri,
DeferTask,
PutVariable,
SetXCom,
diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py
b/task_sdk/src/airflow/sdk/execution_time/context.py
index cdb3880bb36..918526c3004 100644
--- a/task_sdk/src/airflow/sdk/execution_time/context.py
+++ b/task_sdk/src/airflow/sdk/execution_time/context.py
@@ -17,20 +17,31 @@
from __future__ import annotations
import contextlib
-from collections.abc import Generator
-from typing import TYPE_CHECKING, Any
+from collections.abc import Generator, Iterator, Mapping
+from typing import TYPE_CHECKING, Any, Union
+import attrs
import structlog
from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT
from airflow.sdk.definitions._internal.types import NOTSET
+from airflow.sdk.definitions.asset import (
+ Asset,
+ AssetAlias,
+ AssetAliasUniqueKey,
+ AssetNameRef,
+ AssetRef,
+ AssetUniqueKey,
+ AssetUriRef,
+ BaseAssetUniqueKey,
+)
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
if TYPE_CHECKING:
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.variable import Variable
- from airflow.sdk.execution_time.comms import ConnectionResult,
VariableResult
+ from airflow.sdk.execution_time.comms import AssetResult,
ConnectionResult, VariableResult
log = structlog.get_logger(logger_name="task")
@@ -163,6 +174,112 @@ class MacrosAccessor:
return True
[email protected]
+class AssetAliasEvent:
+ """Representation of asset event to be triggered by an asset alias."""
+
+ source_alias_name: str
+ dest_asset_key: AssetUniqueKey
+ extra: dict[str, Any]
+
+
[email protected]
+class OutletEventAccessor:
+ """Wrapper to access an outlet asset event in template."""
+
+ key: BaseAssetUniqueKey
+ extra: dict[str, Any] = attrs.Factory(dict)
+ asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)
+
+ def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
+ """Add an AssetEvent to an existing Asset."""
+ if not isinstance(self.key, AssetAliasUniqueKey):
+ return
+
+ asset_alias_name = self.key.name
+ event = AssetAliasEvent(
+ source_alias_name=asset_alias_name,
+ dest_asset_key=AssetUniqueKey.from_asset(asset),
+ extra=extra or {},
+ )
+ self.asset_alias_events.append(event)
+
+
+class OutletEventAccessors(Mapping[Union[Asset, AssetAlias],
OutletEventAccessor]):
+ """Lazy mapping of outlet asset event accessors."""
+
+ _asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {}
+
+ def __init__(self) -> None:
+ self._dict: dict[BaseAssetUniqueKey, OutletEventAccessor] = {}
+
+ def __str__(self) -> str:
+ return f"OutletEventAccessors(_dict={self._dict})"
+
+ def __iter__(self) -> Iterator[Asset | AssetAlias]:
+ return (
+ key.to_asset() if isinstance(key, AssetUniqueKey) else
key.to_asset_alias() for key in self._dict
+ )
+
+ def __len__(self) -> int:
+ return len(self._dict)
+
+ def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor:
+ hashable_key: BaseAssetUniqueKey
+ if isinstance(key, Asset):
+ hashable_key = AssetUniqueKey.from_asset(key)
+ elif isinstance(key, AssetAlias):
+ hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
+ elif isinstance(key, AssetRef):
+ hashable_key = self._resolve_asset_ref(key)
+ else:
+ raise TypeError(f"Key should be either an asset or an asset alias,
not {type(key)}")
+
+ if hashable_key not in self._dict:
+ self._dict[hashable_key] = OutletEventAccessor(extra={},
key=hashable_key)
+ return self._dict[hashable_key]
+
+ def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
+ with contextlib.suppress(KeyError):
+ return self._asset_ref_cache[ref]
+
+ refs_to_cache: list[AssetRef]
+ if isinstance(ref, AssetNameRef):
+ asset = self._get_asset_from_db(name=ref.name)
+ refs_to_cache = [ref, AssetUriRef(asset.uri)]
+ elif isinstance(ref, AssetUriRef):
+ asset = self._get_asset_from_db(uri=ref.uri)
+ refs_to_cache = [ref, AssetNameRef(asset.name)]
+ else:
+ raise TypeError(f"Unimplemented asset ref: {type(ref)}")
+ unique_key = AssetUniqueKey.from_asset(asset)
+ for ref in refs_to_cache:
+ self._asset_ref_cache[ref] = unique_key
+ return unique_key
+
+ # TODO: This is temporary to avoid code duplication between here &
airflow/models/taskinstance.py
+ @staticmethod
+ def _get_asset_from_db(name: str | None = None, uri: str | None = None) ->
Asset:
+ from airflow.sdk.definitions.asset import Asset
+ from airflow.sdk.execution_time.comms import ErrorResponse,
GetAssetByName, GetAssetByUri
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ if name:
+ SUPERVISOR_COMMS.send_request(log=log,
msg=GetAssetByName(name=name))
+ elif uri:
+ SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri))
+ else:
+ raise ValueError("Either name or uri must be provided")
+
+ msg = SUPERVISOR_COMMS.get_message()
+ if isinstance(msg, ErrorResponse):
+ raise AirflowRuntimeError(msg)
+
+ if TYPE_CHECKING:
+ assert isinstance(msg, AssetResult)
+ return Asset(**msg.model_dump(exclude={"type"}))
+
+
@contextlib.contextmanager
def set_current_context(context: Context) -> Generator[Context, None, None]:
"""
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 32895d36524..bd50ee5126b 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -61,8 +61,11 @@ from airflow.sdk.api.datamodels._generated import (
VariableResponse,
)
from airflow.sdk.execution_time.comms import (
+ AssetResult,
ConnectionResult,
DeferTask,
+ GetAssetByName,
+ GetAssetByUri,
GetConnection,
GetVariable,
GetXCom,
@@ -787,6 +790,14 @@ class ActivitySubprocess(WatchedSubprocess):
self.client.variables.set(msg.key, msg.value, msg.description)
elif isinstance(msg, SetRenderedFields):
self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
+ elif isinstance(msg, GetAssetByName):
+ asset_resp = self.client.assets.get(name=msg.name)
+ asset_result = AssetResult.from_asset_response(asset_resp)
+ resp = asset_result.model_dump_json(exclude_unset=True).encode()
+ elif isinstance(msg, GetAssetByUri):
+ asset_resp = self.client.assets.get(uri=msg.uri)
+ asset_result = AssetResult.from_asset_response(asset_resp)
+ resp = asset_result.model_dump_json(exclude_unset=True).encode()
else:
log.error("Unhandled request", msg=msg)
return
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 186faac878a..d252c24be18 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -50,6 +50,7 @@ from airflow.sdk.execution_time.comms import (
from airflow.sdk.execution_time.context import (
ConnectionAccessor,
MacrosAccessor,
+ OutletEventAccessors,
VariableAccessor,
set_current_context,
)
@@ -92,12 +93,13 @@ class RuntimeTaskInstance(TaskInstance):
# TODO: Ensure that ti.log_url and such are available to use in
context
# especially after removal of `conf` from Context.
"ti": self,
- # "outlet_events": OutletEventAccessors(),
+ "outlet_events": OutletEventAccessors(),
# "expanded_ti_count": expanded_ti_count,
"expanded_ti_count": None, # TODO: Implement this
# "inlet_events": InletEventsAccessors(task.inlets,
session=session),
"macros": MacrosAccessor(),
# "params": validated_params,
+ # TODO: Make this go through Public API longer term.
# "prev_data_interval_start_success":
get_prev_data_interval_start_success(),
# "prev_data_interval_end_success":
get_prev_data_interval_end_success(),
# "prev_start_date_success": get_prev_start_date_success(),
diff --git a/task_sdk/tests/execution_time/test_context.py
b/task_sdk/tests/execution_time/test_context.py
index 6527d517e37..e3ef15dc934 100644
--- a/task_sdk/tests/execution_time/test_context.py
+++ b/task_sdk/tests/execution_time/test_context.py
@@ -22,12 +22,16 @@ from unittest.mock import MagicMock, patch
import pytest
from airflow.sdk import get_current_context
+from airflow.sdk.definitions.asset import Asset, AssetAlias,
AssetAliasUniqueKey, AssetUniqueKey
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.exceptions import ErrorType
-from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse,
VariableResult
+from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult,
ErrorResponse, VariableResult
from airflow.sdk.execution_time.context import (
+ AssetAliasEvent,
ConnectionAccessor,
+ OutletEventAccessor,
+ OutletEventAccessors,
VariableAccessor,
_convert_connection_result_conn,
_convert_variable_result_to_variable,
@@ -248,3 +252,100 @@ class TestCurrentContext:
assert ctx["ContextId"] == i
# End of with statement
ctx_list[i].__exit__(None, None, None)
+
+
+class TestOutletEventAccessor:
+ @pytest.mark.parametrize(
+ "key, asset_alias_events",
+ (
+ (AssetUniqueKey.from_asset(Asset("test_uri")), []),
+ (
+ AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
+ [
+ AssetAliasEvent(
+ source_alias_name="test_alias",
+ dest_asset_key=AssetUniqueKey(uri="test_uri",
name="test_uri"),
+ extra={},
+ )
+ ],
+ ),
+ ),
+ )
+ def test_add(self, key, asset_alias_events, mock_supervisor_comms):
+ asset = Asset("test_uri")
+ mock_supervisor_comms.get_message.return_value = asset
+
+ outlet_event_accessor = OutletEventAccessor(key=key, extra={})
+ outlet_event_accessor.add(asset)
+ assert outlet_event_accessor.asset_alias_events == asset_alias_events
+
+ @pytest.mark.parametrize(
+ "key, asset_alias_events",
+ (
+ (AssetUniqueKey.from_asset(Asset("test_uri")), []),
+ (
+ AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
+ [
+ AssetAliasEvent(
+ source_alias_name="test_alias",
+ dest_asset_key=AssetUniqueKey(name="test-asset",
uri="test://asset-uri/"),
+ extra={},
+ )
+ ],
+ ),
+ ),
+ )
+ def test_add_with_db(self, key, asset_alias_events, mock_supervisor_comms):
+ asset = Asset(uri="test://asset-uri", name="test-asset")
+ mock_supervisor_comms.get_message.return_value = asset
+
+ outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""})
+ outlet_event_accessor.add(asset, extra={})
+ assert outlet_event_accessor.asset_alias_events == asset_alias_events
+
+
+class TestOutletEventAccessors:
+ @pytest.mark.parametrize(
+ "access_key, internal_key",
+ (
+ (Asset("test"), AssetUniqueKey.from_asset(Asset("test"))),
+ (
+ Asset(name="test", uri="test://asset"),
+ AssetUniqueKey.from_asset(Asset(name="test",
uri="test://asset")),
+ ),
+ (AssetAlias("test_alias"),
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias"))),
+ ),
+ )
+ def test__get_item__dict_key_not_exists(self, access_key, internal_key):
+ outlet_event_accessors = OutletEventAccessors()
+ assert len(outlet_event_accessors) == 0
+ outlet_event_accessor = outlet_event_accessors[access_key]
+ assert len(outlet_event_accessors) == 1
+ assert outlet_event_accessor.key == internal_key
+ assert outlet_event_accessor.extra == {}
+
+ @pytest.mark.parametrize(
+ ["access_key", "asset"],
+ (
+ (Asset.ref(name="test"), Asset(name="test")),
+ (Asset.ref(name="test1"), Asset(name="test1",
uri="test://asset-uri")),
+ (Asset.ref(uri="test://asset-uri"), Asset(uri="test://asset-uri")),
+ ),
+ )
+ def test__get_item__asset_ref(self, access_key, asset,
mock_supervisor_comms):
+ """Test accessing OutletEventAccessors with AssetRef resolves to
correct Asset."""
+ internal_key = AssetUniqueKey.from_asset(asset)
+ outlet_event_accessors = OutletEventAccessors()
+ assert len(outlet_event_accessors) == 0
+
+ # Asset from the API Server via the supervisor
+ mock_supervisor_comms.get_message.return_value = AssetResult(
+ name=asset.name,
+ uri=asset.uri,
+ group=asset.group,
+ )
+
+ outlet_event_accessor = outlet_event_accessors[access_key]
+ assert len(outlet_event_accessors) == 1
+ assert outlet_event_accessor.key == internal_key
+ assert outlet_event_accessor.extra == {}
diff --git a/task_sdk/tests/execution_time/test_supervisor.py
b/task_sdk/tests/execution_time/test_supervisor.py
index 59afa26dc2a..5455d0f70cd 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -41,8 +41,11 @@ from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
+ AssetResult,
ConnectionResult,
DeferTask,
+ GetAssetByName,
+ GetAssetByUri,
GetConnection,
GetVariable,
GetXCom,
@@ -805,13 +808,14 @@ class TestHandleRequest:
)
@pytest.mark.parametrize(
- ["message", "expected_buffer", "client_attr_path", "method_arg",
"mock_response"],
+ ["message", "expected_buffer", "client_attr_path", "method_arg",
"method_kwarg", "mock_response"],
[
pytest.param(
GetConnection(conn_id="test_conn"),
b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n',
"connections.get",
("test_conn",),
+ {},
ConnectionResult(conn_id="test_conn", conn_type="mysql"),
id="get_connection",
),
@@ -820,6 +824,7 @@ class TestHandleRequest:
b'{"key":"test_key","value":"test_value","type":"VariableResult"}\n',
"variables.get",
("test_key",),
+ {},
VariableResult(key="test_key", value="test_value"),
id="get_variable",
),
@@ -828,6 +833,7 @@ class TestHandleRequest:
b"",
"variables.set",
("test_key", "test_value", "test_description"),
+ {},
{"ok": True},
id="set_variable",
),
@@ -836,6 +842,7 @@ class TestHandleRequest:
b"",
"task_instances.defer",
(TI_ID, DeferTask(next_method="execute_callback",
classpath="my-classpath")),
+ {},
"",
id="patch_task_instance_to_deferred",
),
@@ -853,6 +860,7 @@ class TestHandleRequest:
end_date=timezone.parse("2024-10-31T12:00:00Z"),
),
),
+ {},
"",
id="patch_task_instance_to_up_for_reschedule",
),
@@ -861,6 +869,7 @@ class TestHandleRequest:
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", None),
+ {},
XComResult(key="test_key", value="test_value"),
id="get_xcom",
),
@@ -871,6 +880,7 @@ class TestHandleRequest:
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", 2),
+ {},
XComResult(key="test_key", value="test_value"),
id="get_xcom_map_index",
),
@@ -879,6 +889,7 @@ class TestHandleRequest:
b'{"key":"test_key","value":null,"type":"XComResult"}\n',
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", None),
+ {},
XComResult(key="test_key", value=None, type="XComResult"),
id="get_xcom_not_found",
),
@@ -900,6 +911,7 @@ class TestHandleRequest:
'{"key": "test_key", "value": {"key2": "value2"}}',
None,
),
+ {},
{"ok": True},
id="set_xcom",
),
@@ -922,6 +934,7 @@ class TestHandleRequest:
'{"key": "test_key", "value": {"key2": "value2"}}',
2,
),
+ {},
{"ok": True},
id="set_xcom_with_map_index",
),
@@ -932,6 +945,7 @@ class TestHandleRequest:
b"",
"",
(),
+ {},
"",
id="patch_task_instance_to_skipped",
),
@@ -940,9 +954,28 @@ class TestHandleRequest:
b"",
"task_instances.set_rtif",
(TI_ID, {"field1": "rendered_value1", "field2":
"rendered_value2"}),
+ {},
{"ok": True},
id="set_rtif",
),
+ pytest.param(
+ GetAssetByName(name="asset"),
+
b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n',
+ "assets.get",
+ [],
+ {"name": "asset"},
+ AssetResult(name="asset", uri="s3://bucket/obj",
group="asset"),
+ id="get_asset_by_name",
+ ),
+ pytest.param(
+ GetAssetByUri(uri="s3://bucket/obj"),
+
b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n',
+ "assets.get",
+ [],
+ {"uri": "s3://bucket/obj"},
+ AssetResult(name="asset", uri="s3://bucket/obj",
group="asset"),
+ id="get_asset_by_uri",
+ ),
],
)
def test_handle_requests(
@@ -953,8 +986,8 @@ class TestHandleRequest:
expected_buffer,
client_attr_path,
method_arg,
+ method_kwarg,
mock_response,
- time_machine,
):
"""
Test handling of different messages to the subprocess. For any new
message type, add a
@@ -980,7 +1013,7 @@ class TestHandleRequest:
# Verify the correct client method was called
if client_attr_path:
- mock_client_method.assert_called_once_with(*method_arg)
+ mock_client_method.assert_called_once_with(*method_arg,
**method_kwarg)
# Verify the response was added to the buffer
val = watched_subprocess.stdin.getvalue()
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index 60b39da455c..f7734279b3f 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -52,7 +52,12 @@ from airflow.sdk.execution_time.comms import (
VariableResult,
XComResult,
)
-from airflow.sdk.execution_time.context import ConnectionAccessor,
MacrosAccessor, VariableAccessor
+from airflow.sdk.execution_time.context import (
+ ConnectionAccessor,
+ MacrosAccessor,
+ OutletEventAccessors,
+ VariableAccessor,
+)
from airflow.sdk.execution_time.task_runner import (
CommsDecoder,
RuntimeTaskInstance,
@@ -613,6 +618,7 @@ class TestRuntimeTaskInstance:
"inlets": task.inlets,
"macros": MacrosAccessor(),
"map_index_template": task.map_index_template,
+ "outlet_events": OutletEventAccessors(),
"outlets": task.outlets,
"run_id": "test_run",
"task": task,
@@ -645,6 +651,7 @@ class TestRuntimeTaskInstance:
"inlets": task.inlets,
"macros": MacrosAccessor(),
"map_index_template": task.map_index_template,
+ "outlet_events": OutletEventAccessors(),
"outlets": task.outlets,
"run_id": "test_run",
"task": task,
diff --git a/tests/api_fastapi/execution_api/routes/test_assets.py
b/tests/api_fastapi/execution_api/routes/test_assets.py
new file mode 100644
index 00000000000..2cf34f8dd7b
--- /dev/null
+++ b/tests/api_fastapi/execution_api/routes/test_assets.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.models.asset import AssetActive, AssetModel
+from airflow.utils import timezone
+
+DEFAULT_DATE = timezone.parse("2021-01-01T00:00:00")
+
+pytestmark = pytest.mark.db_test
+
+
+class TestGetAssetByName:
+ def test_get_asset_by_name(self, client, session):
+ asset = AssetModel(
+ id=1,
+ name="test_get_asset_by_name",
+ uri="s3://bucket/key",
+ group="asset",
+ extra={"foo": "bar"},
+ created_at=DEFAULT_DATE,
+ updated_at=DEFAULT_DATE,
+ )
+
+ asset_active = AssetActive.for_asset(asset)
+
+ session.add_all([asset, asset_active])
+ session.commit()
+
+ response = client.get("/execution/assets/by-name", params={"name":
"test_get_asset_by_name"})
+
+ assert response.status_code == 200
+ assert response.json() == {
+ "name": "test_get_asset_by_name",
+ "uri": "s3://bucket/key",
+ "group": "asset",
+ "extra": {"foo": "bar"},
+ }
+
+ session.delete(asset)
+ session.delete(asset_active)
+ session.commit()
+
+ def test_asset_name_not_found(self, client):
+ response = client.get("/execution/assets/by-name", params={"name":
"non_existent"})
+
+ assert response.status_code == 404
+ assert response.json() == {
+ "detail": {
+ "message": "Asset with name non_existent not found",
+ "reason": "not_found",
+ }
+ }
+
+
+class TestGetAssetByUri:
+ def test_get_asset_by_uri(self, client, session):
+ asset = AssetModel(
+ name="test_get_asset_by_uri",
+ uri="s3://bucket/key",
+ group="asset",
+ extra={"foo": "bar"},
+ )
+
+ asset_active = AssetActive.for_asset(asset)
+
+ session.add_all([asset, asset_active])
+ session.commit()
+
+ response = client.get("/execution/assets/by-uri", params={"uri":
"s3://bucket/key"})
+
+ assert response.status_code == 200
+ assert response.json() == {
+ "name": "test_get_asset_by_uri",
+ "uri": "s3://bucket/key",
+ "group": "asset",
+ "extra": {"foo": "bar"},
+ }
+
+ session.delete(asset)
+ session.delete(asset_active)
+ session.commit()
+
+ def test_asset_uri_not_found(self, client):
+ response = client.get("/execution/assets/by-uri", params={"uri":
"non_existent"})
+
+ assert response.status_code == 404
+ assert response.json() == {
+ "detail": {
+ "message": "Asset with URI non_existent not found",
+ "reason": "not_found",
+ }
+ }
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
index 0faeed038e6..707595b92ff 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -43,11 +43,12 @@ from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
+from airflow.sdk.execution_time.context import AssetAliasEvent,
OutletEventAccessor
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
-from airflow.utils.context import AssetAliasEvent, OutletEventAccessor,
OutletEventAccessors
+from airflow.utils.context import OutletEventAccessors
from airflow.utils.db import LazySelectSequence
from airflow.utils.operator_resources import Resources
from airflow.utils.state import DagRunState, State
diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py
deleted file mode 100644
index 0046ca33cc4..00000000000
--- a/tests/utils/test_context.py
+++ /dev/null
@@ -1,102 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import annotations
-
-import pytest
-
-from airflow.models.asset import AssetActive, AssetAliasModel, AssetModel
-from airflow.sdk.definitions.asset import Asset, AssetAlias,
AssetAliasUniqueKey, AssetUniqueKey
-from airflow.utils.context import AssetAliasEvent, OutletEventAccessor,
OutletEventAccessors
-
-
-class TestOutletEventAccessor:
- @pytest.mark.parametrize(
- "key, asset_alias_events",
- (
- (AssetUniqueKey.from_asset(Asset("test_uri")), []),
- (
- AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
- [
- AssetAliasEvent(
- source_alias_name="test_alias",
- dest_asset_key=AssetUniqueKey(uri="test_uri",
name="test_uri"),
- extra={},
- )
- ],
- ),
- ),
- )
- @pytest.mark.db_test
- def test_add(self, key, asset_alias_events, session):
- asset = Asset("test_uri")
- session.add_all([AssetModel.from_public(asset),
AssetActive.for_asset(asset)])
- session.flush()
-
- outlet_event_accessor = OutletEventAccessor(key=key, extra={})
- outlet_event_accessor.add(asset)
- assert outlet_event_accessor.asset_alias_events == asset_alias_events
-
- @pytest.mark.db_test
- @pytest.mark.parametrize(
- "key, asset_alias_events",
- (
- (AssetUniqueKey.from_asset(Asset("test_uri")), []),
- (
- AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
- [
- AssetAliasEvent(
- source_alias_name="test_alias",
- dest_asset_key=AssetUniqueKey(name="test-asset",
uri="test://asset-uri/"),
- extra={},
- )
- ],
- ),
- ),
- )
- def test_add_with_db(self, key, asset_alias_events, session):
- asset = Asset(uri="test://asset-uri", name="test-asset")
- asm = AssetModel.from_public(asset)
- aam = AssetAliasModel(name="test_alias")
- session.add_all([asm, aam, AssetActive.for_asset(asset)])
- session.flush()
-
- outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""})
- outlet_event_accessor.add(asset, extra={})
- assert outlet_event_accessor.asset_alias_events == asset_alias_events
-
-
-class TestOutletEventAccessors:
- @pytest.mark.parametrize(
- "access_key, internal_key",
- (
- (Asset("test"), AssetUniqueKey.from_asset(Asset("test"))),
- (
- Asset(name="test", uri="test://asset"),
- AssetUniqueKey.from_asset(Asset(name="test",
uri="test://asset")),
- ),
- (AssetAlias("test_alias"),
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias"))),
- ),
- )
- def test___get_item__dict_key_not_exists(self, access_key, internal_key):
- outlet_event_accessors = OutletEventAccessors()
- assert len(outlet_event_accessors) == 0
- outlet_event_accessor = outlet_event_accessors[access_key]
- assert len(outlet_event_accessors) == 1
- assert outlet_event_accessor.key == internal_key
- assert outlet_event_accessor.extra == {}