This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 978adb309a Install sqlalchemy-spanner package into Google provider 
(#31925)
978adb309a is described below

commit 978adb309aee755df02aadab72fdafb61bec5c80
Author: Maksim <[email protected]>
AuthorDate: Fri Jul 21 11:57:04 2023 +0200

    Install sqlalchemy-spanner package into Google provider (#31925)
---
 airflow/providers/google/cloud/hooks/spanner.py    | 49 ++++++++++++++++++++--
 airflow/providers/google/provider.yaml             |  1 +
 generated/provider_dependencies.json               |  3 +-
 tests/providers/google/cloud/hooks/test_spanner.py | 35 +++++++++++++++-
 4 files changed, 83 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/spanner.py 
b/airflow/providers/google/cloud/hooks/spanner.py
index 00fe9a8804..a3e652dcbd 100644
--- a/airflow/providers/google/cloud/hooks/spanner.py
+++ b/airflow/providers/google/cloud/hooks/spanner.py
@@ -18,7 +18,7 @@
 """This module contains a Google Cloud Spanner Hook."""
 from __future__ import annotations
 
-from typing import Callable, Sequence
+from typing import Callable, NamedTuple, Sequence
 
 from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
 from google.cloud.spanner_v1.client import Client
@@ -26,13 +26,23 @@ from google.cloud.spanner_v1.database import Database
 from google.cloud.spanner_v1.instance import Instance
 from google.cloud.spanner_v1.transaction import Transaction
 from google.longrunning.operations_grpc_pb2 import Operation
+from sqlalchemy import create_engine
 
 from airflow.exceptions import AirflowException
+from airflow.providers.common.sql.hooks.sql import DbApiHook
 from airflow.providers.google.common.consts import CLIENT_INFO
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, 
get_field
 
 
-class SpannerHook(GoogleBaseHook):
+class SpannerConnectionParams(NamedTuple):
+    """Information about Google Spanner connection parameters."""
+
+    project_id: str | None
+    instance_id: str | None
+    database_id: str | None
+
+
+class SpannerHook(GoogleBaseHook, DbApiHook):
     """
     Hook for Google Cloud Spanner APIs.
 
@@ -40,6 +50,11 @@ class SpannerHook(GoogleBaseHook):
     keyword arguments rather than positional.
     """
 
+    conn_name_attr = "gcp_conn_id"
+    default_conn_name = "google_cloud_spanner_default"
+    conn_type = "gcpspanner"
+    hook_name = "Google Cloud Spanner"
+
     def __init__(
         self,
         gcp_conn_id: str = "google_cloud_default",
@@ -70,6 +85,34 @@ class SpannerHook(GoogleBaseHook):
             )
         return self._client
 
+    def _get_conn_params(self) -> SpannerConnectionParams:
+        """Extract spanner database connection parameters."""
+        extras = self.get_connection(self.gcp_conn_id).extra_dejson
+        project_id = get_field(extras, "project_id") or self.project_id
+        instance_id = get_field(extras, "instance_id")
+        database_id = get_field(extras, "database_id")
+        return SpannerConnectionParams(project_id, instance_id, database_id)
+
+    def get_uri(self) -> str:
+        """Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
+        project_id, instance_id, database_id = self._get_conn_params()
+        if not all([instance_id, database_id]):
+            raise AirflowException("The instance_id or database_id were not 
specified")
+        return 
f"spanner+spanner:///projects/{project_id}/instances/{instance_id}/databases/{database_id}"
+
+    def get_sqlalchemy_engine(self, engine_kwargs=None):
+        """
+        Get an sqlalchemy_engine object.
+
+        :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
+        :return: the created engine.
+        """
+        if engine_kwargs is None:
+            engine_kwargs = {}
+        project_id, _, _ = self._get_conn_params()
+        spanner_client = self._get_client(project_id=project_id)
+        return create_engine(self.get_uri(), connect_args={"client": 
spanner_client}, **engine_kwargs)
+
     @GoogleBaseHook.fallback_to_default_project_id
     def get_instance(
         self,
diff --git a/airflow/providers/google/provider.yaml 
b/airflow/providers/google/provider.yaml
index 5a7b272f42..3a8471de26 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -126,6 +126,7 @@ dependencies:
   - proto-plus>=1.19.6
   - PyOpenSSL
   - sqlalchemy-bigquery>=1.2.1
+  - sqlalchemy-spanner>=1.6.2
 
 integrations:
   - integration-name: Google Analytics360
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 09c2211b3a..abf1088cd4 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -440,7 +440,8 @@
       "pandas-gbq",
       "pandas>=0.17.1",
       "proto-plus>=1.19.6",
-      "sqlalchemy-bigquery>=1.2.1"
+      "sqlalchemy-bigquery>=1.2.1",
+      "sqlalchemy-spanner>=1.6.2"
     ],
     "cross-providers-deps": [
       "amazon",
diff --git a/tests/providers/google/cloud/hooks/test_spanner.py 
b/tests/providers/google/cloud/hooks/test_spanner.py
index abb346f001..1633d55816 100644
--- a/tests/providers/google/cloud/hooks/test_spanner.py
+++ b/tests/providers/google/cloud/hooks/test_spanner.py
@@ -18,9 +18,10 @@
 from __future__ import annotations
 
 from unittest import mock
-from unittest.mock import PropertyMock
+from unittest.mock import MagicMock, PropertyMock
 
 import pytest
+import sqlalchemy
 
 from airflow.providers.google.cloud.hooks.spanner import SpannerHook
 from airflow.providers.google.common.consts import CLIENT_INFO
@@ -33,6 +34,8 @@ from tests.providers.google.cloud.utils.base_gcp_mock import (
 SPANNER_INSTANCE = "instance"
 SPANNER_CONFIGURATION = "configuration"
 SPANNER_DATABASE = "database-name"
+SPANNER_PROJECT_ID = "test_project_id"
+SPANNER_CONN_PARAMS = (SPANNER_PROJECT_ID, SPANNER_INSTANCE, SPANNER_DATABASE)
 
 
 class TestGcpSpannerHookDefaultProjectId:
@@ -431,6 +434,21 @@ class TestGcpSpannerHookDefaultProjectId:
         run_in_transaction_method.assert_called_once_with(mock.ANY)
         assert res is None
 
+    def test_get_uri(self):
+        self.spanner_hook_default_project_id._get_conn_params = 
MagicMock(return_value=SPANNER_CONN_PARAMS)
+        uri = self.spanner_hook_default_project_id.get_uri()
+        assert (
+            uri
+            == 
f"spanner+spanner:///projects/{SPANNER_PROJECT_ID}/instances/{SPANNER_INSTANCE}/databases/{SPANNER_DATABASE}"
+        )
+
+    
@mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client")
+    def test_get_sqlalchemy_engine(self, get_client):
+        self.spanner_hook_default_project_id._get_conn_params = 
MagicMock(return_value=SPANNER_CONN_PARAMS)
+        engine = self.spanner_hook_default_project_id.get_sqlalchemy_engine()
+        assert isinstance(engine, sqlalchemy.engine.Engine)
+        assert engine.name == "spanner+spanner"
+
 
 class TestGcpSpannerHookNoDefaultProjectID:
     def setup_method(self):
@@ -675,3 +693,18 @@ class TestGcpSpannerHookNoDefaultProjectID:
         database_method.assert_called_once_with(database_id="database-name")
         run_in_transaction_method.assert_called_once_with(mock.ANY)
         assert res is None
+
+    def test_get_uri(self):
+        self.spanner_hook_no_default_project_id._get_conn_params = 
MagicMock(return_value=SPANNER_CONN_PARAMS)
+        uri = self.spanner_hook_no_default_project_id.get_uri()
+        assert (
+            uri
+            == 
f"spanner+spanner:///projects/{SPANNER_PROJECT_ID}/instances/{SPANNER_INSTANCE}/databases/{SPANNER_DATABASE}"
+        )
+
+    
@mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client")
+    def test_get_sqlalchemy_engine(self, get_client):
+        self.spanner_hook_no_default_project_id._get_conn_params = 
MagicMock(return_value=SPANNER_CONN_PARAMS)
+        engine = 
self.spanner_hook_no_default_project_id.get_sqlalchemy_engine()
+        assert isinstance(engine, sqlalchemy.engine.Engine)
+        assert engine.name == "spanner+spanner"

Reply via email to