This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 56be41455f8 feat: Support customizing partition_mapper through airflow 
plugin (#60934)
56be41455f8 is described below

commit 56be41455f8760a062915d5024cad144530e8422
Author: Wei Lee <[email protected]>
AuthorDate: Sat Jan 24 00:40:08 2026 +0800

    feat: Support customizing partition_mapper through airflow plugin (#60934)
    
    Allow user to register partition_mappers through airflow plugin
---
 airflow-core/src/airflow/plugins_manager.py        |  13 ++
 airflow-core/src/airflow/serialization/decoders.py |  12 +-
 airflow-core/src/airflow/serialization/encoders.py |  23 ++--
 airflow-core/src/airflow/serialization/helpers.py  |  19 +--
 airflow-core/tests/unit/assets/test_manager.py     |   2 +-
 airflow-core/tests/unit/jobs/test_scheduler_job.py | 140 +++++++++++++--------
 .../plugins_manager/plugins_manager.py             |   8 +-
 7 files changed, 137 insertions(+), 80 deletions(-)

diff --git a/airflow-core/src/airflow/plugins_manager.py 
b/airflow-core/src/airflow/plugins_manager.py
index 1fb227a6740..af05f1aacd2 100644
--- a/airflow-core/src/airflow/plugins_manager.py
+++ b/airflow-core/src/airflow/plugins_manager.py
@@ -40,6 +40,7 @@ from airflow.configuration import conf
 if TYPE_CHECKING:
     from airflow.lineage.hook import HookLineageReader
     from airflow.listeners.listener import ListenerManager
+    from airflow.partition_mapper.base import PartitionMapper
     from airflow.task.priority_strategy import PriorityWeightStrategy
     from airflow.timetables.base import Timetable
 
@@ -270,6 +271,18 @@ def get_timetables_plugins() -> dict[str, type[Timetable]]:
     }
 
 
+@cache
+def get_partition_mapper_plugins() -> dict[str, type[PartitionMapper]]:
+    """Collect and get partition mapper classes registered by plugins."""
+    log.debug("Initialize extra partition mapper plugins")
+
+    return {
+        qualname(partition_mapper_cls): partition_mapper_cls
+        for plugin in _get_plugins()[0]
+        for partition_mapper_cls in plugin.partition_mappers
+    }
+
+
 @cache
 def get_hook_lineage_readers_plugins() -> list[type[HookLineageReader]]:
     """Collect and get hook lineage reader classes registered by plugins."""
diff --git a/airflow-core/src/airflow/serialization/decoders.py 
b/airflow-core/src/airflow/serialization/decoders.py
index 619db256fef..700abede8ec 100644
--- a/airflow-core/src/airflow/serialization/decoders.py
+++ b/airflow-core/src/airflow/serialization/decoders.py
@@ -36,10 +36,10 @@ from airflow.serialization.definitions.assets import (
 )
 from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
 from airflow.serialization.helpers import (
-    PartitionMapperNotFound,
+    find_registered_custom_partition_mapper,
     find_registered_custom_timetable,
+    is_core_partition_mapper_import_path,
     is_core_timetable_import_path,
-    load_partition_mapper,
 )
 
 if TYPE_CHECKING:
@@ -157,6 +157,8 @@ def decode_partition_mapper(var: dict[str, Any]) -> 
PartitionMapper:
     :meta private:
     """
     importable_string = var[Encoding.TYPE]
-    if (partition_mapper_class := load_partition_mapper(importable_string)) is 
not None:
-        return partition_mapper_class.deserialize(var[Encoding.VAR])
-    raise PartitionMapperNotFound(importable_string)
+    if is_core_partition_mapper_import_path(importable_string):
+        partition_mapper_cls = import_string(importable_string)
+    else:
+        partition_mapper_cls = 
find_registered_custom_partition_mapper(importable_string)
+    return partition_mapper_cls.deserialize(var[Encoding.VAR])
diff --git a/airflow-core/src/airflow/serialization/encoders.py 
b/airflow-core/src/airflow/serialization/encoders.py
index fe09cccb4cd..d17135968a5 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -26,6 +26,7 @@ import attrs
 import pendulum
 
 from airflow._shared.module_loading import qualname
+from airflow.partition_mapper.base import PartitionMapper as 
CorePartitionMapper
 from airflow.sdk import (
     Asset,
     AssetAlias,
@@ -58,8 +59,8 @@ from airflow.serialization.definitions.assets import (
 )
 from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
 from airflow.serialization.helpers import (
+    find_registered_custom_partition_mapper,
     find_registered_custom_timetable,
-    is_core_partition_mapper_import_path,
     is_core_timetable_import_path,
 )
 from airflow.timetables.base import Timetable as CoreTimetable
@@ -319,8 +320,12 @@ class _Serializer:
     }
 
     @functools.singledispatchmethod
-    def serialize_partition_mapper(self, partition_mapper: PartitionMapper) -> 
dict[str, Any]:
-        raise NotImplementedError
+    def serialize_partition_mapper(
+        self, partition_mapper: PartitionMapper | CorePartitionMapper
+    ) -> dict[str, Any]:
+        if not isinstance(partition_mapper, CorePartitionMapper):
+            raise NotImplementedError(f"can not serialize timetable 
{type(partition_mapper).__name__}")
+        return partition_mapper.serialize()
 
     @serialize_partition_mapper.register
     def _(self, partition_mapper: IdentityMapper) -> dict[str, Any]:
@@ -387,13 +392,11 @@ def encode_partition_mapper(var: PartitionMapper) -> 
dict[str, Any]:
 
     :meta private:
     """
-    var_type = qualname(type(var))
-    if not is_core_partition_mapper_import_path(var_type):
-        var_type = _serializer.BUILTIN_PARTITION_MAPPERS[type(var)]
-
-    # TODO: (AIP-76) handle airflow plugins cases (not part of 3.2)
-
+    if (importable_string := 
_serializer.BUILTIN_PARTITION_MAPPERS.get(var_type := type(var), None)) is None:
+        find_registered_custom_partition_mapper(
+            importable_string := qualname(var_type)
+        )  # This raises if not found.
     return {
-        Encoding.TYPE: var_type,
+        Encoding.TYPE: importable_string,
         Encoding.VAR: _serializer.serialize_partition_mapper(var),
     }
diff --git a/airflow-core/src/airflow/serialization/helpers.py 
b/airflow-core/src/airflow/serialization/helpers.py
index 6e0cbd97a48..ca61925edc1 100644
--- a/airflow-core/src/airflow/serialization/helpers.py
+++ b/airflow-core/src/airflow/serialization/helpers.py
@@ -21,7 +21,7 @@ from __future__ import annotations
 import contextlib
 from typing import TYPE_CHECKING, Any
 
-from airflow._shared.module_loading import import_string, qualname
+from airflow._shared.module_loading import qualname
 from airflow._shared.secrets_masker import redact
 from airflow.configuration import conf
 from airflow.settings import json
@@ -131,6 +131,16 @@ def find_registered_custom_timetable(importable_string: 
str) -> type[CoreTimetab
     raise TimetableNotRegistered(importable_string)
 
 
+def find_registered_custom_partition_mapper(importable_string: str) -> 
type[PartitionMapper]:
+    """Find a user-defined custom partition mapper class registered via a 
plugin."""
+    from airflow import plugins_manager
+
+    partition_mapper_cls = plugins_manager.get_partition_mapper_plugins()
+    with contextlib.suppress(KeyError):
+        return partition_mapper_cls[importable_string]
+    raise PartitionMapperNotFound(importable_string)
+
+
 def is_core_timetable_import_path(importable_string: str) -> bool:
     """Whether an importable string points to a core timetable class."""
     return importable_string.startswith("airflow.timetables.")
@@ -153,10 +163,3 @@ class PartitionMapperNotFound(ValueError):
 def is_core_partition_mapper_import_path(importable_string: str) -> bool:
     """Whether an importable string points to a core partition mapper class."""
     return importable_string.startswith("airflow.partition_mapper.")
-
-
-def load_partition_mapper(importable_string: str) -> PartitionMapper | None:
-    if is_core_partition_mapper_import_path(importable_string):
-        return import_string(importable_string)
-    # TODO: (AIP-76) handle airflow plugins cases (3.3+)
-    return None
diff --git a/airflow-core/tests/unit/assets/test_manager.py 
b/airflow-core/tests/unit/assets/test_manager.py
index 7929ae7c0c5..df6608607c8 100644
--- a/airflow-core/tests/unit/assets/test_manager.py
+++ b/airflow-core/tests/unit/assets/test_manager.py
@@ -222,7 +222,7 @@ class TestAssetManager:
 
     @pytest.mark.usefixtures("dag_maker", "testing_dag_bundle")
     def test_get_or_create_apdr_race_condition(self, session, caplog):
-        asm = AssetModel(uri="test://asset1/", name="parition_asset", 
group="asset")
+        asm = AssetModel(uri="test://asset1/", name="partition_asset", 
group="asset")
         testing_dag = DagModel(dag_id="testing_dag", is_stale=False, 
bundle_name="testing")
         session.add_all([asm, testing_dag])
         session.commit()
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 3636c2b6db9..0cf22fdf98f 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -22,7 +22,7 @@ import datetime
 import logging
 import os
 from collections import Counter, deque
-from collections.abc import Generator
+from collections.abc import Generator, Iterable
 from datetime import timedelta
 from pathlib import Path
 from typing import TYPE_CHECKING
@@ -38,6 +38,7 @@ from sqlalchemy import delete, func, select, update
 from sqlalchemy.orm import joinedload
 
 from airflow import settings
+from airflow._shared.module_loading import qualname
 from airflow._shared.timezones import timezone
 from airflow.api_fastapi.auth.tokens import JWTGenerator
 from airflow.assets.manager import AssetManager
@@ -82,7 +83,6 @@ from airflow.providers.standard.operators.empty import 
EmptyOperator
 from airflow.providers.standard.triggers.file import FileDeleteTrigger
 from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher, task
 from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
-from airflow.sdk.definitions.partition_mapper.identity import IdentityMapper
 from airflow.sdk.definitions.timetables.assets import PartitionedAssetTimetable
 from airflow.serialization.definitions.dag import SerializedDAG
 from airflow.serialization.serialized_objects import LazyDeserializedDAG
@@ -8225,60 +8225,92 @@ def test_mark_backfills_completed(dag_maker, session):
 
 
 @pytest.mark.need_serialized_dag
-def 
test_when_dag_run_has_partition_and_downstreams_listening_then_tables_populated(
-    dag_maker,
-    session,
-):
-    asset = Asset(name="hello")
-    with dag_maker(dag_id="asset_event_tester", schedule=None, 
session=session) as dag:
-        EmptyOperator(task_id="hi", outlets=[asset])
-    dag1_id = dag.dag_id
-    dr = dag_maker.create_dagrun(partition_key="abc123", session=session)
-    assert dr.partition_key == "abc123"
-    [ti] = dr.get_task_instances(session=session)
-    session.commit()
-    serialized_outlets = dag.get_task("hi").outlets
-
-    with dag_maker(
-        dag_id="asset_event_listener",
-        schedule=PartitionedAssetTimetable(assets=asset, 
partition_mapper=IdentityMapper()),
-        session=session,
+def test_partitioned_dag_run_with_customized_mapper(dag_maker: DagMaker, 
session: Session):
+    from airflow.partition_mapper.base import PartitionMapper as 
CorePartitionMapper
+
+    class Key1Mapper(CorePartitionMapper):
+        def to_downstream(self, key: str) -> str:
+            return "key-1"
+
+        def to_upstream(self, key: str) -> Iterable[str]:
+            yield key
+
+    def _find_registered_custom_partition_mapper(s):
+        if s == qualname(Key1Mapper):
+            return Key1Mapper
+        raise ValueError(f"unexpected class {s!r}")
+
+    asset_1 = Asset(name="asset-1")
+
+    # Consumer Dag "asset-event-consumer"
+    with (
+        mock.patch(
+            
"airflow.serialization.encoders.find_registered_custom_partition_mapper",
+            _find_registered_custom_partition_mapper,
+        ),
+        mock.patch(
+            
"airflow.serialization.decoders.find_registered_custom_partition_mapper",
+            _find_registered_custom_partition_mapper,
+        ),
     ):
-        EmptyOperator(task_id="hi")
-    session.commit()
+        with dag_maker(
+            dag_id="asset-event-consumer",
+            schedule=PartitionedAssetTimetable(
+                assets=asset_1,
+                # TODO: (GH-57694) this partition mapper interface will be 
moved into asset as per-asset mapper
+                # and the type mismatch will be handled there
+                partition_mapper=Key1Mapper(),  # type: ignore[arg-type]
+            ),
+            session=session,
+        ):
+            EmptyOperator(task_id="hi")
+        session.commit()
 
-    TaskInstance.register_asset_changes_in_db(
-        ti=ti,
-        task_outlets=[o.asprofile() for o in serialized_outlets],
-        outlet_events=[],
-        session=session,
-    )
-    session.commit()
-    event = session.scalar(
-        select(AssetEvent).where(
-            AssetEvent.source_dag_id == dag1_id,
-            AssetEvent.source_run_id == dr.run_id,
+        runner = SchedulerJobRunner(
+            job=Job(job_type=SchedulerJobRunner.job_type, 
executor=MockExecutor(do_update=False))
         )
-    )
-    assert event.partition_key == "abc123"
-    pakl = session.scalar(
-        select(PartitionedAssetKeyLog).where(
-            PartitionedAssetKeyLog.asset_event_id == event.id,
+
+        with dag_maker(dag_id="asset-event-producer", schedule=None, 
session=session) as dag:
+            EmptyOperator(task_id="hi", outlets=[asset_1])
+
+        dr = 
dag_maker.create_dagrun(partition_key="this-is-not-key-1-before-mapped", 
session=session)
+        [ti] = dr.get_task_instances(session=session)
+        session.commit()
+
+        serialized_outlets = dag.get_task("hi").outlets
+        TaskInstance.register_asset_changes_in_db(
+            ti=ti,
+            task_outlets=[o.asprofile() for o in serialized_outlets],
+            outlet_events=[],
+            session=session,
         )
-    )
-    apdr = session.scalar(
-        select(AssetPartitionDagRun).where(AssetPartitionDagRun.id == 
pakl.asset_partition_dag_run_id)
-    )
-    assert apdr is not None
-    assert apdr.created_dag_run_id is None
-    # ok, now we have established that the needed rows are there.
-    # let's see what the scheduler does
+        session.commit()
 
-    runner = SchedulerJobRunner(
-        job=Job(job_type=SchedulerJobRunner.job_type, 
executor=MockExecutor(do_update=False))
-    )
-    partition_dags = 
runner._create_dagruns_for_partitioned_asset_dags(session=session)
-    session.refresh(apdr)
-    assert apdr.created_dag_run_id is not None
-    assert len(partition_dags) == 1
-    assert partition_dags == {"asset_event_listener"}
+        event = session.scalar(
+            select(AssetEvent).where(
+                AssetEvent.source_dag_id == dag.dag_id,
+                AssetEvent.source_run_id == dr.run_id,
+            )
+        )
+        assert event is not None
+        assert event.partition_key == "this-is-not-key-1-before-mapped"
+
+        apdr = session.scalar(
+            select(AssetPartitionDagRun)
+            .join(
+                PartitionedAssetKeyLog,
+                PartitionedAssetKeyLog.asset_partition_dag_run_id == 
AssetPartitionDagRun.id,
+            )
+            .where(PartitionedAssetKeyLog.asset_event_id == event.id)
+        )
+        assert apdr is not None
+        assert apdr.created_dag_run_id is None
+        assert apdr.partition_key == "key-1"
+
+        partition_dags = 
runner._create_dagruns_for_partitioned_asset_dags(session=session)
+        session.refresh(apdr)
+        # Since asset event for Asset(name="asset-2") with key "key-1" has not 
yet been created,
+        # no Dag run will be created
+        assert apdr.created_dag_run_id is not None
+        assert len(partition_dags) == 1
+        assert partition_dags == {"asset-event-consumer"}
diff --git 
a/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py 
b/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
index e7fbb24377f..9ea497e5a10 100644
--- 
a/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
+++ 
b/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
@@ -108,14 +108,18 @@ class AirflowPlugin:
 
     # A list of operator extra links to override or add operator links
     # to existing Airflow Operators.
+    #
     # These extra links will be available on the task page in form of
     # buttons.
     operator_extra_links: list[Any] = []
 
-    # A list of timetable classes that can be used for DAG scheduling.
+    # A list of timetable classes that can be used for Dag scheduling.
     timetables: list[Any] = []
 
-    # A list of listeners that can be used for tracking task and DAG states.
+    # A list of timetable classes that can be used for Dag scheduling.
+    partition_mappers: list[Any] = []
+
+    # A list of listeners that can be used for tracking task and Dag states.
     listeners: list[ModuleType | object] = []
 
     # A list of hook lineage reader classes that can be used for reading 
lineage information from a hook.

Reply via email to