Copilot commented on code in PR #64571:
URL: https://github.com/apache/airflow/pull/64571#discussion_r3066486627
##########
airflow-core/tests/unit/partition_mappers/test_temporal.py:
##########
@@ -57,29 +57,34 @@ def test_to_downstream(
],
)
@pytest.mark.parametrize(
- ("mapper_cls", "expected_outut_format"),
+ ("mapper_cls", "expected_outut_format", "extra_kwargs"),
Review Comment:
Typo in test parameter name: `expected_outut_format` should be
`expected_output_format` for clarity.
##########
airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py:
##########
@@ -115,6 +119,62 @@ def next_run_assets(
if not event.pop("queued", None):
event["lastUpdate"] = None
+ # For partitioned DAGs: enrich events with per-asset received/required
counts,
+ # using to_upstream for rollup mappers, and fix lastUpdate for partial
receipt.
+ if is_partitioned:
+ pending_apdr = session.execute(
+ select(AssetPartitionDagRun.id, AssetPartitionDagRun.partition_key)
+ .where(
+ AssetPartitionDagRun.target_dag_id == dag_id,
+ AssetPartitionDagRun.created_dag_run_id.is_(None),
+ )
+ .order_by(AssetPartitionDagRun.created_at.desc())
+ .limit(1)
+ ).one_or_none()
+
+ if pending_apdr is not None:
+ # Count received log entries per asset for this partition
+ received_by_asset: dict[int, int] = dict(
+ session.execute(
+ select(
+ PartitionedAssetKeyLog.asset_id,
+ func.count(PartitionedAssetKeyLog.id).label("cnt"),
+ )
+ .where(PartitionedAssetKeyLog.asset_partition_dag_run_id
== pending_apdr.id)
+ .group_by(PartitionedAssetKeyLog.asset_id)
+ ).all()
Review Comment:
This per-asset `received_by_asset` count also counts raw log rows, which can
over-count if the same upstream partition key is logged more than once. To keep
UI progress consistent with rollup expectations, count distinct upstream keys
(typically `source_partition_key`) per asset.
##########
airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx:
##########
@@ -39,11 +39,13 @@ export const AssetProgressCell = ({ dagId, partitionKey,
totalReceived, totalReq
const assets: Array<PartitionedDagRunAssetResponse> = data?.assets ?? [];
const events: Array<NextRunEvent> = assets
- .filter((ak: PartitionedDagRunAssetResponse) => ak.received)
+ .filter((ak: PartitionedDagRunAssetResponse) => ak.received_count > 0)
.map((ak: PartitionedDagRunAssetResponse) => ({
id: ak.asset_id,
- lastUpdate: "received",
+ lastUpdate: ak.received ? "received" : null,
name: ak.asset_name,
+ receivedCount: ak.received_count,
+ requiredCount: ak.required_count,
uri: ak.asset_uri,
}));
Review Comment:
`lastUpdate` is set to the string literal `'received'`, but `AssetNode`
renders `<Time datetime={event.lastUpdate} />` when `lastUpdate` is truthy.
Passing `'received'` as a datetime will likely render an invalid date/time or
throw depending on the `Time` component. Prefer setting `lastUpdate` to a real
timestamp (if available), or leave it `null/undefined` and drive the “fully
received” visual state exclusively from `receivedCount`/`requiredCount`.
##########
airflow-core/tests/unit/partition_mappers/test_temporal.py:
##########
@@ -57,29 +57,34 @@ def test_to_downstream(
],
)
@pytest.mark.parametrize(
- ("mapper_cls", "expected_outut_format"),
+ ("mapper_cls", "expected_outut_format", "extra_kwargs"),
[
- (StartOfHourMapper, "%Y-%m-%dT%H"),
- (StartOfDayMapper, "%Y-%m-%d"),
- (StartOfWeekMapper, "%Y-%m-%d (W%V)"),
- (StartOfMonthMapper, "%Y-%m"),
- (StartOfQuarterMapper, "%Y-Q{quarter}"),
- (StartOfYearMapper, "%Y"),
+ (StartOfHourMapper, "%Y-%m-%dT%H", {}),
+ (StartOfDayMapper, "%Y-%m-%d", {}),
+ (StartOfWeekMapper, "%Y-%m-%d (W%V)", {"week_start": 0}),
+ (StartOfMonthMapper, "%Y-%m", {"month_start_day": 1}),
+ (StartOfQuarterMapper, "%Y-Q{quarter}", {}),
+ (StartOfYearMapper, "%Y", {}),
],
)
def test_serialize(
self,
mapper_cls: type[_BaseTemporalMapper],
expected_outut_format: str,
Review Comment:
Typo in test parameter name: `expected_outut_format` should be
`expected_output_format` for clarity.
##########
airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py:
##########
@@ -129,30 +157,34 @@ def get_partitioned_dag_runs(
return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[],
total=0)
if dag_id.value is not None:
- results = [_build_response(row, required_count) for row in rows]
+ timetable, asset_info = _load_timetable_and_assets(dag_id.value,
session)
+ results = [
+ _build_response(row, _compute_total_required(timetable,
asset_info, row.partition_key))
+ for row in rows
+ ]
return
PartitionedDagRunCollectionResponse(partitioned_dag_runs=results,
total=len(results))
- # No dag_id: need to get required counts and expressions per dag
- dag_ids = list({row.target_dag_id for row in rows})
+ # No dag_id filter: load timetables and assets for each unique DAG
+ unique_dag_ids = list({row.target_dag_id for row in rows})
+ dag_timetables_assets: dict[str, tuple[PartitionedAssetTimetable | None,
list[tuple[str, str]]]] = {
+ did: _load_timetable_and_assets(did, session) for did in unique_dag_ids
+ }
Review Comment:
When `dag_id` is not provided, this calls `_load_timetable_and_assets()`
once per DAG, and `_load_timetable_and_assets()` uses
`SerializedDagModel.get(...)` (one DB query per DAG). This can become an N+1
pattern on larger result sets. Consider batch-loading serialized DAGs for
`unique_dag_ids` (similar to `get_latest_serialized_dags`) and then deriving
the `PartitionedAssetTimetable` from the preloaded results.
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1773,40 +1774,124 @@ def _do_scheduling(self, session: Session) -> int:
return num_queued_tis
+ def _check_rollup_asset_status(
+ self,
+ *,
+ asset_id: int,
+ apdr: AssetPartitionDagRun,
+ mapper: RollupMapper,
+ actual_by_asset: dict[int, set[str]],
+ ) -> bool:
+ if TYPE_CHECKING:
+ assert apdr.partition_key is not None
+ expected = mapper.to_upstream(apdr.partition_key)
+ return expected.issubset(actual_by_asset.get(asset_id, set()))
+
+ def _resolve_asset_partition_status(
+ self,
+ *,
+ asset_id: int,
+ name: str,
+ uri: str,
+ apdr: AssetPartitionDagRun,
+ timetable: PartitionedAssetTimetable,
+ actual_by_asset: dict[int, set[str]],
+ ) -> bool | None:
+ """
+ Return the rollup status for one asset within a pending partitioned
Dag run.
+
+ Returns *True*/*False* for rollup assets, or *None* when the asset has
no
+ rollup mapper and should default to satisfied.
+ """
+ try:
+ mapper = timetable.get_partition_mapper(name=name, uri=uri)
+ if not mapper.is_rollup:
+ return None
+ return self._check_rollup_asset_status(
+ asset_id=asset_id,
+ apdr=apdr,
+ mapper=cast("RollupMapper", mapper),
+ actual_by_asset=actual_by_asset,
+ )
+ except Exception:
+ self.log.exception(
+ "Failed to evaluate rollup status for asset; treating as
not-yet-satisfied. "
+ "This likely indicates a misconfigured partition mapper.",
+ dag_id=apdr.target_dag_id,
+ partition_key=apdr.partition_key,
+ asset_name=name,
+ asset_uri=uri,
+ )
+ return False
+
def _create_dagruns_for_partitioned_asset_dags(self, session: Session) ->
set[str]:
partition_dag_ids: set[str] = set()
- evaluator = AssetEvaluator(session)
- for apdr in session.scalars(
+ pending_apdrs = session.scalars(
select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None))
+ ).all()
+ if not pending_apdrs:
+ return partition_dag_ids
+
+ pending_apdr_ids = [apdr.id for apdr in pending_apdrs]
+
+ # Pre-fetch all required serialized Dags in one query.
+ dag_ids = list({apdr.target_dag_id for apdr in pending_apdrs if
apdr.target_dag_id})
+ # {"dag_id": Serialized Dag}
+ serialized_dags: dict[str, SerializedDAG] = {}
+ for serdag in
SerializedDagModel.get_latest_serialized_dags(dag_ids=dag_ids, session=session):
+ try:
+ serdag.load_op_links = False
+ serialized_dags[serdag.dag_id] = serdag.dag
+ except Exception:
+ self.log.exception("Failed to deserialize Dag '%s'",
serdag.dag_id)
+
+ # {apdr_id: {asset_id: set(source_key, ...)}
+ source_key_by_asset_per_apdr: dict[int, dict[int, set[str]]] =
defaultdict(lambda: defaultdict(set))
+ # {apdr_id: {asset_id: (asset_name, asset_uri)}
+ asset_info_per_apdr: dict[int, dict[int, tuple[str, str]]] =
defaultdict(dict)
+ for apdr_id, asset_id, source_key, name, uri in session.execute(
+ select(
+ PartitionedAssetKeyLog.asset_partition_dag_run_id,
+ PartitionedAssetKeyLog.asset_id,
+ PartitionedAssetKeyLog.source_partition_key,
+ AssetModel.name,
+ AssetModel.uri,
+ )
+ .join(AssetModel, AssetModel.id == PartitionedAssetKeyLog.asset_id)
+
.where(PartitionedAssetKeyLog.asset_partition_dag_run_id.in_(pending_apdr_ids))
):
Review Comment:
Previously this logic constrained logs by both `asset_partition_dag_run_id`
and the APDR’s `target_partition_key` (per the removed query). This new
prefetch only filters by `asset_partition_dag_run_id`. If
`PartitionedAssetKeyLog` can contain rows for multiple `target_partition_key`
values under the same APDR id (or if APDR ids can be reused), rollup
satisfaction could be computed from unrelated keys. If the extra filter is
still relevant, add `PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key` (or an equivalent invariant) to preserve correctness.
##########
task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py:
##########
@@ -44,16 +44,42 @@ class StartOfDayMapper(_BaseTemporalMapper):
class StartOfWeekMapper(_BaseTemporalMapper):
- """Map a time-based partition key to week."""
+ """Map a time-based partition key to the start of its week."""
default_output_format = "%Y-%m-%d (W%V)"
+ def __init__(self, *, week_start: int = 0, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.week_start = week_start
+
+
+class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper):
+ """
+ Map a time-based partition key to the start of its week, requiring all 7
daily keys.
+
+ Use this when a partitioned Dag should only run once every daily asset
partition
+ for a full week has been produced.
+ """
+
Review Comment:
In the Task SDK, `WeeklyRollupMapper`/`MonthlyRollupMapper` inherit from
`RollupMapper` but do not implement `to_upstream()`. As written, calls to
`to_upstream()` will fall back to the base implementation and raise
`NotImplementedError`, even though these are exported as usable mappers.
Implement `to_upstream()` in these SDK classes (mirroring `airflow-core`), or
make these classes abstract/not exported until implemented.
##########
task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py:
##########
@@ -44,16 +44,42 @@ class StartOfDayMapper(_BaseTemporalMapper):
class StartOfWeekMapper(_BaseTemporalMapper):
- """Map a time-based partition key to week."""
+ """Map a time-based partition key to the start of its week."""
default_output_format = "%Y-%m-%d (W%V)"
+ def __init__(self, *, week_start: int = 0, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.week_start = week_start
+
+
+class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper):
+ """
+ Map a time-based partition key to the start of its week, requiring all 7
daily keys.
+
+ Use this when a partitioned Dag should only run once every daily asset
partition
+ for a full week has been produced.
+ """
+
class StartOfMonthMapper(_BaseTemporalMapper):
- """Map a time-based partition key to month."""
+ """Map a time-based partition key to the start of its month."""
default_output_format = "%Y-%m"
+ def __init__(self, *, month_start_day: int = 1, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.month_start_day = month_start_day
+
+
+class MonthlyRollupMapper(StartOfMonthMapper, RollupMapper):
+ """
+ Map a time-based partition key to the start of its month, requiring all
daily keys in that month.
+
+ Use this when a partitioned Dag should only run once every daily asset
partition
+ for a full calendar month has been produced.
+ """
Review Comment:
In the Task SDK, `WeeklyRollupMapper`/`MonthlyRollupMapper` inherit from
`RollupMapper` but do not implement `to_upstream()`. As written, calls to
`to_upstream()` will fall back to the base implementation and raise
`NotImplementedError`, even though these are exported as usable mappers.
Implement `to_upstream()` in these SDK classes (mirroring `airflow-core`), or
make these classes abstract/not exported until implemented.
##########
airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py:
##########
@@ -193,13 +225,17 @@ def get_pending_partitioned_dag_run(
f"No PartitionedDagRun for dag={dag_id} partition={partition_key}",
)
- received_subq = (
- select(PartitionedAssetKeyLog.asset_id).where(
- PartitionedAssetKeyLog.asset_partition_dag_run_id ==
partitioned_dag_run.id
+ # Count received PartitionedAssetKeyLog entries per asset for this
partition
+ received_count_col = (
+ select(func.count(PartitionedAssetKeyLog.id))
+ .where(
+ PartitionedAssetKeyLog.asset_partition_dag_run_id ==
partitioned_dag_run.id,
+ PartitionedAssetKeyLog.asset_id == AssetModel.id,
)
- ).correlate(AssetModel)
-
- received_expr = exists(received_subq.where(PartitionedAssetKeyLog.asset_id
== AssetModel.id))
+ .correlate(AssetModel)
+ .scalar_subquery()
+ .label("received_count")
+ )
Review Comment:
Same issue as the list endpoint: `count(PartitionedAssetKeyLog.id)` can be
inflated by duplicates and cause `received_count >= required_count` to become
true incorrectly. Prefer counting distinct upstream keys (e.g.
`distinct(source_partition_key)`) so per-asset progress and `received` are
accurate.
##########
task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py:
##########
@@ -16,10 +16,31 @@
# under the License.
from __future__ import annotations
+from abc import ABC
+
class PartitionMapper:
"""
Base partition mapper class.
Maps keys from asset events to target dag run partitions.
"""
+
+ is_rollup: bool = False
+
+
+class RollupMapper(PartitionMapper, ABC):
+ """
+ Partition mapper that supports rollup (many upstream keys → one downstream
key).
+
+ Subclass this when the downstream Dag should wait for a complete set of
upstream
+ partition keys before triggering. The scheduler calls ``to_upstream`` to
discover
+ which source keys are required and only creates a Dag run once all of them
have
+ arrived in ``PartitionedAssetKeyLog``.
+ """
+
+ is_rollup: bool = True
+
+ def to_upstream(self, downstream_key: str) -> frozenset[str]:
+ """Return the complete set of upstream partition keys required for
*downstream_key*."""
+ raise NotImplementedError
Review Comment:
In the Task SDK, `RollupMapper.to_upstream()` is not marked abstract, so
subclasses are not forced to implement it at type-check time and can be
instantiated while failing at runtime. Consider making `to_upstream()` an
`@abstractmethod` (as done in `airflow-core`) so incomplete rollup mappers are
caught earlier.
##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -99,27 +99,135 @@ def normalize(self, dt: datetime) -> datetime:
class StartOfWeekMapper(_BaseTemporalMapper):
- """Map a time-based partition key to week."""
+ """Map a time-based partition key to the start of its week."""
default_output_format = "%Y-%m-%d (W%V)"
+ def __init__(
+ self,
+ *,
+ week_start: int = 0,
+ timezone: str | Timezone | FixedTimezone = "UTC",
+ input_format: str = "%Y-%m-%dT%H:%M:%S",
+ output_format: str | None = None,
+ ) -> None:
+ super().__init__(timezone=timezone, input_format=input_format,
output_format=output_format)
Review Comment:
`week_start` and `month_start_day` are used in date arithmetic, but there’s
no input validation. Invalid values (e.g. `week_start=7` or
`month_start_day=31`) will either produce incorrect rollups or raise
`ValueError` later (e.g. `.replace(day=...)`). Add validation in `__init__`
(e.g. `0 <= week_start <= 6` and `1 <= month_start_day <= 28`) with a clear
error message.
```suggestion
super().__init__(timezone=timezone, input_format=input_format,
output_format=output_format)
if not 0 <= week_start <= 6:
raise ValueError(
f"week_start must be an integer between 0 (Monday) and 6
(Sunday), got {week_start}"
)
```
##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -99,27 +99,135 @@ def normalize(self, dt: datetime) -> datetime:
class StartOfWeekMapper(_BaseTemporalMapper):
- """Map a time-based partition key to week."""
+ """Map a time-based partition key to the start of its week."""
default_output_format = "%Y-%m-%d (W%V)"
+ def __init__(
+ self,
+ *,
+ week_start: int = 0,
+ timezone: str | Timezone | FixedTimezone = "UTC",
+ input_format: str = "%Y-%m-%dT%H:%M:%S",
+ output_format: str | None = None,
+ ) -> None:
+ super().__init__(timezone=timezone, input_format=input_format,
output_format=output_format)
+ self.week_start = week_start # 0 = Monday (ISO default), 6 = Sunday
+
def normalize(self, dt: datetime) -> datetime:
- start = dt - timedelta(days=dt.weekday())
+ days_since_start = (dt.weekday() - self.week_start) % 7
+ start = dt - timedelta(days=days_since_start)
return start.replace(hour=0, minute=0, second=0, microsecond=0)
+ def serialize(self) -> dict[str, Any]:
+ return {**super().serialize(), "week_start": self.week_start}
+
+ @classmethod
+ def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
+ return cls(
+ week_start=data.get("week_start", 0),
+ timezone=parse_timezone(data.get("timezone", "UTC")),
+ input_format=data["input_format"],
+ output_format=data["output_format"],
+ )
+
+
+class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper):
+ """
+ Map a time-based partition key to the start of its week, requiring all 7
daily keys.
+
+ Use this when a partitioned Dag should only run once every daily asset
partition
+ for a full week has been produced. Configure ``week_start`` to set which
day begins
+ the week (0 = Monday, 6 = Sunday).
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ if "%Y-%m-%d" not in self.output_format:
+ raise ValueError(
+ f"WeeklyRollupMapper requires output_format to contain
'%Y-%m-%d' so that "
+ f"to_upstream() can recover the week-start date, got:
{self.output_format!r}"
+ )
+
+ def to_upstream(self, downstream_key: str) -> frozenset[str]:
+ # Python strptime raises ValueError when %V (ISO week number) appears
without
+ # %G and a weekday directive, so we cannot parse via the full
output_format.
+ # Instead, locate %Y-%m-%d in the format string — __init__ guarantees
it is
+ # present — and parse only the matching 10-char slice of the key.
+ # The prefix before %Y-%m-%d is literal text (no format directives),
so its
+ # length in the format string equals its length in the formatted
output.
+ ymd_fmt = "%Y-%m-%d"
+ key_start = len(self.output_format[:
self.output_format.index(ymd_fmt)])
+ week_start_naive = datetime.strptime(downstream_key[key_start :
key_start + 10], ymd_fmt)
Review Comment:
The `key_start` calculation assumes everything before `%Y-%m-%d` in
`output_format` is literal text with identical length in the formatted output.
That breaks if the prefix contains strftime directives (e.g. `%B`, `%z`, etc.),
causing `downstream_key` slicing to parse the wrong substring. Options: (1)
constrain `output_format` more strictly (e.g. require it starts with `%Y-%m-%d`
and raise otherwise), or (2) locate the date substring by searching
`downstream_key` (e.g. regex for `\\d{4}-\\d{2}-\\d{2}`) rather than inferring
offsets from the format string.
##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -99,27 +99,135 @@ def normalize(self, dt: datetime) -> datetime:
class StartOfWeekMapper(_BaseTemporalMapper):
- """Map a time-based partition key to week."""
+ """Map a time-based partition key to the start of its week."""
default_output_format = "%Y-%m-%d (W%V)"
+ def __init__(
+ self,
+ *,
+ week_start: int = 0,
+ timezone: str | Timezone | FixedTimezone = "UTC",
+ input_format: str = "%Y-%m-%dT%H:%M:%S",
+ output_format: str | None = None,
+ ) -> None:
+ super().__init__(timezone=timezone, input_format=input_format,
output_format=output_format)
+ self.week_start = week_start # 0 = Monday (ISO default), 6 = Sunday
+
def normalize(self, dt: datetime) -> datetime:
- start = dt - timedelta(days=dt.weekday())
+ days_since_start = (dt.weekday() - self.week_start) % 7
+ start = dt - timedelta(days=days_since_start)
return start.replace(hour=0, minute=0, second=0, microsecond=0)
+ def serialize(self) -> dict[str, Any]:
+ return {**super().serialize(), "week_start": self.week_start}
+
+ @classmethod
+ def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
+ return cls(
+ week_start=data.get("week_start", 0),
+ timezone=parse_timezone(data.get("timezone", "UTC")),
+ input_format=data["input_format"],
+ output_format=data["output_format"],
+ )
+
+
+class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper):
+ """
+ Map a time-based partition key to the start of its week, requiring all 7
daily keys.
+
+ Use this when a partitioned Dag should only run once every daily asset
partition
+ for a full week has been produced. Configure ``week_start`` to set which
day begins
+ the week (0 = Monday, 6 = Sunday).
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ if "%Y-%m-%d" not in self.output_format:
+ raise ValueError(
+ f"WeeklyRollupMapper requires output_format to contain
'%Y-%m-%d' so that "
+ f"to_upstream() can recover the week-start date, got:
{self.output_format!r}"
+ )
+
+ def to_upstream(self, downstream_key: str) -> frozenset[str]:
+ # Python strptime raises ValueError when %V (ISO week number) appears
without
+ # %G and a weekday directive, so we cannot parse via the full
output_format.
+ # Instead, locate %Y-%m-%d in the format string — __init__ guarantees
it is
+ # present — and parse only the matching 10-char slice of the key.
+ # The prefix before %Y-%m-%d is literal text (no format directives),
so its
+ # length in the format string equals its length in the formatted
output.
+ ymd_fmt = "%Y-%m-%d"
+ key_start = len(self.output_format[:
self.output_format.index(ymd_fmt)])
+ week_start_naive = datetime.strptime(downstream_key[key_start :
key_start + 10], ymd_fmt)
+ # Arithmetic stays on naive datetimes to keep day-counting unambiguous
across
+ # DST transitions; each result is made timezone-aware before
formatting so that
+ # %z in input_format produces the correct offset.
+ return frozenset(
+ make_aware(week_start_naive + timedelta(days=i),
self._timezone).strftime(self.input_format)
+ for i in range(7)
+ )
+
class StartOfMonthMapper(_BaseTemporalMapper):
- """Map a time-based partition key to month."""
+ """Map a time-based partition key to the start of its month."""
default_output_format = "%Y-%m"
+ def __init__(
+ self,
+ *,
+ month_start_day: int = 1,
+ timezone: str | Timezone | FixedTimezone = "UTC",
+ input_format: str = "%Y-%m-%dT%H:%M:%S",
+ output_format: str | None = None,
+ ) -> None:
+ super().__init__(timezone=timezone, input_format=input_format,
output_format=output_format)
Review Comment:
`week_start` and `month_start_day` are used in date arithmetic, but there’s
no input validation. Invalid values (e.g. `week_start=7` or
`month_start_day=31`) will either produce incorrect rollups or raise
`ValueError` later (e.g. `.replace(day=...)`). Add validation in `__init__`
(e.g. `0 <= week_start <= 6` and `1 <= month_start_day <= 28`) with a clear
error message.
```suggestion
super().__init__(timezone=timezone, input_format=input_format,
output_format=output_format)
if not 1 <= month_start_day <= 28:
raise ValueError(
f"month_start_day must be between 1 and 28 inclusive, got:
{month_start_day!r}"
)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]