Copilot commented on code in PR #64571:
URL: https://github.com/apache/airflow/pull/64571#discussion_r3066486627


##########
airflow-core/tests/unit/partition_mappers/test_temporal.py:
##########
@@ -57,29 +57,34 @@ def test_to_downstream(
         ],
     )
     @pytest.mark.parametrize(
-        ("mapper_cls", "expected_outut_format"),
+        ("mapper_cls", "expected_outut_format", "extra_kwargs"),

Review Comment:
   Typo in test parameter name: `expected_outut_format` should be 
`expected_output_format` for clarity.



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py:
##########
@@ -115,6 +119,62 @@ def next_run_assets(
         if not event.pop("queued", None):
             event["lastUpdate"] = None
 
+    # For partitioned DAGs: enrich events with per-asset received/required 
counts,
+    # using to_upstream for rollup mappers, and fix lastUpdate for partial 
receipt.
+    if is_partitioned:
+        pending_apdr = session.execute(
+            select(AssetPartitionDagRun.id, AssetPartitionDagRun.partition_key)
+            .where(
+                AssetPartitionDagRun.target_dag_id == dag_id,
+                AssetPartitionDagRun.created_dag_run_id.is_(None),
+            )
+            .order_by(AssetPartitionDagRun.created_at.desc())
+            .limit(1)
+        ).one_or_none()
+
+        if pending_apdr is not None:
+            # Count received log entries per asset for this partition
+            received_by_asset: dict[int, int] = dict(
+                session.execute(
+                    select(
+                        PartitionedAssetKeyLog.asset_id,
+                        func.count(PartitionedAssetKeyLog.id).label("cnt"),
+                    )
+                    .where(PartitionedAssetKeyLog.asset_partition_dag_run_id 
== pending_apdr.id)
+                    .group_by(PartitionedAssetKeyLog.asset_id)
+                ).all()

Review Comment:
   This per-asset `received_by_asset` count also counts raw log rows, which can 
over-count if the same upstream partition key is logged more than once. To keep 
UI progress consistent with rollup expectations, count distinct upstream keys 
(typically `source_partition_key`) per asset.



##########
airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx:
##########
@@ -39,11 +39,13 @@ export const AssetProgressCell = ({ dagId, partitionKey, 
totalReceived, totalReq
   const assets: Array<PartitionedDagRunAssetResponse> = data?.assets ?? [];
 
   const events: Array<NextRunEvent> = assets
-    .filter((ak: PartitionedDagRunAssetResponse) => ak.received)
+    .filter((ak: PartitionedDagRunAssetResponse) => ak.received_count > 0)
     .map((ak: PartitionedDagRunAssetResponse) => ({
       id: ak.asset_id,
-      lastUpdate: "received",
+      lastUpdate: ak.received ? "received" : null,
       name: ak.asset_name,
+      receivedCount: ak.received_count,
+      requiredCount: ak.required_count,
       uri: ak.asset_uri,
     }));

Review Comment:
   `lastUpdate` is set to the string literal `'received'`, but `AssetNode` 
renders `<Time datetime={event.lastUpdate} />` when `lastUpdate` is truthy. 
Passing `'received'` as a datetime will likely render an invalid date/time or 
throw depending on the `Time` component. Prefer setting `lastUpdate` to a real 
timestamp (if available), or leave it `null/undefined` and drive the “fully 
received” visual state exclusively from `receivedCount`/`requiredCount`.



##########
airflow-core/tests/unit/partition_mappers/test_temporal.py:
##########
@@ -57,29 +57,34 @@ def test_to_downstream(
         ],
     )
     @pytest.mark.parametrize(
-        ("mapper_cls", "expected_outut_format"),
+        ("mapper_cls", "expected_outut_format", "extra_kwargs"),
         [
-            (StartOfHourMapper, "%Y-%m-%dT%H"),
-            (StartOfDayMapper, "%Y-%m-%d"),
-            (StartOfWeekMapper, "%Y-%m-%d (W%V)"),
-            (StartOfMonthMapper, "%Y-%m"),
-            (StartOfQuarterMapper, "%Y-Q{quarter}"),
-            (StartOfYearMapper, "%Y"),
+            (StartOfHourMapper, "%Y-%m-%dT%H", {}),
+            (StartOfDayMapper, "%Y-%m-%d", {}),
+            (StartOfWeekMapper, "%Y-%m-%d (W%V)", {"week_start": 0}),
+            (StartOfMonthMapper, "%Y-%m", {"month_start_day": 1}),
+            (StartOfQuarterMapper, "%Y-Q{quarter}", {}),
+            (StartOfYearMapper, "%Y", {}),
         ],
     )
     def test_serialize(
         self,
         mapper_cls: type[_BaseTemporalMapper],
         expected_outut_format: str,

Review Comment:
   Typo in test parameter name: `expected_outut_format` should be 
`expected_output_format` for clarity.



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py:
##########
@@ -129,30 +157,34 @@ def get_partitioned_dag_runs(
         return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], 
total=0)
 
     if dag_id.value is not None:
-        results = [_build_response(row, required_count) for row in rows]
+        timetable, asset_info = _load_timetable_and_assets(dag_id.value, 
session)
+        results = [
+            _build_response(row, _compute_total_required(timetable, 
asset_info, row.partition_key))
+            for row in rows
+        ]
         return 
PartitionedDagRunCollectionResponse(partitioned_dag_runs=results, 
total=len(results))
 
-    # No dag_id: need to get required counts and expressions per dag
-    dag_ids = list({row.target_dag_id for row in rows})
+    # No dag_id filter: load timetables and assets for each unique DAG
+    unique_dag_ids = list({row.target_dag_id for row in rows})
+    dag_timetables_assets: dict[str, tuple[PartitionedAssetTimetable | None, 
list[tuple[str, str]]]] = {
+        did: _load_timetable_and_assets(did, session) for did in unique_dag_ids
+    }

Review Comment:
   When `dag_id` is not provided, this calls `_load_timetable_and_assets()` 
once per DAG, and `_load_timetable_and_assets()` uses 
`SerializedDagModel.get(...)` (one DB query per DAG). This can become an N+1 
pattern on larger result sets. Consider batch-loading serialized DAGs for 
`unique_dag_ids` (similar to `get_latest_serialized_dags`) and then deriving 
the `PartitionedAssetTimetable` from the preloaded results.



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1773,40 +1774,124 @@ def _do_scheduling(self, session: Session) -> int:
 
         return num_queued_tis
 
+    def _check_rollup_asset_status(
+        self,
+        *,
+        asset_id: int,
+        apdr: AssetPartitionDagRun,
+        mapper: RollupMapper,
+        actual_by_asset: dict[int, set[str]],
+    ) -> bool:
+        if TYPE_CHECKING:
+            assert apdr.partition_key is not None
+        expected = mapper.to_upstream(apdr.partition_key)
+        return expected.issubset(actual_by_asset.get(asset_id, set()))
+
+    def _resolve_asset_partition_status(
+        self,
+        *,
+        asset_id: int,
+        name: str,
+        uri: str,
+        apdr: AssetPartitionDagRun,
+        timetable: PartitionedAssetTimetable,
+        actual_by_asset: dict[int, set[str]],
+    ) -> bool | None:
+        """
+        Return the rollup status for one asset within a pending partitioned 
Dag run.
+
+        Returns *True*/*False* for rollup assets, or *None* when the asset has 
no
+        rollup mapper and should default to satisfied.
+        """
+        try:
+            mapper = timetable.get_partition_mapper(name=name, uri=uri)
+            if not mapper.is_rollup:
+                return None
+            return self._check_rollup_asset_status(
+                asset_id=asset_id,
+                apdr=apdr,
+                mapper=cast("RollupMapper", mapper),
+                actual_by_asset=actual_by_asset,
+            )
+        except Exception:
+            self.log.exception(
+                "Failed to evaluate rollup status for asset; treating as 
not-yet-satisfied. "
+                "This likely indicates a misconfigured partition mapper.",
+                dag_id=apdr.target_dag_id,
+                partition_key=apdr.partition_key,
+                asset_name=name,
+                asset_uri=uri,
+            )
+            return False
+
     def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> 
set[str]:
         partition_dag_ids: set[str] = set()
 
-        evaluator = AssetEvaluator(session)
-        for apdr in session.scalars(
+        pending_apdrs = session.scalars(
             
select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None))
+        ).all()
+        if not pending_apdrs:
+            return partition_dag_ids
+
+        pending_apdr_ids = [apdr.id for apdr in pending_apdrs]
+
+        # Pre-fetch all required serialized Dags in one query.
+        dag_ids = list({apdr.target_dag_id for apdr in pending_apdrs if 
apdr.target_dag_id})
+        # {"dag_id": Serialized Dag}
+        serialized_dags: dict[str, SerializedDAG] = {}
+        for serdag in 
SerializedDagModel.get_latest_serialized_dags(dag_ids=dag_ids, session=session):
+            try:
+                serdag.load_op_links = False
+                serialized_dags[serdag.dag_id] = serdag.dag
+            except Exception:
+                self.log.exception("Failed to deserialize Dag '%s'", 
serdag.dag_id)
+
+        # {apdr_id: {asset_id: set(source_key, ...)}
+        source_key_by_asset_per_apdr: dict[int, dict[int, set[str]]] = 
defaultdict(lambda: defaultdict(set))
+        # {apdr_id: {asset_id: (asset_name, asset_uri)}
+        asset_info_per_apdr: dict[int, dict[int, tuple[str, str]]] = 
defaultdict(dict)
+        for apdr_id, asset_id, source_key, name, uri in session.execute(
+            select(
+                PartitionedAssetKeyLog.asset_partition_dag_run_id,
+                PartitionedAssetKeyLog.asset_id,
+                PartitionedAssetKeyLog.source_partition_key,
+                AssetModel.name,
+                AssetModel.uri,
+            )
+            .join(AssetModel, AssetModel.id == PartitionedAssetKeyLog.asset_id)
+            
.where(PartitionedAssetKeyLog.asset_partition_dag_run_id.in_(pending_apdr_ids))
         ):

Review Comment:
   Previously this logic constrained logs by both `asset_partition_dag_run_id` 
and the APDR’s `target_partition_key` (per the removed query). This new 
prefetch only filters by `asset_partition_dag_run_id`. If 
`PartitionedAssetKeyLog` can contain rows for multiple `target_partition_key` 
values under the same APDR id (or if APDR ids can be reused), rollup 
satisfaction could be computed from unrelated keys. If the extra filter is 
still relevant, add `PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key` (or an equivalent invariant) to preserve correctness.



##########
task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py:
##########
@@ -44,16 +44,42 @@ class StartOfDayMapper(_BaseTemporalMapper):
 
 
 class StartOfWeekMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to week."""
+    """Map a time-based partition key to the start of its week."""
 
     default_output_format = "%Y-%m-%d (W%V)"
 
+    def __init__(self, *, week_start: int = 0, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.week_start = week_start
+
+
+class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper):
+    """
+    Map a time-based partition key to the start of its week, requiring all 7 
daily keys.
+
+    Use this when a partitioned Dag should only run once every daily asset 
partition
+    for a full week has been produced.
+    """
+
 

Review Comment:
   In the Task SDK, `WeeklyRollupMapper`/`MonthlyRollupMapper` inherit from 
`RollupMapper` but do not implement `to_upstream()`. As written, calls to 
`to_upstream()` will fall back to the base implementation and raise 
`NotImplementedError`, even though these are exported as usable mappers. 
Implement `to_upstream()` in these SDK classes (mirroring `airflow-core`), or 
make these classes abstract/not exported until implemented.



##########
task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py:
##########
@@ -44,16 +44,42 @@ class StartOfDayMapper(_BaseTemporalMapper):
 
 
 class StartOfWeekMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to week."""
+    """Map a time-based partition key to the start of its week."""
 
     default_output_format = "%Y-%m-%d (W%V)"
 
+    def __init__(self, *, week_start: int = 0, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.week_start = week_start
+
+
+class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper):
+    """
+    Map a time-based partition key to the start of its week, requiring all 7 
daily keys.
+
+    Use this when a partitioned Dag should only run once every daily asset 
partition
+    for a full week has been produced.
+    """
+
 
 class StartOfMonthMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to month."""
+    """Map a time-based partition key to the start of its month."""
 
     default_output_format = "%Y-%m"
 
+    def __init__(self, *, month_start_day: int = 1, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.month_start_day = month_start_day
+
+
+class MonthlyRollupMapper(StartOfMonthMapper, RollupMapper):
+    """
+    Map a time-based partition key to the start of its month, requiring all 
daily keys in that month.
+
+    Use this when a partitioned Dag should only run once every daily asset 
partition
+    for a full calendar month has been produced.
+    """

Review Comment:
   In the Task SDK, `WeeklyRollupMapper`/`MonthlyRollupMapper` inherit from 
`RollupMapper` but do not implement `to_upstream()`. As written, calls to 
`to_upstream()` will fall back to the base implementation and raise 
`NotImplementedError`, even though these are exported as usable mappers. 
Implement `to_upstream()` in these SDK classes (mirroring `airflow-core`), or 
make these classes abstract/not exported until implemented.



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py:
##########
@@ -193,13 +225,17 @@ def get_pending_partitioned_dag_run(
             f"No PartitionedDagRun for dag={dag_id} partition={partition_key}",
         )
 
-    received_subq = (
-        select(PartitionedAssetKeyLog.asset_id).where(
-            PartitionedAssetKeyLog.asset_partition_dag_run_id == 
partitioned_dag_run.id
+    # Count received PartitionedAssetKeyLog entries per asset for this 
partition
+    received_count_col = (
+        select(func.count(PartitionedAssetKeyLog.id))
+        .where(
+            PartitionedAssetKeyLog.asset_partition_dag_run_id == 
partitioned_dag_run.id,
+            PartitionedAssetKeyLog.asset_id == AssetModel.id,
         )
-    ).correlate(AssetModel)
-
-    received_expr = exists(received_subq.where(PartitionedAssetKeyLog.asset_id 
== AssetModel.id))
+        .correlate(AssetModel)
+        .scalar_subquery()
+        .label("received_count")
+    )

Review Comment:
   Same issue as the list endpoint: `count(PartitionedAssetKeyLog.id)` can be 
inflated by duplicates and cause `received_count >= required_count` to become 
true incorrectly. Prefer counting distinct upstream keys (e.g. 
`distinct(source_partition_key)`) so per-asset progress and `received` are 
accurate.



##########
task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py:
##########
@@ -16,10 +16,31 @@
 # under the License.
 from __future__ import annotations
 
+from abc import ABC
+
 
 class PartitionMapper:
     """
     Base partition mapper class.
 
     Maps keys from asset events to target dag run partitions.
     """
+
+    is_rollup: bool = False
+
+
+class RollupMapper(PartitionMapper, ABC):
+    """
+    Partition mapper that supports rollup (many upstream keys → one downstream 
key).
+
+    Subclass this when the downstream Dag should wait for a complete set of 
upstream
+    partition keys before triggering. The scheduler calls ``to_upstream`` to 
discover
+    which source keys are required and only creates a Dag run once all of them 
have
+    arrived in ``PartitionedAssetKeyLog``.
+    """
+
+    is_rollup: bool = True
+
+    def to_upstream(self, downstream_key: str) -> frozenset[str]:
+        """Return the complete set of upstream partition keys required for 
*downstream_key*."""
+        raise NotImplementedError

Review Comment:
   In the Task SDK, `RollupMapper.to_upstream()` is not marked abstract, so 
subclasses are not forced to implement it at type-check time and can be 
instantiated while failing at runtime. Consider making `to_upstream()` an 
`@abstractmethod` (as done in `airflow-core`) so incomplete rollup mappers are 
caught earlier.



##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -99,27 +99,135 @@ def normalize(self, dt: datetime) -> datetime:
 
 
 class StartOfWeekMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to week."""
+    """Map a time-based partition key to the start of its week."""
 
     default_output_format = "%Y-%m-%d (W%V)"
 
+    def __init__(
+        self,
+        *,
+        week_start: int = 0,
+        timezone: str | Timezone | FixedTimezone = "UTC",
+        input_format: str = "%Y-%m-%dT%H:%M:%S",
+        output_format: str | None = None,
+    ) -> None:
+        super().__init__(timezone=timezone, input_format=input_format, 
output_format=output_format)

Review Comment:
   `week_start` and `month_start_day` are used in date arithmetic, but there’s 
no input validation. Invalid values (e.g. `week_start=7` or 
`month_start_day=31`) will either produce incorrect rollups or raise 
`ValueError` later (e.g. `.replace(day=...)`). Add validation in `__init__` 
(e.g. `0 <= week_start <= 6` and `1 <= month_start_day <= 28`) with a clear 
error message.
   ```suggestion
           super().__init__(timezone=timezone, input_format=input_format, 
output_format=output_format)
           if not 0 <= week_start <= 6:
               raise ValueError(
                   f"week_start must be an integer between 0 (Monday) and 6 
(Sunday), got {week_start}"
               )
   ```



##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -99,27 +99,135 @@ def normalize(self, dt: datetime) -> datetime:
 
 
 class StartOfWeekMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to week."""
+    """Map a time-based partition key to the start of its week."""
 
     default_output_format = "%Y-%m-%d (W%V)"
 
+    def __init__(
+        self,
+        *,
+        week_start: int = 0,
+        timezone: str | Timezone | FixedTimezone = "UTC",
+        input_format: str = "%Y-%m-%dT%H:%M:%S",
+        output_format: str | None = None,
+    ) -> None:
+        super().__init__(timezone=timezone, input_format=input_format, 
output_format=output_format)
+        self.week_start = week_start  # 0 = Monday (ISO default), 6 = Sunday
+
     def normalize(self, dt: datetime) -> datetime:
-        start = dt - timedelta(days=dt.weekday())
+        days_since_start = (dt.weekday() - self.week_start) % 7
+        start = dt - timedelta(days=days_since_start)
         return start.replace(hour=0, minute=0, second=0, microsecond=0)
 
+    def serialize(self) -> dict[str, Any]:
+        return {**super().serialize(), "week_start": self.week_start}
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
+        return cls(
+            week_start=data.get("week_start", 0),
+            timezone=parse_timezone(data.get("timezone", "UTC")),
+            input_format=data["input_format"],
+            output_format=data["output_format"],
+        )
+
+
+class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper):
+    """
+    Map a time-based partition key to the start of its week, requiring all 7 
daily keys.
+
+    Use this when a partitioned Dag should only run once every daily asset 
partition
+    for a full week has been produced. Configure ``week_start`` to set which 
day begins
+    the week (0 = Monday, 6 = Sunday).
+    """
+
+    def __init__(self, **kwargs) -> None:
+        super().__init__(**kwargs)
+        if "%Y-%m-%d" not in self.output_format:
+            raise ValueError(
+                f"WeeklyRollupMapper requires output_format to contain 
'%Y-%m-%d' so that "
+                f"to_upstream() can recover the week-start date, got: 
{self.output_format!r}"
+            )
+
+    def to_upstream(self, downstream_key: str) -> frozenset[str]:
+        # Python strptime raises ValueError when %V (ISO week number) appears 
without
+        # %G and a weekday directive, so we cannot parse via the full 
output_format.
+        # Instead, locate %Y-%m-%d in the format string — __init__ guarantees 
it is
+        # present — and parse only the matching 10-char slice of the key.
+        # The prefix before %Y-%m-%d is literal text (no format directives), 
so its
+        # length in the format string equals its length in the formatted 
output.
+        ymd_fmt = "%Y-%m-%d"
+        key_start = len(self.output_format[: 
self.output_format.index(ymd_fmt)])
+        week_start_naive = datetime.strptime(downstream_key[key_start : 
key_start + 10], ymd_fmt)

Review Comment:
   The `key_start` calculation assumes everything before `%Y-%m-%d` in 
`output_format` is literal text with identical length in the formatted output. 
That breaks if the prefix contains strftime directives (e.g. `%B`, `%z`, etc.), 
causing `downstream_key` slicing to parse the wrong substring. Options: (1) 
constrain `output_format` more strictly (e.g. require it starts with `%Y-%m-%d` 
and raise otherwise), or (2) locate the date substring by searching 
`downstream_key` (e.g. regex for `\\d{4}-\\d{2}-\\d{2}`) rather than inferring 
offsets from the format string.



##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -99,27 +99,135 @@ def normalize(self, dt: datetime) -> datetime:
 
 
 class StartOfWeekMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to week."""
+    """Map a time-based partition key to the start of its week."""
 
     default_output_format = "%Y-%m-%d (W%V)"
 
+    def __init__(
+        self,
+        *,
+        week_start: int = 0,
+        timezone: str | Timezone | FixedTimezone = "UTC",
+        input_format: str = "%Y-%m-%dT%H:%M:%S",
+        output_format: str | None = None,
+    ) -> None:
+        super().__init__(timezone=timezone, input_format=input_format, 
output_format=output_format)
+        self.week_start = week_start  # 0 = Monday (ISO default), 6 = Sunday
+
     def normalize(self, dt: datetime) -> datetime:
-        start = dt - timedelta(days=dt.weekday())
+        days_since_start = (dt.weekday() - self.week_start) % 7
+        start = dt - timedelta(days=days_since_start)
         return start.replace(hour=0, minute=0, second=0, microsecond=0)
 
+    def serialize(self) -> dict[str, Any]:
+        return {**super().serialize(), "week_start": self.week_start}
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
+        return cls(
+            week_start=data.get("week_start", 0),
+            timezone=parse_timezone(data.get("timezone", "UTC")),
+            input_format=data["input_format"],
+            output_format=data["output_format"],
+        )
+
+
+class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper):
+    """
+    Map a time-based partition key to the start of its week, requiring all 7 
daily keys.
+
+    Use this when a partitioned Dag should only run once every daily asset 
partition
+    for a full week has been produced. Configure ``week_start`` to set which 
day begins
+    the week (0 = Monday, 6 = Sunday).
+    """
+
+    def __init__(self, **kwargs) -> None:
+        super().__init__(**kwargs)
+        if "%Y-%m-%d" not in self.output_format:
+            raise ValueError(
+                f"WeeklyRollupMapper requires output_format to contain 
'%Y-%m-%d' so that "
+                f"to_upstream() can recover the week-start date, got: 
{self.output_format!r}"
+            )
+
+    def to_upstream(self, downstream_key: str) -> frozenset[str]:
+        # Python strptime raises ValueError when %V (ISO week number) appears 
without
+        # %G and a weekday directive, so we cannot parse via the full 
output_format.
+        # Instead, locate %Y-%m-%d in the format string — __init__ guarantees 
it is
+        # present — and parse only the matching 10-char slice of the key.
+        # The prefix before %Y-%m-%d is literal text (no format directives), 
so its
+        # length in the format string equals its length in the formatted 
output.
+        ymd_fmt = "%Y-%m-%d"
+        key_start = len(self.output_format[: 
self.output_format.index(ymd_fmt)])
+        week_start_naive = datetime.strptime(downstream_key[key_start : 
key_start + 10], ymd_fmt)
+        # Arithmetic stays on naive datetimes to keep day-counting unambiguous 
across
+        # DST transitions; each result is made timezone-aware before 
formatting so that
+        # %z in input_format produces the correct offset.
+        return frozenset(
+            make_aware(week_start_naive + timedelta(days=i), 
self._timezone).strftime(self.input_format)
+            for i in range(7)
+        )
+
 
 class StartOfMonthMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to month."""
+    """Map a time-based partition key to the start of its month."""
 
     default_output_format = "%Y-%m"
 
+    def __init__(
+        self,
+        *,
+        month_start_day: int = 1,
+        timezone: str | Timezone | FixedTimezone = "UTC",
+        input_format: str = "%Y-%m-%dT%H:%M:%S",
+        output_format: str | None = None,
+    ) -> None:
+        super().__init__(timezone=timezone, input_format=input_format, 
output_format=output_format)

Review Comment:
   `week_start` and `month_start_day` are used in date arithmetic, but there’s 
no input validation. Invalid values (e.g. `week_start=7` or 
`month_start_day=31`) will either produce incorrect rollups or raise 
`ValueError` later (e.g. `.replace(day=...)`). Add validation in `__init__` 
(e.g. `0 <= week_start <= 6` and `1 <= month_start_day <= 28`) with a clear 
error message.
   ```suggestion
           super().__init__(timezone=timezone, input_format=input_format, 
output_format=output_format)
           if not 1 <= month_start_day <= 28:
               raise ValueError(
                   f"month_start_day must be between 1 and 28 inclusive, got: 
{month_start_day!r}"
               )
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to