o-nikolas commented on code in PR #57222:
URL: https://github.com/apache/airflow/pull/57222#discussion_r2467205840


##########
airflow-core/tests/unit/models/test_deadline.py:
##########
@@ -583,3 +593,272 @@ def test_deadline_reference_creation(self):
         custom_reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=5, 
min_runs=3)
         assert custom_reference.max_runs == 5
         assert custom_reference.min_runs == 3
+
+
+class TestCustomDeadlineReference:
+    class MyCustomRef(ReferenceModels.BaseDeadlineReference):
+        def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+            return timezone.datetime(DEFAULT_DATE)
+
+    class MyInvalidCustomRef:
+        pass
+
+    class MyCustomRefWithKwargs(ReferenceModels.BaseDeadlineReference):
+        required_kwargs = {"custom_id"}
+
+        def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
+            return timezone.datetime(DEFAULT_DATE)
+
+    def setup_method(self):
+        """Store original state before each test."""
+        self.original_dagrun_created = DeadlineReference.TYPES.DAGRUN_CREATED
+        self.original_dagrun_queued = DeadlineReference.TYPES.DAGRUN_QUEUED
+        self.original_dagrun = DeadlineReference.TYPES.DAGRUN
+        self.original_attrs = set(dir(ReferenceModels))
+        self.original_deadline_attrs = set(dir(DeadlineReference))
+
+    def teardown_method(self):
+        """Restore original TYPES and attrs after each test."""
+        DeadlineReference.TYPES.DAGRUN_CREATED = self.original_dagrun_created
+        DeadlineReference.TYPES.DAGRUN_QUEUED = self.original_dagrun_queued
+        DeadlineReference.TYPES.DAGRUN = self.original_dagrun
+
+        # Clean up ReferenceModels attributes.
+        for attr in set(dir(ReferenceModels)):
+            if attr not in self.original_attrs:
+                delattr(ReferenceModels, attr)
+
+        # Clean up DeadlineReference attributes.
+        for attr in set(dir(DeadlineReference)):
+            if attr not in self.original_deadline_attrs:
+                delattr(DeadlineReference, attr)
+
+    def test_custom_reference_consistent_access_pattern(self):
+        DeadlineReference.register_custom_reference(self.MyCustomRef)
+
+        # Should be accessible through DeadlineReference like built-ins.
+        custom_ref = getattr(DeadlineReference, self.MyCustomRef.__name__)
+        assert custom_ref.__class__ is self.MyCustomRef
+
+        # Should behave like built-in references.
+        assert hasattr(custom_ref, "_evaluate_with")
+        assert callable(custom_ref._evaluate_with)
+
+    def test_register_custom_reference_default_timing(self):
+        result = DeadlineReference.register_custom_reference(self.MyCustomRef)
+
+        # Should return the class.
+        assert result is self.MyCustomRef
+
+        # Should be registered with ReferenceModels.
+        assert hasattr(ReferenceModels, self.MyCustomRef.__name__)
+        assert getattr(ReferenceModels, self.MyCustomRef.__name__) is 
self.MyCustomRef
+
+        # Should be accessible through DeadlineReference.
+        assert hasattr(DeadlineReference, self.MyCustomRef.__name__)

Review Comment:
   Fair enough 👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to