This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3eed501dcf Add `on_skipped_callback` in to `BaseOperator` (#36374)
3eed501dcf is described below

commit 3eed501dcfb058d85de4a3ffb342c537036f7c73
Author: rom sharon <[email protected]>
AuthorDate: Sun Jan 14 22:46:57 2024 +0200

    Add `on_skipped_callback` in to `BaseOperator` (#36374)
    
    
    
    ---------
    
    Co-authored-by: Jens Scheffler <[email protected]>
---
 airflow/example_dags/tutorial.py                     |  1 +
 airflow/models/baseoperator.py                       | 10 ++++++++++
 airflow/models/mappedoperator.py                     |  8 ++++++++
 airflow/models/taskinstance.py                       |  2 ++
 .../logging-monitoring/callbacks.rst                 |  4 ++++
 tests/models/test_taskinstance.py                    | 20 ++++++++++++++++++++
 tests/serialization/test_dag_serialization.py        |  1 +
 7 files changed, 46 insertions(+)

diff --git a/airflow/example_dags/tutorial.py b/airflow/example_dags/tutorial.py
index 4656f69c0b..9915810985 100644
--- a/airflow/example_dags/tutorial.py
+++ b/airflow/example_dags/tutorial.py
@@ -60,6 +60,7 @@ with DAG(
         # 'on_success_callback': some_other_function, # or list of functions
         # 'on_retry_callback': another_function, # or list of functions
         # 'sla_miss_callback': yet_another_function, # or list of functions
+        # 'on_skipped_callback': another_function, #or list of functions
         # 'trigger_rule': 'all_success'
     },
     # [END default_args]
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index f7f1d6ccc6..2d0244dbbf 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -247,6 +247,7 @@ def partial(
     on_failure_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
     on_success_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
     on_retry_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
+    on_skipped_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
     run_as_user: str | None | ArgNotSet = NOTSET,
     executor_config: dict | None | ArgNotSet = NOTSET,
     inlets: Any | None | ArgNotSet = NOTSET,
@@ -310,6 +311,7 @@ def partial(
         "on_failure_callback": on_failure_callback,
         "on_retry_callback": on_retry_callback,
         "on_success_callback": on_success_callback,
+        "on_skipped_callback": on_skipped_callback,
         "run_as_user": run_as_user,
         "executor_config": executor_config,
         "inlets": inlets,
@@ -597,6 +599,11 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         that it is executed when retries occur.
     :param on_success_callback: much like the ``on_failure_callback`` except
         that it is executed when the task succeeds.
+    :param on_skipped_callback: much like the ``on_failure_callback`` except
+        that it is executed when skipped occur; this callback will be called 
only if AirflowSkipException get raised.
+        Explicitly it is NOT called if a task is not started to be executed 
because of a preceding branching
+        decision in the DAG or a trigger rule which causes execution to skip 
so that the task execution
+        is never scheduled.
     :param pre_execute: a function to be called immediately before task
         execution, receiving a context dictionary; raising an exception will
         prevent the task from being executed.
@@ -700,6 +707,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         "on_failure_callback",
         "on_success_callback",
         "on_retry_callback",
+        "on_skipped_callback",
         "do_xcom_push",
     }
 
@@ -759,6 +767,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         on_failure_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
         on_success_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
         on_retry_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
+        on_skipped_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
         pre_execute: TaskPreExecuteHook | None = None,
         post_execute: TaskPostExecuteHook | None = None,
         trigger_rule: str = DEFAULT_TRIGGER_RULE,
@@ -825,6 +834,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         self.on_failure_callback = on_failure_callback
         self.on_success_callback = on_success_callback
         self.on_retry_callback = on_retry_callback
+        self.on_skipped_callback = on_skipped_callback
         self._pre_execute_hook = pre_execute
         self._post_execute_hook = post_execute
 
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 8174db145a..3c555e874f 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -522,6 +522,14 @@ class MappedOperator(AbstractOperator):
     def on_success_callback(self, value: TaskStateChangeCallback | None) -> 
None:
         self.partial_kwargs["on_success_callback"] = value
 
+    @property
+    def on_skipped_callback(self) -> None | TaskStateChangeCallback | 
list[TaskStateChangeCallback]:
+        return self.partial_kwargs.get("on_skipped_callback")
+
+    @on_skipped_callback.setter
+    def on_skipped_callback(self, value: TaskStateChangeCallback | None) -> 
None:
+        self.partial_kwargs["on_skipped_callback"] = value
+
     @property
     def run_as_user(self) -> str | None:
         return self.partial_kwargs.get("run_as_user")
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c0427715a2..8e6886ea11 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2361,6 +2361,8 @@ class TaskInstance(Base, LoggingMixin):
                     self.log.info(e)
                 if not test_mode:
                     self.refresh_from_db(lock_for_update=True, session=session)
+                
_run_finished_callback(callbacks=self.task.on_skipped_callback, context=context)
+                session.commit()
                 self.state = TaskInstanceState.SKIPPED
             except AirflowRescheduleException as reschedule_exception:
                 self._handle_reschedule(actual_start_date, 
reschedule_exception, test_mode, session=session)
diff --git 
a/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst
 
b/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst
index c21752ab79..a70a876ba3 100644
--- 
a/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst
+++ 
b/docs/apache-airflow/administration-and-deployment/logging-monitoring/callbacks.rst
@@ -49,6 +49,10 @@ Name                                        Description
 ``sla_miss_callback``                       Invoked when a task misses its 
defined :ref:`SLA <concepts:slas>`
 ``on_retry_callback``                       Invoked when the task is :ref:`up 
for retry <concepts:task-instances>`
 ``on_execute_callback``                     Invoked right before the task 
begins executing.
+``on_skipped_callback``                     Invoked when the task is 
:ref:`running <concepts:task-instances>` and  AirflowSkipException raised.
+                                            Explicitly it is NOT called if a 
task is not started to be executed because of a preceding branching
+                                            decision in the DAG or a trigger 
rule which causes execution to skip so that the task execution
+                                            is never scheduled.
 =========================================== 
================================================================
 
 
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 8d914214b8..319ecd98f9 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3172,6 +3172,26 @@ class TestTaskInstance:
 
         assert 
session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None
 
+    def test_skipped_task_call_on_skipped_callback(self, dag_maker):
+        def raise_skip_exception():
+            raise AirflowSkipException
+
+        callback_function = mock.MagicMock()
+
+        with dag_maker(dag_id="test_skipped_task"):
+            task = PythonOperator(
+                task_id="test_skipped_task",
+                python_callable=raise_skip_exception,
+                on_skipped_callback=callback_function,
+            )
+
+        dr = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        ti = dr.task_instances[0]
+        ti.task = task
+        ti.run()
+        assert State.SKIPPED == ti.state
+        assert callback_function.called
+
 
 @pytest.mark.parametrize("pool_override", [None, "test_pool2"])
 def test_refresh_from_task(pool_override):
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 30407eb945..83d40886c0 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1236,6 +1236,7 @@ class TestStringifiedDAGs:
             "on_execute_callback": None,
             "on_failure_callback": None,
             "on_retry_callback": None,
+            "on_skipped_callback": None,
             "on_success_callback": None,
             "outlets": [],
             "owner": "airflow",

Reply via email to