This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 08d0273c1a8 Use Protocol for `OutletEventAccessor` (#45762)
08d0273c1a8 is described below

commit 08d0273c1a88333f504913ae7b35ddb0414f24b1
Author: Kaxil Naik <[email protected]>
AuthorDate: Mon Jan 20 16:14:23 2025 +0530

    Use Protocol for `OutletEventAccessor` (#45762)
    
    Follow-up of https://github.com/apache/airflow/pull/45727 to use Protocol 
to allow auto-completion on IDE while not introducing runtime dep
---
 airflow/models/taskinstance.py                     |  6 ++---
 airflow/serialization/serialized_objects.py        |  4 ++--
 airflow/utils/context.py                           |  3 ++-
 airflow/utils/operator_helpers.py                  |  4 ++--
 .../providers/edge/example_dags/win_test.py        |  2 +-
 .../amazon/aws/transfers/google_api_to_s3.py       |  2 +-
 .../src/airflow/sdk/definitions/asset/__init__.py  |  9 ++++++++
 task_sdk/src/airflow/sdk/definitions/context.py    |  9 +++++---
 task_sdk/src/airflow/sdk/execution_time/context.py | 10 +-------
 .../src/airflow/sdk/execution_time/task_runner.py  |  1 -
 .../sdk/{definitions/protocols.py => types.py}     | 27 ++++++++++++++++++++++
 task_sdk/tests/execution_time/test_context.py      |  9 ++++++--
 tests/serialization/test_serialized_objects.py     |  5 ++--
 13 files changed, 63 insertions(+), 28 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 2f3fa4e8fb4..5e0f4001d2d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -163,7 +163,7 @@ if TYPE_CHECKING:
     from airflow.models.dagrun import DagRun
     from airflow.models.operator import Operator
     from airflow.sdk.definitions.dag import DAG
-    from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol
+    from airflow.sdk.types import OutletEventAccessorsProtocol, 
RuntimeTaskInstanceProtocol
     from airflow.timetables.base import DataInterval
     from airflow.typing_compat import Literal, TypeGuard
     from airflow.utils.task_group import TaskGroup
@@ -2730,7 +2730,7 @@ class TaskInstance(Base, LoggingMixin):
         )
 
     def _register_asset_changes(
-        self, *, events: OutletEventAccessors, session: Session | None = None
+        self, *, events: OutletEventAccessorsProtocol, session: Session | None 
= None
     ) -> None:
         if session:
             TaskInstance._register_asset_changes_int(ti=self, events=events, 
session=session)
@@ -2740,7 +2740,7 @@ class TaskInstance(Base, LoggingMixin):
     @staticmethod
     @provide_session
     def _register_asset_changes_int(
-        ti: TaskInstance, *, events: OutletEventAccessors, session: Session = 
NEW_SESSION
+        ti: TaskInstance, *, events: OutletEventAccessorsProtocol, session: 
Session = NEW_SESSION
     ) -> None:
         if TYPE_CHECKING:
             assert ti.task
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index d828a9a5b6b..88e0f200bb2 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -56,6 +56,7 @@ from airflow.providers_manager import ProvidersManager
 from airflow.sdk.definitions.asset import (
     Asset,
     AssetAlias,
+    AssetAliasEvent,
     AssetAliasUniqueKey,
     AssetAll,
     AssetAny,
@@ -64,7 +65,7 @@ from airflow.sdk.definitions.asset import (
     BaseAsset,
 )
 from airflow.sdk.definitions.baseoperator import BaseOperator as 
TaskSDKBaseOperator
-from airflow.sdk.execution_time.context import AssetAliasEvent, 
OutletEventAccessor
+from airflow.sdk.execution_time.context import OutletEventAccessor, 
OutletEventAccessors
 from airflow.serialization.dag_dependency import DagDependency
 from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
 from airflow.serialization.helpers import serialize_template_field
@@ -80,7 +81,6 @@ from airflow.utils.code_utils import get_python_source
 from airflow.utils.context import (
     ConnectionAccessor,
     Context,
-    OutletEventAccessors,
     VariableAccessor,
 )
 from airflow.utils.db import LazySelectSequence
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 168243290fa..a36202f0793 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -63,6 +63,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.expression import Select, TextClause
 
     from airflow.models.baseoperator import BaseOperator
+    from airflow.sdk.types import OutletEventAccessorsProtocol
 
 # NOTE: Please keep this in sync with the following:
 # * Context in task_sdk/src/airflow/sdk/definitions/context.py
@@ -331,7 +332,7 @@ def context_copy_partial(source: Context, keys: 
Container[str]) -> Context:
     return cast(Context, new)
 
 
-def context_get_outlet_events(context: Context) -> OutletEventAccessors:
+def context_get_outlet_events(context: Context) -> 
OutletEventAccessorsProtocol:
     try:
         return context["outlet_events"]
     except KeyError:
diff --git a/airflow/utils/operator_helpers.py 
b/airflow/utils/operator_helpers.py
index cb822aa1cc7..ab3c8c89e5e 100644
--- a/airflow/utils/operator_helpers.py
+++ b/airflow/utils/operator_helpers.py
@@ -29,7 +29,7 @@ from airflow.typing_compat import ParamSpec
 from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
-    from airflow.utils.context import OutletEventAccessors
+    from airflow.sdk.types import OutletEventAccessorsProtocol
 
 P = ParamSpec("P")
 R = TypeVar("R")
@@ -230,7 +230,7 @@ class _ExecutionCallableRunner(Protocol):
 
 def ExecutionCallableRunner(
     func: Callable[P, R],
-    outlet_events: OutletEventAccessors,
+    outlet_events: OutletEventAccessorsProtocol,
     *,
     logger: logging.Logger,
 ) -> _ExecutionCallableRunner:
diff --git a/providers/edge/src/airflow/providers/edge/example_dags/win_test.py 
b/providers/edge/src/airflow/providers/edge/example_dags/win_test.py
index a2727363d64..15735b85d18 100644
--- a/providers/edge/src/airflow/providers/edge/example_dags/win_test.py
+++ b/providers/edge/src/airflow/providers/edge/example_dags/win_test.py
@@ -46,7 +46,7 @@ from airflow.utils.types import ArgNotSet
 
 if TYPE_CHECKING:
     try:
-        from airflow.sdk.definitions.protocols import 
RuntimeTaskInstanceProtocol as TaskInstance
+        from airflow.sdk.types import RuntimeTaskInstanceProtocol as 
TaskInstance
     except ImportError:
         from airflow.models import TaskInstance  # type: ignore[assignment]
     from airflow.utils.context import Context
diff --git 
a/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py 
b/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
index a3d6bd619ce..157477341b4 100644
--- a/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
+++ b/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
@@ -31,7 +31,7 @@ from airflow.providers.google.common.hooks.discovery_api 
import GoogleDiscoveryA
 
 if TYPE_CHECKING:
     try:
-        from airflow.sdk.definitions.protocols import 
RuntimeTaskInstanceProtocol
+        from airflow.sdk.types import RuntimeTaskInstanceProtocol
     except ImportError:
         from airflow.models import TaskInstance as RuntimeTaskInstanceProtocol 
 # type: ignore[assignment]
     from airflow.utils.context import Context
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py 
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index ea89f1b6817..51d4abbeda4 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -660,3 +660,12 @@ class AssetAll(_AssetBooleanCondition):
         :meta private:
         """
         return {"all": [o.as_expression() for o in self.objects]}
+
+
[email protected]
+class AssetAliasEvent:
+    """Representation of asset event to be triggered by an asset alias."""
+
+    source_alias_name: str
+    dest_asset_key: AssetUniqueKey
+    extra: dict[str, Any]
diff --git a/task_sdk/src/airflow/sdk/definitions/context.py 
b/task_sdk/src/airflow/sdk/definitions/context.py
index 46a92ec2beb..b98c1a2e048 100644
--- a/task_sdk/src/airflow/sdk/definitions/context.py
+++ b/task_sdk/src/airflow/sdk/definitions/context.py
@@ -27,7 +27,11 @@ if TYPE_CHECKING:
     from airflow.models.operator import Operator
     from airflow.sdk.definitions.baseoperator import BaseOperator
     from airflow.sdk.definitions.dag import DAG
-    from airflow.sdk.definitions.protocols import DagRunProtocol, 
RuntimeTaskInstanceProtocol
+    from airflow.sdk.types import (
+        DagRunProtocol,
+        OutletEventAccessorsProtocol,
+        RuntimeTaskInstanceProtocol,
+    )
 
 
 class Context(TypedDict, total=False):
@@ -38,8 +42,7 @@ class Context(TypedDict, total=False):
     dag_run: DagRunProtocol
     data_interval_end: datetime | None
     data_interval_start: datetime | None
-    # outlet_events: OutletEventAccessors
-    outlet_events: Any
+    outlet_events: OutletEventAccessorsProtocol
     ds: str
     ds_nodash: str
     expanded_ti_count: int | None
diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py 
b/task_sdk/src/airflow/sdk/execution_time/context.py
index 918526c3004..a068b53aec7 100644
--- a/task_sdk/src/airflow/sdk/execution_time/context.py
+++ b/task_sdk/src/airflow/sdk/execution_time/context.py
@@ -28,6 +28,7 @@ from airflow.sdk.definitions._internal.types import NOTSET
 from airflow.sdk.definitions.asset import (
     Asset,
     AssetAlias,
+    AssetAliasEvent,
     AssetAliasUniqueKey,
     AssetNameRef,
     AssetRef,
@@ -174,15 +175,6 @@ class MacrosAccessor:
         return True
 
 
[email protected]
-class AssetAliasEvent:
-    """Representation of asset event to be triggered by an asset alias."""
-
-    source_alias_name: str
-    dest_asset_key: AssetUniqueKey
-    extra: dict[str, Any]
-
-
 @attrs.define
 class OutletEventAccessor:
     """Wrapper to access an outlet asset event in template."""
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index d252c24be18..d4816c8ae59 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -137,7 +137,6 @@ class RuntimeTaskInstance(TaskInstance):
             }
             context.update(context_from_server)
 
-        # TODO: We should use/move TypeDict from airflow.utils.context.Context
         return context
 
     def render_templates(
diff --git a/task_sdk/src/airflow/sdk/definitions/protocols.py 
b/task_sdk/src/airflow/sdk/types.py
similarity index 69%
rename from task_sdk/src/airflow/sdk/definitions/protocols.py
rename to task_sdk/src/airflow/sdk/types.py
index 80dba602ff1..35ee9f8e38c 100644
--- a/task_sdk/src/airflow/sdk/definitions/protocols.py
+++ b/task_sdk/src/airflow/sdk/types.py
@@ -20,8 +20,10 @@ from __future__ import annotations
 from typing import TYPE_CHECKING, Any, Protocol
 
 if TYPE_CHECKING:
+    from collections.abc import Iterator
     from datetime import datetime
 
+    from airflow.sdk.definitions.asset import Asset, AssetAlias, 
AssetAliasEvent, BaseAssetUniqueKey
     from airflow.sdk.definitions.baseoperator import BaseOperator
 
 
@@ -65,3 +67,28 @@ class RuntimeTaskInstanceProtocol(Protocol):
     ) -> Any: ...
 
     def xcom_push(self, key: str, value: Any) -> None: ...
+
+
+class OutletEventAccessorProtocol(Protocol):
+    """Protocol for managing access to a specific outlet event accessor."""
+
+    key: BaseAssetUniqueKey
+    extra: dict[str, Any]
+    asset_alias_events: list[AssetAliasEvent]
+
+    def __init__(
+        self,
+        *,
+        key: BaseAssetUniqueKey,
+        extra: dict[str, Any],
+        asset_alias_events: list[AssetAliasEvent],
+    ) -> None: ...
+    def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: 
...
+
+
+class OutletEventAccessorsProtocol(Protocol):
+    """Protocol for managing access to outlet event accessors."""
+
+    def __iter__(self) -> Iterator[Asset | AssetAlias]: ...
+    def __len__(self) -> int: ...
+    def __getitem__(self, key: Asset | AssetAlias) -> 
OutletEventAccessorProtocol: ...
diff --git a/task_sdk/tests/execution_time/test_context.py 
b/task_sdk/tests/execution_time/test_context.py
index e3ef15dc934..a155f65a9f5 100644
--- a/task_sdk/tests/execution_time/test_context.py
+++ b/task_sdk/tests/execution_time/test_context.py
@@ -22,13 +22,18 @@ from unittest.mock import MagicMock, patch
 import pytest
 
 from airflow.sdk import get_current_context
-from airflow.sdk.definitions.asset import Asset, AssetAlias, 
AssetAliasUniqueKey, AssetUniqueKey
+from airflow.sdk.definitions.asset import (
+    Asset,
+    AssetAlias,
+    AssetAliasEvent,
+    AssetAliasUniqueKey,
+    AssetUniqueKey,
+)
 from airflow.sdk.definitions.connection import Connection
 from airflow.sdk.definitions.variable import Variable
 from airflow.sdk.exceptions import ErrorType
 from airflow.sdk.execution_time.comms import AssetResult, ConnectionResult, 
ErrorResponse, VariableResult
 from airflow.sdk.execution_time.context import (
-    AssetAliasEvent,
     ConnectionAccessor,
     OutletEventAccessor,
     OutletEventAccessors,
diff --git a/tests/serialization/test_serialized_objects.py 
b/tests/serialization/test_serialized_objects.py
index 707595b92ff..06bb477becd 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -42,13 +42,12 @@ from airflow.models.taskinstance import SimpleTaskInstance, 
TaskInstance
 from airflow.models.xcom_arg import XComArg
 from airflow.operators.empty import EmptyOperator
 from airflow.providers.standard.operators.python import PythonOperator
-from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
-from airflow.sdk.execution_time.context import AssetAliasEvent, 
OutletEventAccessor
+from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, 
AssetUniqueKey
+from airflow.sdk.execution_time.context import OutletEventAccessor, 
OutletEventAccessors
 from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
 from airflow.serialization.serialized_objects import BaseSerialization
 from airflow.triggers.base import BaseTrigger
 from airflow.utils import timezone
-from airflow.utils.context import OutletEventAccessors
 from airflow.utils.db import LazySelectSequence
 from airflow.utils.operator_resources import Resources
 from airflow.utils.state import DagRunState, State

Reply via email to