This is an automated email from the ASF dual-hosted git repository.

gopidesu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a08b399b6d4 Add return row count to 
SpannerQueryDatabaseInstanceOperator (#55127)
a08b399b6d4 is described below

commit a08b399b6d48082c54315428c4943debd5b99b70
Author: VladaZakharova <[email protected]>
AuthorDate: Thu Sep 4 10:19:44 2025 +0000

    Add return row count to SpannerQueryDatabaseInstanceOperator (#55127)
---
 .../providers/google/cloud/hooks/spanner.py        | 32 +++++++--
 .../providers/google/cloud/operators/spanner.py    |  5 +-
 .../tests/unit/google/cloud/hooks/test_spanner.py  | 76 ++++++++++++++++++++--
 .../unit/google/cloud/operators/test_spanner.py    | 11 ++--
 4 files changed, 104 insertions(+), 20 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py 
b/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py
index 93ace3ff196..d364dd17673 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py
@@ -19,6 +19,7 @@
 
 from __future__ import annotations
 
+from collections import OrderedDict
 from collections.abc import Callable, Sequence
 from typing import TYPE_CHECKING, NamedTuple
 
@@ -388,7 +389,7 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
         database_id: str,
         queries: list[str],
         project_id: str,
-    ) -> None:
+    ) -> list[int]:
         """
         Execute an arbitrary DML query (INSERT, UPDATE, DELETE).
 
@@ -398,12 +399,31 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
         :param project_id: Optional, the ID of the Google Cloud project that 
owns the Cloud Spanner
             database. If set to None or missing, the default project_id from 
the Google Cloud connection
             is used.
+        :return: list of numbers of affected rows by DML query
         """
-        
self._get_client(project_id=project_id).instance(instance_id=instance_id).database(
-            database_id=database_id
-        ).run_in_transaction(lambda transaction: 
self._execute_sql_in_transaction(transaction, queries))
+        db = (
+            self._get_client(project_id=project_id)
+            .instance(instance_id=instance_id)
+            .database(database_id=database_id)
+        )
+
+        def _tx_runner(tx: Transaction) -> dict[str, int]:
+            return self._execute_sql_in_transaction(tx, queries)
+
+        result = db.run_in_transaction(_tx_runner)
+
+        result_rows_count_per_query = []
+        for i, (sql, rc) in enumerate(result.items(), start=1):
+            if not sql.startswith("SELECT"):
+                preview = sql if len(sql) <= 300 else sql[:300] + "…"
+                self.log.info("[DML %d/%d] affected rows=%d | %s", i, 
len(result), rc, preview)
+                result_rows_count_per_query.append(rc)
+        return result_rows_count_per_query
 
     @staticmethod
-    def _execute_sql_in_transaction(transaction: Transaction, queries: 
list[str]):
+    def _execute_sql_in_transaction(transaction: Transaction, queries: 
list[str]) -> dict[str, int]:
+        counts: OrderedDict[str, int] = OrderedDict()
         for sql in queries:
-            transaction.execute_update(sql)
+            rc = transaction.execute_update(sql)
+            counts[sql] = rc
+        return counts
diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py 
b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py
index 51c4f61f208..732b2e19b7c 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py
@@ -280,8 +280,8 @@ class 
SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
             self.instance_id,
             self.database_id,
         )
-        self.log.info(queries)
-        hook.execute_dml(
+        self.log.info("Executing queries: %s", queries)
+        result_rows_count_per_query = hook.execute_dml(
             project_id=self.project_id,
             instance_id=self.instance_id,
             database_id=self.database_id,
@@ -293,6 +293,7 @@ class 
SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator):
             database_id=self.database_id,
             project_id=self.project_id or hook.project_id,
         )
+        return result_rows_count_per_query
 
     @staticmethod
     def sanitize_queries(queries: list[str]) -> None:
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_spanner.py 
b/providers/google/tests/unit/google/cloud/hooks/test_spanner.py
index 527a0cb0cce..ad1f1906795 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_spanner.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_spanner.py
@@ -17,9 +17,11 @@
 # under the License.
 from __future__ import annotations
 
+from collections import OrderedDict
 from unittest import mock
 from unittest.mock import MagicMock, PropertyMock
 
+import pytest
 import sqlalchemy
 
 from airflow.providers.google.cloud.hooks.spanner import SpannerHook
@@ -405,14 +407,14 @@ class TestGcpSpannerHookDefaultProjectId:
         res = self.spanner_hook_default_project_id.execute_dml(
             instance_id=SPANNER_INSTANCE,
             database_id=SPANNER_DATABASE,
-            queries="",
+            queries=[""],
             project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
         )
         get_client.assert_called_once_with(project_id="example-project")
         instance_method.assert_called_once_with(instance_id="instance")
         database_method.assert_called_once_with(database_id="database-name")
         run_in_transaction_method.assert_called_once_with(mock.ANY)
-        assert res is None
+        assert res == []
 
     
@mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client")
     def test_execute_dml_overridden_project_id(self, get_client):
@@ -422,13 +424,75 @@ class TestGcpSpannerHookDefaultProjectId:
         database_method = instance_method.return_value.database
         run_in_transaction_method = 
