This is an automated email from the ASF dual-hosted git repository.

rahulvats pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d5e5f432d86 feat: Add ChainMapper (#64094)
d5e5f432d86 is described below

commit d5e5f432d860d01b2028e825f2519096e0518e54
Author: Wei Lee <[email protected]>
AuthorDate: Mon Mar 23 19:56:17 2026 +0800

    feat: Add ChainMapper (#64094)
    
    * feat: Add SequenceMapper
    
    * refactor: Rename SequenceMapper as ChainMapper
    
    * fixup! refactor: Rename SequenceMapper as ChainMapper
    
    * fixup! fixup! refactor: Rename SequenceMapper as ChainMapper
    
    * fixup! fixup! fixup! refactor: Rename SequenceMapper as ChainMapper
---
 .../src/airflow/partition_mappers/chain.py         | 73 ++++++++++++++++++++++
 airflow-core/src/airflow/serialization/encoders.py |  6 ++
 .../tests/unit/partition_mappers/test_chain.py     | 71 +++++++++++++++++++++
 .../unit/serialization/test_serialized_objects.py  | 44 +++++++++++++
 task-sdk/docs/api.rst                              |  2 +
 task-sdk/src/airflow/sdk/__init__.py               |  3 +
 task-sdk/src/airflow/sdk/__init__.pyi              |  2 +
 .../sdk/definitions/partition_mappers/chain.py     | 32 ++++++++++
 8 files changed, 233 insertions(+)

diff --git a/airflow-core/src/airflow/partition_mappers/chain.py 
b/airflow-core/src/airflow/partition_mappers/chain.py
new file mode 100644
index 00000000000..850517482cf
--- /dev/null
+++ b/airflow-core/src/airflow/partition_mappers/chain.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from collections.abc import Iterable
+from typing import Any
+
+from airflow.partition_mappers.base import PartitionMapper
+
+
+class ChainMapper(PartitionMapper):
+    """Partition mapper that applies multiple mappers sequentially."""
+
+    def __init__(
+        self,
+        mapper0: PartitionMapper,
+        mapper1: PartitionMapper,
+        /,
+        *mappers: PartitionMapper,
+    ) -> None:
+        self.mappers = [mapper0, mapper1, *mappers]
+
+    def to_downstream(self, key: str) -> str | Iterable[str]:
+        keys: list[str] = [key]
+        for mapper in self.mappers:
+            next_keys: list[str] = []
+            for current_key in keys:
+                mapped = mapper.to_downstream(current_key)
+                if not isinstance(mapped, (str, Iterable)):
+                    raise TypeError(
+                        f"ChainMapper child mappers must return a string or 
iterable of strings, "
+                        f"but {type(mapper).__name__} returned 
{type(mapped).__name__}"
+                    )
+
+                if isinstance(mapped, str):
+                    next_keys.append(mapped)
+                elif isinstance(mapped, Iterable):
+                    for mapped_key in mapped:
+                        if not isinstance(mapped_key, str):
+                            raise TypeError(
+                                f"ChainMapper child mappers must return an 
iterable of strings, "
+                                f"but {type(mapper).__name__} yielded 
{type(mapped_key).__name__}"
+                            )
+                        next_keys.append(mapped_key)
+            keys = next_keys
+        return keys[0] if len(keys) == 1 else keys
+
+    def serialize(self) -> dict[str, Any]:
+        from airflow.serialization.encoders import encode_partition_mapper
+
+        return {"mappers": [encode_partition_mapper(m) for m in self.mappers]}
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> PartitionMapper:
+        from airflow.serialization.decoders import decode_partition_mapper
+
+        mappers = [decode_partition_mapper(m) for m in data["mappers"]]
+        return cls(*mappers)
diff --git a/airflow-core/src/airflow/serialization/encoders.py 
b/airflow-core/src/airflow/serialization/encoders.py
index 506cdb07cdb..eca84065888 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -34,6 +34,7 @@ from airflow.sdk import (
     AssetAll,
     AssetAny,
     AssetOrTimeSchedule,
+    ChainMapper,
     CronDataIntervalTimetable,
     CronTriggerTimetable,
     DeltaDataIntervalTimetable,
@@ -392,6 +393,7 @@ class _Serializer:
         }
 
     BUILTIN_PARTITION_MAPPERS: dict[type, str] = {
+        ChainMapper: "airflow.partition_mappers.chain.ChainMapper",
         IdentityMapper: "airflow.partition_mappers.identity.IdentityMapper",
         ToHourlyMapper: "airflow.partition_mappers.temporal.ToHourlyMapper",
         ToDailyMapper: "airflow.partition_mappers.temporal.ToDailyMapper",
@@ -411,6 +413,10 @@ class _Serializer:
             raise NotImplementedError(f"can not serialize timetable 
{type(partition_mapper).__name__}")
         return partition_mapper.serialize()
 
+    @serialize_partition_mapper.register
+    def _(self, partition_mapper: ChainMapper) -> dict[str, Any]:
+        return {"mappers": [encode_partition_mapper(m) for m in 
partition_mapper.mappers]}
+
     @serialize_partition_mapper.register
     def _(self, partition_mapper: IdentityMapper) -> dict[str, Any]:
         return {}
diff --git a/airflow-core/tests/unit/partition_mappers/test_chain.py 
b/airflow-core/tests/unit/partition_mappers/test_chain.py
new file mode 100644
index 00000000000..cdce301deb0
--- /dev/null
+++ b/airflow-core/tests/unit/partition_mappers/test_chain.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.partition_mappers.base import PartitionMapper
+from airflow.partition_mappers.chain import ChainMapper
+from airflow.partition_mappers.identity import IdentityMapper
+from airflow.partition_mappers.temporal import ToDailyMapper, ToHourlyMapper
+
+
+class _InvalidReturnMapper(PartitionMapper):
+    def to_downstream(self, key: str) -> None:  # type: ignore[override]
+        return None
+
+
+class _InvalidIterableMapper(PartitionMapper):
+    def to_downstream(self, key: str) -> list[None]:  # type: ignore[override]
+        return [None]
+
+
+class TestChainMapper:
+    def test_to_downstream(self):
+        sm = ChainMapper(ToHourlyMapper(), 
ToDailyMapper(input_format="%Y-%m-%dT%H"))
+        assert sm.to_downstream("2024-01-15T10:30:00") == "2024-01-15"
+
+    def test_to_downstream_invalid_non_iterable_return(self):
+        sm = ChainMapper(IdentityMapper(), _InvalidReturnMapper())
+        with pytest.raises(TypeError, match="must return a string or iterable 
of strings"):
+            sm.to_downstream("key")
+
+    def test_to_downstream_invalid_iterable_contents(self):
+        sm = ChainMapper(IdentityMapper(), _InvalidIterableMapper())
+        with pytest.raises(TypeError, match="must return an iterable of 
strings"):
+            sm.to_downstream("key")
+
+    def test_serialize(self):
+        from airflow.serialization.encoders import encode_partition_mapper
+
+        sm = ChainMapper(ToHourlyMapper(), 
ToDailyMapper(input_format="%Y-%m-%dT%H"))
+        result = sm.serialize()
+        assert result == {
+            "mappers": [
+                encode_partition_mapper(ToHourlyMapper()),
+                
encode_partition_mapper(ToDailyMapper(input_format="%Y-%m-%dT%H")),
+            ],
+        }
+
+    def test_deserialize(self):
+        sm = ChainMapper(ToHourlyMapper(), 
ToDailyMapper(input_format="%Y-%m-%dT%H"))
+        serialized = sm.serialize()
+        restored = ChainMapper.deserialize(serialized)
+        assert isinstance(restored, ChainMapper)
+        assert len(restored.mappers) == 2
+        assert restored.to_downstream("2024-01-15T10:30:00") == "2024-01-15"
diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py 
b/airflow-core/tests/unit/serialization/test_serialized_objects.py
index 514abb618aa..7be81c2fe3b 100644
--- a/airflow-core/tests/unit/serialization/test_serialized_objects.py
+++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py
@@ -895,6 +895,50 @@ def test_decode_product_mapper():
     assert core_pm.to_downstream("2024-06-15T10:30:00|2024-06-15T10:30:00") == 
"2024-06-15T10|2024-06-15"
 
 
+def test_encode_chain_mapper():
+    from airflow.sdk import ChainMapper, ToDailyMapper, ToHourlyMapper
+    from airflow.serialization.encoders import encode_partition_mapper
+
+    partition_mapper = ChainMapper(ToHourlyMapper(), 
ToDailyMapper(input_format="%Y-%m-%dT%H"))
+    assert encode_partition_mapper(partition_mapper) == {
+        Encoding.TYPE: "airflow.partition_mappers.chain.ChainMapper",
+        Encoding.VAR: {
+            "mappers": [
+                {
+                    Encoding.TYPE: 
"airflow.partition_mappers.temporal.ToHourlyMapper",
+                    Encoding.VAR: {
+                        "input_format": "%Y-%m-%dT%H:%M:%S",
+                        "output_format": "%Y-%m-%dT%H",
+                    },
+                },
+                {
+                    Encoding.TYPE: 
"airflow.partition_mappers.temporal.ToDailyMapper",
+                    Encoding.VAR: {
+                        "input_format": "%Y-%m-%dT%H",
+                        "output_format": "%Y-%m-%d",
+                    },
+                },
+            ]
+        },
+    }
+
+
+def test_decode_chain_mapper():
+    from airflow.partition_mappers.chain import ChainMapper as CoreChainMapper
+    from airflow.sdk import ChainMapper, ToDailyMapper, ToHourlyMapper
+    from airflow.serialization.decoders import decode_partition_mapper
+    from airflow.serialization.encoders import encode_partition_mapper
+
+    partition_mapper = ChainMapper(ToHourlyMapper(), 
ToDailyMapper(input_format="%Y-%m-%dT%H"))
+    encoded_pm = encode_partition_mapper(partition_mapper)
+
+    core_pm = decode_partition_mapper(encoded_pm)
+
+    assert isinstance(core_pm, CoreChainMapper)
+    assert len(core_pm.mappers) == 2
+    assert core_pm.to_downstream("2024-06-15T10:30:00") == "2024-06-15"
+
+
 def test_encode_allowed_key_mapper():
     from airflow.sdk import AllowedKeyMapper
     from airflow.serialization.encoders import encode_partition_mapper
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index 5c0e6b07632..03b4d3999ba 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -201,6 +201,8 @@ Partition Mapper
 
 .. autoapiclass:: airflow.sdk.PartitionMapper
 
+.. autoapiclass:: airflow.sdk.ChainMapper
+
 .. autoapiclass:: airflow.sdk.IdentityMapper
 
 .. autoapiclass:: airflow.sdk.ToHourlyMapper
diff --git a/task-sdk/src/airflow/sdk/__init__.py 
b/task-sdk/src/airflow/sdk/__init__.py
index 70521464908..8aa55a0e623 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -37,6 +37,7 @@ __all__ = [
     "BaseSensorOperator",
     "BaseXCom",
     "BranchMixIn",
+    "ChainMapper",
     "Connection",
     "Context",
     "CronDataIntervalTimetable",
@@ -125,6 +126,7 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.param import Param, ParamsDict
     from airflow.sdk.definitions.partition_mappers.allowed_key import 
AllowedKeyMapper
     from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
+    from airflow.sdk.definitions.partition_mappers.chain import ChainMapper
     from airflow.sdk.definitions.partition_mappers.identity import 
IdentityMapper
     from airflow.sdk.definitions.partition_mappers.product import ProductMapper
     from airflow.sdk.definitions.partition_mappers.temporal import (
@@ -178,6 +180,7 @@ __lazy_imports: dict[str, str] = {
     "BaseSensorOperator": ".bases.sensor",
     "BaseXCom": ".bases.xcom",
     "BranchMixIn": ".bases.branch",
+    "ChainMapper": ".definitions.partition_mappers.chain",
     "Connection": ".definitions.connection",
     "Context": ".definitions.context",
     "CronDataIntervalTimetable": ".definitions.timetables.interval",
diff --git a/task-sdk/src/airflow/sdk/__init__.pyi 
b/task-sdk/src/airflow/sdk/__init__.pyi
index 905500c6164..a898ff27a70 100644
--- a/task-sdk/src/airflow/sdk/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/__init__.pyi
@@ -64,6 +64,7 @@ from airflow.sdk.definitions.edges import EdgeModifier as 
EdgeModifier, Label as
 from airflow.sdk.definitions.param import Param as Param
 from airflow.sdk.definitions.partition_mappers.allowed_key import 
AllowedKeyMapper
 from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
+from airflow.sdk.definitions.partition_mappers.chain import ChainMapper
 from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper
 from airflow.sdk.definitions.partition_mappers.product import ProductMapper
 from airflow.sdk.definitions.partition_mappers.temporal import (
@@ -117,6 +118,7 @@ __all__ = [
     "BaseSensorOperator",
     "BaseXCom",
     "BranchMixIn",
+    "ChainMapper",
     "Connection",
     "Context",
     "CronDataIntervalTimetable",
diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py 
b/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
new file mode 100644
index 00000000000..0f96e2d2eaf
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/chain.py
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.sdk.definitions.partition_mappers.base import PartitionMapper
+
+
+class ChainMapper(PartitionMapper):
+    """Partition mapper that applies multiple mappers sequentially."""
+
+    def __init__(
+        self,
+        mapper0: PartitionMapper,
+        mapper1: PartitionMapper,
+        /,
+        *mappers: PartitionMapper,
+    ) -> None:
+        self.mappers = [mapper0, mapper1, *mappers]

Reply via email to