This is an automated email from the ASF dual-hosted git repository.

Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f690474a26a Add FixedKeyMapper and SegmentWindow for categorical 
asset-partition rollup (#67716)
f690474a26a is described below

commit f690474a26a18f34e95469a5c28369b87c8f3838
Author: Wei Lee <[email protected]>
AuthorDate: Wed Jun 10 18:06:05 2026 +0800

    Add FixedKeyMapper and SegmentWindow for categorical asset-partition rollup 
(#67716)
---
 .pre-commit-config.yaml                            |   8 +-
 .../docs/authoring-and-scheduling/assets.rst       |  83 +++++
 airflow-core/newsfragments/67716.feature.rst       |   1 +
 .../example_dags/example_asset_partition.py        |  40 +++
 .../src/airflow/partition_mappers/__init__.py      |   4 +
 .../src/airflow/partition_mappers/fixed_key.py     |  65 ++++
 .../src/airflow/partition_mappers/window.py        |  62 ++++
 airflow-core/src/airflow/serialization/encoders.py |  12 +
 airflow-core/tests/unit/jobs/test_scheduler_job.py |  70 ++++
 .../tests/unit/partition_mappers/test_fixed_key.py | 120 +++++++
 .../tests/unit/partition_mappers/test_window.py    |  42 +++
 .../check_partition_mapper_defaults_in_sync.py     | 291 ++++++++++++++--
 ...test_check_partition_mapper_defaults_in_sync.py | 388 +++++++++++++++++++++
 task-sdk/docs/api.rst                              |   4 +
 task-sdk/src/airflow/sdk/__init__.py               |   6 +
 task-sdk/src/airflow/sdk/__init__.pyi              |   4 +
 .../sdk/definitions/partition_mappers/fixed_key.py |  52 +++
 .../sdk/definitions/partition_mappers/window.py    |  56 ++-
 .../task_sdk/definitions/test_partition_mappers.py |  66 ++++
 19 files changed, 1344 insertions(+), 30 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1de10325dd6..2be1cba4bd2 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -249,13 +249,17 @@ repos:
         pass_filenames: false
         require_serial: true
       - id: check-partition-mapper-defaults-in-sync
-        name: Check FanOutMapper default mapper table stays in sync (core/SDK)
+        name: Check partition-mapper core/SDK sync (FanOutMapper table + 
SegmentWindow/FixedKeyMapper)
         entry: ./scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
         language: python
         files: >
           (?x)
           ^airflow-core/src/airflow/partition_mappers/temporal\.py$|
-          ^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$
+          ^airflow-core/src/airflow/partition_mappers/window\.py$|
+          ^airflow-core/src/airflow/partition_mappers/fixed_key\.py$|
+          
^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$|
+          ^task-sdk/src/airflow/sdk/definitions/partition_mappers/window\.py$|
+          
^task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key\.py$
         pass_filenames: false
         require_serial: true
       - id: check-window-in-sync
diff --git a/airflow-core/docs/authoring-and-scheduling/assets.rst 
b/airflow-core/docs/authoring-and-scheduling/assets.rst
index 7bf609e9d98..2bc06a30117 100644
--- a/airflow-core/docs/authoring-and-scheduling/assets.rst
+++ b/airflow-core/docs/authoring-and-scheduling/assets.rst
@@ -565,6 +565,12 @@ downstream Dag partition key:
   passes the key through unchanged if valid.
   For example, ``AllowedKeyMapper(["us", "eu", "apac"])`` accepts only those
   region keys and rejects all others.
+* ``FixedKeyMapper`` collapses every upstream key onto a fixed downstream key,
+  regardless of the upstream value.
+* ``SegmentWindow`` declares a fixed categorical set of string keys (e.g. 
regions,
+  tenants) that constitute one downstream period; paired with 
``FixedKeyMapper``
+  inside a ``RollupMapper`` it holds the downstream run until every declared 
segment
+  has arrived (see :ref:`segment-rollup <segment-categorical-rollup>`).
 
 Example of per-asset mapper configuration and composite-key mapping:
 
@@ -733,6 +739,83 @@ so the run is held indefinitely) and the fall-back day has 
twenty-five (the repe
 hour is dropped). Use a UTC-based upstream mapper for any rollup that crosses 
a DST
 boundary; see the ``DayWindow`` class docstring for the full discussion.
 
+.. _segment-categorical-rollup:
+
+Segment (categorical) rollup
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. versionadded:: 3.3.0
+
+For categorical partitioning — regions, tenants, experiment variants — compose 
a
+``RollupMapper`` from two primitives:
+
+* ``SegmentWindow(["us", "eu", "apac"])`` declares the fixed set of string keys
+  that constitute one downstream period; ``to_upstream`` returns the full set
+  regardless of the downstream anchor.
+* ``FixedKeyMapper("all_regions")`` collapses every upstream key onto the 
single
+  downstream partition key ``"all_regions"``.
+
+The scheduler holds the downstream Dag run until every declared segment has 
arrived
+from the upstream producer, then fires once. All the segment events accumulate 
into
+one ``AssetPartitionDagRun``; the fired run's ``partition_key`` is the value 
passed
+to ``FixedKeyMapper``. This composition only makes sense under ``WAIT_FOR_ALL``
+semantics (the default).
+
+.. code-block:: python
+
+    from airflow.sdk import (
+        DAG,
+        Asset,
+        FixedKeyMapper,
+        PartitionAtRuntime,
+        PartitionedAssetTimetable,
+        RollupMapper,
+        SegmentWindow,
+        asset,
+        task,
+    )
+
+
+    @asset(
+        uri="file://incoming/player-stats/multi-region.csv",
+        schedule=PartitionAtRuntime(),
+    )
+    def multi_region_player_stats(self, outlet_events):
+        # Emit one event per region in a single run.
+        outlet_events[self].add_partitions(["us", "eu", "apac"])
+
+
+    # Consumer: fires once all three region partitions have arrived.
+    with DAG(
+        dag_id="segment_region_stats_rollup",
+        schedule=PartitionedAssetTimetable(
+            assets=Asset.ref(name="multi_region_player_stats"),
+            default_partition_mapper=RollupMapper(
+                upstream_mapper=FixedKeyMapper("all_regions"),
+                window=SegmentWindow(["us", "eu", "apac"]),
+            ),
+        ),
+        catchup=False,
+    ):
+
+        @task
+        def aggregate_all_regions(dag_run=None):
+            # dag_run.partition_key is the downstream key once all segments 
arrive.
+            print(dag_run.partition_key)
+
+        aggregate_all_regions()
+
+Construction validates both components: ``SegmentWindow`` raises 
``ValueError`` for
+an empty list, non-string items, or empty-string keys; duplicate entries are 
silently
+deduplicated. ``FixedKeyMapper`` raises ``ValueError`` if its argument is not a
+non-empty string. Pass a distinct ``FixedKeyMapper`` key when one consumer Dag 
rolls
+up more than one asset, so each rollup uses a distinct bucket and they do not 
collide
+on the same ``(target_dag_id, partition_key)``.
+
+For a segment set that must be computed at runtime, do not encode it here — 
evaluate
+completeness in a consumer-side task instead (the scheduler must not run user 
code to
+decide a partition set).
+
 Setting partition keys at runtime
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
diff --git a/airflow-core/newsfragments/67716.feature.rst 
b/airflow-core/newsfragments/67716.feature.rst
new file mode 100644
index 00000000000..63daeda113c
--- /dev/null
+++ b/airflow-core/newsfragments/67716.feature.rst
@@ -0,0 +1 @@
+Add ``FixedKeyMapper`` and ``SegmentWindow`` for categorical asset-partition 
rollup. ``FixedKeyMapper`` collapses any upstream partition key onto a single 
fixed downstream key, and ``SegmentWindow`` enumerates a fixed categorical 
segment set the scheduler waits for. Composing 
``RollupMapper(FixedKeyMapper(...), SegmentWindow(...))`` expresses a 
categorical rollup, mirroring the temporal rollup shape, and ``SegmentWindow`` 
also composes with ``FanOutMapper`` for categorical scatter. Both  [...]
diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition.py 
b/airflow-core/src/airflow/example_dags/example_asset_partition.py
index 775fea80ec3..65a078d3c3f 100644
--- a/airflow-core/src/airflow/example_dags/example_asset_partition.py
+++ b/airflow-core/src/airflow/example_dags/example_asset_partition.py
@@ -26,12 +26,14 @@ from airflow.sdk import (
     CronPartitionTimetable,
     DayWindow,
     FanOutMapper,
+    FixedKeyMapper,
     IdentityMapper,
     MonthWindow,
     PartitionAtRuntime,
     PartitionedAssetTimetable,
     ProductMapper,
     RollupMapper,
+    SegmentWindow,
     StartOfDayMapper,
     StartOfHourMapper,
     StartOfMonthMapper,
@@ -409,3 +411,41 @@ with DAG(
         print(dag_run.partition_key)
 
     run_inference()
+
+
+# --- Segment (categorical) rollup -------------------------------------------
+# ``multi_region_player_stats`` (defined above) emits one partition per region
+# (``us``, ``eu``, ``apac``) from a single run.  The Dag below holds a 
downstream
+# run until every declared region key has arrived.
+
+with DAG(
+    dag_id="segment_region_stats_rollup",
+    schedule=PartitionedAssetTimetable(
+        assets=Asset.ref(name="multi_region_player_stats"),
+        default_partition_mapper=RollupMapper(
+            upstream_mapper=FixedKeyMapper("all_regions"),
+            window=SegmentWindow(["us", "eu", "apac"]),
+        ),
+    ),
+    catchup=False,
+    tags=["example", "player-stats", "rollup", "segment"],
+):
+    """
+    Categorical rollup: hold until all three region partitions arrive.
+
+    ``RollupMapper(upstream_mapper=FixedKeyMapper("all_regions"), 
window=SegmentWindow([...]))``
+    declares the fixed set of region keys required for one downstream run and 
collapses every
+    region key onto a single ``all_regions`` partition, so the three region 
events accumulate
+    into one downstream run.  The run is held until ``us``, ``eu``, and 
``apac`` have all
+    arrived from ``multi_region_player_stats``; partial arrivals remain 
pending in the
+    next-run-assets view so operators can track progress.
+    """
+
+    @task
+    def aggregate_all_regions(dag_run=None):
+        """Produce the cross-region summary once every region partition has 
arrived."""
+        if TYPE_CHECKING:
+            assert dag_run
+        print(f"All region partitions received. Partition: 
{dag_run.partition_key}")
+
+    aggregate_all_regions()
diff --git a/airflow-core/src/airflow/partition_mappers/__init__.py 
b/airflow-core/src/airflow/partition_mappers/__init__.py
index 9b66876e0df..f7806038707 100644
--- a/airflow-core/src/airflow/partition_mappers/__init__.py
+++ b/airflow-core/src/airflow/partition_mappers/__init__.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 from airflow.partition_mappers.allowed_key import AllowedKeyMapper
 from airflow.partition_mappers.base import PartitionMapper, RollupMapper
 from airflow.partition_mappers.chain import ChainMapper
+from airflow.partition_mappers.fixed_key import FixedKeyMapper
 from airflow.partition_mappers.identity import IdentityMapper
 from airflow.partition_mappers.product import ProductMapper
 from airflow.partition_mappers.temporal import (
@@ -34,6 +35,7 @@ from airflow.partition_mappers.window import (
     HourWindow,
     MonthWindow,
     QuarterWindow,
+    SegmentWindow,
     WeekWindow,
     Window,
     YearWindow,
@@ -43,6 +45,7 @@ __all__ = [
     "AllowedKeyMapper",
     "ChainMapper",
     "DayWindow",
+    "FixedKeyMapper",
     "HourWindow",
     "IdentityMapper",
     "MonthWindow",
@@ -50,6 +53,7 @@ __all__ = [
     "ProductMapper",
     "QuarterWindow",
     "RollupMapper",
+    "SegmentWindow",
     "StartOfDayMapper",
     "StartOfHourMapper",
     "StartOfMonthMapper",
diff --git a/airflow-core/src/airflow/partition_mappers/fixed_key.py 
b/airflow-core/src/airflow/partition_mappers/fixed_key.py
new file mode 100644
index 00000000000..89388bfed12
--- /dev/null
+++ b/airflow-core/src/airflow/partition_mappers/fixed_key.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any
+
+import attrs
+
+from airflow.partition_mappers.base import PartitionMapper
+
+
[email protected]
+class FixedKeyMapper(PartitionMapper):
+    """
+    Collapse every upstream partition key onto one fixed downstream key.
+
+    Returns the same *downstream_key* for any upstream key passed to
+    ``to_downstream``. Does not override ``decode_downstream`` or
+    ``encode_upstream``, so it works with the string-based identity path and
+    satisfies :class:`~airflow.partition_mappers.base.RollupMapper`'s guard
+    when paired with :class:`~airflow.partition_mappers.window.SegmentWindow`.
+
+    Typical use is as the ``upstream_mapper`` inside a categorical rollup::
+
+        RollupMapper(
+            upstream_mapper=FixedKeyMapper("all_regions"),
+            window=SegmentWindow(["us", "eu", "apac"]),
+        )
+
+    :param downstream_key: The fixed downstream partition key every upstream 
key
+        maps to. Must be a non-empty string.
+    :raises ValueError: if *downstream_key* is not a non-empty ``str``.
+    """
+
+    downstream_key: str = attrs.field()
+
+    @downstream_key.validator
+    def _validate_downstream_key(self, attribute: attrs.Attribute, value: str) 
-> None:
+        if not isinstance(value, str) or value == "":
+            raise ValueError(f"FixedKeyMapper downstream_key must be a 
non-empty str; got {value!r}.")
+
+    def to_downstream(self, key: str) -> str:
+        """Return the fixed downstream key regardless of *key*."""
+        return self.downstream_key
+
+    def serialize(self) -> dict[str, Any]:
+        return {"downstream_key": self.downstream_key}
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> FixedKeyMapper:
+        return cls(data["downstream_key"])
diff --git a/airflow-core/src/airflow/partition_mappers/window.py 
b/airflow-core/src/airflow/partition_mappers/window.py
index 087b6180f2a..ffcf0e85ab0 100644
--- a/airflow-core/src/airflow/partition_mappers/window.py
+++ b/airflow-core/src/airflow/partition_mappers/window.py
@@ -21,6 +21,8 @@ from datetime import datetime, timedelta
 from enum import Enum
 from typing import TYPE_CHECKING, Any, ClassVar
 
+import attrs
+
 if TYPE_CHECKING:
     from collections.abc import Callable, Iterable
 
@@ -246,3 +248,63 @@ class YearWindow(Window):
     def to_upstream(self, period_start: datetime) -> Iterable[datetime]:
         _require_day_one(period_start, type(self))
         return _build_directional_steps(period_start, 12, _shift_months, 
self.direction)
+
+
+def _convert_segments(segments: Iterable[str]) -> frozenset[str]:
+    """
+    Validate and convert *segments* to a ``frozenset[str]``.
+
+    Validates each element for type and non-emptiness (with index reporting)
+    before collapsing into a frozenset, then checks the result is non-empty.
+    """
+    validated: list[str] = []
+    for i, item in enumerate(segments):
+        if not isinstance(item, str):
+            raise ValueError(
+                f"SegmentWindow segment keys must be str; got 
{type(item).__name__!r} at index {i}: {item!r}"
+            )
+        if not item:
+            raise ValueError(
+                f"SegmentWindow segment keys must be non-empty; got an empty 
string at index {i}."
+            )
+        validated.append(item)
+    result = frozenset(validated)
+    if not result:
+        raise ValueError("SegmentWindow requires at least one segment key; got 
an empty iterable.")
+    return result
+
+
[email protected]
+class SegmentWindow(Window):
+    """
+    A fixed categorical set of string keys that constitute one downstream 
period.
+
+    Paired with :class:`~airflow.partition_mappers.fixed_key.FixedKeyMapper` 
inside a
+    :class:`~airflow.partition_mappers.base.RollupMapper` to express a 
categorical
+    rollup: the scheduler holds the downstream run until every declared 
segment key
+    has arrived from the upstream producer, then fires once.
+
+    ``to_upstream`` returns the complete segment set regardless of the 
downstream
+    anchor value — the anchor is intentionally ignored because all segments 
map onto
+    a single downstream partition key, not a time-based period.
+
+    :param segments: Non-empty iterable of non-empty string segment keys. 
Duplicates
+        are silently de-duplicated.
+    :raises ValueError: if *segments* is empty, contains a non-``str`` 
element, or
+        contains an empty-string element.
+    """
+
+    expected_decoded_type: ClassVar[type] = str
+
+    _segments: frozenset[str] = attrs.field(converter=_convert_segments)
+
+    def to_upstream(self, decoded_downstream: Any) -> frozenset[str]:
+        """Return the full declared segment set, ignoring the downstream 
anchor."""
+        return self._segments
+
+    def serialize(self) -> dict[str, Any]:
+        return {"segments": sorted(self._segments)}
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> SegmentWindow:
+        return cls(data["segments"])
diff --git a/airflow-core/src/airflow/serialization/encoders.py 
b/airflow-core/src/airflow/serialization/encoders.py
index f4ec081df1b..0c2f030da8f 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -43,6 +43,7 @@ from airflow.sdk import (
     DeltaTriggerTimetable,
     EventsTimetable,
     FanOutMapper,
+    FixedKeyMapper,
     HourWindow,
     IdentityMapper,
     MonthWindow,
@@ -51,6 +52,7 @@ from airflow.sdk import (
     ProductMapper,
     QuarterWindow,
     RollupMapper,
+    SegmentWindow,
     StartOfDayMapper,
     StartOfHourMapper,
     StartOfMonthMapper,
@@ -437,6 +439,7 @@ class _Serializer:
         AllowedKeyMapper: 
"airflow.partition_mappers.allowed_key.AllowedKeyMapper",
         ChainMapper: "airflow.partition_mappers.chain.ChainMapper",
         FanOutMapper: "airflow.partition_mappers.temporal.FanOutMapper",
+        FixedKeyMapper: "airflow.partition_mappers.fixed_key.FixedKeyMapper",
         IdentityMapper: "airflow.partition_mappers.identity.IdentityMapper",
         ProductMapper: "airflow.partition_mappers.product.ProductMapper",
         RollupMapper: "airflow.partition_mappers.base.RollupMapper",
@@ -464,6 +467,10 @@ class _Serializer:
     def _(self, partition_mapper: IdentityMapper) -> dict[str, Any]:
         return {}
 
+    @serialize_partition_mapper.register
+    def _(self, partition_mapper: FixedKeyMapper) -> dict[str, Any]:
+        return {"downstream_key": partition_mapper.downstream_key}
+
     @serialize_partition_mapper.register(StartOfHourMapper)
     @serialize_partition_mapper.register(StartOfDayMapper)
     @serialize_partition_mapper.register(StartOfWeekMapper)
@@ -517,6 +524,7 @@ class _Serializer:
         WeekWindow: "airflow.partition_mappers.window.WeekWindow",
         MonthWindow: "airflow.partition_mappers.window.MonthWindow",
         QuarterWindow: "airflow.partition_mappers.window.QuarterWindow",
+        SegmentWindow: "airflow.partition_mappers.window.SegmentWindow",
         YearWindow: "airflow.partition_mappers.window.YearWindow",
     }
 
@@ -538,6 +546,10 @@ class _Serializer:
     ) -> dict[str, Any]:
         return window.serialize()
 
+    @serialize_window.register
+    def _(self, window: SegmentWindow) -> dict[str, Any]:
+        return {"segments": sorted(window._segments)}
+
 
 _serializer = _Serializer()
 
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 89383233fad..d08628a51fb 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -104,9 +104,11 @@ from airflow.sdk import (
     AssetAlias,
     AssetWatcher,
     CronPartitionTimetable,
+    FixedKeyMapper,
     HourWindow,
     IdentityMapper,
     RollupMapper,
+    SegmentWindow,
     StartOfHourMapper,
     task,
 )
@@ -10273,6 +10275,74 @@ def 
test_partitioned_dag_run_rollup_holds_until_window_complete(
     assert partition_dags == {"rollup-consumer"}
 
 
[email protected]_serialized_dag
[email protected]("clear_asset_partition_rows")
+def test_partitioned_dag_run_segment_rollup_holds_until_all_segments_arrive(
+    dag_maker: DagMaker,
+    session: Session,
+):
+    """
+    A categorical (segment) rollup fires once every declared segment has 
arrived.
+
+    ``RollupMapper(FixedKeyMapper("all_regions"), SegmentWindow([...]))`` 
collapses
+    each region key onto a single ``all_regions`` partition, so all three 
events
+    accumulate into one APDR, and holds the downstream run until ``us``, 
``eu``,
+    and ``apac`` are all present.
+    """
+    asset_1 = Asset(name="asset-1")
+    with dag_maker(
+        dag_id="segment-rollup-consumer",
+        schedule=PartitionedAssetTimetable(
+            assets=asset_1,
+            default_partition_mapper=RollupMapper(
+                upstream_mapper=FixedKeyMapper("all_regions"),
+                window=SegmentWindow(["us", "eu", "apac"]),
+            ),
+        ),
+        session=session,
+    ):
+        EmptyOperator(task_id="hi")
+    session.commit()
+
+    runner = SchedulerJobRunner(
+        job=Job(job_type=SchedulerJobRunner.job_type), 
executors=[MockExecutor(do_update=False)]
+    )
+
+    # First region arrives — only 1 / 3 segments, so the APDR must not fire.
+    # Every region collapses onto the single ``all_regions`` partition.
+    apdr = _produce_and_register_asset_event(
+        dag_id="segment-producer-us",
+        asset=asset_1,
+        partition_key="us",
+        session=session,
+        dag_maker=dag_maker,
+        expected_partition_key="all_regions",
+    )
+    partition_dags = 
runner._create_dagruns_for_partitioned_asset_dags(session=session)
+    session.refresh(apdr)
+    assert apdr.created_dag_run_id is None
+    assert partition_dags == set()
+
+    # The remaining two regions arrive — once all three segments are present 
the
+    # rollup is satisfied and the APDR creates its Dag run on the next tick. 
All
+    # three events share the one ``all_regions`` APDR.
+    for region in ("eu", "apac"):
+        sibling = _produce_and_register_asset_event(
+            dag_id=f"segment-producer-{region}",
+            asset=asset_1,
+            partition_key=region,
+            session=session,
+            dag_maker=dag_maker,
+            expected_partition_key="all_regions",
+        )
+        assert sibling.id == apdr.id
+    partition_dags = 
runner._create_dagruns_for_partitioned_asset_dags(session=session)
+    session.refresh(apdr)
+    assert apdr.created_dag_run_id is not None
+    assert apdr.partition_key == "all_regions"
+    assert partition_dags == {"segment-rollup-consumer"}
+
+
 @pytest.mark.need_serialized_dag
 @pytest.mark.usefixtures("clear_asset_partition_rows")
 def test_partitioned_dag_run_rollup_treats_mapper_exception_as_not_satisfied(
diff --git a/airflow-core/tests/unit/partition_mappers/test_fixed_key.py 
b/airflow-core/tests/unit/partition_mappers/test_fixed_key.py
new file mode 100644
index 00000000000..efdd5be3313
--- /dev/null
+++ b/airflow-core/tests/unit/partition_mappers/test_fixed_key.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.partition_mappers.base import PartitionMapper, RollupMapper
+from airflow.partition_mappers.fixed_key import FixedKeyMapper
+from airflow.partition_mappers.window import DayWindow, SegmentWindow
+from airflow.sdk import (
+    FixedKeyMapper as SdkFixedKeyMapper,
+    RollupMapper as SdkRollupMapper,
+    SegmentWindow as SdkSegmentWindow,
+)
+from airflow.serialization.decoders import decode_partition_mapper
+from airflow.serialization.encoders import encode_partition_mapper
+
+
+class TestFixedKeyMapper:
+    @pytest.mark.parametrize("key", ["us", "eu", "apac", "anything-else"])
+    def test_to_downstream_returns_constant_for_any_key(self, key):
+        assert FixedKeyMapper("all").to_downstream(key) == "all"
+
+    def test_is_rollup_false(self):
+        # A bare FixedKeyMapper is not a rollup; rollup-ness comes from 
RollupMapper.
+        assert FixedKeyMapper("all").is_rollup is False
+
+    def test_does_not_override_decode_encode(self):
+        m = FixedKeyMapper("all")
+        assert type(m).decode_downstream is PartitionMapper.decode_downstream
+        assert type(m).encode_upstream is PartitionMapper.encode_upstream
+
+    @pytest.mark.parametrize(
+        ("downstream_key", "match"),
+        [
+            pytest.param("", "non-empty str", id="empty-string"),
+            pytest.param(None, "non-empty str", id="none"),
+            pytest.param(1, "non-empty str", id="int"),
+        ],
+    )
+    def test_rejects_invalid_downstream_key(self, downstream_key, match):
+        with pytest.raises(ValueError, match=match):
+            FixedKeyMapper(downstream_key)
+
+    def test_requires_downstream_key(self):
+        with pytest.raises(TypeError):
+            FixedKeyMapper()
+
+    def test_serialize_round_trip(self):
+        m = FixedKeyMapper("bucket")
+        restored = FixedKeyMapper.deserialize(m.serialize())
+        assert isinstance(restored, FixedKeyMapper)
+        assert restored.downstream_key == "bucket"
+
+
+class TestCategoricalRollupEquivalence:
+    """RollupMapper(FixedKeyMapper, SegmentWindow) behaves like old 
SegmentMapper."""
+
+    def setup_method(self):
+        self.m = RollupMapper(
+            upstream_mapper=FixedKeyMapper("all"),
+            window=SegmentWindow(["us", "eu", "apac"]),
+        )
+
+    def test_is_rollup_flag(self):
+        assert self.m.is_rollup is True
+
+    def test_to_downstream_collapses_every_segment_onto_downstream_key(self):
+        # Full-sequence equality: every declared segment key maps to the 
constant key.
+        assert [self.m.to_downstream(s) for s in ("us", "eu", "apac")] == 
["all", "all", "all"]
+
+    @pytest.mark.parametrize("anchor", ["all", "anything"])
+    def test_to_upstream_returns_full_set_ignoring_anchor(self, anchor):
+        assert self.m.to_upstream(anchor) == frozenset({"us", "eu", "apac"})
+
+    def test_core_encode_decode_round_trip(self):
+        restored = decode_partition_mapper(encode_partition_mapper(self.m))
+        assert isinstance(restored, RollupMapper)
+        assert restored.is_rollup is True
+        assert restored.to_downstream("us") == "all"
+        assert restored.to_upstream("all") == frozenset({"us", "eu", "apac"})
+
+    def test_sdk_encode_decode_round_trip(self):
+        # User code authors with SDK classes; the scheduler serializes and 
deserializes
+        # into core classes.
+        sdk_mapper = SdkRollupMapper(
+            upstream_mapper=SdkFixedKeyMapper("all_regions"),
+            window=SdkSegmentWindow(["us", "eu", "apac"]),
+        )
+        restored = decode_partition_mapper(encode_partition_mapper(sdk_mapper))
+        assert isinstance(restored, RollupMapper)
+        assert restored.to_upstream("all_regions") == frozenset({"us", "eu", 
"apac"})
+
+
+class TestCategoricalRollupTypeGuard:
+    """Core-side RollupMapper guard: FixedKeyMapper(str) + SegmentWindow(str) 
must pass."""
+
+    def test_fixed_key_with_segment_window_does_not_raise(self):
+        # Core guard: FixedKeyMapper does not override decode_downstream,
+        # SegmentWindow.expected_decoded_type is str -> guard passes.
+        RollupMapper(upstream_mapper=FixedKeyMapper("all"), 
window=SegmentWindow(["us", "eu"]))
+
+    def test_str_mapper_with_datetime_window_raises(self):
+        # Core guard: FixedKeyMapper (no decode override) + DayWindow 
(datetime) -> raise.
+        with pytest.raises(TypeError, match="DayWindow expects decoded values 
of type 'datetime'"):
+            RollupMapper(upstream_mapper=FixedKeyMapper("all"), 
window=DayWindow())
diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py 
b/airflow-core/tests/unit/partition_mappers/test_window.py
index ddfd44cbd76..b4956761c2c 100644
--- a/airflow-core/tests/unit/partition_mappers/test_window.py
+++ b/airflow-core/tests/unit/partition_mappers/test_window.py
@@ -34,6 +34,7 @@ from airflow.partition_mappers.window import (
     HourWindow,
     MonthWindow,
     QuarterWindow,
+    SegmentWindow,
     WeekWindow,
     Window,
     YearWindow,
@@ -465,6 +466,47 @@ class TestDirectionValidation:
             WeekWindow(direction=bad_value)
 
 
+class TestSegmentWindow:
+    def test_to_upstream_returns_full_set_ignoring_anchor(self):
+        w = SegmentWindow(["us", "eu", "apac"])
+        result_a = frozenset(w.to_upstream("any-anchor"))
+        result_b = frozenset(w.to_upstream("different-anchor"))
+        assert result_a == frozenset({"us", "eu", "apac"})
+        assert result_a == result_b
+
+    def test_expected_decoded_type_is_str(self):
+        assert SegmentWindow.expected_decoded_type is str
+
+    @pytest.mark.parametrize(
+        ("segments", "match"),
+        [
+            pytest.param([], "at least one segment key", id="empty-list"),
+            pytest.param(iter([]), "at least one segment key", 
id="empty-iterator"),
+            pytest.param([1, "b"], "must be str", id="int-element"),
+            pytest.param([None, "b"], "must be str", id="none-element"),
+            pytest.param(["", "b"], "non-empty", id="empty-string-first"),
+            pytest.param(["a", ""], "non-empty", id="empty-string-second"),
+        ],
+    )
+    def test_rejects_invalid_segments(self, segments, match):
+        with pytest.raises(ValueError, match=match):
+            SegmentWindow(segments)
+
+    def test_deduplication(self):
+        w = SegmentWindow(["us", "us", "eu"])
+        assert frozenset(w.to_upstream("any")) == frozenset({"us", "eu"})
+
+    def test_serialize_uses_sorted_order(self):
+        w = SegmentWindow(["z", "a", "m"])
+        assert w.serialize() == {"segments": ["a", "m", "z"]}
+
+    def test_deserialize_round_trip(self):
+        w = SegmentWindow(["us", "eu", "apac"])
+        restored = SegmentWindow.deserialize(w.serialize())
+        assert isinstance(restored, SegmentWindow)
+        assert frozenset(restored.to_upstream("any")) == frozenset({"us", 
"eu", "apac"})
+
+
 class TestWindowSerializationGate:
     """``encode_window`` / ``decode_window`` must reject non-built-in Windows.
 
diff --git a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py 
b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
index 0c1adea9669..24f7520f1b5 100755
--- a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
+++ b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
@@ -23,24 +23,25 @@
 # ]
 # ///
 """
-Verify ``FanOutMapper.default_downstream_mapper_by_window_name`` stays in sync
-between core and the Task SDK.
+Verify partition-mapper definitions stay in sync between core and the Task SDK.
 
-The default downstream-mapper table is defined twice — once in the core class
-hierarchy and once in the SDK copy — because the two hierarchies are
-independent (the SDK cannot import core) and the lookup is by ``Window`` class
-*name*. Both copies must list the same ``Window`` name -> mapper class mapping,
-otherwise a ``FanOutMapper`` resolves a different default depending on whether
-it runs in Dag-author code (SDK) or after deserialization (core).
+Checks two things:
 
-This check parses the ``default_downstream_mapper_by_window_name`` class
-attribute from both files via AST and asserts the two mappings are identical.
+1. ``FanOutMapper.default_downstream_mapper_by_window_name`` — the default
+   downstream-mapper table is defined twice (core and SDK) because the two
+   hierarchies are independent. Both copies must list the same
+   ``Window`` name -> mapper class mapping.
+
+2. ``SegmentWindow`` and ``FixedKeyMapper`` — for each class, the field-name
+   set and the ``raise ValueError(...)`` message-template set must be identical
+   between core and the SDK. This catches wording drift (e.g. "non-empty" vs
+   "non-empty strings") that would otherwise silently diverge.
 
 Run from the repo root:
 
     uv run --project scripts python 
scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
 
-Exits 0 if the two tables match, 1 (with a diff) otherwise.
+Exits 0 if everything matches, 1 (with a diff) otherwise.
 """
 
 from __future__ import annotations
@@ -65,6 +66,15 @@ SDK_FILE = (
     AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / 
"partition_mappers" / "temporal.py"
 )
 
+CORE_WINDOW_FILE = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "partition_mappers" 
/ "window.py"
+SDK_WINDOW_FILE = (
+    AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / 
"partition_mappers" / "window.py"
+)
+CORE_FIXED_KEY_FILE = AIRFLOW_CORE_SOURCES_PATH / "airflow" / 
"partition_mappers" / "fixed_key.py"
+SDK_FIXED_KEY_FILE = (
+    AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / 
"partition_mappers" / "fixed_key.py"
+)
+
 
 def _find_attr_value(file_path: Path) -> ast.Dict:
     """Return the AST node assigned to 
``FanOutMapper.default_downstream_mapper_by_window_name``."""
@@ -113,7 +123,200 @@ def _extract_mapping(file_path: Path) -> dict[str, str]:
     return mapping
 
 
+def _joinedstr_to_template(node: ast.JoinedStr) -> str:
+    """
+    Convert an f-string AST node to a template string.
+
+    Literal text fragments are kept as-is; each interpolated expression
+    ``{...}`` is replaced by a ``{}`` placeholder. This lets us compare
+    f-string message templates between core and SDK without caring about the
+    exact expression inside the braces (e.g. ``{type(item).__name__!r}`` vs
+    a future refactoring).
+    """
+    parts: list[str] = []
+    for value in node.values:
+        if isinstance(value, ast.Constant):
+            parts.append(str(value.value))
+        else:
+            parts.append("{}")
+    return "".join(parts)
+
+
+def _extract_raise_messages(subtree: ast.AST) -> set[str]:
+    """
+    Collect ``raise ValueError(...)`` message templates from an arbitrary AST 
subtree.
+
+    For each ``raise ValueError(...)`` whose first argument is a plain string
+    constant, an f-string, or adjacent f-string/constant concatenation
+    (``BinOp(Add, ...)``), extracts the template:
+
+    - ``ast.Constant`` → the literal string value.
+    - ``ast.JoinedStr`` (f-string) → literal fragments joined, interpolated
+      expressions replaced by ``{}``.
+    - ``ast.BinOp(Add, ...)`` → recursively concatenated from both sides.
+    """
+    messages: set[str] = set()
+
+    def _collect_bin(n: ast.expr) -> str:
+        if isinstance(n, ast.Constant) and isinstance(n.value, str):
+            return n.value
+        if isinstance(n, ast.JoinedStr):
+            return _joinedstr_to_template(n)
+        if isinstance(n, ast.BinOp) and isinstance(n.op, ast.Add):
+            return _collect_bin(n.left) + _collect_bin(n.right)
+        return "{}"
+
+    for node in ast.walk(subtree):
+        if not isinstance(node, ast.Raise):
+            continue
+        exc = node.exc
+        if exc is None:
+            continue
+        if not (
+            isinstance(exc, ast.Call)
+            and isinstance(exc.func, ast.Name)
+            and exc.func.id == "ValueError"
+            and exc.args
+        ):
+            continue
+        arg = exc.args[0]
+        if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
+            messages.add(arg.value)
+        elif isinstance(arg, ast.JoinedStr):
+            messages.add(_joinedstr_to_template(arg))
+        elif isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.Add):
+            messages.add(_collect_bin(arg))
+
+    return messages
+
+
+def _collect_converter_names(class_node: ast.ClassDef) -> set[str]:
+    """
+    Return function names referenced via ``converter=<Name>`` in attrs field 
definitions.
+
+    Scans the top-level statements of *class_node* for annotated assignments of
+    the form ``<name>: <ann> = attrs.field(converter=<Name>)`` (or
+    ``attr.field`` / bare ``field(...)``).  Also handles unannotated
+    ``Assign`` nodes with an ``attrs.field(converter=<Name>)`` call on the RHS.
+    Only bare ``ast.Name`` references are collected; lambdas and attribute
+    lookups are ignored (they cannot resolve to a module-level function by
+    name).
+    """
+    names: set[str] = set()
+
+    def _check_call(call_node: ast.expr) -> None:
+        """If *call_node* is an attrs.field / attr.field / field() call, 
collect converter=<Name>."""
+        if not isinstance(call_node, ast.Call):
+            return
+        func = call_node.func
+        is_field_call = (
+            # attrs.field(...) or attr.field(...)
+            (isinstance(func, ast.Attribute) and func.attr == "field")
+            # bare field(...)
+            or (isinstance(func, ast.Name) and func.id == "field")
+        )
+        if not is_field_call:
+            return
+        for kw in call_node.keywords:
+            if kw.arg == "converter" and isinstance(kw.value, ast.Name):
+                names.add(kw.value.id)
+
+    for stmt in class_node.body:
+        if isinstance(stmt, ast.AnnAssign) and stmt.value is not None:
+            _check_call(stmt.value)
+        elif isinstance(stmt, ast.Assign):
+            for target_value in [stmt.value]:
+                _check_call(target_value)
+
+    return names
+
+
+def extract_class_error_messages(file_path: Path, class_name: str) -> set[str]:
+    """
+    Return the set of ``raise ValueError(...)`` message templates for 
*class_name*.
+
+    Scans two sources:
+
+    1. The class body itself (all depths) — covers validator methods and nested
+       helpers declared inside the class.
+    2. Module-level functions referenced via ``converter=<Name>`` in
+       ``attrs.field(...)`` declarations inside the class body.  These
+       converters live outside the class but are logically part of its
+       construction-time validation.
+
+    For each ``raise ValueError(...)`` whose first argument is a plain string
+    constant or an f-string, extracts the template:
+
+    - ``ast.Constant`` → the literal string value.
+    - ``ast.JoinedStr`` (f-string) → literal fragments joined, interpolated
+      expressions replaced by ``{}``.
+    """
+    tree = ast.parse(file_path.read_text(encoding="utf-8"), 
filename=str(file_path))
+
+    target_class: ast.ClassDef | None = None
+    for node in ast.walk(tree):
+        if isinstance(node, ast.ClassDef) and node.name == class_name:
+            target_class = node
+            break
+
+    if target_class is None:
+        return set()
+
+    # 1. Collect messages from the class body.
+    messages = _extract_raise_messages(target_class)
+
+    # 2. Follow converter= references to module-level functions.
+    converter_names = _collect_converter_names(target_class)
+    if converter_names:
+        # Only look at top-level (module-body) function definitions to avoid
+        # accidentally matching a same-named method inside another class.
+        for node in tree.body:
+            if isinstance(node, ast.FunctionDef) and node.name in 
converter_names:
+                messages |= _extract_raise_messages(node)
+
+    return messages
+
+
+def extract_class_field_names(file_path: Path, class_name: str) -> set[str]:
+    """
+    Return the set of annotated attribute names declared in *class_name*'s 
body.
+
+    Collects names from annotated assignments (``name: type = ...`` or
+    ``name: type``) at the top level of the class body. ClassVar annotations
+    are excluded because they are class-level constants, not instance fields.
+    """
+    tree = ast.parse(file_path.read_text(encoding="utf-8"), 
filename=str(file_path))
+    fields: set[str] = set()
+
+    for node in ast.walk(tree):
+        if not (isinstance(node, ast.ClassDef) and node.name == class_name):
+            continue
+        for stmt in node.body:
+            if not isinstance(stmt, ast.AnnAssign):
+                continue
+            if not isinstance(stmt.target, ast.Name):
+                continue
+            # Skip ClassVar[...] annotations
+            ann = stmt.annotation
+            is_classvar = False
+            if isinstance(ann, ast.Subscript):
+                if isinstance(ann.value, ast.Name) and ann.value.id == 
"ClassVar":
+                    is_classvar = True
+                elif isinstance(ann.value, ast.Attribute) and ann.value.attr 
== "ClassVar":
+                    is_classvar = True
+            elif isinstance(ann, ast.Name) and ann.id == "ClassVar":
+                is_classvar = True
+            if is_classvar:
+                continue
+            fields.add(stmt.target.id)
+        break
+
+    return fields
+
+
 def main() -> int:
+    failed = False
+
     try:
         core_mapping = _extract_mapping(CORE_FILE)
         sdk_mapping = _extract_mapping(SDK_FILE)
@@ -121,22 +324,56 @@ def main() -> int:
         console.print(f"[red]Could not read the default mapper table:[/red] 
{exc}")
         return 1
 
-    if core_mapping == sdk_mapping:
-        return 0
-
-    console.print(f"[red]{CLASS_NAME}.{ATTR_NAME} is out of sync between core 
and the Task SDK.[/red]\n")
-    all_windows = sorted(core_mapping.keys() | sdk_mapping.keys())
-    for window in all_windows:
-        core_val = core_mapping.get(window, "<missing>")
-        sdk_val = sdk_mapping.get(window, "<missing>")
-        marker = "  " if core_val == sdk_val else "->"
-        color = "" if core_val == sdk_val else "[red]"
-        end = "" if core_val == sdk_val else "[/red]"
-        console.print(f"{color}{marker} {window}: core={core_val} 
sdk={sdk_val}{end}")
-    console.print(
-        f"\nMake both tables list the same window -> mapper entries:\n  core: 
{CORE_FILE}\n  sdk:  {SDK_FILE}"
-    )
-    return 1
+    if core_mapping != sdk_mapping:
+        console.print(f"[red]{CLASS_NAME}.{ATTR_NAME} is out of sync between 
core and the Task SDK.[/red]\n")
+        all_windows = sorted(core_mapping.keys() | sdk_mapping.keys())
+        for window in all_windows:
+            core_val = core_mapping.get(window, "<missing>")
+            sdk_val = sdk_mapping.get(window, "<missing>")
+            marker = "  " if core_val == sdk_val else "->"
+            color = "" if core_val == sdk_val else "[red]"
+            end = "" if core_val == sdk_val else "[/red]"
+            console.print(f"{color}{marker} {window}: core={core_val} 
sdk={sdk_val}{end}")
+        console.print(
+            f"\nMake both tables list the same window -> mapper entries:\n  
core: {CORE_FILE}\n  sdk:  {SDK_FILE}"
+        )
+        failed = True
+
+    checks = [
+        ("SegmentWindow", CORE_WINDOW_FILE, SDK_WINDOW_FILE),
+        ("FixedKeyMapper", CORE_FIXED_KEY_FILE, SDK_FIXED_KEY_FILE),
+    ]
+
+    for class_name, core_file, sdk_file in checks:
+        core_fields = extract_class_field_names(core_file, class_name)
+        sdk_fields = extract_class_field_names(sdk_file, class_name)
+        if core_fields != sdk_fields:
+            console.print(f"[red]{class_name}: field names are out of sync 
between core and SDK.[/red]")
+            core_only = core_fields - sdk_fields
+            sdk_only = sdk_fields - core_fields
+            if core_only:
+                console.print(f"  core-only fields: {sorted(core_only)}")
+            if sdk_only:
+                console.print(f"  sdk-only  fields: {sorted(sdk_only)}")
+            console.print(f"  core: {core_file}\n  sdk:  {sdk_file}")
+            failed = True
+
+        core_msgs = extract_class_error_messages(core_file, class_name)
+        sdk_msgs = extract_class_error_messages(sdk_file, class_name)
+        if core_msgs != sdk_msgs:
+            console.print(
+                f"[red]{class_name}: raise ValueError(...) message templates 
are out of sync.[/red]"
+            )
+            core_only = core_msgs - sdk_msgs
+            sdk_only = sdk_msgs - core_msgs
+            if core_only:
+                console.print(f"  core-only messages: {sorted(core_only)}")
+            if sdk_only:
+                console.print(f"  sdk-only  messages: {sorted(sdk_only)}")
+            console.print(f"  core: {core_file}\n  sdk:  {sdk_file}")
+            failed = True
+
+    return 1 if failed else 0
 
 
 if __name__ == "__main__":
diff --git 
a/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py 
b/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py
new file mode 100644
index 00000000000..aa1742d33a0
--- /dev/null
+++ b/scripts/tests/ci/prek/test_check_partition_mapper_defaults_in_sync.py
@@ -0,0 +1,388 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import textwrap
+from pathlib import Path
+
+import pytest
+from check_partition_mapper_defaults_in_sync import (
+    extract_class_error_messages,
+    extract_class_field_names,
+)
+
+
+class TestExtractClassErrorMessages:
+    @pytest.mark.parametrize(
+        ("raise_stmt", "expected"),
+        [
+            pytest.param('raise ValueError("msg a")', {"msg a"}, 
id="plain-string"),
+            pytest.param('raise ValueError(f"msg b at {i}")', {"msg b at {}"}, 
id="fstring-template"),
+        ],
+    )
+    def test_extracts_message(self, tmp_path: Path, raise_stmt: str, expected: 
set[str]):
+        """Plain strings are kept verbatim; f-string interpolations become {} 
placeholders."""
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent(f"""\
+                class MyClass:
+                    def validate(self, x, i):
+                        {raise_stmt}
+            """)
+        )
+        assert extract_class_error_messages(f, "MyClass") == expected
+
+    def test_extracts_both_plain_and_fstring(self, tmp_path: Path):
+        """Covers the explicit test requirement: plain + f-string together."""
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent("""\
+                class MyClass:
+                    def check(self, item, i):
+                        raise ValueError("msg a")
+
+                    def check2(self, item, i):
+                        raise ValueError(f"msg b at {i}")
+            """)
+        )
+        result = extract_class_error_messages(f, "MyClass")
+        assert result == {"msg a", "msg b at {}"}
+
+    def test_ignores_other_exception_types(self, tmp_path: Path):
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent("""\
+                class MyClass:
+                    def check(self):
+                        raise TypeError("not a value error")
+            """)
+        )
+        result = extract_class_error_messages(f, "MyClass")
+        assert result == set()
+
+    def test_returns_empty_for_missing_class(self, tmp_path: Path):
+        f = tmp_path / "code.py"
+        f.write_text("x = 1\n")
+        result = extract_class_error_messages(f, "Missing")
+        assert result == set()
+
+    def test_nested_function_messages_included(self, tmp_path: Path):
+        """Messages raised inside nested helpers within the class body are 
collected."""
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent("""\
+                class MyClass:
+                    def __init__(self, items):
+                        def _check(i, item):
+                            if not isinstance(item, str):
+                                raise ValueError(f"must be str; got 
{type(item).__name__!r} at {i}")
+                        for i, item in enumerate(items):
+                            _check(i, item)
+            """)
+        )
+        result = extract_class_error_messages(f, "MyClass")
+        assert result == {"must be str; got {} at {}"}
+
+    def test_multipart_fstring_message(self, tmp_path: Path):
+        """f-string with multiple literal parts and multiple interpolations."""
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent("""\
+                class C:
+                    def v(self, item, i):
+                        raise ValueError(
+                            f"Prefix segment keys must be str; "
+                            f"got {type(item).__name__!r} at index {i}: 
{item!r}"
+                        )
+            """)
+        )
+        result = extract_class_error_messages(f, "C")
+        # The two adjacent f-strings form a single BinOp(Add) node at AST 
level;
+        # the extractor concatenates them into one template.
+        assert "Prefix segment keys must be str; got {} at index {}: {}" in 
result
+
+
+class TestExtractClassFieldNames:
+    def test_extracts_annotated_instance_fields(self, tmp_path: Path):
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent("""\
+                from typing import ClassVar
+                import attrs
+
+                @attrs.define
+                class MyClass:
+                    expected_decoded_type: ClassVar[type] = str
+                    _segments: frozenset[str] = attrs.field()
+            """)
+        )
+        result = extract_class_field_names(f, "MyClass")
+        # ClassVar is excluded; _segments is included
+        assert result == {"_segments"}
+
+    def test_excludes_classvar_fields(self, tmp_path: Path):
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent("""\
+                from typing import ClassVar
+
+                class MyClass:
+                    flag: ClassVar[bool] = False
+                    name: str
+            """)
+        )
+        result = extract_class_field_names(f, "MyClass")
+        assert result == {"name"}
+
+    def test_returns_empty_for_missing_class(self, tmp_path: Path):
+        f = tmp_path / "code.py"
+        f.write_text("x = 1\n")
+        result = extract_class_field_names(f, "Missing")
+        assert result == set()
+
+    def test_private_field_name_preserved(self, tmp_path: Path):
+        """Leading underscore in field name is kept verbatim."""
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent("""\
+                import attrs
+
+                @attrs.define
+                class Container:
+                    _segments: frozenset[str] = attrs.field()
+            """)
+        )
+        result = extract_class_field_names(f, "Container")
+        assert "_segments" in result
+
+
+class TestInSyncPasses:
+    def test_in_sync_passes(self, tmp_path: Path):
+        """Two files with identical field names and messages compare equal."""
+        code = textwrap.dedent("""\
+            import attrs
+
+            @attrs.define
+            class MyMapper:
+                downstream_key: str = attrs.field()
+
+                def validate(self, x):
+                    raise ValueError(f"must be non-empty str; got {x!r}.")
+        """)
+        core_file = tmp_path / "core.py"
+        sdk_file = tmp_path / "sdk.py"
+        core_file.write_text(code)
+        sdk_file.write_text(code)
+
+        core_fields = extract_class_field_names(core_file, "MyMapper")
+        sdk_fields = extract_class_field_names(sdk_file, "MyMapper")
+        assert core_fields == sdk_fields
+
+        core_msgs = extract_class_error_messages(core_file, "MyMapper")
+        sdk_msgs = extract_class_error_messages(sdk_file, "MyMapper")
+        assert core_msgs == sdk_msgs
+
+
+class TestConverterPatternExtraction:
+    """Tests for the module-level converter= follow-through in 
extract_class_error_messages."""
+
+    def test_extracts_messages_from_module_level_converter(self, tmp_path: 
Path):
+        """Converter function outside the class body is followed and its 
messages are collected."""
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent("""\
+                import attrs
+
+                def _my_convert(items):
+                    for i, item in enumerate(items):
+                        if not isinstance(item, str):
+                            raise ValueError(f"must be str; got 
{type(item).__name__!r} at {i}")
+                    if not items:
+                        raise ValueError("must not be empty")
+                    return frozenset(items)
+
+                @attrs.define
+                class MyClass:
+                    _data: frozenset = attrs.field(converter=_my_convert)
+            """)
+        )
+        result = extract_class_error_messages(f, "MyClass")
+        assert "must be str; got {} at {}" in result
+        assert "must not be empty" in result
+
+    def test_converter_fstring_template_placeholders(self, tmp_path: Path):
+        """f-string expressions in the converter become {} placeholders in the 
template."""
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent("""\
+                import attrs
+
+                def _convert(items):
+                    for i, item in enumerate(items):
+                        raise ValueError(f"bad item {item!r} at index {i}")
+                    return frozenset(items)
+
+                @attrs.define
+                class Widget:
+                    _items: frozenset = attrs.field(converter=_convert)
+            """)
+        )
+        result = extract_class_error_messages(f, "Widget")
+        assert "bad item {} at index {}" in result
+
+    def test_converter_messages_not_collected_for_unrelated_class(self, 
tmp_path: Path):
+        """The converter is only followed for the class that references it."""
+        f = tmp_path / "code.py"
+        f.write_text(
+            textwrap.dedent("""\
+                import attrs
+
+                def _my_convert(items):
+                    raise ValueError("converter error")
+                    return frozenset(items)
+
+                @attrs.define
+                class ClassA:
+                    _data: frozenset = attrs.field(converter=_my_convert)
+
+                @attrs.define
+                class ClassB:
+                    _name: str = attrs.field()
+            """)
+        )
+        # ClassA references the converter — should see the message
+        result_a = extract_class_error_messages(f, "ClassA")
+        assert "converter error" in result_a
+
+        # ClassB does not reference the converter — should NOT see the message
+        result_b = extract_class_error_messages(f, "ClassB")
+        assert "converter error" not in result_b
+
+
+class TestConverterDivergenceDetected:
+    """Verify that message drift in a module-level converter is caught by the 
extractor."""
+
+    def test_divergent_converter_message_detected(self, tmp_path: Path):
+        """Changing 'non-empty' to 'non-empty strings' in a converter is 
detected as drift."""
+        core_code = textwrap.dedent("""\
+            import attrs
+
+            def _convert_segments(segments):
+                for i, item in enumerate(segments):
+                    if not item:
+                        raise ValueError(f"keys must be non-empty; got empty 
at {i}.")
+                return frozenset(segments)
+
+            @attrs.define
+            class SegmentWindow:
+                _segments: frozenset = attrs.field(converter=_convert_segments)
+        """)
+        sdk_code_diverged = textwrap.dedent("""\
+            import attrs
+
+            def _convert_segments(segments):
+                for i, item in enumerate(segments):
+                    if not item:
+                        raise ValueError(f"keys must be non-empty strings; got 
empty at {i}.")
+                return frozenset(segments)
+
+            @attrs.define
+            class SegmentWindow:
+                _segments: frozenset = attrs.field(converter=_convert_segments)
+        """)
+        core_file = tmp_path / "core.py"
+        sdk_file = tmp_path / "sdk.py"
+        core_file.write_text(core_code)
+        sdk_file.write_text(sdk_code_diverged)
+
+        core_msgs = extract_class_error_messages(core_file, "SegmentWindow")
+        sdk_msgs = extract_class_error_messages(sdk_file, "SegmentWindow")
+        assert core_msgs != sdk_msgs
+
+
+class TestDivergentMessageFails:
+    def test_divergent_message_detected(self, tmp_path: Path):
+        """Changing 'non-empty' to 'non-empty strings' on one side is 
detected."""
+        core_code = textwrap.dedent("""\
+            import attrs
+
+            @attrs.define
+            class SegmentWindow:
+                _segments: frozenset[str] = attrs.field()
+
+                def validate(self, item, i):
+                    raise ValueError(f"keys must be non-empty; got empty at 
{i}.")
+        """)
+        sdk_code_diverged = textwrap.dedent("""\
+            import attrs
+
+            @attrs.define
+            class SegmentWindow:
+                _segments: frozenset[str] = attrs.field()
+
+                def validate(self, item, i):
+                    raise ValueError(f"keys must be non-empty strings; got 
empty at {i}.")
+        """)
+        core_file = tmp_path / "core.py"
+        sdk_file = tmp_path / "sdk.py"
+        core_file.write_text(core_code)
+        sdk_file.write_text(sdk_code_diverged)
+
+        core_msgs = extract_class_error_messages(core_file, "SegmentWindow")
+        sdk_msgs = extract_class_error_messages(sdk_file, "SegmentWindow")
+        assert core_msgs != sdk_msgs
+
+    @pytest.mark.parametrize(
+        ("core_msg", "sdk_msg"),
+        [
+            pytest.param(
+                "keys must be non-empty; got empty at {}.",
+                "keys must be non-empty strings; got empty at {}.",
+                id="non-empty-vs-non-empty-strings",
+            ),
+            pytest.param(
+                "requires at least one key; got an empty iterable.",
+                "requires at least one key.",
+                id="different-constant-wording",
+            ),
+        ],
+    )
+    def test_parametrized_divergence(self, tmp_path: Path, core_msg: str, 
sdk_msg: str):
+        def _make_file(path: Path, msg: str) -> None:
+            # Use a plain f-string if the message contains '{}', else a 
constant.
+            if "{}" in msg:
+                # Reconstruct as f-string source: replace {} with {i}
+                src_msg = msg.replace("{}", "{i}")
+                stmt = f'raise ValueError(f"{src_msg}")'
+            else:
+                stmt = f'raise ValueError("{msg}")'
+            path.write_text(
+                textwrap.dedent(f"""\
+                    class C:
+                        def v(self, i):
+                            {stmt}
+                """)
+            )
+
+        core_file = tmp_path / "core.py"
+        sdk_file = tmp_path / "sdk.py"
+        _make_file(core_file, core_msg)
+        _make_file(sdk_file, sdk_msg)
+
+        core_msgs = extract_class_error_messages(core_file, "C")
+        sdk_msgs = extract_class_error_messages(sdk_file, "C")
+        assert core_msgs != sdk_msgs
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 8d878978980..cd9cec283c5 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -247,6 +247,8 @@ Partition Mapper
 
 .. autoapiclass:: airflow.sdk.FanOutMapper
 
+.. autoapiclass:: airflow.sdk.FixedKeyMapper
+
 Rollup Windows
 ~~~~~~~~~~~~~~
 
@@ -264,6 +266,8 @@ Rollup Windows
 
 .. autoapiclass:: airflow.sdk.YearWindow
 
+.. autoapiclass:: airflow.sdk.SegmentWindow
+
 I/O Helpers
 -----------
 .. autoapiclass:: airflow.sdk.ObjectStoragePath
diff --git a/task-sdk/src/airflow/sdk/__init__.py 
b/task-sdk/src/airflow/sdk/__init__.py
index ce5b8ab2c2e..00d724cb42d 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -55,6 +55,7 @@ __all__ = [
     "EventsTimetable",
     "ExceptionRetryPolicy",
     "FanOutMapper",
+    "FixedKeyMapper",
     "HourWindow",
     "IdentityMapper",
     "Label",
@@ -77,6 +78,7 @@ __all__ = [
     "RetryPolicy",
     "RetryRule",
     "RollupMapper",
+    "SegmentWindow",
     "SkipMixin",
     "SyncCallback",
     "StartOfDayMapper",
@@ -154,6 +156,7 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.partition_mappers.allowed_key import 
AllowedKeyMapper
     from airflow.sdk.definitions.partition_mappers.base import 
PartitionMapper, RollupMapper
     from airflow.sdk.definitions.partition_mappers.chain import ChainMapper
+    from airflow.sdk.definitions.partition_mappers.fixed_key import 
FixedKeyMapper
     from airflow.sdk.definitions.partition_mappers.identity import 
IdentityMapper
     from airflow.sdk.definitions.partition_mappers.product import ProductMapper
     from airflow.sdk.definitions.partition_mappers.temporal import (
@@ -170,6 +173,7 @@ if TYPE_CHECKING:
         HourWindow,
         MonthWindow,
         QuarterWindow,
+        SegmentWindow,
         WeekWindow,
         Window,
         YearWindow,
@@ -244,6 +248,7 @@ __lazy_imports: dict[str, str] = {
     "EventsTimetable": ".definitions.timetables.events",
     "ExceptionRetryPolicy": ".definitions.retry_policy",
     "FanOutMapper": ".definitions.partition_mappers.temporal",
+    "FixedKeyMapper": ".definitions.partition_mappers.fixed_key",
     "HourWindow": ".definitions.partition_mappers.window",
     "IdentityMapper": ".definitions.partition_mappers.identity",
     "Label": ".definitions.edges",
@@ -266,6 +271,7 @@ __lazy_imports: dict[str, str] = {
     "RetryRule": ".definitions.retry_policy",
     "RollupMapper": ".definitions.partition_mappers.base",
     "SecretCache": ".execution_time.cache",
+    "SegmentWindow": ".definitions.partition_mappers.window",
     "SkipMixin": ".bases.skipmixin",
     "SyncCallback": ".definitions.callback",
     "StartOfDayMapper": ".definitions.partition_mappers.temporal",
diff --git a/task-sdk/src/airflow/sdk/__init__.pyi 
b/task-sdk/src/airflow/sdk/__init__.pyi
index 1bb975e2202..d6fa8bc2a0d 100644
--- a/task-sdk/src/airflow/sdk/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/__init__.pyi
@@ -67,6 +67,7 @@ from airflow.sdk.definitions.param import Param as Param
 from airflow.sdk.definitions.partition_mappers.allowed_key import 
AllowedKeyMapper
 from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, 
RollupMapper
 from airflow.sdk.definitions.partition_mappers.chain import ChainMapper
+from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper
 from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper
 from airflow.sdk.definitions.partition_mappers.product import ProductMapper
 from airflow.sdk.definitions.partition_mappers.temporal import (
@@ -83,6 +84,7 @@ from airflow.sdk.definitions.partition_mappers.window import (
     HourWindow,
     MonthWindow,
     QuarterWindow,
+    SegmentWindow,
     WeekWindow,
     Window,
     YearWindow,
@@ -154,6 +156,7 @@ __all__ = [
     "EventsTimetable",
     "ExceptionRetryPolicy",
     "FanOutMapper",
+    "FixedKeyMapper",
     "HourWindow",
     "IdentityMapper",
     "Label",
@@ -175,6 +178,7 @@ __all__ = [
     "ResumableJobMixin",
     "RollupMapper",
     "SecretCache",
+    "SegmentWindow",
     "SkipMixin",
     "StartOfDayMapper",
     "StartOfHourMapper",
diff --git 
a/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py 
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py
new file mode 100644
index 00000000000..bf2b07905a8
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/fixed_key.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import attrs
+
+from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
+
+
[email protected]
+class FixedKeyMapper(PartitionMapper):
+    """
+    Collapse every upstream partition key onto one fixed downstream key.
+
+    Authoring marker for the scheduler-side
+    :class:`airflow.partition_mappers.fixed_key.FixedKeyMapper`. Paired with
+    :class:`~airflow.sdk.definitions.partition_mappers.window.SegmentWindow` 
inside a
+    :class:`~airflow.sdk.definitions.partition_mappers.base.RollupMapper` to 
express a
+    categorical rollup.
+
+    Construction validates *downstream_key* so Dag parse errors surface
+    immediately rather than deferring to scheduler deserialization.
+
+    :param downstream_key: The fixed downstream partition key every upstream 
key
+        maps to. Must be a non-empty string.
+    :raises ValueError: if *downstream_key* is not a non-empty ``str``.
+    """
+
+    downstream_key: str = attrs.field()
+
+    @downstream_key.validator
+    def _validate_downstream_key(self, attribute: attrs.Attribute, value: str) 
-> None:
+        if not isinstance(value, str) or value == "":
+            raise ValueError(f"FixedKeyMapper downstream_key must be a 
non-empty str; got {value!r}.")
+
+    def to_downstream(self, key: str) -> str:
+        """Return the fixed downstream key regardless of *key*."""
+        return self.downstream_key
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py 
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
index 4ea008ace6f..f491d4279ec 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
@@ -22,7 +22,12 @@ from __future__ import annotations
 
 from datetime import datetime
 from enum import Enum
-from typing import Any, ClassVar
+from typing import TYPE_CHECKING, Any, ClassVar
+
+import attrs
+
+if TYPE_CHECKING:
+    from collections.abc import Iterable
 
 
 class Window:
@@ -114,3 +119,52 @@ class YearWindow(Window):
     """Twelve consecutive monthly keys making up one calendar year."""
 
     expected_decoded_type: ClassVar[type] = datetime
+
+
+def _convert_segments(segments: Iterable[str]) -> frozenset[str]:
+    """
+    Validate and convert *segments* to a ``frozenset[str]``.
+
+    Validates each element for type and non-emptiness (with index reporting)
+    before collapsing into a frozenset, then checks the result is non-empty.
+    """
+    validated: list[str] = []
+    for i, item in enumerate(segments):
+        if not isinstance(item, str):
+            raise ValueError(
+                f"SegmentWindow segment keys must be str; got 
{type(item).__name__!r} at index {i}: {item!r}"
+            )
+        if not item:
+            raise ValueError(
+                f"SegmentWindow segment keys must be non-empty; got an empty 
string at index {i}."
+            )
+        validated.append(item)
+    result = frozenset(validated)
+    if not result:
+        raise ValueError("SegmentWindow requires at least one segment key; got 
an empty iterable.")
+    return result
+
+
[email protected]
+class SegmentWindow(Window):
+    """
+    A fixed categorical set of string keys that constitute one downstream 
period.
+
+    Authoring marker for the scheduler-side
+    :class:`airflow.partition_mappers.window.SegmentWindow`. Paired with
+    
:class:`~airflow.sdk.definitions.partition_mappers.fixed_key.FixedKeyMapper` 
inside a
+    :class:`~airflow.sdk.definitions.partition_mappers.base.RollupMapper` to 
express a
+    categorical rollup.
+
+    Construction validates the segment list so Dag parse errors surface
+    immediately rather than deferring to scheduler deserialization.
+
+    :param segments: Non-empty iterable of non-empty string segment keys. 
Duplicates
+        are silently de-duplicated.
+    :raises ValueError: if *segments* is empty, contains a non-``str`` 
element, or
+        contains an empty-string element.
+    """
+
+    expected_decoded_type: ClassVar[type] = str
+
+    _segments: frozenset[str] = attrs.field(converter=_convert_segments)
diff --git a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py 
b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
index 8b902820da0..7114b65d6c4 100644
--- a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
+++ b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
@@ -22,12 +22,14 @@ from typing import ClassVar
 import pytest
 
 from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, 
RollupMapper
+from airflow.sdk.definitions.partition_mappers.fixed_key import FixedKeyMapper
 from airflow.sdk.definitions.partition_mappers.temporal import StartOfDayMapper
 from airflow.sdk.definitions.partition_mappers.window import (
     DayWindow,
     HourWindow,
     MonthWindow,
     QuarterWindow,
+    SegmentWindow,
     WeekWindow,
     Window,
     YearWindow,
@@ -103,3 +105,67 @@ class TestSdkWindowExpectedDecodedType:
     )
     def test_temporal_windows_declare_datetime(self, window_cls):
         assert window_cls.expected_decoded_type is datetime
+
+
+class TestSdkFixedKeyMapper:
+    """SDK-side FixedKeyMapper construction and validation."""
+
+    @pytest.mark.parametrize("key", ["us", "eu", "apac"])
+    def test_to_downstream_returns_constant_for_any_key(self, key):
+        assert FixedKeyMapper("all_regions").to_downstream(key) == 
"all_regions"
+
+    def test_is_rollup_false(self):
+        assert FixedKeyMapper("all").is_rollup is False
+
+    @pytest.mark.parametrize(
+        ("downstream_key", "match"),
+        [
+            pytest.param("", "non-empty str", id="empty-string"),
+            pytest.param(None, "non-empty str", id="none"),
+            pytest.param(1, "non-empty str", id="int"),
+        ],
+    )
+    def test_rejects_invalid_downstream_key(self, downstream_key, match):
+        with pytest.raises(ValueError, match=match):
+            FixedKeyMapper(downstream_key)
+
+    def test_requires_downstream_key(self):
+        with pytest.raises(TypeError):
+            FixedKeyMapper()
+
+
+class TestSdkSegmentWindow:
+    """SDK-side SegmentWindow construction and validation mirrors the core 
implementation."""
+
+    def test_expected_decoded_type_is_str(self):
+        assert SegmentWindow.expected_decoded_type is str
+
+    def test_deduplication(self):
+        w = SegmentWindow(["a", "b", "a"])
+        assert w._segments == frozenset({"a", "b"})
+
+    @pytest.mark.parametrize(
+        ("segments", "match"),
+        [
+            pytest.param([], "at least one segment key", id="empty-list"),
+            pytest.param([1, "b"], "must be str", id="int-element"),
+            pytest.param(["", "b"], "non-empty", id="empty-string"),
+        ],
+    )
+    def test_rejects_invalid_segments(self, segments, match):
+        with pytest.raises(ValueError, match=match):
+            SegmentWindow(segments)
+
+
+class TestSdkCategoricalRollupGuard:
+    """SDK-side RollupMapper guard mirrors core: str mapper + str window 
passes."""
+
+    def test_fixed_key_with_segment_window_does_not_raise(self):
+        # SDK guard: FixedKeyMapper.expected_decoded_type is str,
+        # SegmentWindow.expected_decoded_type is str -> guard passes.
+        RollupMapper(upstream_mapper=FixedKeyMapper("all"), 
window=SegmentWindow(["us", "eu"]))
+
+    def test_str_mapper_with_datetime_window_raises(self):
+        # SDK guard: FixedKeyMapper (str) + DayWindow (datetime) -> raise.
+        with pytest.raises(TypeError, match="DayWindow expects decoded values 
of type 'datetime'"):
+            RollupMapper(upstream_mapper=FixedKeyMapper("all"), 
window=DayWindow())

Reply via email to