database_method.return_value.run_in_transaction
         res = self.spanner_hook_default_project_id.execute_dml(
-            project_id="new-project", instance_id=SPANNER_INSTANCE, 
database_id=SPANNER_DATABASE, queries=""
+            project_id="new-project", instance_id=SPANNER_INSTANCE, 
database_id=SPANNER_DATABASE, queries=[""]
         )
         get_client.assert_called_once_with(project_id="new-project")
         instance_method.assert_called_once_with(instance_id="instance")
         database_method.assert_called_once_with(database_id="database-name")
         run_in_transaction_method.assert_called_once_with(mock.ANY)
-        assert res is None
+        assert res == []
+
+    
@mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client")
+    def test_execute_dml_oqueries_row_count(self, get_client):
+        pass
+
+    @pytest.mark.parametrize(
+        "returned_items, expected_counts",
+        [
+            pytest.param(
+                [
+                    ("DELETE FROM T WHERE archived = TRUE", 5),
+                    ("SELECT * FROM T", 42),
+                    ("UPDATE U SET flag = FALSE WHERE x = 1", 3),
+                ],
+                [5, 3],
+            ),
+            pytest.param(
+                [
+                    ("DELETE FROM Logs WHERE created_at < '2024-01-01'", 7),
+                ],
+                [7],
+            ),
+            pytest.param(
+                [
+                    (
+                        "UPDATE Accounts SET active=false WHERE last_login < 
DATE_SUB(CURRENT_DATE(), INTERVAL 365 DAY)",
+                        11,
+                    ),
+                    ("DELETE FROM Sessions WHERE expires_at < 
CURRENT_TIMESTAMP()", 23),
+                ],
+                [11, 23],
+            ),
+            pytest.param(
+                [
+                    ("SELECT COUNT(*) FROM Users", 50000),
+                    ("SELECT * FROM BigTable", 123456),
+                ],
+                [],
+            ),
+            pytest.param(
+                [],
+                [],
+            ),
+        ],
+    )
+    
@mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client")
+    def test_execute_dml_parametrized(self, get_client, returned_items, 
expected_counts):
+        instance_method = get_client.return_value.instance
+        database_method = instance_method.return_value.database
+        run_in_tx = database_method.return_value.run_in_transaction
+
+        returned_mapping = OrderedDict(returned_items)
+        run_in_tx.return_value = returned_mapping
+
+        res = self.spanner_hook_default_project_id.execute_dml(
+            instance_id=SPANNER_INSTANCE,
+            database_id=SPANNER_DATABASE,
+            queries=[sql for sql, _ in returned_items],
+            project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
+        )
+
+        assert res == expected_counts
 
     def test_get_uri(self):
         self.spanner_hook_default_project_id._get_conn_params = 
MagicMock(return_value=SPANNER_CONN_PARAMS)
@@ -682,13 +746,13 @@ class TestGcpSpannerHookNoDefaultProjectID:
             project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
             instance_id=SPANNER_INSTANCE,
             database_id=SPANNER_DATABASE,
-            queries="",
+            queries=[""],
         )
         get_client.assert_called_once_with(project_id="example-project")
         instance_method.assert_called_once_with(instance_id="instance")
         database_method.assert_called_once_with(database_id="database-name")
         run_in_transaction_method.assert_called_once_with(mock.ANY)
-        assert res is None
+        assert res == []
 
     def test_get_uri(self):
         self.spanner_hook_no_default_project_id._get_conn_params = 
MagicMock(return_value=SPANNER_CONN_PARAMS)
diff --git a/providers/google/tests/unit/google/cloud/operators/test_spanner.py 
b/providers/google/tests/unit/google/cloud/operators/test_spanner.py
index e9d800665bf..1784a0499aa 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_spanner.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_spanner.py
@@ -250,7 +250,7 @@ class TestCloudSpanner:
 
     @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
     def test_instance_query(self, mock_hook):
-        mock_hook.return_value.execute_sql.return_value = None
+        mock_hook.return_value.execute_dml.return_value = [3]
         op = SpannerQueryDatabaseInstanceOperator(
             project_id=PROJECT_ID,
             instance_id=INSTANCE_ID,
@@ -258,8 +258,7 @@ class TestCloudSpanner:
             query=INSERT_QUERY,
             task_id="id",
         )
-        context = mock.MagicMock()
-        result = op.execute(context=context)
+        result = op.execute(context=mock.MagicMock())
         mock_hook.assert_called_once_with(
             gcp_conn_id="google_cloud_default",
             impersonation_chain=None,
@@ -267,11 +266,11 @@ class TestCloudSpanner:
         mock_hook.return_value.execute_dml.assert_called_once_with(
             project_id=PROJECT_ID, instance_id=INSTANCE_ID, database_id=DB_ID, 
queries=[INSERT_QUERY]
         )
-        assert result is None
+        assert result == [3]
 
     @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
     def test_instance_query_missing_project_id(self, mock_hook):
-        mock_hook.return_value.execute_sql.return_value = None
+        mock_hook.return_value.execute_dml.return_value = [3]
         op = SpannerQueryDatabaseInstanceOperator(
             instance_id=INSTANCE_ID, database_id=DB_ID, query=INSERT_QUERY, 
task_id="id"
         )
@@ -284,7 +283,7 @@ class TestCloudSpanner:
         mock_hook.return_value.execute_dml.assert_called_once_with(
             project_id=None, instance_id=INSTANCE_ID, database_id=DB_ID, 
queries=[INSERT_QUERY]
         )
-        assert result is None
+        assert result == [3]
 
     @pytest.mark.parametrize(
         "project_id, instance_id, database_id, query, exp_msg",

Reply via email to