This is an automated email from the ASF dual-hosted git repository.
rahulvats pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d5e5f432d86 feat: Add ChainMapper (#64094)
d5e5f432d86 is described below
commit d5e5f432d860d01b2028e825f2519096e0518e54
Author: Wei Lee <[email protected]>
AuthorDate: Mon Mar 23 19:56:17 2026 +0800
feat: Add ChainMapper (#64094)
* feat: Add SequenceMapper
* refactor: Rename SequenceMapper as ChainMapper
* fixup! refactor: Rename SequenceMapper as ChainMapper
* fixup! fixup! refactor: Rename SequenceMapper as ChainMapper
* fixup! fixup! fixup! refactor: Rename SequenceMapper as ChainMapper
---
.../src/airflow/partition_mappers/chain.py | 73 ++++++++++++++++++++++
airflow-core/src/airflow/serialization/encoders.py | 6 ++
.../tests/unit/partition_mappers/test_chain.py | 71 +++++++++++++++++++++
.../unit/serialization/test_serialized_objects.py | 44 +++++++++++++
task-sdk/docs/api.rst | 2 +
task-sdk/src/airflow/sdk/__init__.py | 3 +
task-sdk/src/airflow/sdk/__init__.pyi | 2 +
.../sdk/definitions/partition_mappers/chain.py | 32 ++++++++++
8 files changed, 233 insertions(+)
diff --git a/airflow-core/src/airflow/partition_mappers/chain.py
b/airflow-core/src/airflow/partition_mappers/chain.py
new file mode 100644
index 00000000000..850517482cf
--- /dev/null
+++ b/airflow-core/src/airflow/partition_mappers/chain.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import Any
+
+from airflow.partition_mappers.base import PartitionMapper
+
+
+class ChainMapper(PartitionMapper):
+ """Partition mapper that applies multiple mappers sequentially."""
+
+ def __init__(
+ self,
+ mapper0: PartitionMapper,
+ mapper1: PartitionMapper,
+ /,
+ *mappers: PartitionMapper,
+ ) -> None:
+ self.mappers = [mapper0, mapper1, *mappers]
+
+ def to_downstream(self, key: str) -> str | Iterable[str]:
+ keys: list[str] = [key]
+ for mapper in self.mappers:
+ next_keys: list[str] = []
+ for current_key in keys:
+ mapped = mapper.to_downstream(current_key)
+ if not isinstance(mapped, (str, Iterable)):
+ raise TypeError(
+ f"ChainMapper child mappers must return a string or
iterable of strings, "
+ f"but {type(mapper).__name__} returned
{type(mapped).__name__}"
+ )
+
+ if isinstance(mapped, str):
+ next_keys.append(mapped)
+ elif isinstance(mapped, Iterable):
+ for mapped_key in mapped:
+ if not isinstance(mapped_key, str):
+ raise TypeError(
+ f"ChainMapper child mappers must return an
iterable of strings, "
+ f"but {type(mapper).__name__} yielded
{type(mapped_key).__name__}"
+ )
+ next_keys.append(mapped_key)
+ keys = next_keys
+ return keys[0] if len(keys) == 1 else keys
+
+ def serialize(self) -> dict[str, Any]:
+ from airflow.serialization.encoders import encode_partition_mapper
+
+ return {"mappers": [encode_partition_mapper(m) for m in self.mappers]}
+
+ @classmethod
+ def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
+ from airflow.serialization.decoders import decode_partition_mapper
+
+ mappers = [decode_partition_mapper(m) for m in data["mappers"]]
+ return cls(*mappers)
diff --git a/airflow-core/src/airflow/serialization/encoders.py
b/airflow-core/src/airflow/serialization/encoders.py
index 506cdb07cdb..eca84065888 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -34,6 +34,7 @@ from airflow.sdk import (
AssetAll,
AssetAny,
AssetOrTimeSchedule,
+ ChainMapper,
CronDataIntervalTimetable,
CronTriggerTimetable,
DeltaDataIntervalTimetable,
@@ -392,6 +393,7 @@ class _Serializer:
}
BUILTIN_PARTITION_MAPPERS: dict[type, str] = {
+ ChainMapper: "airflow.partition_mappers.chain.ChainMapper",
IdentityMapper: "airflow.partition_mappers.identity.IdentityMapper",
ToHourlyMapper: "airflow.partition_mappers.temporal.ToHourlyMapper",
ToDailyMapper: "airflow.partition_mappers.temporal.ToDailyMapper",
@@ -411,6 +413,10 @@ class _Serializer:
raise NotImplementedError(f"can not serialize timetable
{type(partition_mapper).__name__}")
return partition_mapper.serialize()
+ @serialize_partition_mapper.register
+ def _(self, partition_mapper: ChainMapper) -> dict[str, Any]:
+ return {"mappers": [encode_partition_mapper(m) for m in
partition_mapper.mappers]}
+
@serialize_partition_mapper.register
def _(self, partition_mapper: IdentityMapper) -> dict[str, Any]:
return {}
diff --git a/airflow-core/tests/unit/partition_mappers/test_chain.py
b/airflow-core/tests/unit/partition_mappers/test_chain.py
new file mode 100644
index 00000000000..cdce301deb0
--- /dev/null
+++ b/airflow-core/tests/unit/partition_mappers/test_chain.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.partition_mappers.base import PartitionMapper
+from airflow.partition_mappers.chain import ChainMapper
+from airflow.partition_mappers.identity import IdentityMapper
+from airflow.partition_mappers.temporal import ToDailyMapper, ToHourlyMapper
+
+
+class _InvalidReturnMapper(PartitionMapper):
+ def to_downstream(self, key: str) -> None: # type: ignore[override]
+ return None
+
+
+class _InvalidIterableMapper(PartitionMapper):
+ def to_downstream(self, key: str) -> list[None]: # type: ignore[override]
+ return [None]
+
+
+class TestChainMapper:
+ def test_to_downstream(self):
+ sm = ChainMapper(ToHourlyMapper(),
ToDailyMapper(input_format="%Y-%m-%dT%H"))
+ assert sm.to_downstream("2024-01-15T10:30:00") == "2024-01-15"
+
+ def test_to_downstream_invalid_non_iterable_return(self):
+ sm = ChainMapper(IdentityMapper(), _InvalidReturnMapper())
+ with pytest.raises(TypeError, match="must return a string or iterable
of strings"):
+ sm.to_downstream("key")
+
+ def test_to_downstream_invalid_iterable_contents(self):
+ sm = ChainMapper(IdentityMapper(), _InvalidIterableMapper())
+ with pytest.raises(TypeError, match="must return an iterable of
strings"):
+ sm.to_downstream("key")
+
+ def test_serialize(self):
+ from airflow.serialization.encoders import encode_partition_mapper
+
+ sm = ChainMapper(ToHourlyMapper(),
ToDailyMapper(input_format="%Y-%m-%dT%H"))
+ result = sm.serialize()
+ assert result == {
+ "mappers": [
+ encode_partition_mapper(ToHourlyMapper()),
+
encode_partition_mapper(ToDailyMapper(input_format="%Y-%m-%dT%H")),
+ ],
+ }
+
+ def test_deserialize(self):
+ sm = ChainMapper(ToHourlyMapper(),
ToDailyMapper(input_format="%Y-%m-%dT%H"))
+ serialized = sm.serialize()
+ restored = ChainMapper.deserialize(serialized)
+ assert isinstance(restored, ChainMapper)
+ assert len(restored.mappers) == 2
+ assert restored.to_downstream("2024-01-15T10:30:00") == "2024-01-15"
diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py
b/airflow-core/tests/unit/serialization/test_serialized_objects.py
index 514abb618aa..7be81c2fe3b 100644
--- a/airflow-core/tests/unit/serialization/test_serialized_objects.py
+++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py
@@ -895,6 +895,50 @@ def test_decode_product_mapper():
assert core_pm.to_downstream("2024-06-15T10:30:00|2024-06-15T10:30:00") ==
"2024-06-15T10|2024-06-15"
+def test_encode_chain_mapper():
+ from airflow.sdk import ChainMapper, ToDailyMapper, ToHourlyMapper
+ from airflow.serialization.encoders import encode_partition_mapper
+
+ partition_mapper = ChainMapper(ToHourlyMapper(),
ToDailyMapper(input_format="%Y-%m-%dT%H"))
+ assert encode_partition_mapper(partition_mapper) == {
+ Encoding.TYPE: "airflow.partition_mappers.chain.ChainMapper",
+ Encoding.VAR: {
+ "mappers": [
+ {
+ Encoding.TYPE:
"airflow.partition_mappers.temporal.ToHourlyMapper",
+ Encoding.VAR: {
+ "input_format": "%Y-%m-%dT%H:%M:%S",
+ "output_format": "%Y-%m-%dT%H",
+ },
+ },
+ {
+ Encoding.TYPE:
"airflow.partition_mappers.temporal.ToDailyMapper",
+ Encoding.VAR: {
+ "input_format": "%Y-%m-%dT%H",
+ "output_format": "%Y-%m-%d",
+ },
+ },
+ ]
+ },
+ }
+
+
+def test_decode_chain_mapper():
+ from airflow.partition_mappers.chain import ChainMapper as CoreChainMapper
+ from airflow.sdk import ChainMapper, ToDailyMapper, ToHourlyMapper
+ from airflow.serialization.decoders import decode_partition_mapper
+ from airflow.serialization.encoders import encode_partition_mapper
+
+ partition_mapper = ChainMapper(ToHourlyMapper(),
ToDailyMapper(input_format="%Y-%m-%dT%H"))
+ encoded_pm = encode_partition_mapper(partition_mapper)
+
+ core_pm = decode_partition_mapper(encoded_pm)
+
+ assert isinstance(core_pm, CoreChainMapper)
+ assert len(core_pm.mappers) == 2
+ assert core_pm.to_downstream("2024-06-15T10:30:00") == "2024-06-15"
+
+
def test_encode_allowed_key_mapper():
from airflow.sdk import AllowedKeyMapper
from airflow.serialization.encoders import encode_partition_mapper
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 5c0e6b07632..03b4d3999ba 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -201,6 +201,8 @@ Partition Mapper
.. autoapiclass:: airflow.sdk.PartitionMapper
+.. autoapiclass:: airflow.sdk.ChainMapper
+
.. autoapiclass:: airflow.sdk.IdentityMapper
.. autoapiclass:: airflow.sdk.ToHourlyMapper
diff --git a/task-sdk/src/airflow/sdk/__init__.py
b/task-sdk/src/airflow/sdk/__init__.py
index 70521464908..8aa55a0e623 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -37,6 +37,7 @@ __all__ = [
"BaseSensorOperator",
"BaseXCom",
"BranchMixIn",
+ "ChainMapper",
"Connection",
"Context",
"CronDataIntervalTimetable",
@@ -125,6 +126,7 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.sdk.definitions.partition_mappers.allowed_key import
AllowedKeyMapper
from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
+ from airflow.sdk.definitions.partition_mappers.chain import ChainMapper
from airflow.sdk.definitions.partition_mappers.identity import
IdentityMapper
from airflow.sdk.definitions.partition_mappers.product import ProductMapper
from airflow.sdk.definitions.partition_mappers.temporal import (
@@ -178,6 +180,7 @@ __lazy_imports: dict[str, str] = {
"BaseSensorOperator": ".bases.sensor",
"BaseXCom": ".bases.xcom",
"BranchMixIn": ".bases.branch",
+ "ChainMapper": ".definitions.partition_mappers.chain",
"Connection": ".definitions.connection",
"Context": ".definitions.context",
"CronDataIntervalTimetable": ".definitions.timetables.interval",
diff --git a/task-sdk/src/airflow/sdk/__init__.pyi
b/task-sdk/src/airflow/sdk/__init__.pyi
index 905500c6164..a898ff27a70 100644
--- a/task-sdk/src/airflow/sdk/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/__init__.pyi
@@ -64,6 +64,7 @@ from airflow.sdk.definitions.edges import EdgeModifier as
EdgeModifier, Label as
from airflow.sdk.definitions.param import Param as Param
from airflow.sdk.definitions.partition_mappers.allowed_key import
AllowedKeyMapper
from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
+from airflow.sdk.definitions.partition_mappers.chain import ChainMapper
from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper
from airflow.sdk.definitions.partition_mappers.product import ProductMapper
from airflow.sdk.definitions.partition_mappers.temporal import (
@@ -117,6 +118,7 @@ __all__ = [
"BaseSensorOperator",
"BaseXCom",
"BranchMixIn",
+ "ChainMapper",
"Connection",
"Context",
"CronDataIntervalTimetable",
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
new file mode 100644
index 00000000000..0f96e2d2eaf
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
+
+
+class ChainMapper(PartitionMapper):
+ """Partition mapper that applies multiple mappers sequentially."""
+
+ def __init__(
+ self,
+ mapper0: PartitionMapper,
+ mapper1: PartitionMapper,
+ /,
+ *mappers: PartitionMapper,
+ ) -> None:
+ self.mappers = [mapper0, mapper1, *mappers]