o-nikolas commented on code in PR #57222:
URL: https://github.com/apache/airflow/pull/57222#discussion_r2466726307
##########
task-sdk/src/airflow/sdk/definitions/deadline.py:
##########
@@ -321,3 +326,71 @@ def FIXED_DATETIME(cls, datetime: datetime) ->
DeadlineReferenceType:
(DeadlineReferenceType,),
{"_evaluate_with": lambda self, **kwargs: datetime.now()},
)()
+
+ @classmethod
+ def register_custom_reference(
+ cls,
+ reference_class: type[ReferenceModels.BaseDeadlineReference],
+ timing: DeadlineReferenceTuple | None = None,
+ ):
+ """
+ Register a custom deadline reference class.
+
+ :param reference_class: The custom reference class inheriting from
BaseDeadlineReference
+ :param timing: A DeadlineReference.TYPES for when the deadline should
be evaluated ("DAGRUN_CREATED",
+ "DAGRUN_QUEUED", etc.); defaults to
DeadlineReference.TYPES.DAGRUN_CREATED
+ """
+ from airflow.models.deadline import ReferenceModels
+
+ # Default to DAGRUN_CREATED if no timing specified
+ if timing is None:
+ timing = cls.TYPES.DAGRUN_CREATED
+
+ # Validate the reference class inherits from BaseDeadlineReference
+ if not issubclass(reference_class,
ReferenceModels.BaseDeadlineReference):
+ raise ValueError(f"{reference_class.__name__} must inherit from
BaseDeadlineReference")
+
+ # Register the new reference with ReferenceModels and
DeadlineReference for discoverability
+ setattr(ReferenceModels, reference_class.__name__, reference_class)
+ setattr(cls, reference_class.__name__, reference_class())
+ logger.info("Registered DeadlineReference %s",
reference_class.__name__)
Review Comment:
How often is this going to be logged? Just once when the plugin is loaded,
right?
##########
airflow-core/docs/howto/deadline-alerts.rst:
##########
@@ -328,24 +328,95 @@ Custom References
^^^^^^^^^^^^^^^^^
While the built-in references should cover most use cases, and more will be
released over time, you
-can create custom references by implementing a class that inherits from
DeadlineReference. This may
-be useful if you have calendar integrations or other sources that you want to
use as a reference.
+can create custom references. This may be useful if you have calendar
integrations or other sources
+that you want to use as a reference. You can create custom references by
implementing a class that
+inherits from BaseDeadlineReference, give it am _evaluate_with() method, and
register it. There are
+two ways to accomplish this. The recommended way is to use the
``@deadline_reference`` decorator
+but for more complicated implementations, the ``register_custom_reference()``
method is available.
Review Comment:
What are the more complicated implementations? Would we ever realistically
need this more difficult approach? We could just fully remove it from the docs.
If it's just when it is evaluated it seems easy to create two references for
the very rare cases the user wants that? Or just support that condition in the
decorator if this is a common thing (but I'd be surprised about that).
##########
airflow-core/docs/howto/deadline-alerts.rst:
##########
@@ -328,24 +328,95 @@ Custom References
^^^^^^^^^^^^^^^^^
While the built-in references should cover most use cases, and more will be
released over time, you
-can create custom references by implementing a class that inherits from
DeadlineReference. This may
-be useful if you have calendar integrations or other sources that you want to
use as a reference.
+can create custom references. This may be useful if you have calendar
integrations or other sources
+that you want to use as a reference. You can create custom references by
implementing a class that
+inherits from BaseDeadlineReference, give it am _evaluate_with() method, and
register it. There are
+two ways to accomplish this. The recommended way is to use the
``@deadline_reference`` decorator
+but for more complicated implementations, the ``register_custom_reference()``
method is available.
+
+**Recommended: Using the decorator**
.. code-block:: python
- class CustomReference(DeadlineReference):
- """A deadline reference that uses a custom data source."""
+ from airflow._shared.timezones.timezone import datetime
+ from airflow.models.deadline import ReferenceModels
+ from sqlalchemy.orm import Session
+
+ from airflow.sdk.definitions.deadline import DeadlineReference,
deadline_reference
+
+
+ # By default, the evaluate_with method will be executed when the dagrun is
created.
+ @deadline_reference()
+ class MyCustomDecoratedReference(ReferenceModels.BaseDeadlineReference):
+ """A custom reference evaluated when DAG runs are created."""
+
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ # Add your business logic here
+ return your_datetime
+
+
+ # You can specify when evaluate_with will be called by providing a
DeadlineReference.TYPES value.
+ @deadline_reference(DeadlineReference.TYPES.DAGRUN_QUEUED)
+ class MyQueuedReference(ReferenceModels.BaseDeadlineReference):
+ """A custom reference evaluated when DAG runs are queued."""
+
+ required_kwargs = {"custom_param"}
+
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ custom_value = kwargs["custom_param"]
+ # Use custom_value in your calculation
+ return your_datetime
+
+**Alternative: Manual Registration**
+
+For advanced use cases requiring conditional or dynamic registration, you may
wish use the registration method directly.
+In this case, the plugin file will look something like this:
+
+.. code-block:: python
+
+ from sqlalchemy.orm import Session
+
+ from airflow.models.deadline import ReferenceModels
+ from airflow.sdk.definitions.deadline import DeadlineReference
- # Define any required parameters for your reference
- required_kwargs = {"custom_id"}
+ class MyManualReference(ReferenceModels.BaseDeadlineReference):
def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
- """
- Evaluate the reference time using the provided session and kwargs.
-
- The session parameter can be used for database queries, and kwargs
- will contain any required parameters defined in required_kwargs.
- """
- custom_id = kwargs["custom_id"]
- # Your custom logic here to determine the reference time
+ # Add your business logic here
return your_datetime
+
+
+ # Register with specific timing based on configuration
+ timing = (
+ DeadlineReference.TYPES.DAGRUN_QUEUED if use_queued_timing else
DeadlineReference.TYPES.DAGRUN_CREATED
+ )
+ DeadlineReference.register_custom_reference(MyManualReference, timing)
+
+**Using Custom References in DAGs**
+
+Once registered, use your custom references in DAG definitions like any other
reference:
+
+.. code-block:: python
+
+ from datetime import timedelta
+ from airflow import DAG
+ from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert,
DeadlineReference
+
+ with DAG(
+ dag_id="custom_reference_example",
+ deadline=DeadlineAlert(
+ reference=DeadlineReference.MyCustomDecoratedReference(),
+ interval=timedelta(hours=2),
+ callback=AsyncCallback(my_callback),
+ ),
+ ):
+ # Your tasks here
+ ...
+
+**Important Notes:**
+
+* **Timezone Awareness**: Always return timezone-aware datetime objects
+* **Plugin Placement**: Place custom references in plugin files (e.g.,
``plugins/my_deadline_references.py``)
+* **Scheduler Restart**: Restart the Airflow scheduler after adding or
modifying custom references
Review Comment:
Dag Parser too? Does it need to know how to parse the new custom references?
##########
airflow-core/tests/unit/models/test_deadline.py:
##########
@@ -583,3 +593,272 @@ def test_deadline_reference_creation(self):
custom_reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=5,
min_runs=3)
assert custom_reference.max_runs == 5
assert custom_reference.min_runs == 3
+
+
+class TestCustomDeadlineReference:
+ class MyCustomRef(ReferenceModels.BaseDeadlineReference):
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ class MyInvalidCustomRef:
+ pass
+
+ class MyCustomRefWithKwargs(ReferenceModels.BaseDeadlineReference):
+ required_kwargs = {"custom_id"}
+
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ def setup_method(self):
+ """Store original state before each test."""
+ self.original_dagrun_created = DeadlineReference.TYPES.DAGRUN_CREATED
+ self.original_dagrun_queued = DeadlineReference.TYPES.DAGRUN_QUEUED
+ self.original_dagrun = DeadlineReference.TYPES.DAGRUN
+ self.original_attrs = set(dir(ReferenceModels))
+ self.original_deadline_attrs = set(dir(DeadlineReference))
+
+ def teardown_method(self):
+ """Restore original TYPES and attrs after each test."""
+ DeadlineReference.TYPES.DAGRUN_CREATED = self.original_dagrun_created
+ DeadlineReference.TYPES.DAGRUN_QUEUED = self.original_dagrun_queued
+ DeadlineReference.TYPES.DAGRUN = self.original_dagrun
+
+ # Clean up ReferenceModels attributes.
+ for attr in set(dir(ReferenceModels)):
+ if attr not in self.original_attrs:
+ delattr(ReferenceModels, attr)
+
+ # Clean up DeadlineReference attributes.
+ for attr in set(dir(DeadlineReference)):
+ if attr not in self.original_deadline_attrs:
+ delattr(DeadlineReference, attr)
+
+ def test_custom_reference_consistent_access_pattern(self):
+ DeadlineReference.register_custom_reference(self.MyCustomRef)
+
+ # Should be accessible through DeadlineReference like built-ins.
+ custom_ref = getattr(DeadlineReference, self.MyCustomRef.__name__)
+ assert custom_ref.__class__ is self.MyCustomRef
+
+ # Should behave like built-in references.
+ assert hasattr(custom_ref, "_evaluate_with")
+ assert callable(custom_ref._evaluate_with)
+
+ def test_register_custom_reference_default_timing(self):
+ result = DeadlineReference.register_custom_reference(self.MyCustomRef)
+
+ # Should return the class.
+ assert result is self.MyCustomRef
+
+ # Should be registered with ReferenceModels.
+ assert hasattr(ReferenceModels, self.MyCustomRef.__name__)
+ assert getattr(ReferenceModels, self.MyCustomRef.__name__) is
self.MyCustomRef
+
+ # Should be accessible through DeadlineReference.
+ assert hasattr(DeadlineReference, self.MyCustomRef.__name__)
Review Comment:
Is this assert duplicated from line 654?
##########
airflow-core/tests/unit/models/test_deadline.py:
##########
@@ -583,3 +593,272 @@ def test_deadline_reference_creation(self):
custom_reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=5,
min_runs=3)
assert custom_reference.max_runs == 5
assert custom_reference.min_runs == 3
+
+
+class TestCustomDeadlineReference:
+ class MyCustomRef(ReferenceModels.BaseDeadlineReference):
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ class MyInvalidCustomRef:
+ pass
+
+ class MyCustomRefWithKwargs(ReferenceModels.BaseDeadlineReference):
+ required_kwargs = {"custom_id"}
+
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ def setup_method(self):
+ """Store original state before each test."""
+ self.original_dagrun_created = DeadlineReference.TYPES.DAGRUN_CREATED
+ self.original_dagrun_queued = DeadlineReference.TYPES.DAGRUN_QUEUED
+ self.original_dagrun = DeadlineReference.TYPES.DAGRUN
+ self.original_attrs = set(dir(ReferenceModels))
+ self.original_deadline_attrs = set(dir(DeadlineReference))
+
+ def teardown_method(self):
+ """Restore original TYPES and attrs after each test."""
+ DeadlineReference.TYPES.DAGRUN_CREATED = self.original_dagrun_created
+ DeadlineReference.TYPES.DAGRUN_QUEUED = self.original_dagrun_queued
+ DeadlineReference.TYPES.DAGRUN = self.original_dagrun
+
+ # Clean up ReferenceModels attributes.
+ for attr in set(dir(ReferenceModels)):
+ if attr not in self.original_attrs:
+ delattr(ReferenceModels, attr)
+
+ # Clean up DeadlineReference attributes.
+ for attr in set(dir(DeadlineReference)):
+ if attr not in self.original_deadline_attrs:
+ delattr(DeadlineReference, attr)
+
+ def test_custom_reference_consistent_access_pattern(self):
+ DeadlineReference.register_custom_reference(self.MyCustomRef)
+
+ # Should be accessible through DeadlineReference like built-ins.
+ custom_ref = getattr(DeadlineReference, self.MyCustomRef.__name__)
+ assert custom_ref.__class__ is self.MyCustomRef
+
+ # Should behave like built-in references.
+ assert hasattr(custom_ref, "_evaluate_with")
+ assert callable(custom_ref._evaluate_with)
+
+ def test_register_custom_reference_default_timing(self):
+ result = DeadlineReference.register_custom_reference(self.MyCustomRef)
+
+ # Should return the class.
+ assert result is self.MyCustomRef
+
+ # Should be registered with ReferenceModels.
+ assert hasattr(ReferenceModels, self.MyCustomRef.__name__)
+ assert getattr(ReferenceModels, self.MyCustomRef.__name__) is
self.MyCustomRef
+
+ # Should be accessible through DeadlineReference.
+ assert hasattr(DeadlineReference, self.MyCustomRef.__name__)
+ assert getattr(DeadlineReference, self.MyCustomRef.__name__).__class__
is self.MyCustomRef
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN_CREATED
+ assert self.MyCustomRef not in DeadlineReference.TYPES.DAGRUN_QUEUED
+
+ # Should update combined DAGRUN tuple
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN
+
+ def
test_register_custom_reference_dagrun_created_with_explicit_timing(self):
+ result = DeadlineReference.register_custom_reference(
+ self.MyCustomRef, DeadlineReference.TYPES.DAGRUN_CREATED
+ )
+
+ assert result is self.MyCustomRef
+ assert hasattr(ReferenceModels, self.MyCustomRef.__name__)
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN_CREATED
+ assert self.MyCustomRef not in DeadlineReference.TYPES.DAGRUN_QUEUED
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN
+
+ def test_register_custom_reference_with_dagrun_queued(self):
+ result = DeadlineReference.register_custom_reference(
+ self.MyCustomRef, DeadlineReference.TYPES.DAGRUN_QUEUED
+ )
+
+ assert result is self.MyCustomRef
+ assert hasattr(ReferenceModels, self.MyCustomRef.__name__)
+ assert self.MyCustomRef not in DeadlineReference.TYPES.DAGRUN_CREATED
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN_QUEUED
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN
+
+ def test_register_custom_reference_invalid_inheritance(self):
+ with pytest.raises(ValueError, match="InvalidCustomRef must inherit
from BaseDeadlineReference"):
+
DeadlineReference.register_custom_reference(self.MyInvalidCustomRef)
+
+ def test_register_custom_reference_invalid_timing(self):
+ invalid_timing = ("not", "a", "valid", "timing")
+
+ with pytest.raises(
+ ValueError, match="Invalid timing value; must be a valid
DeadlineReference.TYPES option"
+ ):
+ DeadlineReference.register_custom_reference(self.MyCustomRef,
invalid_timing)
+
+ def test_register_custom_reference_with_required_kwargs(self):
+ result =
DeadlineReference.register_custom_reference(self.MyCustomRefWithKwargs)
+
+ assert result is self.MyCustomRefWithKwargs
+ assert hasattr(ReferenceModels, self.MyCustomRefWithKwargs.__name__)
+ assert self.MyCustomRefWithKwargs in
DeadlineReference.TYPES.DAGRUN_CREATED
+
+ def test_register_multiple_custom_references(self):
+ class TestCustomRef1(ReferenceModels.BaseDeadlineReference):
+ def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ class TestCustomRef2(ReferenceModels.BaseDeadlineReference):
+ def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ # Register first reference to DAGRUN_CREATED
Review Comment:
Nit: I'm not sure so many of the statements need comments (throughout the
changes). I love comments as much as the next person, but if they're stating
the obvious rather than adding deeper context, then they just add clutter.
##########
task-sdk/src/airflow/sdk/definitions/deadline.py:
##########
@@ -321,3 +326,71 @@ def FIXED_DATETIME(cls, datetime: datetime) ->
DeadlineReferenceType:
(DeadlineReferenceType,),
{"_evaluate_with": lambda self, **kwargs: datetime.now()},
)()
+
+ @classmethod
+ def register_custom_reference(
+ cls,
+ reference_class: type[ReferenceModels.BaseDeadlineReference],
+ timing: DeadlineReferenceTuple | None = None,
+ ):
+ """
+ Register a custom deadline reference class.
+
+ :param reference_class: The custom reference class inheriting from
BaseDeadlineReference
+ :param timing: A DeadlineReference.TYPES for when the deadline should
be evaluated ("DAGRUN_CREATED",
+ "DAGRUN_QUEUED", etc.); defaults to
DeadlineReference.TYPES.DAGRUN_CREATED
+ """
+ from airflow.models.deadline import ReferenceModels
+
+ # Default to DAGRUN_CREATED if no timing specified
+ if timing is None:
+ timing = cls.TYPES.DAGRUN_CREATED
+
+ # Validate the reference class inherits from BaseDeadlineReference
+ if not issubclass(reference_class,
ReferenceModels.BaseDeadlineReference):
+ raise ValueError(f"{reference_class.__name__} must inherit from
BaseDeadlineReference")
+
+ # Register the new reference with ReferenceModels and
DeadlineReference for discoverability
+ setattr(ReferenceModels, reference_class.__name__, reference_class)
+ setattr(cls, reference_class.__name__, reference_class())
+ logger.info("Registered DeadlineReference %s",
reference_class.__name__)
+
+ # Add to appropriate timing classification
+ if timing is cls.TYPES.DAGRUN_CREATED:
+ cls.TYPES.DAGRUN_CREATED = cls.TYPES.DAGRUN_CREATED +
(reference_class,)
+ elif timing is cls.TYPES.DAGRUN_QUEUED:
+ cls.TYPES.DAGRUN_QUEUED = cls.TYPES.DAGRUN_QUEUED +
(reference_class,)
+ else:
+ raise ValueError("Invalid timing value; must be a valid
DeadlineReference.TYPES option.")
Review Comment:
Is there a more future proof way of doing this rather than adding a new elif
branch each time a new timing classification is added?
##########
airflow-core/tests/unit/models/test_deadline.py:
##########
@@ -583,3 +593,272 @@ def test_deadline_reference_creation(self):
custom_reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=5,
min_runs=3)
assert custom_reference.max_runs == 5
assert custom_reference.min_runs == 3
+
+
+class TestCustomDeadlineReference:
+ class MyCustomRef(ReferenceModels.BaseDeadlineReference):
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ class MyInvalidCustomRef:
+ pass
+
+ class MyCustomRefWithKwargs(ReferenceModels.BaseDeadlineReference):
+ required_kwargs = {"custom_id"}
+
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ def setup_method(self):
+ """Store original state before each test."""
+ self.original_dagrun_created = DeadlineReference.TYPES.DAGRUN_CREATED
+ self.original_dagrun_queued = DeadlineReference.TYPES.DAGRUN_QUEUED
+ self.original_dagrun = DeadlineReference.TYPES.DAGRUN
+ self.original_attrs = set(dir(ReferenceModels))
+ self.original_deadline_attrs = set(dir(DeadlineReference))
+
+ def teardown_method(self):
+ """Restore original TYPES and attrs after each test."""
+ DeadlineReference.TYPES.DAGRUN_CREATED = self.original_dagrun_created
+ DeadlineReference.TYPES.DAGRUN_QUEUED = self.original_dagrun_queued
+ DeadlineReference.TYPES.DAGRUN = self.original_dagrun
+
+ # Clean up ReferenceModels attributes.
+ for attr in set(dir(ReferenceModels)):
+ if attr not in self.original_attrs:
+ delattr(ReferenceModels, attr)
+
+ # Clean up DeadlineReference attributes.
+ for attr in set(dir(DeadlineReference)):
+ if attr not in self.original_deadline_attrs:
+ delattr(DeadlineReference, attr)
+
+ def test_custom_reference_consistent_access_pattern(self):
Review Comment:
Is this test superseded by the tests below? The only diff is the check for
evaluate_with, but that can be added to one of the other tests below?
##########
airflow-core/docs/howto/deadline-alerts.rst:
##########
@@ -328,24 +328,95 @@ Custom References
^^^^^^^^^^^^^^^^^
While the built-in references should cover most use cases, and more will be
released over time, you
-can create custom references by implementing a class that inherits from
DeadlineReference. This may
-be useful if you have calendar integrations or other sources that you want to
use as a reference.
+can create custom references. This may be useful if you have calendar
integrations or other sources
+that you want to use as a reference. You can create custom references by
implementing a class that
+inherits from BaseDeadlineReference, give it am _evaluate_with() method, and
register it. There are
+two ways to accomplish this. The recommended way is to use the
``@deadline_reference`` decorator
+but for more complicated implementations, the ``register_custom_reference()``
method is available.
+
+**Recommended: Using the decorator**
.. code-block:: python
- class CustomReference(DeadlineReference):
- """A deadline reference that uses a custom data source."""
+ from airflow._shared.timezones.timezone import datetime
+ from airflow.models.deadline import ReferenceModels
+ from sqlalchemy.orm import Session
+
+ from airflow.sdk.definitions.deadline import DeadlineReference,
deadline_reference
+
+
+ # By default, the evaluate_with method will be executed when the dagrun is
created.
+ @deadline_reference()
+ class MyCustomDecoratedReference(ReferenceModels.BaseDeadlineReference):
+ """A custom reference evaluated when DAG runs are created."""
+
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ # Add your business logic here
+ return your_datetime
+
+
+ # You can specify when evaluate_with will be called by providing a
DeadlineReference.TYPES value.
+ @deadline_reference(DeadlineReference.TYPES.DAGRUN_QUEUED)
+ class MyQueuedReference(ReferenceModels.BaseDeadlineReference):
+ """A custom reference evaluated when DAG runs are queued."""
+
+ required_kwargs = {"custom_param"}
+
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ custom_value = kwargs["custom_param"]
+ # Use custom_value in your calculation
+ return your_datetime
+
+**Alternative: Manual Registration**
+
+For advanced use cases requiring conditional or dynamic registration, you may
wish use the registration method directly.
+In this case, the plugin file will look something like this:
+
+.. code-block:: python
+
+ from sqlalchemy.orm import Session
+
+ from airflow.models.deadline import ReferenceModels
+ from airflow.sdk.definitions.deadline import DeadlineReference
- # Define any required parameters for your reference
- required_kwargs = {"custom_id"}
+ class MyManualReference(ReferenceModels.BaseDeadlineReference):
def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
- """
- Evaluate the reference time using the provided session and kwargs.
-
- The session parameter can be used for database queries, and kwargs
- will contain any required parameters defined in required_kwargs.
- """
- custom_id = kwargs["custom_id"]
- # Your custom logic here to determine the reference time
+ # Add your business logic here
return your_datetime
+
+
+ # Register with specific timing based on configuration
+ timing = (
+ DeadlineReference.TYPES.DAGRUN_QUEUED if use_queued_timing else
DeadlineReference.TYPES.DAGRUN_CREATED
+ )
+ DeadlineReference.register_custom_reference(MyManualReference, timing)
+
+**Using Custom References in DAGs**
+
+Once registered, use your custom references in DAG definitions like any other
reference:
Review Comment:
I'm not sure we should block every PR that mentions the word Dag in docs. We
can go back and fix all occurrences en masse once the vote has concluded.
##########
airflow-core/tests/unit/models/test_deadline.py:
##########
@@ -583,3 +593,272 @@ def test_deadline_reference_creation(self):
custom_reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=5,
min_runs=3)
assert custom_reference.max_runs == 5
assert custom_reference.min_runs == 3
+
+
+class TestCustomDeadlineReference:
+ class MyCustomRef(ReferenceModels.BaseDeadlineReference):
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ class MyInvalidCustomRef:
+ pass
+
+ class MyCustomRefWithKwargs(ReferenceModels.BaseDeadlineReference):
+ required_kwargs = {"custom_id"}
+
+ def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ def setup_method(self):
+ """Store original state before each test."""
+ self.original_dagrun_created = DeadlineReference.TYPES.DAGRUN_CREATED
+ self.original_dagrun_queued = DeadlineReference.TYPES.DAGRUN_QUEUED
+ self.original_dagrun = DeadlineReference.TYPES.DAGRUN
+ self.original_attrs = set(dir(ReferenceModels))
+ self.original_deadline_attrs = set(dir(DeadlineReference))
+
+ def teardown_method(self):
+ """Restore original TYPES and attrs after each test."""
+ DeadlineReference.TYPES.DAGRUN_CREATED = self.original_dagrun_created
+ DeadlineReference.TYPES.DAGRUN_QUEUED = self.original_dagrun_queued
+ DeadlineReference.TYPES.DAGRUN = self.original_dagrun
+
+ # Clean up ReferenceModels attributes.
+ for attr in set(dir(ReferenceModels)):
+ if attr not in self.original_attrs:
+ delattr(ReferenceModels, attr)
+
+ # Clean up DeadlineReference attributes.
+ for attr in set(dir(DeadlineReference)):
+ if attr not in self.original_deadline_attrs:
+ delattr(DeadlineReference, attr)
+
+ def test_custom_reference_consistent_access_pattern(self):
+ DeadlineReference.register_custom_reference(self.MyCustomRef)
+
+ # Should be accessible through DeadlineReference like built-ins.
+ custom_ref = getattr(DeadlineReference, self.MyCustomRef.__name__)
+ assert custom_ref.__class__ is self.MyCustomRef
+
+ # Should behave like built-in references.
+ assert hasattr(custom_ref, "_evaluate_with")
+ assert callable(custom_ref._evaluate_with)
+
+ def test_register_custom_reference_default_timing(self):
+ result = DeadlineReference.register_custom_reference(self.MyCustomRef)
+
+ # Should return the class.
+ assert result is self.MyCustomRef
+
+ # Should be registered with ReferenceModels.
+ assert hasattr(ReferenceModels, self.MyCustomRef.__name__)
+ assert getattr(ReferenceModels, self.MyCustomRef.__name__) is
self.MyCustomRef
+
+ # Should be accessible through DeadlineReference.
+ assert hasattr(DeadlineReference, self.MyCustomRef.__name__)
+ assert getattr(DeadlineReference, self.MyCustomRef.__name__).__class__
is self.MyCustomRef
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN_CREATED
+ assert self.MyCustomRef not in DeadlineReference.TYPES.DAGRUN_QUEUED
+
+ # Should update combined DAGRUN tuple
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN
+
+ def
test_register_custom_reference_dagrun_created_with_explicit_timing(self):
+ result = DeadlineReference.register_custom_reference(
+ self.MyCustomRef, DeadlineReference.TYPES.DAGRUN_CREATED
+ )
+
+ assert result is self.MyCustomRef
+ assert hasattr(ReferenceModels, self.MyCustomRef.__name__)
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN_CREATED
+ assert self.MyCustomRef not in DeadlineReference.TYPES.DAGRUN_QUEUED
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN
+
+ def test_register_custom_reference_with_dagrun_queued(self):
+ result = DeadlineReference.register_custom_reference(
+ self.MyCustomRef, DeadlineReference.TYPES.DAGRUN_QUEUED
+ )
+
+ assert result is self.MyCustomRef
+ assert hasattr(ReferenceModels, self.MyCustomRef.__name__)
+ assert self.MyCustomRef not in DeadlineReference.TYPES.DAGRUN_CREATED
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN_QUEUED
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN
+
+ def test_register_custom_reference_invalid_inheritance(self):
+ with pytest.raises(ValueError, match="InvalidCustomRef must inherit
from BaseDeadlineReference"):
+
DeadlineReference.register_custom_reference(self.MyInvalidCustomRef)
+
+ def test_register_custom_reference_invalid_timing(self):
+ invalid_timing = ("not", "a", "valid", "timing")
+
+ with pytest.raises(
+ ValueError, match="Invalid timing value; must be a valid
DeadlineReference.TYPES option"
+ ):
+ DeadlineReference.register_custom_reference(self.MyCustomRef,
invalid_timing)
+
+ def test_register_custom_reference_with_required_kwargs(self):
+ result =
DeadlineReference.register_custom_reference(self.MyCustomRefWithKwargs)
+
+ assert result is self.MyCustomRefWithKwargs
+ assert hasattr(ReferenceModels, self.MyCustomRefWithKwargs.__name__)
+ assert self.MyCustomRefWithKwargs in
DeadlineReference.TYPES.DAGRUN_CREATED
+
+ def test_register_multiple_custom_references(self):
+ class TestCustomRef1(ReferenceModels.BaseDeadlineReference):
+ def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ class TestCustomRef2(ReferenceModels.BaseDeadlineReference):
+ def _evaluate_with(self, *, session: Session, **kwargs) ->
datetime:
+ return timezone.datetime(DEFAULT_DATE)
+
+ # Register first reference to DAGRUN_CREATED
+ DeadlineReference.register_custom_reference(TestCustomRef1)
+
+ # Register second reference to DAGRUN_QUEUED
+ DeadlineReference.register_custom_reference(TestCustomRef2,
DeadlineReference.TYPES.DAGRUN_QUEUED)
+
+ # Both should be registered
+ assert hasattr(ReferenceModels, TestCustomRef1.__name__)
+ assert hasattr(ReferenceModels, TestCustomRef2.__name__)
+
+ # Should be in correct timing tuples
+ assert TestCustomRef1 in DeadlineReference.TYPES.DAGRUN_CREATED
+ assert TestCustomRef2 in DeadlineReference.TYPES.DAGRUN_QUEUED
+
+ # Both should be in combined DAGRUN tuple
+ assert TestCustomRef1 in DeadlineReference.TYPES.DAGRUN
+ assert TestCustomRef2 in DeadlineReference.TYPES.DAGRUN
+
+ def test_register_custom_reference_preserves_existing_types(self):
+ # Get original built-in types
+ original_created_types = set(DeadlineReference.TYPES.DAGRUN_CREATED)
+ original_queued_types = set(DeadlineReference.TYPES.DAGRUN_QUEUED)
+
+ # Register custom reference
+ DeadlineReference.register_custom_reference(self.MyCustomRef)
+
+ # Built-in types should still be present
+ for builtin_type in original_created_types:
+ assert builtin_type in DeadlineReference.TYPES.DAGRUN_CREATED
+
+ for builtin_type in original_queued_types:
+ assert builtin_type in DeadlineReference.TYPES.DAGRUN_QUEUED
+
+ # Custom type should be added
+ assert self.MyCustomRef in DeadlineReference.TYPES.DAGRUN_CREATED
+
+ def test_custom_reference_discoverable_by_get_reference_class(self):
+ DeadlineReference.register_custom_reference(self.MyCustomRef)
+
+ found_class = ReferenceModels.get_reference_class("MyCustomRef")
+ assert found_class is self.MyCustomRef
+
+ @pytest.mark.parametrize(
+ "timing, expected_in_created, expected_in_queued",
+ [
+ pytest.param(None, True, False, id="default_timing"),
+ pytest.param(DeadlineReference.TYPES.DAGRUN_CREATED, True, False,
id="explicit_created"),
+ pytest.param(DeadlineReference.TYPES.DAGRUN_QUEUED, False, True,
id="explicit_queued"),
+ ],
+ )
+ def test_custom_reference_timing_classification(self, timing,
expected_in_created, expected_in_queued):
Review Comment:
Isn't there at least one test above for this case? I'm not sure we need
those and this one.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]