This is an automated email from the ASF dual-hosted git repository.
Lee-W pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a8c99107c9f feat(AIP-76): support forward fan-out via forward kwarg on
Window (#67475)
a8c99107c9f is described below
commit a8c99107c9fda3664979c1eba34a5c7e3e0944b3
Author: Wei Lee <[email protected]>
AuthorDate: Wed Jun 10 16:26:49 2026 +0800
feat(AIP-76): support forward fan-out via forward kwarg on Window (#67475)
---
.pre-commit-config.yaml | 10 +
airflow-core/newsfragments/67475.feature.rst | 1 +
.../src/airflow/partition_mappers/temporal.py | 16 +-
.../src/airflow/partition_mappers/window.py | 74 +++++--
airflow-core/src/airflow/serialization/encoders.py | 2 +-
.../tests/unit/partition_mappers/test_fan_out.py | 78 ++++++++
.../tests/unit/partition_mappers/test_window.py | 171 ++++++++++++++--
.../unit/serialization/test_serialized_objects.py | 6 +-
scripts/ci/prek/check_window_in_sync.py | 216 +++++++++++++++++++++
.../sdk/definitions/partition_mappers/temporal.py | 15 +-
.../sdk/definitions/partition_mappers/window.py | 31 ++-
.../task_sdk/definitions/test_partition_mappers.py | 30 +++
12 files changed, 613 insertions(+), 37 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index e3c1b6faf47..1de10325dd6 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -258,6 +258,16 @@ repos:
^task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal\.py$
pass_filenames: false
require_serial: true
+ - id: check-window-in-sync
+ name: Check Window definitions stay in sync (core/SDK)
+ entry: ./scripts/ci/prek/check_window_in_sync.py
+ language: python
+ files: >
+ (?x)
+ ^airflow-core/src/airflow/partition_mappers/window\.py$|
+ ^task-sdk/src/airflow/sdk/definitions/partition_mappers/window\.py$
+ pass_filenames: false
+ require_serial: true
- id: sync-uv-min-version-markers
name: Sync `# sync-uv-min-version` markers with [tool.uv]
required-version
entry: ./scripts/ci/prek/sync_uv_min_version_markers.py
diff --git a/airflow-core/newsfragments/67475.feature.rst
b/airflow-core/newsfragments/67475.feature.rst
new file mode 100644
index 00000000000..ea0b227b0d5
--- /dev/null
+++ b/airflow-core/newsfragments/67475.feature.rst
@@ -0,0 +1 @@
+``Window`` subclasses accept a ``direction`` keyword —
``Window.Direction.FORWARD`` (default) fans out the period starting at the
upstream key (forward in time); pass ``direction=Window.Direction.BACKWARD``
(e.g. ``WeekWindow(direction=Window.Direction.BACKWARD)``) to fan out the
trailing period ending at the upstream key instead.
diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py
b/airflow-core/src/airflow/partition_mappers/temporal.py
index 18656761122..4fc1ed45ed3 100644
--- a/airflow-core/src/airflow/partition_mappers/temporal.py
+++ b/airflow-core/src/airflow/partition_mappers/temporal.py
@@ -443,10 +443,21 @@ class FanOutMapper(PartitionMapper):
is N→1 (downstream waits until all members arrive), fan-out is 1→N (one
upstream event creates one downstream Dag run per member).
+ For forward fan-out (emit the trailing period ending at the upstream key,
+ instead of the period it represents), pass
``direction=Window.Direction.FORWARD``
+ to the window:
+
.. code-block:: python
- # Weekly upstream → 7 daily downstream Dag runs
+ from airflow.partition_mappers import WeekWindow, Window
+ from airflow.partition_mappers.temporal import FanOutMapper,
StartOfWeekMapper
+
+ # Weekly upstream → 7 daily downstream Dag runs (the 7 days the
upstream Monday represents)
FanOutMapper(upstream_mapper=StartOfWeekMapper(), window=WeekWindow())
+
+ # Weekly upstream → the 7 days ending at the upstream Monday (trailing
period)
+ forward_window = WeekWindow(direction=Window.Direction.FORWARD)
+ FanOutMapper(upstream_mapper=StartOfWeekMapper(),
window=forward_window)
"""
# Keep ``FanOutMapper.default_downstream_mapper_by_window_name`` in sync
with
@@ -476,8 +487,7 @@ class FanOutMapper(PartitionMapper):
the SDK ``Window`` classes (used in Dag-author code) and the core
``Window`` classes (used after deserialization) both resolve to the
same default. Subclasses can extend or override the defaults by
- setting :attr:`default_downstream_mapper_by_window_name` on the
- subclass.
+ setting :attr:`default_downstream_mapper_by_window_name` on the
subclass.
"""
mapper_cls =
cls.default_downstream_mapper_by_window_name.get(type(window).__name__)
if mapper_cls is None:
diff --git a/airflow-core/src/airflow/partition_mappers/window.py
b/airflow-core/src/airflow/partition_mappers/window.py
index f0c92d458d5..087b6180f2a 100644
--- a/airflow-core/src/airflow/partition_mappers/window.py
+++ b/airflow-core/src/airflow/partition_mappers/window.py
@@ -18,10 +18,11 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
+from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar
if TYPE_CHECKING:
- from collections.abc import Iterable
+ from collections.abc import Callable, Iterable
def _require_day_one(dt: datetime, window_cls: type) -> None:
@@ -55,6 +56,27 @@ def _shift_months(dt: datetime, months: int) -> datetime:
return dt.replace(year=dt.year + total // 12, month=total % 12 + 1)
+def _build_directional_steps(
+ period_start: datetime,
+ count: int,
+ step: Callable[[datetime, int], datetime],
+ direction: Window.Direction,
+) -> Iterable[datetime]:
+ """
+ Enumerate *count* period-starts beginning at or ending at *period_start*.
+
+ *step* maps ``(base, i) -> base`` advanced by ``i`` units (e.g. ``i``
minutes
+ for an hour window, ``i`` months for a year window). For ``FORWARD`` the
+ sequence starts at *period_start*; for ``BACKWARD`` it is the trailing
+ sequence whose last member is *period_start* (the mirror of ``FORWARD``),
+ computed by stepping the base back ``count - 1`` units first. Callers that
+ need a day-1 precondition must enforce it before calling this — it is not
+ checked here.
+ """
+ base = step(period_start, -(count - 1)) if direction is
Window.Direction.BACKWARD else period_start
+ return (step(base, i) for i in range(count))
+
+
class Window(ABC):
"""
Describes a rollup window: which decoded upstream items make up one
decoded downstream period.
@@ -82,22 +104,34 @@ class Window(ABC):
mapper.to_downstream(upstream_key)`` holds.
"""
+ class Direction(str, Enum):
+ """Direction of a :class:`Window` fan-out relative to the upstream
key."""
+
+ BACKWARD = "backward"
+ """Yield the trailing period ending at the upstream key (the mirror of
FORWARD)."""
+
+ FORWARD = "forward"
+ """Default; yield the period starting at the upstream key (forward in
time)."""
+
#: Type that ``to_upstream`` expects as its ``decoded_downstream``
argument.
#: ``RollupMapper.__init__`` uses this to reject pairings where the
upstream
#: mapper's ``decode_downstream`` leaves the value as ``str`` (base
identity)
#: but the window needs a different type. Temporal windows declare
``datetime``.
expected_decoded_type: ClassVar[type] = str
+ def __init__(self, *, direction: Window.Direction = Direction.FORWARD) ->
None:
+ self.direction = self.Direction(direction)
+
@abstractmethod
def to_upstream(self, decoded_downstream: Any) -> Iterable[Any]:
"""Yield each decoded upstream item composing *decoded_downstream*."""
def serialize(self) -> dict[str, Any]:
- return {}
+ return {"direction": self.direction.value}
@classmethod
def deserialize(cls, data: dict[str, Any]) -> Window:
- return cls()
+ return cls(direction=cls.Direction(data["direction"]))
class HourWindow(Window):
@@ -106,7 +140,9 @@ class HourWindow(Window):
expected_decoded_type: ClassVar[type] = datetime
def to_upstream(self, period_start: datetime) -> Iterable[datetime]:
- return (period_start + timedelta(minutes=i) for i in range(60))
+ return _build_directional_steps(
+ period_start, 60, lambda s, i: s + timedelta(minutes=i),
self.direction
+ )
class DayWindow(Window):
@@ -138,12 +174,16 @@ class DayWindow(Window):
**Mitigation**: use UTC ``input_format`` (e.g. ``%Y-%m-%dT%H%z``) and
ensure upstream producers emit UTC partition keys so local-clock
ambiguity never arises.
+
+ The same 24-hour-stride assumption applies to
``DayWindow(direction=Window.Direction.BACKWARD)``:
+ the 24 members are enumerated as naive hourly steps ending at the
anchor, not as
+ a step back to the "previous calendar day" in local time.
"""
expected_decoded_type: ClassVar[type] = datetime
def to_upstream(self, period_start: datetime) -> Iterable[datetime]:
- return (period_start + timedelta(hours=i) for i in range(24))
+ return _build_directional_steps(period_start, 24, lambda s, i: s +
timedelta(hours=i), self.direction)
class WeekWindow(Window):
@@ -152,7 +192,7 @@ class WeekWindow(Window):
expected_decoded_type: ClassVar[type] = datetime
def to_upstream(self, period_start: datetime) -> Iterable[datetime]:
- return (period_start + timedelta(days=i) for i in range(7))
+ return _build_directional_steps(period_start, 7, lambda s, i: s +
timedelta(days=i), self.direction)
class MonthWindow(Window):
@@ -168,10 +208,22 @@ class MonthWindow(Window):
expected_decoded_type: ClassVar[type] = datetime
def to_upstream(self, period_start: datetime) -> Iterable[datetime]:
+ # Not expressible via _build_directional_steps: the member count is
not fixed (28-31)
+ # and BACKWARD is an open-closed (prev_month_start, anchor] generator,
not a
+ # shift-then-forward mirror of FORWARD.
_require_day_one(period_start, type(self))
- next_month = period_start.month % 12 + 1
- next_year = period_start.year + (1 if period_start.month == 12 else 0)
- next_start = period_start.replace(year=next_year, month=next_month)
+ if self.direction is Window.Direction.BACKWARD:
+ # Backward yields the trailing period ending at the anchor
(period_start),
+ # analogous to WeekWindow BACKWARD which yields the 7 days ending
at the
+ # anchor rather than a calendar week. The members are the
open-closed
+ # interval (prev_month_start, anchor] — every day from the day
after the
+ # previous month's 1st up to and including anchor itself. This
does NOT
+ # align to a calendar month: anchor=Mar 1 yields Feb 2…Mar 1 (29
days in
+ # 2024), not the full calendar February.
+ prev = _shift_months(period_start, -1)
+ days = (period_start - prev).days
+ return (prev + timedelta(days=i + 1) for i in range(days))
+ next_start = _shift_months(period_start, 1)
days = (next_start - period_start).days
return (period_start + timedelta(days=i) for i in range(days))
@@ -183,7 +235,7 @@ class QuarterWindow(Window):
def to_upstream(self, period_start: datetime) -> Iterable[datetime]:
_require_day_one(period_start, type(self))
- return (_shift_months(period_start, i) for i in range(3))
+ return _build_directional_steps(period_start, 3, _shift_months,
self.direction)
class YearWindow(Window):
@@ -193,4 +245,4 @@ class YearWindow(Window):
def to_upstream(self, period_start: datetime) -> Iterable[datetime]:
_require_day_one(period_start, type(self))
- return (_shift_months(period_start, i) for i in range(12))
+ return _build_directional_steps(period_start, 12, _shift_months,
self.direction)
diff --git a/airflow-core/src/airflow/serialization/encoders.py
b/airflow-core/src/airflow/serialization/encoders.py
index d82b9cd359e..f4ec081df1b 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -536,7 +536,7 @@ class _Serializer:
self,
window: HourWindow | DayWindow | WeekWindow | MonthWindow |
QuarterWindow | YearWindow,
) -> dict[str, Any]:
- return {}
+ return window.serialize()
_serializer = _Serializer()
diff --git a/airflow-core/tests/unit/partition_mappers/test_fan_out.py
b/airflow-core/tests/unit/partition_mappers/test_fan_out.py
index 683931a84f2..17a455a4bbc 100644
--- a/airflow-core/tests/unit/partition_mappers/test_fan_out.py
+++ b/airflow-core/tests/unit/partition_mappers/test_fan_out.py
@@ -36,6 +36,7 @@ from airflow.partition_mappers.window import (
MonthWindow,
QuarterWindow,
WeekWindow,
+ Window,
YearWindow,
)
@@ -233,3 +234,80 @@ class TestFanOutMapper:
assert isinstance(restored.window, window_cls)
assert isinstance(restored.downstream_mapper, expected_downstream_cls)
assert list(restored.to_downstream(upstream_key)) ==
list(mapper.to_downstream(upstream_key))
+
+ @pytest.mark.parametrize(
+ ("direction", "expected"),
+ [
+ pytest.param(
+ Window.Direction.FORWARD,
+ [
+ "2024-03-04",
+ "2024-03-05",
+ "2024-03-06",
+ "2024-03-07",
+ "2024-03-08",
+ "2024-03-09",
+ "2024-03-10",
+ ],
+ id="forward",
+ ),
+ pytest.param(
+ Window.Direction.BACKWARD,
+ [
+ "2024-02-27",
+ "2024-02-28",
+ "2024-02-29",
+ "2024-03-01",
+ "2024-03-02",
+ "2024-03-03",
+ "2024-03-04",
+ ],
+ id="backward",
+ ),
+ ],
+ )
+ def test_fan_out_with_directional_window(self, direction, expected):
+ """WeekWindow direction selects the period relative to the upstream
Monday.
+
+ 2024-03-04 is a Monday; StartOfWeekMapper normalises it to itself.
+ FORWARD yields the 7 days starting at that Monday (03-04 Mon … 03-10
Sun);
+ BACKWARD yields the trailing 7 days ending at it (02-27 Tue … 03-04
Mon,
+ including the leap-year Feb 29).
+ """
+ mapper = FanOutMapper(
+ upstream_mapper=StartOfWeekMapper(),
+ window=WeekWindow(direction=direction),
+ downstream_mapper=StartOfDayMapper(),
+ )
+ result = list(mapper.to_downstream("2024-03-04T00:00:00"))
+ assert result == expected
+
+ @pytest.mark.parametrize(
+ "direction",
+ [Window.Direction.FORWARD, Window.Direction.BACKWARD],
+ )
+ def
test_fan_out_with_directional_window_resolves_default_downstream_mapper(self,
direction):
+ """A directional WeekWindow is still a WeekWindow — default downstream
lookup works unchanged."""
+ mapper = FanOutMapper(
+ upstream_mapper=StartOfWeekMapper(),
+ window=WeekWindow(direction=direction),
+ )
+ assert isinstance(mapper.downstream_mapper, StartOfDayMapper)
+
+ @pytest.mark.parametrize(
+ "direction",
+ [Window.Direction.FORWARD, Window.Direction.BACKWARD],
+ )
+ def test_fan_out_with_directional_window_serialize_roundtrip(self,
direction):
+ """A directional WeekWindow survives serialize → deserialize
(direction and output preserved)."""
+ mapper = FanOutMapper(
+ upstream_mapper=StartOfWeekMapper(),
+ window=WeekWindow(direction=direction),
+ )
+ restored = FanOutMapper.deserialize(mapper.serialize())
+ assert isinstance(restored, FanOutMapper)
+ assert isinstance(restored.window, WeekWindow)
+ assert restored.window.direction is direction
+ assert list(restored.to_downstream("2024-03-04T00:00:00")) == list(
+ mapper.to_downstream("2024-03-04T00:00:00")
+ )
diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py
b/airflow-core/tests/unit/partition_mappers/test_window.py
index f32b9423579..ddfd44cbd76 100644
--- a/airflow-core/tests/unit/partition_mappers/test_window.py
+++ b/airflow-core/tests/unit/partition_mappers/test_window.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from datetime import datetime
+from datetime import datetime, timedelta
import pytest
@@ -35,8 +35,13 @@ from airflow.partition_mappers.window import (
MonthWindow,
QuarterWindow,
WeekWindow,
+ Window,
YearWindow,
)
+from airflow.serialization.decoders import decode_partition_mapper,
decode_window
+from airflow.serialization.encoders import encode_partition_mapper,
encode_window
+from airflow.serialization.enums import Encoding
+from airflow.serialization.helpers import WindowNotSupported
class TestHourWindow:
@@ -146,6 +151,126 @@ class TestYearWindow:
list(YearWindow().to_upstream(datetime(2024, 1, 31)))
+class TestDirection:
+ def test_default_direction_is_forward(self):
+ assert WeekWindow().direction is Window.Direction.FORWARD
+
+ @pytest.mark.parametrize(
+ ("window", "anchor", "expected"),
+ [
+ pytest.param(
+ HourWindow(direction=Window.Direction.BACKWARD),
+ datetime(2024, 3, 4, 0),
+ [datetime(2024, 3, 4, 0) - timedelta(minutes=59) +
timedelta(minutes=i) for i in range(60)],
+ id="hour",
+ ),
+ pytest.param(
+ DayWindow(direction=Window.Direction.BACKWARD),
+ datetime(2024, 3, 4),
+ [datetime(2024, 3, 4) - timedelta(hours=23) +
timedelta(hours=i) for i in range(24)],
+ id="day",
+ ),
+ pytest.param(
+ WeekWindow(direction=Window.Direction.BACKWARD),
+ datetime(2024, 3, 4), # Monday
+ [datetime(2024, 2, 27) + timedelta(days=i) for i in range(7)],
+ id="week",
+ ),
+ pytest.param(
+ MonthWindow(direction=Window.Direction.BACKWARD),
+ datetime(2024, 3, 1),
+ # 2024-02 has 29 days (leap year); trailing period = Feb 2 …
Mar 1 (29 members)
+ [datetime(2024, 2, d) for d in range(2, 30)] + [datetime(2024,
3, 1)],
+ id="month_backward_trailing",
+ ),
+ pytest.param(
+ QuarterWindow(direction=Window.Direction.BACKWARD),
+ datetime(2024, 1, 1),
+ [datetime(2023, 11, 1), datetime(2023, 12, 1), datetime(2024,
1, 1)],
+ id="quarter_backward_trailing",
+ ),
+ pytest.param(
+ YearWindow(direction=Window.Direction.BACKWARD),
+ datetime(2024, 1, 1),
+ [datetime(2023, m, 1) for m in range(2, 13)] + [datetime(2024,
1, 1)],
+ id="year_backward_trailing",
+ ),
+ ],
+ )
+ def test_backward_yields_trailing_period_ending_at_anchor(self, window,
anchor, expected):
+ result = list(window.to_upstream(anchor))
+ assert result == expected
+ assert result[-1] == anchor
+
+ def test_month_backward_trailing_not_calendar_month(self):
+ """Month backward yields the trailing period (prev_month_start,
anchor], not a calendar month."""
+ anchor = datetime(2024, 3, 1)
+ result =
list(MonthWindow(direction=Window.Direction.BACKWARD).to_upstream(anchor))
+ # Must include anchor (Mar 1) and must NOT include prev_month_start
(Feb 1)
+ assert datetime(2024, 3, 1) in result
+ assert datetime(2024, 2, 1) not in result
+
+ @pytest.mark.parametrize(
+ "window",
+ [
+ pytest.param(MonthWindow(direction=Window.Direction.BACKWARD),
id="month"),
+ pytest.param(QuarterWindow(direction=Window.Direction.BACKWARD),
id="quarter"),
+ pytest.param(YearWindow(direction=Window.Direction.BACKWARD),
id="year"),
+ ],
+ )
+ def test_backward_on_non_day_one_raises(self, window):
+ """_require_day_one fires before the backward shift — non-day-1 +
BACKWARD direction still raises."""
+ with pytest.raises(ValueError, match="expects a period start on day
1"):
+ list(window.to_upstream(datetime(2024, 3, 15)))
+
+ @pytest.mark.parametrize(
+ "window",
+ [
+ pytest.param(HourWindow(direction=Window.Direction.FORWARD),
id="hour"),
+ pytest.param(DayWindow(direction=Window.Direction.FORWARD),
id="day"),
+ pytest.param(WeekWindow(direction=Window.Direction.FORWARD),
id="week"),
+ pytest.param(MonthWindow(direction=Window.Direction.FORWARD),
id="month"),
+ pytest.param(QuarterWindow(direction=Window.Direction.FORWARD),
id="quarter"),
+ pytest.param(YearWindow(direction=Window.Direction.FORWARD),
id="year"),
+ ],
+ )
+ def test_serialize_roundtrip_with_forward(self, window):
+ """Window.Direction.FORWARD survives serialize → deserialize;
behaviour is identical."""
+ restored = decode_window(encode_window(window))
+ assert restored.direction is Window.Direction.FORWARD
+ assert type(restored) is type(window)
+ if isinstance(window, WeekWindow):
+ anchor = datetime(2024, 3, 4)
+ assert list(restored.to_upstream(anchor)) ==
list(window.to_upstream(anchor))
+
+ @pytest.mark.parametrize(
+ "window_cls",
+ [HourWindow, DayWindow, WeekWindow, MonthWindow, QuarterWindow,
YearWindow],
+ )
+ def test_serialize_roundtrip_backward(self, window_cls):
+ window = window_cls(direction=Window.Direction.BACKWARD)
+ restored = decode_window(encode_window(window))
+ assert restored.direction is Window.Direction.BACKWARD
+ assert type(restored) is window_cls
+ if window_cls is WeekWindow:
+ anchor = datetime(2024, 3, 4)
+ assert list(restored.to_upstream(anchor)) ==
list(window.to_upstream(anchor))
+
+ @pytest.mark.parametrize(
+ ("window_cls", "anchor", "expected_count"),
+ [
+ (HourWindow, datetime(2024, 3, 15, 7, 30), 60),
+ (DayWindow, datetime(2024, 3, 15), 24),
+ (WeekWindow, datetime(2024, 3, 13), 7), # Wed, non-Monday
+ ],
+ ids=["hour", "day", "week"],
+ )
+ def test_backward_non_day_one_does_not_raise(self, window_cls, anchor,
expected_count):
+ window = window_cls(direction=Window.Direction.BACKWARD)
+ result = list(window.to_upstream(anchor))
+ assert len(result) == expected_count
+
+
class TestRollupMapperComposition:
def test_to_downstream_delegates_to_upstream_mapper(self):
mapper = RollupMapper(
@@ -257,9 +382,6 @@ class TestRollupMapperComposition:
assert mapper.to_upstream("2024") == frozenset(f"2024-{m:02d}" for m
in range(1, 13))
def test_serialize_round_trip(self):
- from airflow.serialization.decoders import decode_partition_mapper
- from airflow.serialization.encoders import encode_partition_mapper
-
mapper = RollupMapper(
upstream_mapper=StartOfWeekMapper(input_format="%Y-%m-%d",
output_format="%Y-%m-%d"),
window=WeekWindow(),
@@ -306,9 +428,6 @@ class TestRollupMapperComposition:
],
)
def test_window_serialize_round_trip(self, upstream_factory, window,
downstream_key):
- from airflow.serialization.decoders import decode_partition_mapper
- from airflow.serialization.encoders import encode_partition_mapper
-
mapper = RollupMapper(upstream_mapper=upstream_factory(),
window=window)
restored = decode_partition_mapper(encode_partition_mapper(mapper))
assert isinstance(restored, RollupMapper)
@@ -316,6 +435,36 @@ class TestRollupMapperComposition:
assert restored.to_upstream(downstream_key) ==
mapper.to_upstream(downstream_key)
+class TestDirectionValidation:
+ """Window.__init__ must coerce valid strings and reject invalid ones at
construction time."""
+
+ @pytest.mark.parametrize(
+ ("direction_input", "expected_member"),
+ [
+ pytest.param(Window.Direction.FORWARD, Window.Direction.FORWARD,
id="enum_forward"),
+ pytest.param(Window.Direction.BACKWARD, Window.Direction.BACKWARD,
id="enum_backward"),
+ pytest.param("forward", Window.Direction.FORWARD,
id="str_forward"),
+ pytest.param("backward", Window.Direction.BACKWARD,
id="str_backward"),
+ ],
+ )
+ def test_valid_direction_coerced_to_enum(self, direction_input,
expected_member):
+ window = WeekWindow(direction=direction_input)
+ assert window.direction is expected_member
+
+ @pytest.mark.parametrize(
+ "bad_value",
+ [
+ pytest.param("forwrd", id="typo_forwrd"),
+ pytest.param("backwards", id="typo_backwards"),
+ pytest.param("FORWARD", id="wrong_case"),
+ pytest.param("", id="empty_string"),
+ ],
+ )
+ def test_invalid_direction_raises_value_error(self, bad_value):
+ with pytest.raises(ValueError, match=r"is not a valid
Window\.Direction"):
+ WeekWindow(direction=bad_value)
+
+
class TestWindowSerializationGate:
"""``encode_window`` / ``decode_window`` must reject non-built-in Windows.
@@ -325,10 +474,6 @@ class TestWindowSerializationGate:
"""
def test_encode_rejects_custom_window_subclass(self):
- from airflow.partition_mappers.window import Window
- from airflow.serialization.encoders import encode_window
- from airflow.serialization.helpers import WindowNotSupported
-
class CustomWindow(Window):
def to_upstream(self, decoded_downstream):
return ()
@@ -337,9 +482,5 @@ class TestWindowSerializationGate:
encode_window(CustomWindow())
def test_decode_rejects_non_core_import_path(self):
- from airflow.serialization.decoders import decode_window
- from airflow.serialization.enums import Encoding
- from airflow.serialization.helpers import WindowNotSupported
-
with pytest.raises(WindowNotSupported, match="os.system"):
decode_window({Encoding.TYPE: "os.system", Encoding.VAR: {}})
diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py
b/airflow-core/tests/unit/serialization/test_serialized_objects.py
index 1e60848f2d3..58cb1f77900 100644
--- a/airflow-core/tests/unit/serialization/test_serialized_objects.py
+++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py
@@ -1119,21 +1119,21 @@ def test_encode_fan_out_mapper():
"upstream_mapper": {
Encoding.TYPE:
"airflow.partition_mappers.temporal.StartOfWeekMapper",
Encoding.VAR: {
- "timezone": "UTC",
"input_format": "%Y-%m-%dT%H:%M:%S",
"output_format": "%Y-%m-%d (W%V)",
+ "timezone": "UTC",
},
},
"window": {
Encoding.TYPE: "airflow.partition_mappers.window.WeekWindow",
- Encoding.VAR: {},
+ Encoding.VAR: {"direction": "forward"},
},
"downstream_mapper": {
Encoding.TYPE:
"airflow.partition_mappers.temporal.StartOfDayMapper",
Encoding.VAR: {
- "timezone": "UTC",
"input_format": "%Y-%m-%dT%H:%M:%S",
"output_format": "%Y-%m-%d",
+ "timezone": "UTC",
},
},
},
diff --git a/scripts/ci/prek/check_window_in_sync.py
b/scripts/ci/prek/check_window_in_sync.py
new file mode 100755
index 00000000000..1d7b9510a80
--- /dev/null
+++ b/scripts/ci/prek/check_window_in_sync.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+# "rich>=13.6.0",
+# ]
+# ///
+"""
+Verify the shared surface of ``window.py`` stays in sync between core and the
Task SDK.
+
+The ``Window`` class hierarchy is defined twice — once in core and once in the
SDK —
+because the two hierarchies are independent (the SDK cannot import core). The
SDK copy
+is the author-facing API used at Dag-parse time; the core copy carries the
scheduler-side
+runtime logic. Both must agree on the shared surface so serialization
round-trips and
+direction semantics behave identically wherever the classes are instantiated.
+
+This check parses both files via AST and asserts the following are identical:
+
+- ``Window.Direction`` enum member values
+- The set of shared class names (Direction, Window, HourWindow, DayWindow,
+ WeekWindow, MonthWindow, QuarterWindow, YearWindow)
+- ``Window.__init__`` ``direction`` kwarg default
+- ``Window.__init__`` body, ``Window.serialize`` body, ``Window.deserialize``
body
+- Each subclass ``expected_decoded_type`` ClassVar value
+
+Core-only symbols (``ABC`` base, ``abstractmethod to_upstream``, helper
functions,
+imports, docstrings) are intentionally excluded from the comparison.
+
+Run from the repo root:
+
+ uv run --project scripts python scripts/ci/prek/check_window_in_sync.py
+
+Exits 0 if synced, 1 (with a diff) otherwise.
+"""
+
+from __future__ import annotations
+
+import ast
+import sys
+from pathlib import Path
+
+from common_prek_utils import (
+ AIRFLOW_CORE_SOURCES_PATH,
+ AIRFLOW_TASK_SDK_SOURCES_PATH,
+)
+from rich.console import Console
+
+console = Console(color_system="standard", width=200)
+
+CORE_FILE = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "partition_mappers" /
"window.py"
+SDK_FILE = (
+ AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" /
"partition_mappers" / "window.py"
+)
+
+EXPECTED_CLASS_NAMES = [
+ "Direction",
+ "Window",
+ "HourWindow",
+ "DayWindow",
+ "WeekWindow",
+ "MonthWindow",
+ "QuarterWindow",
+ "YearWindow",
+]
+
+
+def _parse(file_path: Path) -> ast.Module:
+ return ast.parse(file_path.read_text(encoding="utf-8"),
filename=str(file_path))
+
+
+def _find_class(tree: ast.Module, name: str, file_path: Path) -> ast.ClassDef:
+ """Return the ClassDef named *name* from *tree*; raise ValueError if
absent."""
+ for node in ast.walk(tree):
+ if isinstance(node, ast.ClassDef) and node.name == name:
+ return node
+ raise ValueError(f"{file_path}: no class {name!r} found.")
+
+
+def _extract_enum_members(tree: ast.Module, file_path: Path) -> dict[str, str]:
+ """Return the ``{member_name: value}`` dict for ``Direction``."""
+ cls = _find_class(tree, "Direction", file_path)
+ members: dict[str, str] = {}
+ for stmt in cls.body:
+ if (
+ isinstance(stmt, ast.Assign)
+ and len(stmt.targets) == 1
+ and isinstance(stmt.targets[0], ast.Name)
+ and isinstance(stmt.value, ast.Constant)
+ and isinstance(stmt.value.value, str)
+ ):
+ members[stmt.targets[0].id] = stmt.value.value
+ return members
+
+
+def _find_method(
+ class_def: ast.ClassDef, method_name: str, file_path: Path
+) -> ast.FunctionDef | ast.AsyncFunctionDef:
+ """Return the method named *method_name* from *class_def*; raise
ValueError if absent."""
+ for stmt in class_def.body:
+ if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)) and
stmt.name == method_name:
+ return stmt
+ raise ValueError(f"{file_path}: class {class_def.name!r} has no method
{method_name!r}.")
+
+
+def _extract_method_source(class_def: ast.ClassDef, method_name: str,
file_path: Path) -> str:
+ """Return the unparsed body of *method_name* in *class_def*."""
+ method = _find_method(class_def, method_name, file_path)
+ return "\n".join(ast.unparse(stmt) for stmt in method.body)
+
+
+def _extract_direction_default(class_def: ast.ClassDef, file_path: Path) ->
str:
+ """Return the ``ast.unparse`` of the ``direction`` kwarg default in
``Window.__init__``."""
+ init = _find_method(class_def, "__init__", file_path)
+ args = init.args
+ for i, kwarg in enumerate(args.kwonlyargs):
+ if kwarg.arg == "direction":
+ default = args.kw_defaults[i]
+ if default is None:
+ raise ValueError(
+ f"{file_path}: Window.__init__ 'direction' kwarg has no
default — "
+ "expected Direction.FORWARD."
+ )
+ return ast.unparse(default)
+ raise ValueError(f"{file_path}: Window.__init__ has no 'direction' kwonly
parameter.")
+
+
+def _extract_decoded_types(tree: ast.Module, file_path: Path) -> dict[str,
str]:
+ """Return ``{class_name: ast.unparse(expected_decoded_type value)}`` for
the Window class hierarchy."""
+ result: dict[str, str] = {}
+ for name in EXPECTED_CLASS_NAMES:
+ try:
+ cls = _find_class(tree, name, file_path)
+ except ValueError:
+ continue
+ for stmt in cls.body:
+ target_name = None
+ value_node = None
+ if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target,
ast.Name):
+ target_name = stmt.target.id
+ value_node = stmt.value
+ elif (
+ isinstance(stmt, ast.Assign)
+ and len(stmt.targets) == 1
+ and isinstance(stmt.targets[0], ast.Name)
+ ):
+ target_name = stmt.targets[0].id
+ value_node = stmt.value
+ if target_name == "expected_decoded_type" and value_node is not
None:
+ result[name] = ast.unparse(value_node)
+ break
+ return result
+
+
+def _build_surface(file_path: Path) -> dict[str, object]:
+ """Build the comparable surface dict for one ``window.py`` file."""
+ tree = _parse(file_path)
+ present_classes = sorted(
+ name
+ for name in EXPECTED_CLASS_NAMES
+ if any(isinstance(node, ast.ClassDef) and node.name == name for node
in ast.walk(tree))
+ )
+ window_cls = _find_class(tree, "Window", file_path)
+ return {
+ "class_names": present_classes,
+ "enum_members": _extract_enum_members(tree, file_path),
+ "direction_default": _extract_direction_default(window_cls, file_path),
+ "init_body": _extract_method_source(window_cls, "__init__", file_path),
+ "serialize_body": _extract_method_source(window_cls, "serialize",
file_path),
+ "deserialize_body": _extract_method_source(window_cls, "deserialize",
file_path),
+ "decoded_types": _extract_decoded_types(tree, file_path),
+ }
+
+
+def main() -> int:
+ try:
+ core_surface = _build_surface(CORE_FILE)
+ sdk_surface = _build_surface(SDK_FILE)
+ except ValueError as exc:
+ console.print(f"[red]Could not read the window definitions:[/red]
{exc}")
+ return 1
+
+ if core_surface == sdk_surface:
+ return 0
+
+ console.print("[red]window.py shared surface is out of sync between core
and the Task SDK.[/red]\n")
+ all_keys = sorted(set(core_surface) | set(sdk_surface))
+ for key in all_keys:
+ core_val = core_surface.get(key, "<missing>")
+ sdk_val = sdk_surface.get(key, "<missing>")
+ marker = " " if core_val == sdk_val else "->"
+ color = "" if core_val == sdk_val else "[red]"
+ end = "" if core_val == sdk_val else "[/red]"
+ console.print(f"{color}{marker} {key}: core={core_val!r}
sdk={sdk_val!r}{end}")
+ console.print(f"\nMake both copies match:\n core: {CORE_FILE}\n sdk:
{SDK_FILE}")
+ return 1
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
index 200a547c403..3eaf4e19d74 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py
@@ -127,10 +127,20 @@ class FanOutMapper(PartitionMapper):
waits for all members), fan-out is 1→N (one upstream event creates many
downstream Dag runs).
+ For forward fan-out (emit the *next* period's members instead of the
current
+ one), pass ``direction=Window.Direction.FORWARD`` to the window:
+
.. code-block:: python
- # Weekly upstream → 7 daily downstream Dag runs
+ from airflow.sdk import WeekWindow, Window
+ from airflow.sdk.definitions.partition_mappers.temporal import
FanOutMapper, StartOfWeekMapper
+
+ # Weekly upstream → 7 daily downstream Dag runs (current week)
FanOutMapper(upstream_mapper=StartOfWeekMapper(), window=WeekWindow())
+
+ # Weekly upstream → 7 daily keys for the *following* week
+ forward_window = WeekWindow(direction=Window.Direction.FORWARD)
+ FanOutMapper(upstream_mapper=StartOfWeekMapper(),
window=forward_window)
"""
# Keep ``FanOutMapper.default_downstream_mapper_by_window_name`` in sync
with
@@ -160,8 +170,7 @@ class FanOutMapper(PartitionMapper):
the SDK ``Window`` classes (used in Dag-author code) and the core
``Window`` classes (used after deserialization) both resolve to the
same default. Subclasses can extend or override the defaults by
- setting :attr:`default_downstream_mapper_by_window_name` on the
- subclass.
+ setting :attr:`default_downstream_mapper_by_window_name` on the
subclass.
"""
mapper_cls =
cls.default_downstream_mapper_by_window_name.get(type(window).__name__)
if mapper_cls is None:
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
index c7dacad1994..4ea008ace6f 100644
--- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py
@@ -14,10 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+# The SDK and core class hierarchies are independent (the SDK cannot import
core),
+# so both carry the same author-facing definitions; runtime logic lives in the
core
+# copy. The ``check-window-in-sync`` prek hook enforces that the two stay in
sync.
from __future__ import annotations
from datetime import datetime
-from typing import ClassVar
+from enum import Enum
+from typing import Any, ClassVar
class Window:
@@ -42,14 +47,38 @@ class Window:
therefore requires customizing **both** sides consistently so the
invariant ``upstream_key in window.to_upstream(D) ⇔ D in
mapper.to_downstream(upstream_key)`` holds.
+
+ :param direction: ``Window.Direction.FORWARD`` (default) fans out the
period
+ starting at the upstream key (forward in time);
+ ``Window.Direction.BACKWARD`` fans out the trailing period ending at
the
+ upstream key (the mirror of FORWARD).
"""
+ class Direction(str, Enum):
+ """Direction of a :class:`Window` fan-out relative to the upstream
key."""
+
+ BACKWARD = "backward"
+ """Yield the trailing period ending at the upstream key (the mirror of
FORWARD)."""
+
+ FORWARD = "forward"
+ """Default; yield the period starting at the upstream key (forward in
time)."""
+
#: Decoded type the window iterates in; ``RollupMapper.__init__`` uses this
#: to reject pairings where the upstream mapper decodes to a different
type.
#: Default ``str`` matches the identity mapper; temporal windows declare
#: ``datetime``. Mirrors the same attribute on the core ``Window``.
expected_decoded_type: ClassVar[type] = str
+ def __init__(self, *, direction: Window.Direction = Direction.FORWARD) ->
None:
+ self.direction = self.Direction(direction)
+
+ def serialize(self) -> dict[str, Any]:
+ return {"direction": self.direction.value}
+
+ @classmethod
+ def deserialize(cls, data: dict[str, Any]) -> Window:
+ return cls(direction=cls.Direction(data["direction"]))
+
class HourWindow(Window):
"""Sixty consecutive minute keys making up one hour."""
diff --git a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
index 3be4eaf177d..8b902820da0 100644
--- a/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
+++ b/task-sdk/tests/task_sdk/definitions/test_partition_mappers.py
@@ -64,6 +64,36 @@ class TestSdkRollupMapperInit:
RollupMapper(upstream_mapper=_StringOnlyMapper(),
window=_AlphaWindow())
+class TestSdkDirectionValidation:
+ """SDK Window.__init__ must coerce valid strings and reject invalid ones
at construction time."""
+
+ @pytest.mark.parametrize(
+ ("direction_input", "expected_member"),
+ [
+ pytest.param(Window.Direction.FORWARD, Window.Direction.FORWARD,
id="enum_forward"),
+ pytest.param(Window.Direction.BACKWARD, Window.Direction.BACKWARD,
id="enum_backward"),
+ pytest.param("forward", Window.Direction.FORWARD,
id="str_forward"),
+ pytest.param("backward", Window.Direction.BACKWARD,
id="str_backward"),
+ ],
+ )
+ def test_valid_direction_coerced_to_enum(self, direction_input,
expected_member):
+ window = WeekWindow(direction=direction_input)
+ assert window.direction is expected_member
+
+ @pytest.mark.parametrize(
+ "bad_value",
+ [
+ pytest.param("forwrd", id="typo_forwrd"),
+ pytest.param("backwards", id="typo_backwards"),
+ pytest.param("FORWARD", id="wrong_case"),
+ pytest.param("", id="empty_string"),
+ ],
+ )
+ def test_invalid_direction_raises_value_error(self, bad_value):
+ with pytest.raises(ValueError, match=r"is not a valid
Window\.Direction"):
+ WeekWindow(direction=bad_value)
+
+
class TestSdkWindowExpectedDecodedType:
"""Each SDK temporal window must declare ``datetime`` so the validation
lines up with core."""