This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f5c2748c33 fix(providers/sql): respect soft_fail argument when 
exception is raised (#34199)
f5c2748c33 is described below

commit f5c2748c3346bdebf445afd615657af8849345dd
Author: Wei Lee <[email protected]>
AuthorDate: Sat Sep 9 04:09:13 2023 +0800

    fix(providers/sql): respect soft_fail argument when exception is raised 
(#34199)
---
 airflow/providers/common/sql/sensors/sql.py    | 28 ++++++++--
 tests/providers/common/sql/sensors/test_sql.py | 72 ++++++++++++++++++++------
 2 files changed, 80 insertions(+), 20 deletions(-)

diff --git a/airflow/providers/common/sql/sensors/sql.py 
b/airflow/providers/common/sql/sensors/sql.py
index 73505390fc..7eab94e558 100644
--- a/airflow/providers/common/sql/sensors/sql.py
+++ b/airflow/providers/common/sql/sensors/sql.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 from typing import Any, Sequence
 
-from airflow import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.hooks.base import BaseHook
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 from airflow.sensors.base import BaseSensorOperator
@@ -96,19 +96,37 @@ class SqlSensor(BaseSensorOperator):
         records = hook.get_records(self.sql, self.parameters)
         if not records:
             if self.fail_on_empty:
-                raise AirflowException("No rows returned, raising as per 
fail_on_empty flag")
+                # TODO: remove this if block when min_airflow_version is set 
to higher than 2.7.1
+                message = "No rows returned, raising as per fail_on_empty flag"
+                if self.soft_fail:
+                    raise AirflowSkipException(message)
+                raise AirflowException(message)
             else:
                 return False
+
         first_cell = records[0][0]
         if self.failure is not None:
             if callable(self.failure):
                 if self.failure(first_cell):
-                    raise AirflowException(f"Failure criteria met. 
self.failure({first_cell}) returned True")
+                    # TODO: remove this if block when min_airflow_version is 
set to higher than 2.7.1
+                    message = f"Failure criteria met. 
self.failure({first_cell}) returned True"
+                    if self.soft_fail:
+                        raise AirflowSkipException(message)
+                    raise AirflowException(message)
             else:
-                raise AirflowException(f"self.failure is present, but not 
callable -> {self.failure}")
+                # TODO: remove this if block when min_airflow_version is set 
to higher than 2.7.1
+                message = f"self.failure is present, but not callable -> 
{self.failure}"
+                if self.soft_fail:
+                    raise AirflowSkipException(message)
+                raise AirflowException(message)
+
         if self.success is not None:
             if callable(self.success):
                 return self.success(first_cell)
             else:
-                raise AirflowException(f"self.success is present, but not 
callable -> {self.success}")
+                # TODO: remove this if block when min_airflow_version is set 
to higher than 2.7.1
+                message = f"self.success is present, but not callable -> 
{self.success}"
+                if self.soft_fail:
+                    raise AirflowSkipException(message)
+                raise AirflowException(message)
         return bool(first_cell)
diff --git a/tests/providers/common/sql/sensors/test_sql.py 
b/tests/providers/common/sql/sensors/test_sql.py
index b14c977bd3..7491d03e1a 100644
--- a/tests/providers/common/sql/sensors/test_sql.py
+++ b/tests/providers/common/sql/sensors/test_sql.py
@@ -21,7 +21,7 @@ from unittest import mock
 
 import pytest
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.models.dag import DAG
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 from airflow.providers.common.sql.sensors.sql import SqlSensor
@@ -117,17 +117,26 @@ class TestSqlSensor:
         mock_get_records.return_value = [["1"]]
         assert op.poke(None)
 
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, 
AirflowSkipException))
+    )
     @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
-    def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook):
+    def test_sql_sensor_postgres_poke_fail_on_empty(
+        self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+    ):
         op = SqlSensor(
-            task_id="sql_sensor_check", conn_id="postgres_default", 
sql="SELECT 1", fail_on_empty=True
+            task_id="sql_sensor_check",
+            conn_id="postgres_default",
+            sql="SELECT 1",
+            fail_on_empty=True,
+            soft_fail=soft_fail,
         )
 
         mock_hook.get_connection.return_value.get_hook.return_value = 
mock.MagicMock(spec=DbApiHook)
         mock_get_records = 
mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = []
-        with pytest.raises(AirflowException):
+        with pytest.raises(expected_exception):
             op.poke(None)
 
     @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
@@ -148,10 +157,19 @@ class TestSqlSensor:
         mock_get_records.return_value = [["1"]]
         assert not op.poke(None)
 
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, 
AirflowSkipException))
+    )
     @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
