This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3eed501dcf Add `on_skipped_callback` in to `BaseOperator` (#36374)
3eed501dcf is described below
commit 3eed501dcfb058d85de4a3ffb342c537036f7c73
Author: rom sharon <[email protected]>
AuthorDate: Sun Jan 14 22:46:57 2024 +0200
Add `on_skipped_callback` in to `BaseOperator` (#36374)
---------
Co-authored-by: Jens Scheffler <[email protected]>
---
airflow/example_dags/tutorial.py | 1 +
airflow/models/baseoperator.py | 10 ++++++++++
airflow/models/mappedoperator.py | 8 ++++++++
airflow/models/taskinstance.py | 2 ++
.../logging-monitoring/callbacks.rst | 4 ++++
tests/models/test_taskinstance.py | 20 ++++++++++++++++++++
tests/serialization/test_dag_serialization.py | 1 +
7 files changed, 46 insertions(+)
diff --git a/airflow/example_dags/tutorial.py b/airflow/example_dags/tutorial.py
index 4656f69c0b..9915810985 100644
--- a/airflow/example_dags/tutorial.py
+++ b/airflow/example_dags/tutorial.py
@@ -60,6 +60,7 @@ with DAG(
# 'on_success_callback': some_other_function, # or list of functions
# 'on_retry_callback': another_function, # or list of functions
# 'sla_miss_callback': yet_another_function, # or list of functions
+ # 'on_skipped_callback': another_function, #or list of functions
# 'trigger_rule': 'all_success'
},
# [END default_args]
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index f7f1d6ccc6..2d0244dbbf 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -247,6 +247,7 @@ def partial(
on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
+ on_skipped_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
run_as_user: str | None | ArgNotSet = NOTSET,
executor_config: dict | None | ArgNotSet = NOTSET,
inlets: Any | None | ArgNotSet = NOTSET,
@@ -310,6 +311,7 @@ def partial(
"on_failure_callback": on_failure_callback,
"on_retry_callback": on_retry_callback,
"on_success_callback": on_success_callback,
+ "on_skipped_callback": on_skipped_callback,
"run_as_user": run_as_user,
"executor_config": executor_config,
"inlets": inlets,
@@ -597,6 +599,11 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
that it is executed when retries occur.
:param on_success_callback: much like the ``on_failure_callback`` except
that it is executed when the task succeeds.
+ :param on_skipped_callback: much like the ``on_failure_callback`` except
+ that it is executed when skipped occur; this callback will be called
only if AirflowSkipException get raised.
+ Explicitly it is NOT called if a task is not started to be executed
because of a preceding branching
+ decision in the DAG or a trigger rule which causes execution to skip
so that the task execution
+ is never scheduled.
:param pre_execute: a function to be called immediately before task
execution, receiving a context dictionary; raising an exception will
prevent the task from being executed.
@@ -700,6 +707,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
"on_failure_callback",
"on_success_callback",
"on_retry_callback",
+ "on_skipped_callback",
"do_xcom_push",
}
@@ -759,6 +767,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
+ on_skipped_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
pre_execute: TaskPreExecuteHook | None = None,
post_execute: TaskPostExecuteHook | None = None,
trigger_rule: str = DEFAULT_TRIGGER_RULE,
@@ -825,6 +834,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
self.on_failure_callback = on_failure_callback
self.on_success_callback = on_success_callback
self.on_retry_callback = on_retry_callback
+ self.on_skipped_callback = on_skipped_callback
self._pre_execute_hook = pre_execute
self._post_execute_hook = post_execute
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 8174db145a..3c555e874f 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -522,6 +522,14 @@ class MappedOperator(AbstractOperator):
def on_success_callback(self, value: TaskStateChangeCallback | None) ->
None:
self.partial_kwargs["on_success_callback"] = value
+ @property
+ def on_skipped_callback(self) -> None | TaskStateChangeCallback |
list[TaskStateChangeCallback]:
+ return self.partial_kwargs.get("on_skipped_callback")
+
+ @on_skipped_callback.setter
+ def on_skipped_callback(self, value: TaskStateChangeCallback | None) ->
None:
+ self.partial_kwargs["on_skipped_callback"] = value
+
@property
def run_as_user(self) -> str | None:
return self.partial_kwargs.get("run_as_user")
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c0427715a2..8e6886ea11 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2361,6 +2361,8 @@ class TaskInstance(Base, LoggingMixin):
self.log.info(e)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
+
_run_finished_callback(callbacks=self.task.on_skipped_callback, context=context)
+ session.commit()
self.state = TaskInstanceState.SKIPPED
except AirflowRescheduleException as reschedule_exception:
self._handle_reschedule(actual_start_date,
reschedule_exception, test_mode, session=session)
diff --git
a/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst
b/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst
index c21752ab79..a70a876ba3 100644
---
a/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst
+++
b/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst
@@ -49,6 +49,10 @@ Name Description
``sla_miss_callback`` Invoked when a task misses its
defined :ref:`SLA <concepts:slas>`
``on_retry_callback`` Invoked when the task is :ref:`up
for retry <concepts:task-instances>`
``on_execute_callback`` Invoked right before the task
begins executing.
+``on_skipped_callback`` Invoked when the task is
:ref:`running <concepts:task-instances>` and AirflowSkipException raised.
+ Explicitly it is NOT called if a
task is not started to be executed because of a preceding branching
+ decision in the DAG or a trigger
rule which causes execution to skip so that the task execution
+ is never scheduled.
===========================================
================================================================
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 8d914214b8..319ecd98f9 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3172,6 +3172,26 @@ class TestTaskInstance:
assert
session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None
+ def test_skipped_task_call_on_skipped_callback(self, dag_maker):
+ def raise_skip_exception():
+ raise AirflowSkipException
+
+ callback_function = mock.MagicMock()
+
+ with dag_maker(dag_id="test_skipped_task"):
+ task = PythonOperator(
+ task_id="test_skipped_task",
+ python_callable=raise_skip_exception,
+ on_skipped_callback=callback_function,
+ )
+
+ dr = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ ti = dr.task_instances[0]
+ ti.task = task
+ ti.run()
+ assert State.SKIPPED == ti.state
+ assert callback_function.called
+
@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
def test_refresh_from_task(pool_override):
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 30407eb945..83d40886c0 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1236,6 +1236,7 @@ class TestStringifiedDAGs:
"on_execute_callback": None,
"on_failure_callback": None,
"on_retry_callback": None,
+ "on_skipped_callback": None,
"on_success_callback": None,
"outlets": [],
"owner": "airflow",