This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b6d33068c84 Allow to_downstream() to return more than one key (#62346)
b6d33068c84 is described below
commit b6d33068c84cef3c920515bfc3c2b2efe72bbf42
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Feb 25 17:33:19 2026 +0800
Allow to_downstream() to return more than one key (#62346)
---
airflow-core/src/airflow/assets/manager.py | 43 ++++++++++++++--------
airflow-core/src/airflow/partition_mappers/base.py | 7 +++-
airflow-core/src/airflow/utils/helpers.py | 6 +--
3 files changed, 35 insertions(+), 21 deletions(-)
diff --git a/airflow-core/src/airflow/assets/manager.py
b/airflow-core/src/airflow/assets/manager.py
index 258cd2cfc72..9c778d4a50e 100644
--- a/airflow-core/src/airflow/assets/manager.py
+++ b/airflow-core/src/airflow/assets/manager.py
@@ -41,6 +41,7 @@ from airflow.models.asset import (
DagScheduleAssetUriReference,
PartitionedAssetKeyLog,
)
+from airflow.utils.helpers import is_container
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.sqlalchemy import get_dialect_name, with_row_locks
@@ -416,22 +417,32 @@ class AssetManager(LoggingMixin):
target_key = timetable.get_partition_mapper(
name=asset_model.name, uri=asset_model.uri
).to_downstream(partition_key)
-
- apdr = cls._get_or_create_apdr(
- target_key=target_key,
- target_dag=target_dag,
- asset_id=asset_id,
- session=session,
- )
- log_record = PartitionedAssetKeyLog(
- asset_id=asset_id,
- asset_event_id=event.id,
- asset_partition_dag_run_id=apdr.id,
- source_partition_key=partition_key,
- target_dag_id=target_dag.dag_id,
- target_partition_key=target_key,
- )
- session.add(log_record)
+ if is_container(target_key):
+ # TODO (AIP-76): This never happens now. When we implement
+ # one-to-many partition key mapping, this should also add a
+ # config to cap the iterable size so the scheduler does not
+ # blow up with an incorrectly implemented PartitionMapper.
+ target_keys: Iterable[str] = target_key
+ else:
+ target_keys = [target_key]
+ del target_key
+
+ for target_key in target_keys:
+ apdr = cls._get_or_create_apdr(
+ target_key=target_key,
+ target_dag=target_dag,
+ asset_id=asset_id,
+ session=session,
+ )
+ log_record = PartitionedAssetKeyLog(
+ asset_id=asset_id,
+ asset_event_id=event.id,
+ asset_partition_dag_run_id=apdr.id,
+ source_partition_key=partition_key,
+ target_dag_id=target_dag.dag_id,
+ target_partition_key=target_key,
+ )
+ session.add(log_record)
@classmethod
def _get_or_create_apdr(
diff --git a/airflow-core/src/airflow/partition_mappers/base.py
b/airflow-core/src/airflow/partition_mappers/base.py
index 92e87c0bd28..7c64d056258 100644
--- a/airflow-core/src/airflow/partition_mappers/base.py
+++ b/airflow-core/src/airflow/partition_mappers/base.py
@@ -18,7 +18,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import Any
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from collections.abc import Iterable
class PartitionMapper(ABC):
@@ -29,7 +32,7 @@ class PartitionMapper(ABC):
"""
@abstractmethod
- def to_downstream(self, key: str) -> str:
+ def to_downstream(self, key: str) -> str | Iterable[str]:
"""Return the target key that the given source partition key maps
to."""
def serialize(self) -> dict[str, Any]:
diff --git a/airflow-core/src/airflow/utils/helpers.py
b/airflow-core/src/airflow/utils/helpers.py
index a48780022a6..50bd8b82a46 100644
--- a/airflow-core/src/airflow/utils/helpers.py
+++ b/airflow-core/src/airflow/utils/helpers.py
@@ -34,9 +34,9 @@ from airflow.serialization.definitions.notset import
is_arg_set
if TYPE_CHECKING:
from datetime import datetime
- from typing import TypeGuard
import jinja2
+ from typing_extensions import TypeIs
from airflow.models.taskinstance import TaskInstance
@@ -95,11 +95,11 @@ def prompt_with_timeout(question: str, timeout: int,
default: bool | None = None
@overload
-def is_container(obj: None | int | Iterable[int] | range) ->
TypeGuard[Iterable[int]]: ...
+def is_container(obj: None | int | Iterable[int] | range) ->
TypeIs[Iterable[int]]: ...
@overload
-def is_container(obj: None | CT | Iterable[CT]) -> TypeGuard[Iterable[CT]]: ...
+def is_container(obj: None | CT | Iterable[CT]) -> TypeIs[Iterable[CT]]: ...
def is_container(obj) -> bool: