This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 56be41455f8 feat: Support customizing partition_mapper through airflow
plugin (#60934)
56be41455f8 is described below
commit 56be41455f8760a062915d5024cad144530e8422
Author: Wei Lee <[email protected]>
AuthorDate: Sat Jan 24 00:40:08 2026 +0800
feat: Support customizing partition_mapper through airflow plugin (#60934)
Allow user to register partition_mappers through airflow plugin
---
airflow-core/src/airflow/plugins_manager.py | 13 ++
airflow-core/src/airflow/serialization/decoders.py | 12 +-
airflow-core/src/airflow/serialization/encoders.py | 23 ++--
airflow-core/src/airflow/serialization/helpers.py | 19 +--
airflow-core/tests/unit/assets/test_manager.py | 2 +-
airflow-core/tests/unit/jobs/test_scheduler_job.py | 140 +++++++++++++--------
.../plugins_manager/plugins_manager.py | 8 +-
7 files changed, 137 insertions(+), 80 deletions(-)
diff --git a/airflow-core/src/airflow/plugins_manager.py
b/airflow-core/src/airflow/plugins_manager.py
index 1fb227a6740..af05f1aacd2 100644
--- a/airflow-core/src/airflow/plugins_manager.py
+++ b/airflow-core/src/airflow/plugins_manager.py
@@ -40,6 +40,7 @@ from airflow.configuration import conf
if TYPE_CHECKING:
from airflow.lineage.hook import HookLineageReader
from airflow.listeners.listener import ListenerManager
+ from airflow.partition_mapper.base import PartitionMapper
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.timetables.base import Timetable
@@ -270,6 +271,18 @@ def get_timetables_plugins() -> dict[str, type[Timetable]]:
}
+@cache
+def get_partition_mapper_plugins() -> dict[str, type[PartitionMapper]]:
+ """Collect and get partition mapper classes registered by plugins."""
+ log.debug("Initialize extra partition mapper plugins")
+
+ return {
+ qualname(partition_mapper_cls): partition_mapper_cls
+ for plugin in _get_plugins()[0]
+ for partition_mapper_cls in plugin.partition_mappers
+ }
+
+
@cache
def get_hook_lineage_readers_plugins() -> list[type[HookLineageReader]]:
"""Collect and get hook lineage reader classes registered by plugins."""
diff --git a/airflow-core/src/airflow/serialization/decoders.py
b/airflow-core/src/airflow/serialization/decoders.py
index 619db256fef..700abede8ec 100644
--- a/airflow-core/src/airflow/serialization/decoders.py
+++ b/airflow-core/src/airflow/serialization/decoders.py
@@ -36,10 +36,10 @@ from airflow.serialization.definitions.assets import (
)
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import (
- PartitionMapperNotFound,
+ find_registered_custom_partition_mapper,
find_registered_custom_timetable,
+ is_core_partition_mapper_import_path,
is_core_timetable_import_path,
- load_partition_mapper,
)
if TYPE_CHECKING:
@@ -157,6 +157,8 @@ def decode_partition_mapper(var: dict[str, Any]) ->
PartitionMapper:
:meta private:
"""
importable_string = var[Encoding.TYPE]
- if (partition_mapper_class := load_partition_mapper(importable_string)) is
not None:
- return partition_mapper_class.deserialize(var[Encoding.VAR])
- raise PartitionMapperNotFound(importable_string)
+ if is_core_partition_mapper_import_path(importable_string):
+ partition_mapper_cls = import_string(importable_string)
+ else:
+ partition_mapper_cls =
find_registered_custom_partition_mapper(importable_string)
+ return partition_mapper_cls.deserialize(var[Encoding.VAR])
diff --git a/airflow-core/src/airflow/serialization/encoders.py
b/airflow-core/src/airflow/serialization/encoders.py
index fe09cccb4cd..d17135968a5 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -26,6 +26,7 @@ import attrs
import pendulum
from airflow._shared.module_loading import qualname
+from airflow.partition_mapper.base import PartitionMapper as
CorePartitionMapper
from airflow.sdk import (
Asset,
AssetAlias,
@@ -58,8 +59,8 @@ from airflow.serialization.definitions.assets import (
)
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import (
+ find_registered_custom_partition_mapper,
find_registered_custom_timetable,
- is_core_partition_mapper_import_path,
is_core_timetable_import_path,
)
from airflow.timetables.base import Timetable as CoreTimetable
@@ -319,8 +320,12 @@ class _Serializer:
}
@functools.singledispatchmethod
- def serialize_partition_mapper(self, partition_mapper: PartitionMapper) ->
dict[str, Any]:
- raise NotImplementedError
+ def serialize_partition_mapper(
+ self, partition_mapper: PartitionMapper | CorePartitionMapper
+ ) -> dict[str, Any]:
+ if not isinstance(partition_mapper, CorePartitionMapper):
+ raise NotImplementedError(f"can not serialize timetable
{type(partition_mapper).__name__}")
+ return partition_mapper.serialize()
@serialize_partition_mapper.register
def _(self, partition_mapper: IdentityMapper) -> dict[str, Any]:
@@ -387,13 +392,11 @@ def encode_partition_mapper(var: PartitionMapper) ->
dict[str, Any]:
:meta private:
"""
- var_type = qualname(type(var))
- if not is_core_partition_mapper_import_path(var_type):
- var_type = _serializer.BUILTIN_PARTITION_MAPPERS[type(var)]
-
- # TODO: (AIP-76) handle airflow plugins cases (not part of 3.2)
-
+ if (importable_string :=
_serializer.BUILTIN_PARTITION_MAPPERS.get(var_type := type(var), None)) is None:
+ find_registered_custom_partition_mapper(
+ importable_string := qualname(var_type)
+ ) # This raises if not found.
return {
- Encoding.TYPE: var_type,
+ Encoding.TYPE: importable_string,
Encoding.VAR: _serializer.serialize_partition_mapper(var),
}
diff --git a/airflow-core/src/airflow/serialization/helpers.py
b/airflow-core/src/airflow/serialization/helpers.py
index 6e0cbd97a48..ca61925edc1 100644
--- a/airflow-core/src/airflow/serialization/helpers.py
+++ b/airflow-core/src/airflow/serialization/helpers.py
@@ -21,7 +21,7 @@ from __future__ import annotations
import contextlib
from typing import TYPE_CHECKING, Any
-from airflow._shared.module_loading import import_string, qualname
+from airflow._shared.module_loading import qualname
from airflow._shared.secrets_masker import redact
from airflow.configuration import conf
from airflow.settings import json
@@ -131,6 +131,16 @@ def find_registered_custom_timetable(importable_string:
str) -> type[CoreTimetab
raise TimetableNotRegistered(importable_string)
+def find_registered_custom_partition_mapper(importable_string: str) ->
type[PartitionMapper]:
+ """Find a user-defined custom partition mapper class registered via a
plugin."""
+ from airflow import plugins_manager
+
+ partition_mapper_cls = plugins_manager.get_partition_mapper_plugins()
+ with contextlib.suppress(KeyError):
+ return partition_mapper_cls[importable_string]
+ raise PartitionMapperNotFound(importable_string)
+
+
def is_core_timetable_import_path(importable_string: str) -> bool:
"""Whether an importable string points to a core timetable class."""
return importable_string.startswith("airflow.timetables.")
@@ -153,10 +163,3 @@ class PartitionMapperNotFound(ValueError):
def is_core_partition_mapper_import_path(importable_string: str) -> bool:
"""Whether an importable string points to a core partition mapper class."""
return importable_string.startswith("airflow.partition_mapper.")
-
-
-def load_partition_mapper(importable_string: str) -> PartitionMapper | None:
- if is_core_partition_mapper_import_path(importable_string):
- return import_string(importable_string)
- # TODO: (AIP-76) handle airflow plugins cases (3.3+)
- return None
diff --git a/airflow-core/tests/unit/assets/test_manager.py
b/airflow-core/tests/unit/assets/test_manager.py
index 7929ae7c0c5..df6608607c8 100644
--- a/airflow-core/tests/unit/assets/test_manager.py
+++ b/airflow-core/tests/unit/assets/test_manager.py
@@ -222,7 +222,7 @@ class TestAssetManager:
@pytest.mark.usefixtures("dag_maker", "testing_dag_bundle")
def test_get_or_create_apdr_race_condition(self, session, caplog):
- asm = AssetModel(uri="test://asset1/", name="parition_asset",
group="asset")
+ asm = AssetModel(uri="test://asset1/", name="partition_asset",
group="asset")
testing_dag = DagModel(dag_id="testing_dag", is_stale=False,
bundle_name="testing")
session.add_all([asm, testing_dag])
session.commit()
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 3636c2b6db9..0cf22fdf98f 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -22,7 +22,7 @@ import datetime
import logging
import os
from collections import Counter, deque
-from collections.abc import Generator
+from collections.abc import Generator, Iterable
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
@@ -38,6 +38,7 @@ from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import joinedload
from airflow import settings
+from airflow._shared.module_loading import qualname
from airflow._shared.timezones import timezone
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.assets.manager import AssetManager
@@ -82,7 +83,6 @@ from airflow.providers.standard.operators.empty import
EmptyOperator
from airflow.providers.standard.triggers.file import FileDeleteTrigger
from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher, task
from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
-from airflow.sdk.definitions.partition_mapper.identity import IdentityMapper
from airflow.sdk.definitions.timetables.assets import PartitionedAssetTimetable
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.serialized_objects import LazyDeserializedDAG
@@ -8225,60 +8225,92 @@ def test_mark_backfills_completed(dag_maker, session):
@pytest.mark.need_serialized_dag
-def
test_when_dag_run_has_partition_and_downstreams_listening_then_tables_populated(
- dag_maker,
- session,
-):
- asset = Asset(name="hello")
- with dag_maker(dag_id="asset_event_tester", schedule=None,
session=session) as dag:
- EmptyOperator(task_id="hi", outlets=[asset])
- dag1_id = dag.dag_id
- dr = dag_maker.create_dagrun(partition_key="abc123", session=session)
- assert dr.partition_key == "abc123"
- [ti] = dr.get_task_instances(session=session)
- session.commit()
- serialized_outlets = dag.get_task("hi").outlets
-
- with dag_maker(
- dag_id="asset_event_listener",
- schedule=PartitionedAssetTimetable(assets=asset,
partition_mapper=IdentityMapper()),
- session=session,
+def test_partitioned_dag_run_with_customized_mapper(dag_maker: DagMaker,
session: Session):
+ from airflow.partition_mapper.base import PartitionMapper as
CorePartitionMapper
+
+ class Key1Mapper(CorePartitionMapper):
+ def to_downstream(self, key: str) -> str:
+ return "key-1"
+
+ def to_upstream(self, key: str) -> Iterable[str]:
+ yield key
+
+ def _find_registered_custom_partition_mapper(s):
+ if s == qualname(Key1Mapper):
+ return Key1Mapper
+ raise ValueError(f"unexpected class {s!r}")
+
+ asset_1 = Asset(name="asset-1")
+
+ # Consumer Dag "asset-event-consumer"
+ with (
+ mock.patch(
+
"airflow.serialization.encoders.find_registered_custom_partition_mapper",
+ _find_registered_custom_partition_mapper,
+ ),
+ mock.patch(
+
"airflow.serialization.decoders.find_registered_custom_partition_mapper",
+ _find_registered_custom_partition_mapper,
+ ),
):
- EmptyOperator(task_id="hi")
- session.commit()
+ with dag_maker(
+ dag_id="asset-event-consumer",
+ schedule=PartitionedAssetTimetable(
+ assets=asset_1,
+ # TODO: (GH-57694) this partition mapper interface will be
moved into asset as per-asset mapper
+ # and the type mismatch will be handled there
+ partition_mapper=Key1Mapper(), # type: ignore[arg-type]
+ ),
+ session=session,
+ ):
+ EmptyOperator(task_id="hi")
+ session.commit()
- TaskInstance.register_asset_changes_in_db(
- ti=ti,
- task_outlets=[o.asprofile() for o in serialized_outlets],
- outlet_events=[],
- session=session,
- )
- session.commit()
- event = session.scalar(
- select(AssetEvent).where(
- AssetEvent.source_dag_id == dag1_id,
- AssetEvent.source_run_id == dr.run_id,
+ runner = SchedulerJobRunner(
+ job=Job(job_type=SchedulerJobRunner.job_type,
executor=MockExecutor(do_update=False))
)
- )
- assert event.partition_key == "abc123"
- pakl = session.scalar(
- select(PartitionedAssetKeyLog).where(
- PartitionedAssetKeyLog.asset_event_id == event.id,
+
+ with dag_maker(dag_id="asset-event-producer", schedule=None,
session=session) as dag:
+ EmptyOperator(task_id="hi", outlets=[asset_1])
+
+ dr =
dag_maker.create_dagrun(partition_key="this-is-not-key-1-before-mapped",
session=session)
+ [ti] = dr.get_task_instances(session=session)
+ session.commit()
+
+ serialized_outlets = dag.get_task("hi").outlets
+ TaskInstance.register_asset_changes_in_db(
+ ti=ti,
+ task_outlets=[o.asprofile() for o in serialized_outlets],
+ outlet_events=[],
+ session=session,
)
- )
- apdr = session.scalar(
- select(AssetPartitionDagRun).where(AssetPartitionDagRun.id ==
pakl.asset_partition_dag_run_id)
- )
- assert apdr is not None
- assert apdr.created_dag_run_id is None
- # ok, now we have established that the needed rows are there.
- # let's see what the scheduler does
+ session.commit()
- runner = SchedulerJobRunner(
- job=Job(job_type=SchedulerJobRunner.job_type,
executor=MockExecutor(do_update=False))
- )
- partition_dags =
runner._create_dagruns_for_partitioned_asset_dags(session=session)
- session.refresh(apdr)
- assert apdr.created_dag_run_id is not None
- assert len(partition_dags) == 1
- assert partition_dags == {"asset_event_listener"}
+ event = session.scalar(
+ select(AssetEvent).where(
+ AssetEvent.source_dag_id == dag.dag_id,
+ AssetEvent.source_run_id == dr.run_id,
+ )
+ )
+ assert event is not None
+ assert event.partition_key == "this-is-not-key-1-before-mapped"
+
+ apdr = session.scalar(
+ select(AssetPartitionDagRun)
+ .join(
+ PartitionedAssetKeyLog,
+ PartitionedAssetKeyLog.asset_partition_dag_run_id ==
AssetPartitionDagRun.id,
+ )
+ .where(PartitionedAssetKeyLog.asset_event_id == event.id)
+ )
+ assert apdr is not None
+ assert apdr.created_dag_run_id is None
+ assert apdr.partition_key == "key-1"
+
+ partition_dags =
runner._create_dagruns_for_partitioned_asset_dags(session=session)
+ session.refresh(apdr)
+ # Since asset event for Asset(name="asset-2") with key "key-1" has not
yet been created,
+ # no Dag run will be created
+ assert apdr.created_dag_run_id is not None
+ assert len(partition_dags) == 1
+ assert partition_dags == {"asset-event-consumer"}
diff --git
a/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
b/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
index e7fbb24377f..9ea497e5a10 100644
---
a/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
+++
b/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
@@ -108,14 +108,18 @@ class AirflowPlugin:
# A list of operator extra links to override or add operator links
# to existing Airflow Operators.
+ #
# These extra links will be available on the task page in form of
# buttons.
operator_extra_links: list[Any] = []
- # A list of timetable classes that can be used for DAG scheduling.
+ # A list of timetable classes that can be used for Dag scheduling.
timetables: list[Any] = []
- # A list of listeners that can be used for tracking task and DAG states.
+ # A list of timetable classes that can be used for Dag scheduling.
+ partition_mappers: list[Any] = []
+
+ # A list of listeners that can be used for tracking task and Dag states.
listeners: list[ModuleType | object] = []
# A list of hook lineage reader classes that can be used for reading
lineage information from a hook.