-    def test_sql_sensor_postgres_poke_failure(self, mock_hook):
+    def test_sql_sensor_postgres_poke_failure(
+        self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+    ):
         op = SqlSensor(
-            task_id="sql_sensor_check", conn_id="postgres_default", 
sql="SELECT 1", failure=lambda x: x in [1]
+            task_id="sql_sensor_check",
+            conn_id="postgres_default",
+            sql="SELECT 1",
+            failure=lambda x: x in [1],
+            soft_fail=soft_fail,
         )
 
         mock_hook.get_connection.return_value.get_hook.return_value = 
mock.MagicMock(spec=DbApiHook)
@@ -161,17 +179,23 @@ class TestSqlSensor:
         assert not op.poke(None)
 
         mock_get_records.return_value = [[1]]
-        with pytest.raises(AirflowException):
+        with pytest.raises(expected_exception):
             op.poke(None)
 
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, 
AirflowSkipException))
+    )
     @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
-    def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
+    def test_sql_sensor_postgres_poke_failure_success(
+        self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+    ):
         op = SqlSensor(
             task_id="sql_sensor_check",
             conn_id="postgres_default",
             sql="SELECT 1",
             failure=lambda x: x in [1],
             success=lambda x: x in [2],
+            soft_fail=soft_fail,
         )
 
         mock_hook.get_connection.return_value.get_hook.return_value = 
mock.MagicMock(spec=DbApiHook)
@@ -181,20 +205,26 @@ class TestSqlSensor:
         assert not op.poke(None)
 
         mock_get_records.return_value = [[1]]
-        with pytest.raises(AirflowException):
+        with pytest.raises(expected_exception):
             op.poke(None)
 
         mock_get_records.return_value = [[2]]
         assert op.poke(None)
 
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, 
AirflowSkipException))
+    )
     @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
-    def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
+    def test_sql_sensor_postgres_poke_failure_success_same(
+        self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+    ):
         op = SqlSensor(
             task_id="sql_sensor_check",
             conn_id="postgres_default",
             sql="SELECT 1",
             failure=lambda x: x in [1],
             success=lambda x: x in [1],
+            soft_fail=soft_fail,
         )
 
         mock_hook.get_connection.return_value.get_hook.return_value = 
mock.MagicMock(spec=DbApiHook)
@@ -204,40 +234,52 @@ class TestSqlSensor:
         assert not op.poke(None)
 
         mock_get_records.return_value = [[1]]
-        with pytest.raises(AirflowException):
+        with pytest.raises(expected_exception):
             op.poke(None)
 
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, 
AirflowSkipException))
+    )
     @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
-    def test_sql_sensor_postgres_poke_invalid_failure(self, mock_hook):
+    def test_sql_sensor_postgres_poke_invalid_failure(
+        self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+    ):
         op = SqlSensor(
             task_id="sql_sensor_check",
             conn_id="postgres_default",
             sql="SELECT 1",
             failure=[1],
+            soft_fail=soft_fail,
         )
 
         mock_hook.get_connection.return_value.get_hook.return_value = 
mock.MagicMock(spec=DbApiHook)
         mock_get_records = 
mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = [[1]]
-        with pytest.raises(AirflowException) as ctx:
+        with pytest.raises(expected_exception) as ctx:
             op.poke(None)
         assert "self.failure is present, but not callable -> [1]" == 
str(ctx.value)
 
+    @pytest.mark.parametrize(
+        "soft_fail, expected_exception", ((False, AirflowException), (True, 
AirflowSkipException))
+    )
     @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
-    def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook):
+    def test_sql_sensor_postgres_poke_invalid_success(
+        self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+    ):
         op = SqlSensor(
             task_id="sql_sensor_check",
             conn_id="postgres_default",
             sql="SELECT 1",
             success=[1],
+            soft_fail=soft_fail,
         )
 
         mock_hook.get_connection.return_value.get_hook.return_value = 
mock.MagicMock(spec=DbApiHook)
         mock_get_records = 
mock_hook.get_connection.return_value.get_hook.return_value.get_records
 
         mock_get_records.return_value = [[1]]
-        with pytest.raises(AirflowException) as ctx:
+        with pytest.raises(expected_exception) as ctx:
             op.poke(None)
         assert "self.success is present, but not callable -> [1]" == 
str(ctx.value)
 

Reply via email to