This is an automated email from the ASF dual-hosted git repository.
Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f690474a26a Add FixedKeyMapper and SegmentWindow for categorical
asset-partition rollup (#67716)
f690474a26a is described below
commit f690474a26a18f34e95469a5c28369b87c8f3838
Author: Wei Lee <[email protected]>
AuthorDate: Wed Jun 10 18:06:05 2026 +0800
Add FixedKeyMapper and SegmentWindow for categorical asset-partition rollup
(#67716)
---
.pre-commit-config.yaml | 8 +-
.../docs/authoring-and-scheduling/assets.rst | 83 +++++
airflow-core/newsfragments/67716.feature.rst | 1 +
.../example_dags/example_asset_partition.py | 40 +++
.../src/airflow/partition_mappers/__init__.py | 4 +
.../src/airflow/partition_mappers/fixed_key.py | 65 ++++
.../src/airflow/partition_mappers/window.py | 62 ++++
airflow-core/src/airflow/serialization/encoders.py | 12 +
airflow-core/tests/unit/jobs/test_scheduler_job.py | 70 ++++
.../tests/unit/partition_mappers/test_fixed_key.py | 120 +++++++
.../tests/unit/partition_mappers/test_window.py | 42 +++
.../check_partition_mapper_defaults_in_sync.py | 291 ++++++++++++++--
...test_check_partition_mapper_defaults_in_sync.py | 388 +++++++++++++++++++++
task-sdk/docs/api.rst | 4 +
task-sdk/src/airflow/sdk/__init__.py | 6 +
task-sdk/src/airflow/sdk/__init__.pyi | 4 +
.../sdk/definitions/partition_mappers/fixed_key.py | 52 +++
.../sdk/definitions/partition_mappers/window.py | 56 ++-
.../task_sdk/definitions/test_partition_mappers.py | 66 ++++
19 files changed, 1344 insertions(+), 30 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1de10325dd6..2be1cba4bd2 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -249,13 +249,17 @@ repos:
pass_filenames: false
require_serial: true
- id: check-partition-mapper-defaults-in-sync
- name: Check FanOutMapper default mapper table stays in sync (core/SDK)
+ name: Check partition-mapper core/SDK sync (FanOutMapper table +
SegmentWindow/FixedKeyMapper)
entry: ./scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
language: python
files: >
(?x)
^airflow-core/src/airflow/partition_mappers/temporal\.py$|
- ^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$
+ ^airflow-core/src/airflow/partition_mappers/window\.py$|
+ ^airflow-core/src/airflow/partition_mappers/fixed_key\.py$|
+
^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$|
+ ^task-sdk/src/airflow/sdk/definitions/partition_mappers/window\.py$|
+
^task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key\.py$
pass_filenames: false
require_serial: true
- id: check-window-in-sync
diff --git a/airflow-core/docs/authoring-and-scheduling/assets.rst
b/airflow-core/docs/authoring-and-scheduling/assets.rst
index 7bf609e9d98..2bc06a30117 100644
--- a/airflow-core/docs/authoring-and-scheduling/assets.rst
+++ b/airflow-core/docs/authoring-and-scheduling/assets.rst
@@ -565,6 +565,12 @@ downstream Dag partition key:
passes the key through unchanged if valid.
For example, ``AllowedKeyMapper(["us", "eu", "apac"])`` accepts only those
region keys and rejects all others.
+* ``FixedKeyMapper`` collapses every upstream key onto a fixed downstream key,
+ regardless of the upstream value.
+* ``SegmentWindow`` declares a fixed categorical set of string keys (e.g.
regions,
+ tenants) that constitute one downstream period; paired with
``FixedKeyMapper``
+ inside a ``RollupMapper`` it holds the downstream run until every declared
segment
+ has arrived (see :ref:`segment-rollup <segment-categorical-rollup>`).
Example of per-asset mapper configuration and composite-key mapping:
@@ -733,6 +739,83 @@ so the run is held indefinitely) and the fall-back day has
twenty-five (the repe
hour is dropped). Use a UTC-based upstream mapper for any rollup that crosses
a DST
boundary; see the ``DayWindow`` class docstring for the full discussion.
+.. _segment-categorical-rollup:
+
+Segment (categorical) rollup
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. versionadded:: 3.3.0
+
+For categorical partitioning — regions, tenants, experiment variants — compose
a
+``RollupMapper`` from two primitives:
+
+* ``SegmentWindow(["us", "eu", "apac"])`` declares the fixed set of string keys
+ that constitute one downstream period; ``to_upstream`` returns the full set
+ regardless of the downstream anchor.
+* ``FixedKeyMapper("all_regions")`` collapses every upstream key onto the
single
+ downstream partition key ``"all_regions"``.
+
+The scheduler holds the downstream Dag run until every declared segment has
arrived
+from the upstream producer, then fires once. All the segment events accumulate
into
+one ``AssetPartitionDagRun``; the fired run's ``partition_key`` is the value
passed
+to ``FixedKeyMapper``. This composition only makes sense under ``WAIT_FOR_ALL``
+semantics (the default).
+
+.. code-block:: python
+
+ from airflow.sdk import (
+ DAG,
+ Asset,
+ FixedKeyMapper,
+ PartitionAtRuntime,
+ PartitionedAssetTimetable,
+ RollupMapper,
+ SegmentWindow,
+ asset,
+ task,
+ )
+
+
+ @asset(
+ uri="file://incoming/player-stats/multi-region.csv",
+ schedule=PartitionAtRuntime(),
+ )
+ def multi_region_player_stats(self, outlet_events):
+ # Emit one event per region in a single run.
+ outlet_events[self].add_partitions(["us", "eu", "apac"])
+
+
+ # Consumer: fires once all three region partitions have arrived.
+ with DAG(
+ dag_id="segment_region_stats_rollup",
+ schedule=PartitionedAssetTimetable(
+ assets=Asset.ref(name="multi_region_player_stats"),
+ default_partition_mapper=RollupMapper(
+ upstream_mapper=FixedKeyMapper("all_regions"),
+ window=SegmentWindow(["us", "eu", "apac"]),
+ ),
+ ),
+ catchup=False,
+ ):
+
+ @task
+ def aggregate_all_regions(dag_run=None):
+ # dag_run.partition_key is the downstream key once all segments
arrive.
+ print(dag_run.partition_key)
+
+ aggregate_all_regions()
+
+Construction validates both components: ``SegmentWindow`` raises
``ValueError`` for
+an empty list, non-string items, or empty-string keys; duplicate entries are
silently
+deduplicated. ``FixedKeyMapper`` raises ``ValueError`` if its argument is not a
+non-empty string. Pass a distinct ``FixedKeyMapper`` key when one consumer Dag
rolls
+up more than one asset, so each rollup uses a distinct bucket and they do not
collide
+on the same ``(target_dag_id, partition_key)``.
+
+For a segment set that must be computed at runtime, do not encode it here —
evaluate
+completeness in a consumer-side task instead (the scheduler must not run user
code to
+decide a partition set).
+
Setting partition keys at runtime
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/airflow-core/newsfragments/67716.feature.rst
b/airflow-core/newsfragments/67716.feature.rst
new file mode 100644
index 00000000000..63daeda113c
--- /dev/null
+++ b/airflow-core/newsfragments/67716.feature.rst
@@ -0,0 +1 @@
+Add ``FixedKeyMapper`` and ``SegmentWindow`` for categorical asset-partition
rollup. ``FixedKeyMapper`` collapses any upstream partition key onto a single
fixed downstream key, and ``SegmentWindow`` enumerates a fixed categorical
segment set the scheduler waits for. Composing
``RollupMapper(FixedKeyMapper(...), SegmentWindow(...))`` expresses a
categorical rollup, mirroring the temporal rollup shape, and ``SegmentWindow``
also composes with ``FanOutMapper`` for categorical scatter. Both [...]
diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition.py
b/airflow-core/src/airflow/example_dags/example_asset_partition.py
index 775fea80ec3..65a078d3c3f 100644
--- a/airflow-core/src/airflow/example_dags/example_asset_partition.py
+++ b/airflow-core/src/airflow/example_dags/example_asset_partition.py
@@ -26,12 +26,14 @@ from airflow.sdk import (
CronPartitionTimetable,
DayWindow,
FanOutMapper,
+ FixedKeyMapper,
IdentityMapper,
MonthWindow,
PartitionAtRuntime,
PartitionedAssetTimetable,
ProductMapper,
RollupMapper,
+ SegmentWindow,
StartOfDayMapper,
StartOfHourMapper,
StartOfMonthMapper,
@@ -409,3 +411,41 @@ with DAG(
print(dag_run.partition_key)
run_inference()
+
+
+# --- Segment (categorical) rollup -------------------------------------------
+# ``multi_region_player_stats`` (defined above) emits one partition per region
+# (``us``, ``eu``, ``apac``) from a single run. The Dag below holds a
downstream
+# run until every declared region key has arrived.
+
+with DAG(
+ dag_id="segment_region_stats_rollup",
+ schedule=PartitionedAssetTimetable(
+ assets=Asset.ref(name="multi_region_player_stats"),
+ default_partition_mapper=RollupMapper(
+ upstream_mapper=FixedKeyMapper("all_regions"),
+ window=SegmentWindow(["us", "eu", "apac"]),
+ ),
+ ),
+ catchup=False,
+ tags=["example", "player-stats", "rollup", "segment"],
+):
+ """
+ Categorical rollup: hold until all three region partitions arrive.
+
+ ``RollupMapper(upstream_mapper=FixedKeyMapper("all_regions"),
window=SegmentWindow([...]))``
+ declares the fixed set of region keys required for one downstream run and
collapses every
+ region key onto a single ``all_regions`` partition, so the three region
events accumulate
+ into one downstream run. The run is held until ``us``, ``eu``, and
``apac`` have all
+ arrived from ``multi_region_player_stats``; partial arrivals remain
pending in the
+ next-run-assets view so operators can track progress.
+ """
+
+ @task
+ def aggregate_all_regions(dag_run=None):
+ """Produce the cross-region summary once every region partition has
arrived."""
+ if TYPE_CHECKING:
+ assert dag_run
+ print(f"All region partitions received. Partition:
{dag_run.partition_key}")
+
+ aggregate_all_regions()
diff --git a/airflow-core/src/airflow/partition_mappers/__init__.py
b/airflow-core/src/airflow/partition_mappers/__init__.py
index 9b66876e0df..f7806038707 100644
--- a/airflow-core/src/airflow/partition_mappers/__init__.py
+++ b/airflow-core/src/airflow/partition_mappers/__init__.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from airflow.partition_mappers.allowed_key import AllowedKeyMapper
from airflow.partition_mappers.base import PartitionMapper, RollupMapper
from airflow.partition_mappers.chain import ChainMapper
+from airflow.partition_mappers.fixed_key import FixedKeyMapper
from airflow.partition_mappers.identity import IdentityMapper
from airflow.partition_mappers.product import ProductMapper
from airflow.partition_mappers.temporal import (
@@ -34,6 +35,7 @@ from airflow.partition_mappers.window import (
HourWindow,
MonthWindow,
QuarterWindow,
+ SegmentWindow,
WeekWindow,
Window,
YearWindow,
@@ -43,6 +45,7 @@ __all__ = [
"AllowedKeyMapper",
"ChainMapper",
"DayWindow",
+ "FixedKeyMapper",
"HourWindow",
"IdentityMapper",
"MonthWindow",
@@ -50,6 +53,7 @@ __all__ = [
"ProductMapper",
"QuarterWindow",
"RollupMapper",
+ "SegmentWindow",
"StartOfDayMapper",
"StartOfHourMapper",
"StartOfMonthMapper",
diff --git a/airflow-core/src/airflow/partition_mappers/fixed_key.py
b/airflow-core/src/airflow/partition_mappers/fixed_key.py
new file mode 100644
index 00000000000..89388bfed12
--- /dev/null
+++ b/airflow-core/src/airflow/partition_mappers/fixed_key.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+
+import attrs
+
+from airflow.partition_mappers.base import PartitionMapper
+
+
[email protected]
+class FixedKeyMapper(PartitionMapper):
+ """
+ Collapse every upstream partition key onto one fixed downstream key.
+
+ Returns the same *downstream_key* for any upstream key passed to
+ ``to_downstream``. Does not override ``decode_downstream`` or
+ ``encode_upstream``, so it works with the string-based identity path and
+ satisfies :class:`~airflow.partition_mappers.base.RollupMapper`'s guard
+ when paired with :class:`~airflow.partition_mappers.window.SegmentWindow`.
+
+ Typical use is as the ``upstream_mapper`` inside a categorical rollup::
+
+ RollupMapper(
+ upstream_mapper=FixedKeyMapper("all_regions"),
+ window=SegmentWindow(["us", "eu", "apac"]),
+ )
+
+ :param downstream_key: The fixed downstream partition key every upstream
key
+ maps to. Must be a non-empty string.
+ :raises ValueError: if *downstream_key* is not a non-empty ``str``.
+ """
+
+ downstream_key: str = attrs.field()
+
+ @downstream_key.validator
+ def _validate_downstream_key(self, attribute: attrs.Attribute, value: str)
-> None:
+ if not isinstance(value, str) or value == "":
+ raise ValueError(f"FixedKeyMapper downstream_key must be a
non-empty str; got {value!r}.")
+
+ def to_downstream(self, key: str) -> str:
+ """Return the fixed downstream key regardless of *key*."""
+ return self.downstream_key
+
+ def serialize(self) -> dict[str, Any]:
+ return {"downstream_key": self.downstream_key}
+
+ @classmethod
+ def deserialize(cls, data: dict[str, Any]) -> FixedKeyMapper:
+ return cls(data["downstream_key"])
diff --git a/airflow-core/src/airflow/partition_mappers/window.py
b/airflow-core/src/airflow/partition_mappers/window.py
index 087b6180f2a..ffcf0e85ab0 100644
--- a/airflow-core/src/airflow/partition_mappers/window.py
+++ b/airflow-core/src/airflow/partition_mappers/window.py
@@ -21,6 +21,8 @@ from datetime import datetime, timedelta
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar
+import attrs
+
if TYPE_CHECKING:
from collections.abc import Callable, Iterable
@@ -246,3 +248,63 @@ class YearWindow(Window):
def to_upstream(self, period_start: datetime) -> Iterable[datetime]:
_require_day_one(period_start, type(self))
return _build_directional_steps(period_start, 12, _shift_months,
self.direction)
+
+
+def _convert_segments(segments: Iterable[str]) -> frozenset[str]:
+ """
+ Validate and convert *segments* to a ``frozenset[str]``.
+
+ Validates each element for type and non-emptiness (with index reporting)
+ before collapsing into a frozenset, then checks the result is non-empty.
+ """
+ validated: list[str] = []
+ for i, item in enumerate(segments):
+ if not isinstance(item, str):
+ raise ValueError(
+ f"SegmentWindow segment keys must be str; got
{type(item).__name__!r} at index {i}: {item!r}"
+ )
+ if not item:
+ raise ValueError(
+ f"SegmentWindow segment keys must be non-empty; got an empty
string at index {i}."
+ )
+ validated.append(item)
+ result = frozenset(validated)
+ if not result:
+ raise ValueError("SegmentWindow requires at least one segment key; got
an empty iterable.")
+ return result
+
+
[email protected]
+class SegmentWindow(Window):
+ """
+ A fixed categorical set of string keys that constitute one downstream
period.
+
+ Paired with :class:`~airflow.partition_mappers.fixed_key.FixedKeyMapper`
inside a
+ :class:`~airflow.partition_mappers.base.RollupMapper` to express a
categorical
+ rollup: the scheduler holds the downstream run until every declared
segment key
+ has arrived from the upstream producer, then fires once.
+
+ ``to_upstream`` returns the complete segment set regardless of the
downstream
+ anchor value — the anchor is intentionally ignored because all segments
map onto
+ a single downstream partition key, not a time-based period.
+
+ :param segments: Non-empty iterable of non-empty string segment keys.
Duplicates
+ are silently de-duplicated.
+ :raises ValueError: if *segments* is empty, contains a non-``str``
element, or
+ contains an empty-string element.
+ """
+
+ expected_decoded_type: ClassVar[type] = str
+
+ _segments: frozenset[str] = attrs.field(converter=_convert_segments)
+
+ def to_upstream(self, decoded_downstream: Any) -> frozenset[str]:
+ """Return the full declared segment set, ignoring the downstream
anchor."""
+ return self._segments
+
+ def serialize(self) -> dict[str, Any]:
+ return {"segments": sorted(self._segments)}
+
+ @classmethod
+ def deserialize(cls, data: dict[str, Any]) -> SegmentWindow:
+ return cls(data["segments"])
diff --git a/airflow-core/src/airflow/serialization/encoders.py
b/airflow-core/src/airflow/serialization/encoders.py
index f4ec081df1b..0c2f030da8f 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -43,6 +43,7 @@ from airflow.sdk import (
DeltaTriggerTimetable,
EventsTimetable,
FanOutMapper,
+ FixedKeyMapper,
HourWindow,
IdentityMapper,
MonthWindow,
@@ -51,6 +52,7 @@ from airflow.sdk import (
ProductMapper,
QuarterWindow,
RollupMapper,
+ SegmentWindow,
StartOfDayMapper,
StartOfHourMapper,
StartOfMonthMapper,
@@ -437,6 +439,7 @@ class _Serializer:
AllowedKeyMapper:
"airflow.partition_mappers.allowed_key.AllowedKeyMapper",
ChainMapper: "airflow.partition_mappers.chain.ChainMapper",
FanOutMapper: "airflow.partition_mappers.temporal.FanOutMapper",
+ FixedKeyMapper: "airflow.partition_mappers.fixed_key.FixedKeyMapper",
IdentityMapper: "airflow.partition_mappers.identity.IdentityMapper",
ProductMapper: "airflow.partition_mappers.product.ProductMapper",
RollupMapper: "airflow.partition_mappers.base.RollupMapper",
@@ -464,6 +467,10 @@ class _Serializer:
def _(self, partition_mapper: IdentityMapper) -> dict[str, Any]:
return {}
+ @serialize_partition_mapper.register
+ def _(self, partition_mapper: FixedKeyMapper) -> dict[str, Any]:
+ return {"downstream_key": partition_mapper.downstream_key}
+
@serialize_partition_mapper.register(StartOfHourMapper)
@serialize_partition_mapper.register(StartOfDayMapper)
@serialize_partition_mapper.register(StartOfWeekMapper)
@@ -517,6 +524,7 @@ class _Serializer:
WeekWindow: "airflow.partition_mappers.window.WeekWindow",
MonthWindow: "airflow.partition_mappers.window.MonthWindow",
QuarterWindow: "airflow.partition_mappers.window.QuarterWindow",
+ SegmentWindow: "airflow.partition_mappers.window.SegmentWindow",
YearWindow: "airflow.partition_mappers.window.YearWindow",
}
@@ -538,6 +546,10 @@ class _Serializer:
) -> dict[str, Any]:
return window.serialize()
+ @serialize_window.register
+ def _(self, window: SegmentWindow) -> dict[str, Any]:
+ return {"segments": sorted(window._segments)}
+
_serializer = _Serializer()
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 89383233fad..d08628a51fb 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -104,9 +104,11 @@ from airflow.sdk import (
AssetAlias,
AssetWatcher,
CronPartitionTimetable,
+ FixedKeyMapper,
HourWindow,
IdentityMapper,
RollupMapper,
+ SegmentWindow,
StartOfHourMapper,
task,
)
@@ -10273,6 +10275,74 @@ def
test_partitioned_dag_run_rollup_holds_until_window_complete(
assert partition_dags == {"rollup-consumer"}
[email protected]_serialized_dag
[email protected]("clear_asset_partition_rows")
+def test_partitioned_dag_run_segment_rollup_holds_until_all_segments_arrive(
+ dag_maker: DagMaker,
+ session: Session,
+):
+ """
+ A categorical (segment) rollup fires once every declared segment has
arrived.
+
+ ``RollupMapper(FixedKeyMapper("all_regions"), SegmentWindow([...]))``
collapses
+ each region key onto a single ``all_regions`` partition, so all three
events
+ accumulate into one APDR, and holds the downstream run until ``us``,
``eu``,
+ and ``apac`` are all present.
+ """
+ asset_1 = Asset(name="asset-1")
+ with dag_maker(
+ dag_id="segment-rollup-consumer",
+ schedule=PartitionedAssetTimetable(
+ assets=asset_1,
+ default_partition_mapper=RollupMapper(
+ upstream_mapper=FixedKeyMapper("all_regions"),
+ window=SegmentWindow(["us", "eu", "apac"]),
+ ),
+ ),
+ session=session,
+ ):
+ EmptyOperator(task_id="hi")
+ session.commit()
+
+ runner = SchedulerJobRunner(
+ job=Job(job_type=SchedulerJobRunner.job_type),
executors=[MockExecutor(do_update=False)]
+ )
+
+ # First region arrives — only 1 / 3 segments, so the APDR must not fire.
+ # Every region collapses onto the single ``all_regions`` partition.
+ apdr = _produce_and_register_asset_event(
+ dag_id="segment-producer-us",
+ asset=asset_1,
+ partition_key="us",
+ session=session,
+ dag_maker=dag_maker,
+ expected_partition_key="all_regions",
+ )
+ partition_dags =
runner._create_dagruns_for_partitioned_asset_dags(session=session)
+ session.refresh(apdr)
+ assert apdr.created_dag_run_id is None
+ assert partition_dags == set()
+
+ # The remaining two regions arrive — once all three segments are present
the
+ # rollup is satisfied and the APDR creates its Dag run on the next tick.
All
+ # three events share the one ``all_regions`` APDR.
+ for region in ("eu", "apac"):
+ sibling = _produce_and_register_asset_event(
+ dag_id=f"segment-producer-{region}",
+ asset=asset_1,
+ partition_key=region,
+ session=session,
+ dag_maker=dag_maker,
+ expected_partition_key="all_regions",
+ )
+ assert sibling.id == apdr.id
+ partition_dags =
runner._create_dagruns_for_partitioned_asset_dags(session=session)
+ session.refresh(apdr)
+ assert apdr.created_dag_run_id is not None
+ assert apdr.partition_key == "all_regions"
+ assert partition_dags == {"segment-rollup-consumer"}
+
+
@pytest.mark.need_serialized_dag
@pytest.mark.usefixtures("clear_asset_partition_rows")
def test_partitioned_dag_run_rollup_treats_mapper_exception_as_not_satisfied(
diff --git a/airflow-core/tests/unit/partition_mappers/test_fixed_key.py
b/airflow-core/tests/unit/partition_mappers/test_fixed_key.py
new file mode 100644
index 00000000000..efdd5be3313
--- /dev/null
+++ b/airflow-core/tests/unit/partition_mappers/test_fixed_key.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.partition_mappers.base import PartitionMapper, RollupMapper
+from airflow.partition_mappers.fixed_key import FixedKeyMapper
+from airflow.partition_mappers.window import DayWindow, SegmentWindow
+from airflow.sdk import (
+ FixedKeyMapper as SdkFixedKeyMapper,
+ RollupMapper as SdkRollupMapper,
+ SegmentWindow as SdkSegmentWindow,
+)
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+
+
+class TestFixedKeyMapper:
+ @pytest.mark.parametrize("key", ["us", "eu", "apac", "anything-else"])
+ def test_to_downstream_returns_constant_for_any_key(self, key):
+ assert FixedKeyMapper("all").to_downstream(key) == "all"
+
+ def test_is_rollup_false(self):
+ # A bare FixedKeyMapper is not a rollup; rollup-ness comes from
RollupMapper.
+ assert FixedKeyMapper("all").is_rollup is False
+
+ def test_does_not_override_decode_encode(self):
+ m = FixedKeyMapper("all")
+ assert type(m).decode_downstream is PartitionMapper.decode_downstream
+ assert type(m).encode_upstream is PartitionMapper.encode_upstream
+
+ @pytest.mark.parametrize(
+ ("downstream_key", "match"),
+ [
+ pytest.param("", "non-empty str", id="empty-string"),
+ pytest.param(None, "non-empty str", id="none"),
+ pytest.param(1, "non-empty str", id="int"),
+ ],
+ )
+ def test_rejects_invalid_downstream_key(self, downstream_key, match):
+ with pytest.raises(ValueError, match=match):
+ FixedKeyMapper(downstream_key)
+
+ def test_requires_downstream_key(self):
+ with pytest.raises(TypeError):
+ FixedKeyMapper()
+
+ def test_serialize_round_trip(self):
+ m = FixedKeyMapper("bucket")
+ restored = FixedKeyMapper.deserialize(m.serialize())
+ assert isinstance(restored, FixedKeyMapper)
+ assert restored.downstream_key == "bucket"
+
+
+class TestCategoricalRollupEquivalence:
+ """RollupMapper(FixedKeyMapper, SegmentWindow) behaves like old
SegmentMapper."""
+
+ def setup_method(self):
+ self.m = RollupMapper(
+ upstream_mapper=FixedKeyMapper("all"),
+ window=SegmentWindow(["us", "eu", "apac"]),
+ )
+
+ def test_is_rollup_flag(self):
+ assert self.m.is_rollup is True
+
+ def test_to_downstream_collapses_every_segment_onto_downstream_key(self):
+ # Full-sequence equality: every declared segment key maps to the
constant key.
+ assert [self.m.to_downstream(s) for s in ("us", "eu", "apac")] ==
["all", "all", "all"]
+
+ @pytest.mark.parametrize("anchor", ["all", "anything"])
+ def test_to_upstream_returns_full_set_ignoring_anchor(self, anchor):
+ assert self.m.to_upstream(anchor) == frozenset({"us", "eu", "apac"})
+
+ def test_core_encode_decode_round_trip(self):
+ restored = decode_partition_mapper(encode_partition_mapper(self.m))
+ assert isinstance(restored, RollupMapper)
+ assert restored.is_rollup is True
+ assert restored.to_downstream("us") == "all"
+ assert restored.to_upstream("all") == frozenset({"us", "eu", "apac"})
+
+ def test_sdk_encode_decode_round_trip(self):
+ # User code authors with SDK classes; the scheduler serializes and
deserializes
+ # into core classes.
+ sdk_mapper = SdkRollupMapper(
+ upstream_mapper=SdkFixedKeyMapper("all_regions"),
+ window=SdkSegmentWindow(["us", "eu", "apac"]),
+ )
+ restored = decode_partition_mapper(encode_partition_mapper(sdk_mapper))
+ assert isinstance(restored, RollupMapper)
+ assert restored.to_upstream("all_regions") == frozenset({"us", "eu",
"apac"})
+
+
+class TestCategoricalRollupTypeGuard:
+ """Core-side RollupMapper guard: FixedKeyMapper(str) + SegmentWindow(str)
must pass."""
+
+ def test_fixed_key_with_segment_window_does_not_raise(self):
+ # Core guard: FixedKeyMapper does not override decode_downstream,
+ # SegmentWindow.expected_decoded_type is str -> guard passes.
+ RollupMapper(upstream_mapper=FixedKeyMapper("all"),
window=SegmentWindow(["us", "eu"]))
+
+ def test_str_mapper_with_datetime_window_raises(self):
+ # Core guard: FixedKeyMapper (no decode override) + DayWindow
(datetime) -> raise.
+ with pytest.raises(TypeError, match="DayWindow expects decoded values
of type 'datetime'"):
+ RollupMapper(upstream_mapper=FixedKeyMapper("all"),
window=DayWindow())
diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py
b/airflow-core/tests/unit/partition_mappers/test_window.py
index ddfd44cbd76..b4956761c2c 100644
--- a/airflow-core/tests/unit/partition_mappers/test_window.py
+++ b/airflow-core/tests/unit/partition_mappers/test_window.py
@@ -34,6 +34,7 @@ from airflow.partition_mappers.window import (
HourWindow,
MonthWindow,
QuarterWindow,
+ SegmentWindow,
WeekWindow,
Window,
YearWindow,
@@ -465,6 +466,47 @@ class TestDirectionValidation:
WeekWindow(direction=bad_value)
+class TestSegmentWindow:
+ def test_to_upstream_returns_full_set_ignoring_anchor(self):
+ w = SegmentWindow(["us", "eu", "apac"])
+ result_a = frozenset(w.to_upstream("any-anchor"))
+ result_b = frozenset(w.to_upstream("different-anchor"))
+ assert result_a == frozenset({"us", "eu", "apac"})
+ assert result_a == result_b
+
+ def test_expected_decoded_type_is_str(self):
+ assert SegmentWindow.expected_decoded_type is str
+
+ @pytest.mark.parametrize(
+ ("segments", "match"),
+ [
+ pytest.param([], "at least one segment key", id="empty-list"),
+ pytest.param(iter([]), "at least one segment key",
id="empty-iterator"),
+ pytest.param([1, "b"], "must be str", id="int-element"),
+ pytest.param([None, "b"], "must be str", id="none-element"),
+ pytest.param(["", "b"], "non-empty", id="empty-string-first"),
+ pytest.param(["a", ""], "non-empty", id="empty-string-second"),
+ ],
+ )
+ def test_rejects_invalid_segments(self, segments, match):
+ with pytest.raises(ValueError, match=match):
+ SegmentWindow(segments)
+
+ def test_deduplication(self):
+ w = SegmentWindow(["us", "us", "eu"])
+ assert frozenset(w.to_upstream("any")) == frozenset({"us", "eu"})
+
+ def test_serialize_uses_sorted_order(self):
+ w = SegmentWindow(["z", "a", "m"])
+ assert w.serialize() == {"segments": ["a", "m", "z"]}
+
+ def test_deserialize_round_trip(self):
+ w = SegmentWindow(["us", "eu", "apac"])
+ restored = SegmentWindow.deserialize(w.serialize())
+ assert isinstance(restored, SegmentWindow)
+ assert frozenset(restored.to_upstream("any")) == frozenset({"us",
"eu", "apac"})
+
+
class TestWindowSerializationGate:
"""``encode_window`` / ``decode_window`` must reject non-built-in Windows.
diff --git a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
index 0c1adea9669..24f7520f1b5 100755
--- a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
+++ b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
@@ -23,24 +23,25 @@
# ]
# ///
"""
-Verify ``FanOutMapper.default_downstream_mapper_by_window_name`` stays in sync
-between core and the Task SDK.
+Verify partition-mapper definitions stay in sync between core and the Task SDK.
-The default downstream-mapper table is defined twice — once in the core class
-hierarchy and once in the SDK copy — because the two hierarchies are
-independent (the SDK cannot import core) and the lookup is by ``Window`` class
-*name*. Both copies must list the same ``Window`` name -> mapper class mapping,
-otherwise a ``FanOutMapper`` resolves a different default depending on whether
-it runs in Dag-author code (SDK) or after deserialization (core).
+Checks two things:
-This check parses the ``default_downstream_mapper_by_window_name`` class
-attribute from both files via AST and asserts the two mappings are identical.
+1. ``FanOutMapper.default_downstream_mapper_by_window_name`` — the default
+ downstream-mapper table is defined twice (core and SDK) because the two
+ hierarchies are independent. Both copies must list the same
+ ``Window`` name -> mapper class mapping.
+
+2. ``SegmentWindow`` and ``FixedKeyMapper`` — for each class, the field-name
+ set and the ``raise ValueError(...)`` message-template set must be identical
+ between core and the SDK. This catches wording drift (e.g. "non-empty" vs
+ "non-empty strings") that would otherwise silently diverge.
Run from the repo root:
uv run --project scripts python
scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
-Exits 0 if the two tables match, 1 (with a diff) otherwise.
+Exits 0 if everything matches, 1 (with a diff) otherwise.
"""
from __future__ import annotations
@@ -65,6 +66,15 @@ SDK_FILE = (
AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" /
"partition_mappers" / "temporal.py"
)
+CORE_WINDOW_FILE = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "partition_mappers"
/ "window.py"
+SDK_WINDOW_FILE = (
+ AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" /
"partition_mappers" / "window.py"
+)
+CORE_FIXED_KEY_FILE = AIRFLOW_CORE_SOURCES_PATH / "airflow" /
"partition_mappers" / "fixed_key.py"
+SDK_FIXED_KEY_FILE = (
+ AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" /
"partition_mappers" / "fixed_key.py"
+)
+
def _find_attr_value(file_path: Path) -> ast.Dict:
"""Return the AST node assigned to
``FanOutMapper.default_downstream_mapper_by_window_name``."""
@@ -113,7 +123,200 @@ def _extract_mapping(file_path: Path) -> dict[str, str]:
return mapping
+def _joinedstr_to_template(node: ast.JoinedStr) -> str:
+ """
+ Convert an f-string AST node to a template string.
+
+ Literal text fragments are kept as-is; each interpolated expression
+ ``{...}`` is replaced by a ``{}`` placeholder. This lets us compare
+ f-string message templates between core and SDK without caring about the
+ exact expression inside the braces (e.g. ``{type(item).__name__!r}`` vs
+ a future refactoring).
+ """
+ parts: list[str] = []
+ for value in node.values:
+ if isinstance(value, ast.Constant):
+ parts.append(str(value.value))
+ else:
+ parts.append("{}")
+ return "".join(parts)
+
+
+def _extract_raise_messages(subtree: ast.AST) -> set[str]:
+ """
+ Collect ``raise ValueError(...)`` message templates from an arbitrary AST
subtree.
+
+ For each ``raise ValueError(...)`` whose first argument is a plain string
+ constant, an f-string, or adjacent f-string/constant concatenation
+ (``BinOp(Add, ...)``), extracts the template:
+
+ - ``ast.Constant`` → the literal string value.
+ - ``ast.JoinedStr`` (f-string) → literal fragments joined, interpolated
+ expressions replaced by ``{}``.
+ - ``ast.BinOp(Add, ...)`` → recursively concatenated from both sides.
+ """
+ messages: set[str] = set()
+
+ def _collect_bin(n: ast.expr) -> str:
+ if isinstance(n, ast.Constant) and isinstance(n.value, str):
+ return n.value
+ if isinstance(n, ast.JoinedStr):
+ return _joinedstr_to_template(n)
+ if isinstance(n, ast.BinOp) and isinstance(n.op, ast.Add):
+ return _collect_bin(n.left) + _collect_bin(n.right)
+ return "{}"
+
+ for node in ast.walk(subtree):
+ if not isinstance(node, ast.Raise):
+ continue
+ exc = node.exc
+ if exc is None:
+ continue
+ if not (
+ isinstance(exc, ast.Call)
+ and isinstance(exc.func, ast.Name)
+ and exc.func.id == "ValueError"
+ and exc.args
+ ):
+ continue
+ arg = exc.args[0]
+ if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
+ messages.add(arg.value)
+ elif isinstance(arg, ast.JoinedStr):
+ messages.add(_joinedstr_to_template(arg))
+ elif isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.Add):
+ messages.add(_collect_bin(arg))
+
+ return messages
+
+
+def _collect_converter_names(class_node: ast.ClassDef) -> set[str]:
+ """
+ Return function names referenced via ``converter=<Name>`` in attrs field
definitions.
+
+ Scans the top-level statements of *class_node* for annotated assignments of
+ the form ``<name>: <ann> = attrs.field(converter=<Name>)`` (or
+ ``attr.field`` / bare ``field(...)``). Also handles unannotated
+ ``Assign`` nodes with an ``attrs.field(converter=<Name>)`` call on the RHS.
+ Only bare ``ast.Name`` references are collected; lambdas and attribute
+ lookups are ignored (they cannot resolve to a module-level function by
+ name).
+ """
+ names: set[str] = set()
+
+ def _check_call(call_node: ast.expr) -> None:
+ """If *call_node* is an attrs.field / attr.field / field() call,
collect converter=<Name>."""
+ if not isinstance(call_node, ast.Call):
+ return
+ func = call_node.func
+ is_field_call = (
+ # attrs.field(...) or attr.field(...)
+ (isinstance(func, ast.Attribute) and func.attr == "field")
+ # bare field(...)
+ or (isinstance(func, ast.Name) and func.id == "field")
+ )
+ if not is_field_call:
+ return
+ for kw in call_node.keywords:
+ if kw.arg == "converter" and isinstance(kw.value, ast.Name):
+ names.add(kw.value.id)
+
+ for stmt in class_node.body:
+ if isinstance(stmt, ast.AnnAssign) and stmt.value is not None:
+ _check_call(stmt.value)
+ elif isinstance(stmt, ast.Assign):
+ for target_value in [stmt.value]:
+ _check_call(target_value)
+
+ return names
+
+
+def extract_class_error_messages(file_path: Path, class_name: str) -> set[str]:
+ """
+ Return the set of ``raise ValueError(...)`` message templates for
*class_name*.
+
+ Scans two sources:
+
+ 1. The class body itself (all depths) — covers validator methods and nested
+ helpers declared inside the class.
+ 2. Module-level functions referenced via ``converter=<Name>`` in
+ ``attrs.field(...)`` declarations inside the class body. These
+ converters live outside the class but are logically part of its
+ construction-time validation.
+
+ For each ``raise ValueError(...)`` whose first argument is a plain string
+ constant or an f-string, extracts the template:
+
+ - ``ast.Constant`` → the literal string value.
+ - ``ast.JoinedStr`` (f-string) → literal fragments joined, interpolated
+ expressions replaced by ``{}``.
+ """
+ tree = ast.parse(file_path.read_text(encoding="utf-8"),
filename=str(file_path))
+
+ target_class: ast.ClassDef | None = None
+ for node in ast.walk(tree):
+ if isinstance(node, ast.ClassDef) and node.name == class_name:
+ target_class = node
+ break
+
+ if target_class is None:
+ return set()
+
+ # 1. Collect messages from the class body.
+ messages = _extract_raise_messages(target_class)
+
+ # 2. Follow converter= references to module-level functions.
+ converter_names = _collect_converter_names(target_class)
+ if converter_names:
+ # Only look at top-level (module-body) function definitions to avoid
+ # accidentally matching a same-named method inside another class.
+ for node in tree.body:
+ if isinstance(node, ast.FunctionDef) and node.name in
converter_names:
+ messages |= _extract_raise_messages(node)
+
+ return messages
+
+
+def extract_class_field_names(file_path: Path, class_name: str) -> set[str]:
+ """
+ Return the set of annotated attribute names declared in *class_name*'s
body.
+
+ Collects names from annotated assignments (``name: type = ...`` or
+ ``name: type``) at the top level of the class body. ClassVar annotations
+ are excluded because they are class-level constants, not instance fields.
+ """
+ tree = ast.parse(file_path.read_text(encoding="utf-8"),
filename=str(file_path))
+ fields: set[str] = set()
+
+ for node in ast.walk(tree):
+ if not (isinstance(node, ast.ClassDef) and node.name == class_name):
+ continue
+ for stmt in node.body:
+ if not isinstance(stmt, ast.AnnAssign):
+ continue
+ if not isinstance(stmt.target, ast.Name):
+ continue
+ # Skip ClassVar[...] annotations
+ ann = stmt.annotation
+ is_classvar = False
+ if isinstance(ann, ast.Subscript):
+ if isinstance(ann.value, ast.Name) and ann.value.id ==
"ClassVar":
+ is_classvar = True
+ elif isinstance(ann.value, ast.Attribute) and ann.value.attr
== "ClassVar":
+ is_classvar = True
+ elif isinstance(ann, ast.Name) and ann.id == "ClassVar":
+ is_classvar = True
+ if is_classvar:
+ continue
+ fields.add(stmt.target.id)
+ break
+
+ return fields
+
+
def main() -> int:
+ failed = False
+
try:
core_mapping = _extract_mapping(CORE_FILE)
sdk_mapping = _extract_mapping(SDK_FILE)
@@ -121,22 +324,56 @@ def main() -> int:
console.print(f"[red]Could not read the default mapper table:[/red]
{exc}")
return 1
- if core_mapping == sdk_mapping:
- return 0
-
- console.print(f"[red]{CLASS_NAME}.{ATTR_NAME} is out of sync between core
and the Task SDK.[/red]\n")
- all_windows = sorted(core_mapping.keys() | sdk_mapping.keys())
- for window in all_windows:
- core_val = core_mapping.get(window, "<missing>")
- sdk_val = sdk_mapping.get(window, "<missing>")
- marker = " " if core_val == sdk_val else "->"
- color = "" if core_val == sdk_val else "[red]"
- end = "" if core_val == sdk_val else "[/red]"
- console.print(f"{color}{marker} {window}: core={core_val}
sdk={sdk_val}{end}")
- console.print(
- f"\nMake both tables list the same window -> mapper entries:\n core:
{CORE_FILE}\n sdk: {SDK_FILE}"
- )
- return 1
+ if core_mapping != sdk_mapping:
+ console.print(f"[red]{CLASS_NAME}.{ATTR_NAME} is out of sync between
core and the Task SDK.[/red]\n")
+ all_windows = sorted(core_mapping.keys() | sdk_mapping.keys())
+ for window in all_windows:
+ core_val = core_mapping.get(window, "<missing>")
+ sdk_val = sdk_mapping.get(window, "<missing>")
+ marker = " " if core_val == sdk_val else "->"
+ color = "" if core_val == sdk_val else "[red]"
+ end = "" if core_val == sdk_val else "[/red]"
+ console.print(f"{color}{marker} {window}: core={core_val}
sdk={sdk_val}{end}")
+ console.print(
+ f"\nMake both tables list the same window -> mapper entries:\n
core: {CORE_FILE}\n sdk: {SDK_FILE}"
+ )
+ failed = True
+
+ checks = [
+ ("SegmentWindow", CORE_WINDOW_FILE, SDK_WINDOW_FILE),
+ ("FixedKeyMapper", CORE_FIXED_KEY_FILE, SDK_FIXED_KEY_FILE),
+ ]
+
+ for class_name, core_file, sdk_file in checks:
+ core_fields = extract_class_field_names(core_file, class_name)
+ sdk_fields = extract_class_field_names(sdk_file, class_name)
+ if core_fields != sdk_fields:
+ console.print(f"[red]{class_name}: field names are out of sync
between core and SDK.[/red]")
+ core_only = core_fields - sdk_fields
+ sdk_only = sdk_fields - core_fields
+ if core_only:
+ console.print(f" core-only fields: {sorted(core_only)}")
+ if sdk_only:
+ console.print(f" sdk-only fields: {sorted(sdk_only)}")
+ console.print(f" core: {core_file}\n sdk: {sdk_file}")
+ failed = True
+
+ core_msgs = extract_class_error_messages(core_file, class_name)
+ sdk_msgs = extract_class_error_messages(sdk_file, class_name)
+ if core_msgs != sdk_msgs:
+ console.print(
+ f"[red]{class_name}: raise ValueError(...) message templates
are out of sync.[/red]"
+ )
+ core_only = core_msgs - sdk_msgs
+ sdk_only = sdk_msgs - core_msgs
+ if core_only:
+ console.print(f" core-only messages: {sorted(core_only)}")
+ if sdk_only:
+ console.print(f" sdk-only messages: {sorted(sdk_only)}")
+ console.print(f" core: {core_file}\n sdk: {sdk_file}")
+ failed = True
+
+ return 1 if failed else 0
if __name__ == "__main__":
diff --git
a/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py
b/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py
new file mode 100644
index 00000000000..aa1742d33a0
--- /dev/null
+++ b/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py
@@ -0,0 +1,388 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import textwrap
+from pathlib import Path
+
+import pytest
+from check_partition_mapper_defaults_in_sync import (
+ extract_class_error_messages,
+ extract_class_field_names,
+)
+
+
+class TestExtractClassErrorMessages:
+ @pytest.mark.parametrize(
+ ("raise_stmt", "expected"),
+ [
+ pytest.param('raise ValueError("msg a")', {"msg a"},
id="plain-string"),
+ pytest.param('raise ValueError(f"msg b at {i}")', {"msg b at {}"},
id="fstring-template"),
+ ],
+ )
+ def test_extracts_message(self, tmp_path: Path, raise_stmt: str, expected:
set[str]):
+ """Plain strings are kept verbatim; f-string interpolations become {}
placeholders."""
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent(f"""\
+ class MyClass:
+ def validate(self, x, i):
+ {raise_stmt}
+ """)
+ )
+ assert extract_class_error_messages(f, "MyClass") == expected
+
+ def test_extracts_both_plain_and_fstring(self, tmp_path: Path):
+ """Covers the explicit test requirement: plain + f-string together."""
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent("""\
+ class MyClass:
+ def check(self, item, i):
+ raise ValueError("msg a")
+
+ def check2(self, item, i):
+ raise ValueError(f"msg b at {i}")
+ """)
+ )
+ result = extract_class_error_messages(f, "MyClass")
+ assert result == {"msg a", "msg b at {}"}
+
+ def test_ignores_other_exception_types(self, tmp_path: Path):
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent("""\
+ class MyClass:
+ def check(self):
+ raise TypeError("not a value error")
+ """)
+ )
+ result = extract_class_error_messages(f, "MyClass")
+ assert result == set()
+
+ def test_returns_empty_for_missing_class(self, tmp_path: Path):
+ f = tmp_path / "code.py"
+ f.write_text("x = 1\n")
+ result = extract_class_error_messages(f, "Missing")
+ assert result == set()
+
+ def test_nested_function_messages_included(self, tmp_path: Path):
+ """Messages raised inside nested helpers within the class body are
collected."""
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent("""\
+ class MyClass:
+ def __init__(self, items):
+ def _check(i, item):
+ if not isinstance(item, str):
+ raise ValueError(f"must be str; got
{type(item).__name__!r} at {i}")
+ for i, item in enumerate(items):
+ _check(i, item)
+ """)
+ )
+ result = extract_class_error_messages(f, "MyClass")
+ assert result == {"must be str; got {} at {}"}
+
+ def test_multipart_fstring_message(self, tmp_path: Path):
+ """f-string with multiple literal parts and multiple interpolations."""
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent("""\
+ class C:
+ def v(self, item, i):
+ raise ValueError(
+ f"Prefix segment keys must be str; "
+ f"got {type(item).__name__!r} at index {i}:
{item!r}"
+ )
+ """)
+ )
+ result = extract_class_error_messages(f, "C")
+ # The two adjacent f-strings form a single BinOp(Add) node at AST
level;
+ # the extractor concatenates them into one template.
+ assert "Prefix segment keys must be str; got {} at index {}: {}" in
result
+
+
+class TestExtractClassFieldNames:
+ def test_extracts_annotated_instance_fields(self, tmp_path: Path):
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent("""\
+ from typing import ClassVar
+ import attrs
+
+ @attrs.define
+ class MyClass:
+ expected_decoded_type: ClassVar[type] = str
+ _segments: frozenset[str] = attrs.field()
+ """)
+ )
+ result = extract_class_field_names(f, "MyClass")
+ # ClassVar is excluded; _segments is included
+ assert result == {"_segments"}
+
+ def test_excludes_classvar_fields(self, tmp_path: Path):
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent("""\
+ from typing import ClassVar
+
+ class MyClass:
+ flag: ClassVar[bool] = False
+ name: str
+ """)
+ )
+ result = extract_class_field_names(f, "MyClass")
+ assert result == {"name"}
+
+ def test_returns_empty_for_missing_class(self, tmp_path: Path):
+ f = tmp_path / "code.py"
+ f.write_text("x = 1\n")
+ result = extract_class_field_names(f, "Missing")
+ assert result == set()
+
+ def test_private_field_name_preserved(self, tmp_path: Path):
+ """Leading underscore in field name is kept verbatim."""
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent("""\
+ import attrs
+
+ @attrs.define
+ class Container:
+ _segments: frozenset[str] = attrs.field()
+ """)
+ )
+ result = extract_class_field_names(f, "Container")
+ assert "_segments" in result
+
+
+class TestInSyncPasses:
+ def test_in_sync_passes(self, tmp_path: Path):
+ """Two files with identical field names and messages compare equal."""
+ code = textwrap.dedent("""\
+ import attrs
+
+ @attrs.define
+ class MyMapper:
+ downstream_key: str = attrs.field()
+
+ def validate(self, x):
+ raise ValueError(f"must be non-empty str; got {x!r}.")
+ """)
+ core_file = tmp_path / "core.py"
+ sdk_file = tmp_path / "sdk.py"
+ core_file.write_text(code)
+ sdk_file.write_text(code)
+
+ core_fields = extract_class_field_names(core_file, "MyMapper")
+ sdk_fields = extract_class_field_names(sdk_file, "MyMapper")
+ assert core_fields == sdk_fields
+
+ core_msgs = extract_class_error_messages(core_file, "MyMapper")
+ sdk_msgs = extract_class_error_messages(sdk_file, "MyMapper")
+ assert core_msgs == sdk_msgs
+
+
+class TestConverterPatternExtraction:
+ """Tests for the module-level converter= follow-through in
extract_class_error_messages."""
+
+ def test_extracts_messages_from_module_level_converter(self, tmp_path:
Path):
+ """Converter function outside the class body is followed and its
messages are collected."""
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent("""\
+ import attrs
+
+ def _my_convert(items):
+ for i, item in enumerate(items):
+ if not isinstance(item, str):
+ raise ValueError(f"must be str; got
{type(item).__name__!r} at {i}")
+ if not items:
+ raise ValueError("must not be empty")
+ return frozenset(items)
+
+ @attrs.define
+ class MyClass:
+ _data: frozenset = attrs.field(converter=_my_convert)
+ """)
+ )
+ result = extract_class_error_messages(f, "MyClass")
+ assert "must be str; got {} at {}" in result
+ assert "must not be empty" in result
+
+ def test_converter_fstring_template_placeholders(self, tmp_path: Path):
+ """f-string expressions in the converter become {} placeholders in the
template."""
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent("""\
+ import attrs
+
+ def _convert(items):
+ for i, item in enumerate(items):
+ raise ValueError(f"bad item {item!r} at index {i}")
+ return frozenset(items)
+
+ @attrs.define
+ class Widget:
+ _items: frozenset = attrs.field(converter=_convert)
+ """)
+ )
+ result = extract_class_error_messages(f, "Widget")
+ assert "bad item {} at index {}" in result
+
+ def test_converter_messages_not_collected_for_unrelated_class(self,
tmp_path: Path):
+ """The converter is only followed for the class that references it."""
+ f = tmp_path / "code.py"
+ f.write_text(
+ textwrap.dedent("""\
+ import attrs
+
+ def _my_convert(items):
+ raise ValueError("converter error")
+ return frozenset(items)
+
+ @attrs.define
+ class ClassA:
+ _data: frozenset = attrs.field(converter=_my_convert)
+
+ @attrs.define
+ class ClassB:
+ _name: str = attrs.field()
+ """)
+ )
+ # ClassA references the converter — should see the message
+ result_a = extract_class_error_messages(f, "ClassA")
+ assert "converter error" in result_a
+
+ # ClassB does not reference the converter — should NOT see the message
+ result_b = extract_class_error_messages(f, "ClassB")
+ assert "converter error" not in result_b
+
+
+class TestConverterDivergenceDetected:
+ """Verify that message drift in a module-level converter is caught by the
extractor."""
+
+ def test_divergent_converter_message_detected(self, tmp_path: Path):
+ """Changing 'non-empty' to 'non-empty strings' in a converter is
detected as drift."""
+ core_code = textwrap.dedent("""\
+ import attrs
+
+ def _convert_segments(segments):
+ for i, item in enumerate(segments):
+ if not item:
+ raise ValueError(f"keys must be non-empty; got empty
at {i}.")
+ return frozenset(segments)
+
+ @attrs.define
+ class SegmentWindow:
+ _segments: frozenset = attrs.field(converter=_convert_segments)
+ """)
+ sdk_code_diverged = textwrap.dedent("""\
+ import attrs
+
+ def _convert_segments(segments):
+ for i, item in enumerate(segments):
+ if not item:
+ raise ValueError(f"keys must be non-empty strings; got
empty at {i}.")
+ return frozenset(segments)
+
+ @attrs.define
+ class SegmentWindow:
+ _segments: frozenset = attrs.field(converter=_convert_segments)
+ """)
+ core_file = tmp_path / "core.py"
+ sdk_file = tmp_path / "sdk.py"
+ core_file.write_text(core_code)
+ sdk_file.write_text(sdk_code_diverged)
+
+ core_msgs = extract_class_error_messages(core_file, "SegmentWindow")
+ sdk_msgs = extract_class_error_messages(sdk_file, "SegmentWindow")
+ assert core_msgs != sdk_msgs
+
+
+class TestDivergentMessageFails:
+ def test_divergent_message_detected(self, tmp_path: Path):
+ """Changing 'non-empty' to 'non-empty strings' on one side is
detected."""
+ core_code = textwrap.dedent("""\
+ import attrs
+
+ @attrs.define
+ class SegmentWindow:
+ _segments: frozenset[str] = attrs.field()
+
+ def validate(self, item, i):
+ raise ValueError(f"keys must be non-empty; got empty at
{i}.")
+ """)
+ sdk_code_diverged = textwrap.dedent("""\
+ import attrs
+
+ @attrs.define
+ class SegmentWindow:
+ _segments: frozenset[str] = attrs.field()
+
+ def validate(self, item, i):
+ raise ValueError(f"keys must be non-empty strings; got
empty at {i}.")
+ """)
+ core_file = tmp_path / "core.py"
+ sdk_file = tmp_path / "sdk.py"
+ core_file.write_text(core_code)
+ sdk_file.write_text(sdk_code_diverged)
+
+ core_msgs = extract_class_error_messages(core_file, "SegmentWindow")
+ sdk_msgs = extract_class_error_messages(sdk_file, "SegmentWindow")
+ assert core_msgs != sdk_msgs
+
+ @pytest.mark.parametrize(
+ ("core_msg", "sdk_msg"),
+ [
+ pytest.param(
+ "keys must be non-empty; got empty at {}.",
+ "keys must be non-empty strings; got empty at {}.",
+ id="non-empty-vs-non-empty-strings",
+ ),
+ pytest.param(
+ "requires at least one key; got an empty iterable.",
+ "requires at least one key.",
+ id="different-constant-wording",
+ ),
+ ],
+ )
+ def test_parametrized_divergence(self, tmp_path: Path, core_msg: str,
sdk_msg: str):
+ def _make_file(path: Path, msg: str) -> None:
+ # Use a plain f-string if the message contains '{}', else a
constant.
+ if "{}" in msg:
+ # Reconstruct as f-string source: replace {} with {i}
+ src_msg = msg.replace("{}", "{i}")
+ stmt = f'raise ValueError(f"{src_msg}")'
+ else:
+ stmt = f'raise ValueError("{msg}")'
+ path.write_text(
+ textwrap.dedent(f"""\
+ class C:
+ def v(self, i):
+ {stmt}
+ """)
+ )
+
+ core_file = tmp_path / "core.py"
+ sdk_file = tmp_path / "sdk.py"
+ _make_file(core_file, core_msg)
+ _make_file(sdk_file, sdk_msg)
+
+ core_msgs = extract_class_error_messages(core_file, "C")
+ sdk_msgs = extract_class_error_messages(sdk_file, "C")
+ assert core_msgs != sdk_msgs
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 8d878978980..cd9cec283c5 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -247,6 +247,8 @@ Partition Mapper
.. autoapiclass:: airflow.sdk.FanOutMapper
+.. autoapiclass:: airflow.sdk.FixedKeyMapper
+
Rollup Windows
~~~~~~~~~~~~~~
@@ -264,6 +266,8 @@ Rollup Windows
.. autoapiclass:: airflow.sdk.YearWindow
+.. autoapiclass:: airflow.sdk.SegmentWindow
+
I/O Helpers
-----------
.. autoapiclass:: airflow.sdk.ObjectStoragePath
diff --git a/task-sdk/src/airflow/sdk/__init__.py
b/task-sdk/src/airflow/sdk/__init__.py
index ce5b8ab2c2e..00d724cb42d 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -55,6 +55,7 @@ __all__ = [
"EventsTimetable",
"ExceptionRetryPolicy",
"FanOutMapper",
+ "FixedKeyMapper",
"HourWindow",
"IdentityMapper",
"Label",
@@ -77,6 +78,7 @@ __all__ = [
"RetryPolicy",
"RetryRule",
"RollupMapper",
+ "SegmentWindow",
"SkipMixin",
"SyncCallback",
"StartOfDayMapper",
@@ -154,6 +156,7 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.partition_mappers.allowed_key import
AllowedKeyMapper
from airflow.sdk.definitions.partition_mappers.base import
PartitionMapper, RollupMapper
from airflow.sdk.definitions.partition_mappers.chain import ChainMapper
+ from airflow.sdk.definitions.partition_mappers.fixed_key import
FixedKeyMapper
from airflow.sdk.definitions.partition_mappers.identity import
IdentityMapper
from airflow.sdk.definitions.partition_mappers.product import ProductMapper
from airflow.sdk.definitions.partition_mappers.temporal import (
@@ -170,6 +173,7 @@ if TYPE_CHECKING:
HourWindow,
MonthWindow,
QuarterWindow,
+ SegmentWindow,
WeekWindow,
Window,
YearWindow,
@@ -244,6 +248,7 @@ __lazy_imports: dict[str, str] = {
"EventsTimetable": ".definitions.timetables.events",
"ExceptionRetryPolicy": ".definitions.retry_policy",
"FanOutMapper": ".definitions.partition_mappers.temporal",
+ "FixedKeyMapper": ".definitions.partition_mappers.fixed_key",
"HourWindow": ".definitions.partition_mappers.window",
"IdentityMapper": ".definitions.partition_mappers.identity",
"Label": ".definitions.edges",
@@ -266,6 +271,7 @@ __lazy_imports: dict[str, str] = {
"RetryRule": ".definitions.retry_policy",
"RollupMapper": ".definitions.partition_mappers.base",
"SecretCache": ".execution_time.cache",
+ "SegmentWindow": ".definitions.partition_mappers.window",
"SkipMixin": ".bases.skipmixin",
"SyncCallback": ".definitions.callback",
"StartOfDayMapper": ".definitions.partition_mappers.temporal",
diff --git a/task-sdk/src/airflow/sdk/__init__.pyi
b/task-sdk/src/airflow/sdk/__init__.pyi
index 1bb975e2202..d6fa8bc2a0d 100644
--- a/task-sdk/src/airflow/sdk/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/__init__.pyi
@@ -67,6 +67,7 @@ from airflow.sdk.definitions.param import Param as Param
from airflow.sdk.definitions.partition_mappers.allowed_key import
AllowedKeyMapper
from airflow.sdk.definitions.partition_mappers.base import PartitionMapper,
RollupMapper
from airflow.sdk.definitions.partition_mappers.chain import ChainMapper
+from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper
from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper
from airflow.sdk.definitions.partition_mappers.product import ProductMapper
from airflow.sdk.definitions.partition_mappers.temporal import (
@@ -83,6 +84,7 @@ from airflow.sdk.definitions.partition_mappers.window import (
HourWindow,
MonthWindow,
QuarterWindow,
+ SegmentWindow,
WeekWindow,
Window,
YearWindow,
@@ -154,6 +156,7 @@ __all__ = [
"EventsTimetable",
"ExceptionRetryPolicy",
"FanOutMapper",
+ "FixedKeyMapper",
"HourWindow",
"IdentityMapper",
"Label",
@@ -175,6 +178,7 @@ __all__ = [
"ResumableJobMixin",
"RollupMapper",
"SecretCache",
+ "SegmentWindow",
"SkipMixin",
"StartOfDayMapper",
"StartOfHourMapper",
diff --git
a/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py
new file mode 100644
index 00000000000..bf2b07905a8
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import attrs
+
+from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
+
+
[email protected]
+class FixedKeyMapper(PartitionMapper):
+ """
+ Collapse every upstream partition key onto one fixed downstream key.
+
+ Authoring marker for the scheduler-side
+ :class:`airflow.partition_mappers.fixed_key.FixedKeyMapper`. Paired with
+ :class:`~airflow.sdk.definitions.partition_mappers.window.SegmentWindow`
inside a
+ :class:`~airflow.sdk.definitions.partition_mappers.base.RollupMapper` to
express a
+ categorical rollup.
+
+ Construction validates *downstream_key* so Dag parse errors surface
+ immediately rather than deferring to scheduler deserialization.
+
+ :param downstream_key: The fixed downstream partition key every upstream
key
+ maps to. Must be a non-empty string.
+ :raises ValueError: if *downstream_key* is not a non-empty ``str``.
+ """
+
+ downstream_key: str = attrs.field()
+
+ @downstream_key.validator
+ def _validate_downstream_key(self, attribute: attrs.Attribute, value: str)
-> None:
+ if not isinstance(value, str) or value == "":
+ raise ValueError(f"FixedKeyMapper downstream_key must be a
non-empty str; got {value!r}.")
+
+ def to_downstream(self, key: str) -> str:
+ """Return the fixed downstream key regardless of *key*."""
+ return self.downstream_key
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
index 4ea008ace6f..f491d4279ec 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
@@ -22,7 +22,12 @@ from __future__ import annotations
from datetime import datetime
from enum import Enum
-from typing import Any, ClassVar
+from typing import TYPE_CHECKING, Any, ClassVar
+
+import attrs
+
+if TYPE_CHECKING:
+ from collections.abc import Iterable
class Window:
@@ -114,3 +119,52 @@ class YearWindow(Window):
"""Twelve consecutive monthly keys making up one calendar year."""
expected_decoded_type: ClassVar[type] = datetime
+
+
+def _convert_segments(segments: Iterable[str]) -> frozenset[str]:
+ """
+ Validate and convert *segments* to a ``frozenset[str]``.
+
+ Validates each element for type and non-emptiness (with index reporting)
+ before collapsing into a frozenset, then checks the result is non-empty.
+ """
+ validated: list[str] = []
+ for i, item in enumerate(segments):
+ if not isinstance(item, str):
+ raise ValueError(
+ f"SegmentWindow segment keys must be str; got
{type(item).__name__!r} at index {i}: {item!r}"
+ )
+ if not item:
+ raise ValueError(
+ f"SegmentWindow segment keys must be non-empty; got an empty
string at index {i}."
+ )
+ validated.append(item)
+ result = frozenset(validated)
+ if not result:
+ raise ValueError("SegmentWindow requires at least one segment key; got
an empty iterable.")
+ return result
+
+
[email protected]
+class SegmentWindow(Window):
+ """
+ A fixed categorical set of string keys that constitute one downstream
period.
+
+ Authoring marker for the scheduler-side
+ :class:`airflow.partition_mappers.window.SegmentWindow`. Paired with
+
:class:`~airflow.sdk.definitions.partition_mappers.fixed_key.FixedKeyMapper`
inside a
+ :class:`~airflow.sdk.definitions.partition_mappers.base.RollupMapper` to
express a
+ categorical rollup.
+
+ Construction validates the segment list so Dag parse errors surface
+ immediately rather than deferring to scheduler deserialization.
+
+ :param segments: Non-empty iterable of non-empty string segment keys.
Duplicates
+ are silently de-duplicated.
+ :raises ValueError: if *segments* is empty, contains a non-``str``
element, or
+ contains an empty-string element.
+ """
+
+ expected_decoded_type: ClassVar[type] = str
+
+ _segments: frozenset[str] = attrs.field(converter=_convert_segments)
diff --git a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
index 8b902820da0..7114b65d6c4 100644
--- a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
+++ b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
@@ -22,12 +22,14 @@ from typing import ClassVar
import pytest
from airflow.sdk.definitions.partition_mappers.base import PartitionMapper,
RollupMapper
+from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper
from airflow.sdk.definitions.partition_mappers.temporal import StartOfDayMapper
from airflow.sdk.definitions.partition_mappers.window import (
DayWindow,
HourWindow,
MonthWindow,
QuarterWindow,
+ SegmentWindow,
WeekWindow,
Window,
YearWindow,
@@ -103,3 +105,67 @@ class TestSdkWindowExpectedDecodedType:
)
def test_temporal_windows_declare_datetime(self, window_cls):
assert window_cls.expected_decoded_type is datetime
+
+
+class TestSdkFixedKeyMapper:
+ """SDK-side FixedKeyMapper construction and validation."""
+
+ @pytest.mark.parametrize("key", ["us", "eu", "apac"])
+ def test_to_downstream_returns_constant_for_any_key(self, key):
+ assert FixedKeyMapper("all_regions").to_downstream(key) ==
"all_regions"
+
+ def test_is_rollup_false(self):
+ assert FixedKeyMapper("all").is_rollup is False
+
+ @pytest.mark.parametrize(
+ ("downstream_key", "match"),
+ [
+ pytest.param("", "non-empty str", id="empty-string"),
+ pytest.param(None, "non-empty str", id="none"),
+ pytest.param(1, "non-empty str", id="int"),
+ ],
+ )
+ def test_rejects_invalid_downstream_key(self, downstream_key, match):
+ with pytest.raises(ValueError, match=match):
+ FixedKeyMapper(downstream_key)
+
+ def test_requires_downstream_key(self):
+ with pytest.raises(TypeError):
+ FixedKeyMapper()
+
+
+class TestSdkSegmentWindow:
+ """SDK-side SegmentWindow construction and validation mirrors the core
implementation."""
+
+ def test_expected_decoded_type_is_str(self):
+ assert SegmentWindow.expected_decoded_type is str
+
+ def test_deduplication(self):
+ w = SegmentWindow(["a", "b", "a"])
+ assert w._segments == frozenset({"a", "b"})
+
+ @pytest.mark.parametrize(
+ ("segments", "match"),
+ [
+ pytest.param([], "at least one segment key", id="empty-list"),
+ pytest.param([1, "b"], "must be str", id="int-element"),
+ pytest.param(["", "b"], "non-empty", id="empty-string"),
+ ],
+ )
+ def test_rejects_invalid_segments(self, segments, match):
+ with pytest.raises(ValueError, match=match):
+ SegmentWindow(segments)
+
+
+class TestSdkCategoricalRollupGuard:
+ """SDK-side RollupMapper guard mirrors core: str mapper + str window
passes."""
+
+ def test_fixed_key_with_segment_window_does_not_raise(self):
+ # SDK guard: FixedKeyMapper.expected_decoded_type is str,
+ # SegmentWindow.expected_decoded_type is str -> guard passes.
+ RollupMapper(upstream_mapper=FixedKeyMapper("all"),
window=SegmentWindow(["us", "eu"]))
+
+ def test_str_mapper_with_datetime_window_raises(self):
+ # SDK guard: FixedKeyMapper (str) + DayWindow (datetime) -> raise.
+ with pytest.raises(TypeError, match="DayWindow expects decoded values
of type 'datetime'"):
+ RollupMapper(upstream_mapper=FixedKeyMapper("all"),
window=DayWindow())