This is an automated email from the ASF dual-hosted git repository.
gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a08b399b6d4 Add return row count to
SpannerQueryDatabaseInstanceOperator (#55127)
a08b399b6d4 is described below
commit a08b399b6d48082c54315428c4943debd5b99b70
Author: VladaZakharova <[email protected]>
AuthorDate: Thu Sep 4 10:19:44 2025 +0000
Add return row count to SpannerQueryDatabaseInstanceOperator (#55127)
---
.../providers/google/cloud/hooks/spanner.py | 32 +++++++--
.../providers/google/cloud/operators/spanner.py | 5 +-
.../tests/unit/google/cloud/hooks/test_spanner.py | 76 ++++++++++++++++++++--
.../unit/google/cloud/operators/test_spanner.py | 11 ++--
4 files changed, 104 insertions(+), 20 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py
b/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py
index 93ace3ff196..d364dd17673 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+from collections import OrderedDict
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, NamedTuple
@@ -388,7 +389,7 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
database_id: str,
queries: list[str],
project_id: str,
- ) -> None:
+ ) -> list[int]:
"""
Execute an arbitrary DML query (INSERT, UPDATE, DELETE).
@@ -398,12 +399,31 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
:param project_id: Optional, the ID of the Google Cloud project that
owns the Cloud Spanner
database. If set to None or missing, the default project_id from
the Google Cloud connection
is used.
+ :return: list of numbers of affected rows by DML query
"""
-
self._get_client(project_id=project_id).instance(instance_id=instance_id).database(
- database_id=database_id
- ).run_in_transaction(lambda transaction:
self._execute_sql_in_transaction(transaction, queries))
+ db = (
+ self._get_client(project_id=project_id)
+ .instance(instance_id=instance_id)
+ .database(database_id=database_id)
+ )
+
+ def _tx_runner(tx: Transaction) -> dict[str, int]:
+ return self._execute_sql_in_transaction(tx, queries)
+
+ result = db.run_in_transaction(_tx_runner)
+
+ result_rows_count_per_query = []
+ for i, (sql, rc) in enumerate(result.items(), start=1):
+ if not sql.startswith("SELECT"):
+ preview = sql if len(sql) <= 300 else sql[:300] + "…"
+ self.log.info("[DML %d/%d] affected rows=%d | %s", i,
len(result), rc, preview)
+ result_rows_count_per_query.append(rc)
+ return result_rows_count_per_query
@staticmethod
- def _execute_sql_in_transaction(transaction: Transaction, queries:
list[str]):
+ def _execute_sql_in_transaction(transaction: Transaction, queries:
list[str]) -> dict[str, int]:
+ counts: OrderedDict[str, int] = OrderedDict()
for sql in queries:
- transaction.execute_update(sql)
+ rc = transaction.execute_update(sql)
+ counts[sql] = rc
+ return counts
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py
b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py
index 51c4f61f208..732b2e19b7c 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py
@@ -280,8 +280,8 @@ class
SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
self.instance_id,
self.database_id,
)
- self.log.info(queries)
- hook.execute_dml(
+ self.log.info("Executing queries: %s", queries)
+ result_rows_count_per_query = hook.execute_dml(
project_id=self.project_id,
instance_id=self.instance_id,
database_id=self.database_id,
@@ -293,6 +293,7 @@ class
SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
database_id=self.database_id,
project_id=self.project_id or hook.project_id,
)
+ return result_rows_count_per_query
@staticmethod
def sanitize_queries(queries: list[str]) -> None:
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_spanner.py
b/providers/google/tests/unit/google/cloud/hooks/test_spanner.py
index 527a0cb0cce..ad1f1906795 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_spanner.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_spanner.py
@@ -17,9 +17,11 @@
# under the License.
from __future__ import annotations
+from collections import OrderedDict
from unittest import mock
from unittest.mock import MagicMock, PropertyMock
+import pytest
import sqlalchemy
from airflow.providers.google.cloud.hooks.spanner import SpannerHook
@@ -405,14 +407,14 @@ class TestGcpSpannerHookDefaultProjectId:
res = self.spanner_hook_default_project_id.execute_dml(
instance_id=SPANNER_INSTANCE,
database_id=SPANNER_DATABASE,
- queries="",
+ queries=[""],
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
get_client.assert_called_once_with(project_id="example-project")
instance_method.assert_called_once_with(instance_id="instance")
database_method.assert_called_once_with(database_id="database-name")
run_in_transaction_method.assert_called_once_with(mock.ANY)
- assert res is None
+ assert res == []
@mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client")
def test_execute_dml_overridden_project_id(self, get_client):
@@ -422,13 +424,75 @@ class TestGcpSpannerHookDefaultProjectId:
database_method = instance_method.return_value.database
run_in_transaction_method =
database_method.return_value.run_in_transaction
res = self.spanner_hook_default_project_id.execute_dml(
- project_id="new-project", instance_id=SPANNER_INSTANCE,
database_id=SPANNER_DATABASE, queries=""
+ project_id="new-project", instance_id=SPANNER_INSTANCE,
database_id=SPANNER_DATABASE, queries=[""]
)
get_client.assert_called_once_with(project_id="new-project")
instance_method.assert_called_once_with(instance_id="instance")
database_method.assert_called_once_with(database_id="database-name")
run_in_transaction_method.assert_called_once_with(mock.ANY)
- assert res is None
+ assert res == []
+
+
@mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client")
+ def test_execute_dml_oqueries_row_count(self, get_client):
+ pass
+
+ @pytest.mark.parametrize(
+ "returned_items, expected_counts",
+ [
+ pytest.param(
+ [
+ ("DELETE FROM T WHERE archived = TRUE", 5),
+ ("SELECT * FROM T", 42),
+ ("UPDATE U SET flag = FALSE WHERE x = 1", 3),
+ ],
+ [5, 3],
+ ),
+ pytest.param(
+ [
+ ("DELETE FROM Logs WHERE created_at < '2024-01-01'", 7),
+ ],
+ [7],
+ ),
+ pytest.param(
+ [
+ (
+ "UPDATE Accounts SET active=false WHERE last_login <
DATE_SUB(CURRENT_DATE(), INTERVAL 365 DAY)",
+ 11,
+ ),
+ ("DELETE FROM Sessions WHERE expires_at <
CURRENT_TIMESTAMP()", 23),
+ ],
+ [11, 23],
+ ),
+ pytest.param(
+ [
+ ("SELECT COUNT(*) FROM Users", 50000),
+ ("SELECT * FROM BigTable", 123456),
+ ],
+ [],
+ ),
+ pytest.param(
+ [],
+ [],
+ ),
+ ],
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client")
+ def test_execute_dml_parametrized(self, get_client, returned_items,
expected_counts):
+ instance_method = get_client.return_value.instance
+ database_method = instance_method.return_value.database
+ run_in_tx = database_method.return_value.run_in_transaction
+
+ returned_mapping = OrderedDict(returned_items)
+ run_in_tx.return_value = returned_mapping
+
+ res = self.spanner_hook_default_project_id.execute_dml(
+ instance_id=SPANNER_INSTANCE,
+ database_id=SPANNER_DATABASE,
+ queries=[sql for sql, _ in returned_items],
+ project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
+ )
+
+ assert res == expected_counts
def test_get_uri(self):
self.spanner_hook_default_project_id._get_conn_params =
MagicMock(return_value=SPANNER_CONN_PARAMS)
@@ -682,13 +746,13 @@ class TestGcpSpannerHookNoDefaultProjectID:
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
instance_id=SPANNER_INSTANCE,
database_id=SPANNER_DATABASE,
- queries="",
+ queries=[""],
)
get_client.assert_called_once_with(project_id="example-project")
instance_method.assert_called_once_with(instance_id="instance")
database_method.assert_called_once_with(database_id="database-name")
run_in_transaction_method.assert_called_once_with(mock.ANY)
- assert res is None
+ assert res == []
def test_get_uri(self):
self.spanner_hook_no_default_project_id._get_conn_params =
MagicMock(return_value=SPANNER_CONN_PARAMS)
diff --git a/providers/google/tests/unit/google/cloud/operators/test_spanner.py
b/providers/google/tests/unit/google/cloud/operators/test_spanner.py
index e9d800665bf..1784a0499aa 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_spanner.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_spanner.py
@@ -250,7 +250,7 @@ class TestCloudSpanner:
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_query(self, mock_hook):
- mock_hook.return_value.execute_sql.return_value = None
+ mock_hook.return_value.execute_dml.return_value = [3]
op = SpannerQueryDatabaseInstanceOperator(
project_id=PROJECT_ID,
instance_id=INSTANCE_ID,
@@ -258,8 +258,7 @@ class TestCloudSpanner:
query=INSERT_QUERY,
task_id="id",
)
- context = mock.MagicMock()
- result = op.execute(context=context)
+ result = op.execute(context=mock.MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id="google_cloud_default",
impersonation_chain=None,
@@ -267,11 +266,11 @@ class TestCloudSpanner:
mock_hook.return_value.execute_dml.assert_called_once_with(
project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID,
queries=[INSERT_QUERY]
)
- assert result is None
+ assert result == [3]
@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_instance_query_missing_project_id(self, mock_hook):
- mock_hook.return_value.execute_sql.return_value = None
+ mock_hook.return_value.execute_dml.return_value = [3]
op = SpannerQueryDatabaseInstanceOperator(
instance_id=INSTANCE_ID, database_id=DB_ID, query=INSERT_QUERY,
task_id="id"
)
@@ -284,7 +283,7 @@ class TestCloudSpanner:
mock_hook.return_value.execute_dml.assert_called_once_with(
project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID,
queries=[INSERT_QUERY]
)
- assert result is None
+ assert result == [3]
@pytest.mark.parametrize(
"project_id, instance_id, database_id, query, exp_msg",