This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f5c2748c33 fix(providers/sql): respect soft_fail argument when
exception is raised (#34199)
f5c2748c33 is described below
commit f5c2748c3346bdebf445afd615657af8849345dd
Author: Wei Lee <[email protected]>
AuthorDate: Sat Sep 9 04:09:13 2023 +0800
fix(providers/sql): respect soft_fail argument when exception is raised
(#34199)
---
airflow/providers/common/sql/sensors/sql.py | 28 ++++++++--
tests/providers/common/sql/sensors/test_sql.py | 72 ++++++++++++++++++++------
2 files changed, 80 insertions(+), 20 deletions(-)
diff --git a/airflow/providers/common/sql/sensors/sql.py
b/airflow/providers/common/sql/sensors/sql.py
index 73505390fc..7eab94e558 100644
--- a/airflow/providers/common/sql/sensors/sql.py
+++ b/airflow/providers/common/sql/sensors/sql.py
@@ -18,7 +18,7 @@ from __future__ import annotations
from typing import Any, Sequence
-from airflow import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.sensors.base import BaseSensorOperator
@@ -96,19 +96,37 @@ class SqlSensor(BaseSensorOperator):
records = hook.get_records(self.sql, self.parameters)
if not records:
if self.fail_on_empty:
- raise AirflowException("No rows returned, raising as per
fail_on_empty flag")
+ # TODO: remove this if block when min_airflow_version is set
to higher than 2.7.1
+ message = "No rows returned, raising as per fail_on_empty flag"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
else:
return False
+
first_cell = records[0][0]
if self.failure is not None:
if callable(self.failure):
if self.failure(first_cell):
- raise AirflowException(f"Failure criteria met.
self.failure({first_cell}) returned True")
+ # TODO: remove this if block when min_airflow_version is
set to higher than 2.7.1
+ message = f"Failure criteria met.
self.failure({first_cell}) returned True"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
else:
- raise AirflowException(f"self.failure is present, but not
callable -> {self.failure}")
+ # TODO: remove this if block when min_airflow_version is set
to higher than 2.7.1
+ message = f"self.failure is present, but not callable ->
{self.failure}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
+
if self.success is not None:
if callable(self.success):
return self.success(first_cell)
else:
- raise AirflowException(f"self.success is present, but not
callable -> {self.success}")
+ # TODO: remove this if block when min_airflow_version is set
to higher than 2.7.1
+ message = f"self.success is present, but not callable ->
{self.success}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
return bool(first_cell)
diff --git a/tests/providers/common/sql/sensors/test_sql.py
b/tests/providers/common/sql/sensors/test_sql.py
index b14c977bd3..7491d03e1a 100644
--- a/tests/providers/common/sql/sensors/test_sql.py
+++ b/tests/providers/common/sql/sensors/test_sql.py
@@ -21,7 +21,7 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.dag import DAG
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.sensors.sql import SqlSensor
@@ -117,17 +117,26 @@ class TestSqlSensor:
mock_get_records.return_value = [["1"]]
assert op.poke(None)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
- def test_sql_sensor_postgres_poke_fail_on_empty(self, mock_hook):
+ def test_sql_sensor_postgres_poke_fail_on_empty(
+ self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+ ):
op = SqlSensor(
- task_id="sql_sensor_check", conn_id="postgres_default",
sql="SELECT 1", fail_on_empty=True
+ task_id="sql_sensor_check",
+ conn_id="postgres_default",
+ sql="SELECT 1",
+ fail_on_empty=True,
+ soft_fail=soft_fail,
)
mock_hook.get_connection.return_value.get_hook.return_value =
mock.MagicMock(spec=DbApiHook)
mock_get_records =
mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
op.poke(None)
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
@@ -148,10 +157,19 @@ class TestSqlSensor:
mock_get_records.return_value = [["1"]]
assert not op.poke(None)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
- def test_sql_sensor_postgres_poke_failure(self, mock_hook):
+ def test_sql_sensor_postgres_poke_failure(
+ self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+ ):
op = SqlSensor(
- task_id="sql_sensor_check", conn_id="postgres_default",
sql="SELECT 1", failure=lambda x: x in [1]
+ task_id="sql_sensor_check",
+ conn_id="postgres_default",
+ sql="SELECT 1",
+ failure=lambda x: x in [1],
+ soft_fail=soft_fail,
)
mock_hook.get_connection.return_value.get_hook.return_value =
mock.MagicMock(spec=DbApiHook)
@@ -161,17 +179,23 @@ class TestSqlSensor:
assert not op.poke(None)
mock_get_records.return_value = [[1]]
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
op.poke(None)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
- def test_sql_sensor_postgres_poke_failure_success(self, mock_hook):
+ def test_sql_sensor_postgres_poke_failure_success(
+ self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+ ):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=lambda x: x in [1],
success=lambda x: x in [2],
+ soft_fail=soft_fail,
)
mock_hook.get_connection.return_value.get_hook.return_value =
mock.MagicMock(spec=DbApiHook)
@@ -181,20 +205,26 @@ class TestSqlSensor:
assert not op.poke(None)
mock_get_records.return_value = [[1]]
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
op.poke(None)
mock_get_records.return_value = [[2]]
assert op.poke(None)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
- def test_sql_sensor_postgres_poke_failure_success_same(self, mock_hook):
+ def test_sql_sensor_postgres_poke_failure_success_same(
+ self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+ ):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=lambda x: x in [1],
success=lambda x: x in [1],
+ soft_fail=soft_fail,
)
mock_hook.get_connection.return_value.get_hook.return_value =
mock.MagicMock(spec=DbApiHook)
@@ -204,40 +234,52 @@ class TestSqlSensor:
assert not op.poke(None)
mock_get_records.return_value = [[1]]
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
op.poke(None)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
- def test_sql_sensor_postgres_poke_invalid_failure(self, mock_hook):
+ def test_sql_sensor_postgres_poke_invalid_failure(
+ self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+ ):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
failure=[1],
+ soft_fail=soft_fail,
)
mock_hook.get_connection.return_value.get_hook.return_value =
mock.MagicMock(spec=DbApiHook)
mock_get_records =
mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = [[1]]
- with pytest.raises(AirflowException) as ctx:
+ with pytest.raises(expected_exception) as ctx:
op.poke(None)
assert "self.failure is present, but not callable -> [1]" ==
str(ctx.value)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook")
- def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook):
+ def test_sql_sensor_postgres_poke_invalid_success(
+ self, mock_hook, soft_fail: bool, expected_exception: AirflowException
+ ):
op = SqlSensor(
task_id="sql_sensor_check",
conn_id="postgres_default",
sql="SELECT 1",
success=[1],
+ soft_fail=soft_fail,
)
mock_hook.get_connection.return_value.get_hook.return_value =
mock.MagicMock(spec=DbApiHook)
mock_get_records =
mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = [[1]]
- with pytest.raises(AirflowException) as ctx:
+ with pytest.raises(expected_exception) as ctx:
op.poke(None)
assert "self.success is present, but not callable -> [1]" ==
str(ctx.value)