o-nikolas commented on code in PR #53951:
URL: https://github.com/apache/airflow/pull/53951#discussion_r2248919416
##########
task-sdk/src/airflow/sdk/definitions/deadline.py:
##########
@@ -134,11 +132,109 @@ def deserialize_deadline_alert(cls, encoded_data: dict)
-> DeadlineAlert:
return cls(
reference=reference,
interval=timedelta(seconds=data[DeadlineAlertFields.INTERVAL]),
- callback=data[DeadlineAlertFields.CALLBACK], # Keep as string path
- callback_kwargs=data[DeadlineAlertFields.CALLBACK_KWARGS],
+ callback=cast("Callback",
deserialize(data[DeadlineAlertFields.CALLBACK])),
)
+class Callback(ABC):
+ """
+ Base class for deadline alert callbacks.
+
+ Callbacks are used to execute custom logic when a deadline is missed.
+
+ The `callback_callable` can be a Python callable type or a string
containing the path to the callable that
+ can be used to import the callable. It must be a top-level callable in a
module present on the host where
+ it will run.
+
+ It will be called with Airflow context and specified kwargs when a
deadline is missed.
+ """
+
+ path: str
+ kwargs: dict | None
+
+ def __init__(self, callback_callable: Callable | str, kwargs: dict | None
= None):
+ self.path = DeadlineAlert.get_callback_path(callback_callable)
+ self.kwargs = kwargs
+
+ def serialize(self) -> dict[str, Any]:
+ return {f: getattr(self, f) for f in self.serialized_fields()}
+
+ @classmethod
+ def deserialize(cls, data: dict, version):
+ path = data.pop("path")
+ return cls(callback_callable=path, **data)
+
+ @classmethod
+ def serialized_fields(cls) -> tuple:
+ return ("path", "kwargs")
+
+ def __eq__(self, other):
+ if type(self) is not type(other):
+ return NotImplemented
+ return self.serialize() == other.serialize()
+
+ def __hash__(self):
+ serialized = self.serialize()
+ hashable_items = []
+ for k, v in serialized.items():
+ if isinstance(v, dict) and v:
+ hashable_items.append((k, tuple(sorted(v.items()))))
+ else:
+ hashable_items.append((k, v))
+ return hash(tuple(sorted(hashable_items)))
+
+
+class AsyncCallback(Callback):
+ """
+ Asynchronous callback that runs in the triggerer.
+
+ The `callback_callable` can be a Python callable type or a string
containing the path to the callable that
+ can be used to import the callable. It must be a top-level awaitable
callable in a module present on the
+ triggerer.
+
+ It will be called with Airflow context and specified kwargs when a
deadline is missed.
+ """
+
+ def __init__(self, callback_callable: Callable | str, kwargs: dict | None
= None):
+ super().__init__(callback_callable=callback_callable, kwargs=kwargs)
+
+ if isinstance(callback_callable, str):
+ try:
+ callback_callable = import_string(callback_callable)
+ except ImportError as e:
+ logger.info(
+ "Failed to import callback_callable\nAssuming it exists on
the triggerer and is awaitable\n%s",
+ e,
+ )
Review Comment:
Should this be logger.debug?
##########
task-sdk/src/airflow/sdk/definitions/deadline.py:
##########
@@ -134,11 +132,109 @@ def deserialize_deadline_alert(cls, encoded_data: dict)
-> DeadlineAlert:
return cls(
reference=reference,
interval=timedelta(seconds=data[DeadlineAlertFields.INTERVAL]),
- callback=data[DeadlineAlertFields.CALLBACK], # Keep as string path
- callback_kwargs=data[DeadlineAlertFields.CALLBACK_KWARGS],
+ callback=cast("Callback",
deserialize(data[DeadlineAlertFields.CALLBACK])),
)
+class Callback(ABC):
+ """
+ Base class for deadline alert callbacks.
+
+ Callbacks are used to execute custom logic when a deadline is missed.
+
+ The `callback_callable` can be a Python callable type or a string
containing the path to the callable that
+ can be used to import the callable. It must be a top-level callable in a
module present on the host where
+ it will run.
+
+ It will be called with Airflow context and specified kwargs when a
deadline is missed.
+ """
+
+ path: str
+ kwargs: dict | None
+
+ def __init__(self, callback_callable: Callable | str, kwargs: dict | None
= None):
+ self.path = DeadlineAlert.get_callback_path(callback_callable)
+ self.kwargs = kwargs
+
+ def serialize(self) -> dict[str, Any]:
+ return {f: getattr(self, f) for f in self.serialized_fields()}
+
+ @classmethod
+ def deserialize(cls, data: dict, version):
+ path = data.pop("path")
+ return cls(callback_callable=path, **data)
+
+ @classmethod
+ def serialized_fields(cls) -> tuple:
+ return ("path", "kwargs")
+
+ def __eq__(self, other):
+ if type(self) is not type(other):
+ return NotImplemented
+ return self.serialize() == other.serialize()
+
+ def __hash__(self):
+ serialized = self.serialize()
+ hashable_items = []
+ for k, v in serialized.items():
+ if isinstance(v, dict) and v:
+ hashable_items.append((k, tuple(sorted(v.items()))))
+ else:
+ hashable_items.append((k, v))
+ return hash(tuple(sorted(hashable_items)))
+
+
+class AsyncCallback(Callback):
+ """
+ Asynchronous callback that runs in the triggerer.
+
+ The `callback_callable` can be a Python callable type or a string
containing the path to the callable that
+ can be used to import the callable. It must be a top-level awaitable
callable in a module present on the
+ triggerer.
+
+ It will be called with Airflow context and specified kwargs when a
deadline is missed.
+ """
+
+ def __init__(self, callback_callable: Callable | str, kwargs: dict | None
= None):
+ super().__init__(callback_callable=callback_callable, kwargs=kwargs)
+
+ if isinstance(callback_callable, str):
+ try:
+ callback_callable = import_string(callback_callable)
+ except ImportError as e:
+ logger.info(
+ "Failed to import callback_callable\nAssuming it exists on
the triggerer and is awaitable\n%s",
+ e,
+ )
+ return
+
+ if not (inspect.iscoroutinefunction(callback_callable) or
hasattr(callback_callable, "__await__")):
+ raise TypeError(f"Callback {callback_callable} must be awaitable")
+
+
+class SyncCallback(Callback):
+ """
+ Synchronous callback that runs in the specified or default executor.
+
+ The `callback_callable` can be a Python callable type or a string
containing the path to the callable that
+ can be used to import the callable. It must be a top-level callable in a
module present on the executor.
+
+ It will be called with Airflow context and specified kwargs when a
deadline is missed.
+ """
+
+ executor: str | None
+
+ def __init__(
+ self, callback_callable: Callable | str, kwargs: dict | None = None,
executor: str | None = None
+ ):
Review Comment:
Is it worth checking if it has a __call__ dunder like you do with the
__await__ in the async callback?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]