This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 27dc7e80df Optimize `SnowflakeSqlApiOperator` execution in deferrable
mode (#36850)
27dc7e80df is described below
commit 27dc7e80df3ecf5aa61718334f32a1d128b0125c
Author: vatsrahul1001 <[email protected]>
AuthorDate: Thu Jan 18 19:33:04 2024 +0530
Optimize `SnowflakeSqlApiOperator` execution in deferrable mode (#36850)
---
airflow/providers/snowflake/operators/snowflake.py | 15 +++++
.../snowflake/operators/test_snowflake.py | 67 +++++++++++++++++++++-
2 files changed, 81 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/snowflake/operators/snowflake.py
b/airflow/providers/snowflake/operators/snowflake.py
index 9e0bf3d1cf..f7890b87e1 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -514,6 +514,21 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
if self.do_xcom_push:
context["ti"].xcom_push(key="query_ids", value=self.query_ids)
+ succeeded_query_ids = []
+ for query_id in self.query_ids:
+ self.log.info("Retrieving status for query id %s", query_id)
+ statement_status = self._hook.get_sql_api_query_status(query_id)
+ if statement_status.get("status") == "running":
+ break
+ elif statement_status.get("status") == "success":
+ succeeded_query_ids.append(query_id)
+ else:
+ raise AirflowException(f"{statement_status.get('status')}:
{statement_status.get('message')}")
+
+ if len(self.query_ids) == len(succeeded_query_ids):
+ self.log.info("%s completed successfully.", self.task_id)
+ return
+
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
diff --git a/tests/providers/snowflake/operators/test_snowflake.py
b/tests/providers/snowflake/operators/test_snowflake.py
index 07df5fb147..7f429277b9 100644
--- a/tests/providers/snowflake/operators/test_snowflake.py
+++ b/tests/providers/snowflake/operators/test_snowflake.py
@@ -253,7 +253,9 @@ class TestSnowflakeSqlApiOperator:
@pytest.mark.parametrize("mock_sql, statement_count",
[(SQL_MULTIPLE_STMTS, 4), (SINGLE_STMT, 1)])
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.execute_query")
- def test_snowflake_sql_api_execute_operator_async(self, mock_db_hook,
mock_sql, statement_count):
+ def test_snowflake_sql_api_execute_operator_async(
+ self, mock_execute_query, mock_sql, statement_count,
mock_get_sql_api_query_status
+ ):
"""
Asserts that a task is deferred and an SnowflakeSqlApiTrigger will be
fired
when the SnowflakeSqlApiOperator is executed.
@@ -266,6 +268,9 @@ class TestSnowflakeSqlApiOperator:
deferrable=True,
)
+ mock_execute_query.return_value = ["uuid1"]
+ mock_get_sql_api_query_status.side_effect = [{"status": "running"}]
+
with pytest.raises(TaskDeferred) as exc:
operator.execute(create_context(operator))
@@ -311,3 +316,63 @@ class TestSnowflakeSqlApiOperator:
with mock.patch.object(operator.log, "info") as mock_log_info:
operator.execute_complete(context=None, event=mock_event)
mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)
+
+
@mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
+ def test_snowflake_sql_api_execute_operator_failed_before_defer(
+ self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
+ ):
+ """Asserts that a task is not deferred when its failed"""
+
+ operator = SnowflakeSqlApiOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id="snowflake_default",
+ sql=SQL_MULTIPLE_STMTS,
+ statement_count=4,
+ do_xcom_push=False,
+ deferrable=True,
+ )
+ mock_execute_query.return_value = ["uuid1"]
+ mock_get_sql_api_query_status.side_effect = [{"status": "error"}]
+ with pytest.raises(AirflowException):
+ operator.execute(create_context(operator))
+ assert not mock_defer.called
+
+
@mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
+ def test_snowflake_sql_api_execute_operator_succeeded_before_defer(
+ self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
+ ):
+ """Asserts that a task is not deferred when its succeeded"""
+
+ operator = SnowflakeSqlApiOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id="snowflake_default",
+ sql=SQL_MULTIPLE_STMTS,
+ statement_count=4,
+ do_xcom_push=False,
+ deferrable=True,
+ )
+ mock_execute_query.return_value = ["uuid1"]
+ mock_get_sql_api_query_status.side_effect = [{"status": "success"}]
+ operator.execute(create_context(operator))
+
+ assert not mock_defer.called
+
+
@mock.patch("airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator.defer")
+ def test_snowflake_sql_api_execute_operator_running_before_defer(
+ self, mock_defer, mock_execute_query, mock_get_sql_api_query_status
+ ):
+ """Asserts that a task is deferred when its running"""
+
+ operator = SnowflakeSqlApiOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id="snowflake_default",
+ sql=SQL_MULTIPLE_STMTS,
+ statement_count=4,
+ do_xcom_push=False,
+ deferrable=True,
+ )
+ mock_execute_query.return_value = ["uuid1"]
+ mock_get_sql_api_query_status.side_effect = [{"status": "running"}]
+ operator.execute(create_context(operator))
+
+ assert mock_defer.called