This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 978adb309a Install sqlalchemy-spanner package into Google provider
(#31925)
978adb309a is described below
commit 978adb309aee755df02aadab72fdafb61bec5c80
Author: Maksim <[email protected]>
AuthorDate: Fri Jul 21 11:57:04 2023 +0200
Install sqlalchemy-spanner package into Google provider (#31925)
---
airflow/providers/google/cloud/hooks/spanner.py | 49 ++++++++++++++++++++--
airflow/providers/google/provider.yaml | 1 +
generated/provider_dependencies.json | 3 +-
tests/providers/google/cloud/hooks/test_spanner.py | 35 +++++++++++++++-
4 files changed, 83 insertions(+), 5 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/spanner.py
b/airflow/providers/google/cloud/hooks/spanner.py
index 00fe9a8804..a3e652dcbd 100644
--- a/airflow/providers/google/cloud/hooks/spanner.py
+++ b/airflow/providers/google/cloud/hooks/spanner.py
@@ -18,7 +18,7 @@
"""This module contains a Google Cloud Spanner Hook."""
from __future__ import annotations
-from typing import Callable, Sequence
+from typing import Callable, NamedTuple, Sequence
from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
from google.cloud.spanner_v1.client import Client
@@ -26,13 +26,23 @@ from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
from google.cloud.spanner_v1.transaction import Transaction
from google.longrunning.operations_grpc_pb2 import Operation
+from sqlalchemy import create_engine
from airflow.exceptions import AirflowException
+from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.common.consts import CLIENT_INFO
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import GoogleBaseHook,
get_field
-class SpannerHook(GoogleBaseHook):
+class SpannerConnectionParams(NamedTuple):
+ """Information about Google Spanner connection parameters."""
+
+ project_id: str | None
+ instance_id: str | None
+ database_id: str | None
+
+
+class SpannerHook(GoogleBaseHook, DbApiHook):
"""
Hook for Google Cloud Spanner APIs.
@@ -40,6 +50,11 @@ class SpannerHook(GoogleBaseHook):
keyword arguments rather than positional.
"""
+ conn_name_attr = "gcp_conn_id"
+ default_conn_name = "google_cloud_spanner_default"
+ conn_type = "gcpspanner"
+ hook_name = "Google Cloud Spanner"
+
def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
@@ -70,6 +85,34 @@ class SpannerHook(GoogleBaseHook):
)
return self._client
+ def _get_conn_params(self) -> SpannerConnectionParams:
+ """Extract spanner database connection parameters."""
+ extras = self.get_connection(self.gcp_conn_id).extra_dejson
+ project_id = get_field(extras, "project_id") or self.project_id
+ instance_id = get_field(extras, "instance_id")
+ database_id = get_field(extras, "database_id")
+ return SpannerConnectionParams(project_id, instance_id, database_id)
+
+ def get_uri(self) -> str:
+ """Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
+ project_id, instance_id, database_id = self._get_conn_params()
+ if not all([instance_id, database_id]):
+ raise AirflowException("The instance_id or database_id were not
specified")
+ return
f"spanner+spanner:///projects/{project_id}/instances/{instance_id}/databases/{database_id}"
+
+ def get_sqlalchemy_engine(self, engine_kwargs=None):
+ """
+ Get an sqlalchemy_engine object.
+
+ :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
+ :return: the created engine.
+ """
+ if engine_kwargs is None:
+ engine_kwargs = {}
+ project_id, _, _ = self._get_conn_params()
+ spanner_client = self._get_client(project_id=project_id)
+ return create_engine(self.get_uri(), connect_args={"client":
spanner_client}, **engine_kwargs)
+
@GoogleBaseHook.fallback_to_default_project_id
def get_instance(
self,
diff --git a/airflow/providers/google/provider.yaml
b/airflow/providers/google/provider.yaml
index 5a7b272f42..3a8471de26 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -126,6 +126,7 @@ dependencies:
- proto-plus>=1.19.6
- PyOpenSSL
- sqlalchemy-bigquery>=1.2.1
+ - sqlalchemy-spanner>=1.6.2
integrations:
- integration-name: Google Analytics360
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 09c2211b3a..abf1088cd4 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -440,7 +440,8 @@
"pandas-gbq",
"pandas>=0.17.1",
"proto-plus>=1.19.6",
- "sqlalchemy-bigquery>=1.2.1"
+ "sqlalchemy-bigquery>=1.2.1",
+ "sqlalchemy-spanner>=1.6.2"
],
"cross-providers-deps": [
"amazon",
diff --git a/tests/providers/google/cloud/hooks/test_spanner.py
b/tests/providers/google/cloud/hooks/test_spanner.py
index abb346f001..1633d55816 100644
--- a/tests/providers/google/cloud/hooks/test_spanner.py
+++ b/tests/providers/google/cloud/hooks/test_spanner.py
@@ -18,9 +18,10 @@
from __future__ import annotations
from unittest import mock
-from unittest.mock import PropertyMock
+from unittest.mock import MagicMock, PropertyMock
import pytest
+import sqlalchemy
from airflow.providers.google.cloud.hooks.spanner import SpannerHook
from airflow.providers.google.common.consts import CLIENT_INFO
@@ -33,6 +34,8 @@ from tests.providers.google.cloud.utils.base_gcp_mock import (
SPANNER_INSTANCE = "instance"
SPANNER_CONFIGURATION = "configuration"
SPANNER_DATABASE = "database-name"
+SPANNER_PROJECT_ID = "test_project_id"
+SPANNER_CONN_PARAMS = (SPANNER_PROJECT_ID, SPANNER_INSTANCE, SPANNER_DATABASE)
class TestGcpSpannerHookDefaultProjectId:
@@ -431,6 +434,21 @@ class TestGcpSpannerHookDefaultProjectId:
run_in_transaction_method.assert_called_once_with(mock.ANY)
assert res is None
+ def test_get_uri(self):
+ self.spanner_hook_default_project_id._get_conn_params =
MagicMock(return_value=SPANNER_CONN_PARAMS)
+ uri = self.spanner_hook_default_project_id.get_uri()
+ assert (
+ uri
+ ==
f"spanner+spanner:///projects/{SPANNER_PROJECT_ID}/instances/{SPANNER_INSTANCE}/databases/{SPANNER_DATABASE}"
+ )
+
+
@mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client")
+ def test_get_sqlalchemy_engine(self, get_client):
+ self.spanner_hook_default_project_id._get_conn_params =
MagicMock(return_value=SPANNER_CONN_PARAMS)
+ engine = self.spanner_hook_default_project_id.get_sqlalchemy_engine()
+ assert isinstance(engine, sqlalchemy.engine.Engine)
+ assert engine.name == "spanner+spanner"
+
class TestGcpSpannerHookNoDefaultProjectID:
def setup_method(self):
@@ -675,3 +693,18 @@ class TestGcpSpannerHookNoDefaultProjectID:
database_method.assert_called_once_with(database_id="database-name")
run_in_transaction_method.assert_called_once_with(mock.ANY)
assert res is None
+
+ def test_get_uri(self):
+ self.spanner_hook_no_default_project_id._get_conn_params =
MagicMock(return_value=SPANNER_CONN_PARAMS)
+ uri = self.spanner_hook_no_default_project_id.get_uri()
+ assert (
+ uri
+ ==
f"spanner+spanner:///projects/{SPANNER_PROJECT_ID}/instances/{SPANNER_INSTANCE}/databases/{SPANNER_DATABASE}"
+ )
+
+
@mock.patch("airflow.providers.google.cloud.hooks.spanner.SpannerHook._get_client")
+ def test_get_sqlalchemy_engine(self, get_client):
+ self.spanner_hook_no_default_project_id._get_conn_params =
MagicMock(return_value=SPANNER_CONN_PARAMS)
+ engine =
self.spanner_hook_no_default_project_id.get_sqlalchemy_engine()
+ assert isinstance(engine, sqlalchemy.engine.Engine)
+ assert engine.name == "spanner+spanner"