This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 08d0273c1a8 Use Protocol for `OutletEventAccessor` (#45762)
08d0273c1a8 is described below
commit 08d0273c1a88333f504913ae7b35ddb0414f24b1
Author: Kaxil Naik <[email protected]>
AuthorDate: Mon Jan 20 16:14:23 2025 +0530
Use Protocol for `OutletEventAccessor` (#45762)
Follow-up of https://github.com/apache/airflow/pull/45727 to use Protocol
to allow auto-completion on IDE while not introducing runtime dep
---
airflow/models/taskinstance.py | 6 ++---
airflow/serialization/serialized_objects.py | 4 ++--
airflow/utils/context.py | 3 ++-
airflow/utils/operator_helpers.py | 4 ++--
.../providers/edge/example_dags/win_test.py | 2 +-
.../amazon/aws/transfers/google_api_to_s3.py | 2 +-
.../src/airflow/sdk/definitions/asset/__init__.py | 9 ++++++++
task_sdk/src/airflow/sdk/definitions/context.py | 9 +++++---
task_sdk/src/airflow/sdk/execution_time/context.py | 10 +-------
.../src/airflow/sdk/execution_time/task_runner.py | 1 -
.../sdk/{definitions/protocols.py => types.py} | 27 ++++++++++++++++++++++
task_sdk/tests/execution_time/test_context.py | 9 ++++++--
tests/serialization/test_serialized_objects.py | 5 ++--
13 files changed, 63 insertions(+), 28 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 2f3fa4e8fb4..5e0f4001d2d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -163,7 +163,7 @@ if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
- from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol
+ from airflow.sdk.types import OutletEventAccessorsProtocol,
RuntimeTaskInstanceProtocol
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Literal, TypeGuard
from airflow.utils.task_group import TaskGroup
@@ -2730,7 +2730,7 @@ class TaskInstance(Base, LoggingMixin):
)
def _register_asset_changes(
- self, *, events: OutletEventAccessors, session: Session | None = None
+ self, *, events: OutletEventAccessorsProtocol, session: Session | None
= None
) -> None:
if session:
TaskInstance._register_asset_changes_int(ti=self, events=events,
session=session)
@@ -2740,7 +2740,7 @@ class TaskInstance(Base, LoggingMixin):
@staticmethod
@provide_session
def _register_asset_changes_int(
- ti: TaskInstance, *, events: OutletEventAccessors, session: Session =
NEW_SESSION
+ ti: TaskInstance, *, events: OutletEventAccessorsProtocol, session:
Session = NEW_SESSION
) -> None:
if TYPE_CHECKING:
assert ti.task
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index d828a9a5b6b..88e0f200bb2 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -56,6 +56,7 @@ from airflow.providers_manager import ProvidersManager
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
+ AssetAliasEvent,
AssetAliasUniqueKey,
AssetAll,
AssetAny,
@@ -64,7 +65,7 @@ from airflow.sdk.definitions.asset import (
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as
TaskSDKBaseOperator
-from airflow.sdk.execution_time.context import AssetAliasEvent,
OutletEventAccessor
+from airflow.sdk.execution_time.context import OutletEventAccessor,
OutletEventAccessors
from airflow.serialization.dag_dependency import DagDependency
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
@@ -80,7 +81,6 @@ from airflow.utils.code_utils import get_python_source
from airflow.utils.context import (
ConnectionAccessor,
Context,
- OutletEventAccessors,
VariableAccessor,
)
from airflow.utils.db import LazySelectSequence
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 168243290fa..a36202f0793 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -63,6 +63,7 @@ if TYPE_CHECKING:
from sqlalchemy.sql.expression import Select, TextClause
from airflow.models.baseoperator import BaseOperator
+ from airflow.sdk.types import OutletEventAccessorsProtocol
# NOTE: Please keep this in sync with the following:
# * Context in task_sdk/src/airflow/sdk/definitions/context.py
@@ -331,7 +332,7 @@ def context_copy_partial(source: Context, keys:
Container[str]) -> Context:
return cast(Context, new)
-def context_get_outlet_events(context: Context) -> OutletEventAccessors:
+def context_get_outlet_events(context: Context) ->
OutletEventAccessorsProtocol:
try:
return context["outlet_events"]
except KeyError:
diff --git a/airflow/utils/operator_helpers.py
b/airflow/utils/operator_helpers.py
index cb822aa1cc7..ab3c8c89e5e 100644
--- a/airflow/utils/operator_helpers.py
+++ b/airflow/utils/operator_helpers.py
@@ -29,7 +29,7 @@ from airflow.typing_compat import ParamSpec
from airflow.utils.types import NOTSET
if TYPE_CHECKING:
- from airflow.utils.context import OutletEventAccessors
+ from airflow.sdk.types import OutletEventAccessorsProtocol
P = ParamSpec("P")
R = TypeVar("R")
@@ -230,7 +230,7 @@ class _ExecutionCallableRunner(Protocol):
def ExecutionCallableRunner(
func: Callable[P, R],
- outlet_events: OutletEventAccessors,
+ outlet_events: OutletEventAccessorsProtocol,
*,
logger: logging.Logger,
) -> _ExecutionCallableRunner:
diff --git a/providers/edge/src/airflow/providers/edge/example_dags/win_test.py
b/providers/edge/src/airflow/providers/edge/example_dags/win_test.py
index a2727363d64..15735b85d18 100644
--- a/providers/edge/src/airflow/providers/edge/example_dags/win_test.py
+++ b/providers/edge/src/airflow/providers/edge/example_dags/win_test.py
@@ -46,7 +46,7 @@ from airflow.utils.types import ArgNotSet
if TYPE_CHECKING:
try:
- from airflow.sdk.definitions.protocols import
RuntimeTaskInstanceProtocol as TaskInstance
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol as
TaskInstance
except ImportError:
from airflow.models import TaskInstance # type: ignore[assignment]
from airflow.utils.context import Context
diff --git
a/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
b/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
index a3d6bd619ce..157477341b4 100644
--- a/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
+++ b/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
@@ -31,7 +31,7 @@ from airflow.providers.google.common.hooks.discovery_api
import GoogleDiscoveryA
if TYPE_CHECKING:
try:
- from airflow.sdk.definitions.protocols import
RuntimeTaskInstanceProtocol
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol
except ImportError:
from airflow.models import TaskInstance as RuntimeTaskInstanceProtocol
# type: ignore[assignment]
from airflow.utils.context import Context
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index ea89f1b6817..51d4abbeda4 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -660,3 +660,12 @@ class AssetAll(_AssetBooleanCondition):
:meta private:
"""
return {"all": [o.as_expression() for o in self.objects]}
+
+
[email protected]
+class AssetAliasEvent:
+ """Representation of asset event to be triggered by an asset alias."""
+
+ source_alias_name: str
+ dest_asset_key: AssetUniqueKey
+ extra: dict[str, Any]
diff --git a/task_sdk/src/airflow/sdk/definitions/context.py
b/task_sdk/src/airflow/sdk/definitions/context.py
index 46a92ec2beb..b98c1a2e048 100644
--- a/task_sdk/src/airflow/sdk/definitions/context.py
+++ b/task_sdk/src/airflow/sdk/definitions/context.py
@@ -27,7 +27,11 @@ if TYPE_CHECKING:
from airflow.models.operator import Operator
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
- from airflow.sdk.definitions.protocols import DagRunProtocol,
RuntimeTaskInstanceProtocol
+ from airflow.sdk.types import (
+ DagRunProtocol,
+ OutletEventAccessorsProtocol,
+ RuntimeTaskInstanceProtocol,
+ )
class Context(TypedDict, total=False):
@@ -38,8 +42,7 @@ class Context(TypedDict, total=False):
dag_run: DagRunProtocol
data_interval_end: datetime | None
data_interval_start: datetime | None
- # outlet_events: OutletEventAccessors
- outlet_events: Any
+ outlet_events: OutletEventAccessorsProtocol
ds: str
ds_nodash: str
expanded_ti_count: int | None
diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py
b/task_sdk/src/airflow/sdk/execution_time/context.py
index 918526c3004..a068b53aec7 100644
--- a/task_sdk/src/airflow/sdk/execution_time/context.py
+++ b/task_sdk/src/airflow/sdk/execution_time/context.py
@@ -28,6 +28,7 @@ from airflow.sdk.definitions._internal.types import NOTSET
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
+ AssetAliasEvent,
AssetAliasUniqueKey,
AssetNameRef,
AssetRef,
@@ -174,15 +175,6 @@ class MacrosAccessor:
return True
[email protected]
-class AssetAliasEvent:
- """Representation of asset event to be triggered by an asset alias."""
-
- source_alias_name: str
- dest_asset_key: AssetUniqueKey
- extra: dict[str, Any]
-
-
@attrs.define
class OutletEventAccessor:
"""Wrapper to access an outlet asset event in template."""
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index d252c24be18..d4816c8ae59 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -137,7 +137,6 @@ class RuntimeTaskInstance(TaskInstance):
}
context.update(context_from_server)
- # TODO: We should use/move TypeDict from airflow.utils.context.Context
return context
def render_templates(
diff --git a/task_sdk/src/airflow/sdk/definitions/protocols.py
b/task_sdk/src/airflow/sdk/types.py
similarity index 69%
rename from task_sdk/src/airflow/sdk/definitions/protocols.py
rename to task_sdk/src/airflow/sdk/types.py
index 80dba602ff1..35ee9f8e38c 100644
--- a/task_sdk/src/airflow/sdk/definitions/protocols.py
+++ b/task_sdk/src/airflow/sdk/types.py
@@ -20,8 +20,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING:
+ from collections.abc import Iterator
from datetime import datetime
+ from airflow.sdk.definitions.asset import Asset, AssetAlias,
AssetAliasEvent, BaseAssetUniqueKey
from airflow.sdk.definitions.baseoperator import BaseOperator
@@ -65,3 +67,28 @@ class RuntimeTaskInstanceProtocol(Protocol):
) -> Any: ...
def xcom_push(self, key: str, value: Any) -> None: ...
+
+
+class OutletEventAccessorProtocol(Protocol):
+ """Protocol for managing access to a specific outlet event accessor."""
+
+ key: BaseAssetUniqueKey
+ extra: dict[str, Any]
+ asset_alias_events: list[AssetAliasEvent]
+
+ def __init__(
+ self,
+ *,
+ key: BaseAssetUniqueKey,
+ extra: dict[str, Any],
+ asset_alias_events: list[AssetAliasEvent],
+ ) -> None: ...
+ def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
...
+
+
+class OutletEventAccessorsProtocol(Protocol):
+ """Protocol for managing access to outlet event accessors."""
+
+ def __iter__(self) -> Iterator[Asset | AssetAlias]: ...
+ def __len__(self) -> int: ...
+ def __getitem__(self, key: Asset | AssetAlias) ->
OutletEventAccessorProtocol: ...
diff --git a/task_sdk/tests/execution_time/test_context.py
b/task_sdk/tests/execution_time/test_context.py
index e3ef15dc934..a155f65a9f5 100644
--- a/task_sdk/tests/execution_time/test_context.py
+++ b/task_sdk/tests/execution_time/test_context.py
@@ -22,13 +22,18 @@ from unittest.mock import MagicMock, patch
import pytest
from airflow.sdk import get_current_context
-from airflow.sdk.definitions.asset import Asset, AssetAlias,
AssetAliasUniqueKey, AssetUniqueKey
+from airflow.sdk.definitions.asset import (
+ Asset,
+ AssetAlias,
+ AssetAliasEvent,
+ AssetAliasUniqueKey,
+ AssetUniqueKey,
+)
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult,
ErrorResponse, VariableResult
from airflow.sdk.execution_time.context import (
- AssetAliasEvent,
ConnectionAccessor,
OutletEventAccessor,
OutletEventAccessors,
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
index 707595b92ff..06bb477becd 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -42,13 +42,12 @@ from airflow.models.taskinstance import SimpleTaskInstance,
TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
-from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
-from airflow.sdk.execution_time.context import AssetAliasEvent,
OutletEventAccessor
+from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent,
AssetUniqueKey
+from airflow.sdk.execution_time.context import OutletEventAccessor,
OutletEventAccessors
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
-from airflow.utils.context import OutletEventAccessors
from airflow.utils.db import LazySelectSequence
from airflow.utils.operator_resources import Resources
from airflow.utils.state import DagRunState, State