This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1757704d8f Enhancement for SSL-support in CloudSQLExecuteQueryOperator
(#38894)
1757704d8f is described below
commit 1757704d8f7c7a335cdf8f90ccb12b4d8e6f9d9a
Author: max <[email protected]>
AuthorDate: Thu Apr 11 10:56:34 2024 +0200
Enhancement for SSL-support in CloudSQLExecuteQueryOperator (#38894)
---
.../cloud/example_dags/example_cloud_sql_query.py | 289 -----------
airflow/providers/google/cloud/hooks/cloud_sql.py | 152 +++++-
.../providers/google/cloud/hooks/secret_manager.py | 252 ++++++++-
.../providers/google/cloud/operators/cloud_sql.py | 57 +-
.../operators/cloud/cloud_sql.rst | 54 +-
docs/spelling_wordlist.txt | 8 +
.../providers/google/cloud/hooks/test_cloud_sql.py | 417 ++++++++++++++-
.../google/cloud/hooks/test_secret_manager.py | 239 ++++++++-
.../cloud/cloud_sql/example_cloud_sql_query.py | 572 +++++++++++++++++++++
.../cloud_sql/example_cloud_sql_query_mysql.py | 285 ----------
.../cloud_sql/example_cloud_sql_query_postgres.py | 290 -----------
.../cloud/cloud_sql/example_cloud_sql_query_ssl.py | 518 +++++++++++++++++++
12 files changed, 2223 insertions(+), 910 deletions(-)
diff --git
a/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py
b/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py
deleted file mode 100644
index b883ed13cf..0000000000
--- a/airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py
+++ /dev/null
@@ -1,289 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Example Airflow DAG that performs query in a Cloud SQL instance.
-
-This DAG relies on the following OS environment variables
-
-* GCP_PROJECT_ID - Google Cloud project for the Cloud SQL instance
-* GCP_REGION - Google Cloud region where the database is created
-*
-* GCSQL_POSTGRES_INSTANCE_NAME - Name of the postgres Cloud SQL instance
-* GCSQL_POSTGRES_USER - Name of the postgres database user
-* GCSQL_POSTGRES_PASSWORD - Password of the postgres database user
-* GCSQL_POSTGRES_PUBLIC_IP - Public IP of the Postgres database
-* GCSQL_POSTGRES_PUBLIC_PORT - Port of the postgres database
-*
-* GCSQL_MYSQL_INSTANCE_NAME - Name of the postgres Cloud SQL instance
-* GCSQL_MYSQL_USER - Name of the mysql database user
-* GCSQL_MYSQL_PASSWORD - Password of the mysql database user
-* GCSQL_MYSQL_PUBLIC_IP - Public IP of the mysql database
-* GCSQL_MYSQL_PUBLIC_PORT - Port of the mysql database
-"""
-
-from __future__ import annotations
-
-import os
-import subprocess
-from datetime import datetime
-from pathlib import Path
-from urllib.parse import quote_plus
-
-from airflow.models.dag import DAG
-from airflow.providers.google.cloud.operators.cloud_sql import
CloudSQLExecuteQueryOperator
-
-GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
-GCP_REGION = os.environ.get("GCP_REGION", "europe-west1")
-
-GCSQL_POSTGRES_INSTANCE_NAME_QUERY = os.environ.get(
- "GCSQL_POSTGRES_INSTANCE_NAME_QUERY", "test-postgres-query"
-)
-GCSQL_POSTGRES_DATABASE_NAME = os.environ.get("GCSQL_POSTGRES_DATABASE_NAME",
"postgresdb")
-GCSQL_POSTGRES_USER = os.environ.get("GCSQL_POSTGRES_USER", "postgres_user")
-GCSQL_POSTGRES_PASSWORD = os.environ.get("GCSQL_POSTGRES_PASSWORD",
"JoxHlwrPzwch0gz9")
-GCSQL_POSTGRES_PUBLIC_IP = os.environ.get("GCSQL_POSTGRES_PUBLIC_IP",
"0.0.0.0")
-GCSQL_POSTGRES_PUBLIC_PORT = os.environ.get("GCSQL_POSTGRES_PUBLIC_PORT", 5432)
-GCSQL_POSTGRES_CLIENT_CERT_FILE = os.environ.get(
- "GCSQL_POSTGRES_CLIENT_CERT_FILE", ".key/postgres-client-cert.pem"
-)
-GCSQL_POSTGRES_CLIENT_KEY_FILE = os.environ.get(
- "GCSQL_POSTGRES_CLIENT_KEY_FILE", ".key/postgres-client-key.pem"
-)
-GCSQL_POSTGRES_SERVER_CA_FILE =
os.environ.get("GCSQL_POSTGRES_SERVER_CA_FILE", ".key/postgres-server-ca.pem")
-
-GCSQL_MYSQL_INSTANCE_NAME_QUERY =
os.environ.get("GCSQL_MYSQL_INSTANCE_NAME_QUERY", "test-mysql-query")
-GCSQL_MYSQL_DATABASE_NAME = os.environ.get("GCSQL_MYSQL_DATABASE_NAME",
"mysqldb")
-GCSQL_MYSQL_USER = os.environ.get("GCSQL_MYSQL_USER", "mysql_user")
-GCSQL_MYSQL_PASSWORD = os.environ.get("GCSQL_MYSQL_PASSWORD",
"JoxHlwrPzwch0gz9")
-GCSQL_MYSQL_PUBLIC_IP = os.environ.get("GCSQL_MYSQL_PUBLIC_IP", "0.0.0.0")
-GCSQL_MYSQL_PUBLIC_PORT = os.environ.get("GCSQL_MYSQL_PUBLIC_PORT", 3306)
-GCSQL_MYSQL_CLIENT_CERT_FILE = os.environ.get("GCSQL_MYSQL_CLIENT_CERT_FILE",
".key/mysql-client-cert.pem")
-GCSQL_MYSQL_CLIENT_KEY_FILE = os.environ.get("GCSQL_MYSQL_CLIENT_KEY_FILE",
".key/mysql-client-key.pem")
-GCSQL_MYSQL_SERVER_CA_FILE = os.environ.get("GCSQL_MYSQL_SERVER_CA_FILE",
".key/mysql-server-ca.pem")
-
-SQL = [
- "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
- "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)", # shows warnings
logged
- "INSERT INTO TABLE_TEST VALUES (0)",
- "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)",
- "DROP TABLE TABLE_TEST",
- "DROP TABLE TABLE_TEST2",
-]
-
-
-# [START howto_operator_cloudsql_query_connections]
-
-HOME_DIR = Path.home()
-
-
-def get_absolute_path(path):
- """
- Returns absolute path.
- """
- return os.fspath(HOME_DIR / path)
-
-
-postgres_kwargs = {
- "user": quote_plus(GCSQL_POSTGRES_USER),
- "password": quote_plus(GCSQL_POSTGRES_PASSWORD),
- "public_port": GCSQL_POSTGRES_PUBLIC_PORT,
- "public_ip": quote_plus(GCSQL_POSTGRES_PUBLIC_IP),
- "project_id": quote_plus(GCP_PROJECT_ID),
- "location": quote_plus(GCP_REGION),
- "instance": quote_plus(GCSQL_POSTGRES_INSTANCE_NAME_QUERY),
- "database": quote_plus(GCSQL_POSTGRES_DATABASE_NAME),
- "client_cert_file":
quote_plus(get_absolute_path(GCSQL_POSTGRES_CLIENT_CERT_FILE)),
- "client_key_file":
quote_plus(get_absolute_path(GCSQL_POSTGRES_CLIENT_KEY_FILE)),
- "server_ca_file":
quote_plus(get_absolute_path(GCSQL_POSTGRES_SERVER_CA_FILE)),
-}
-
-# The connections below are created using one of the standard approaches - via
environment
-# variables named AIRFLOW_CONN_* . The connections can also be created in the
database
-# of AIRFLOW (using command line or UI).
-
-# Postgres: connect via proxy over TCP
-os.environ["AIRFLOW_CONN_PROXY_POSTGRES_TCP"] = (
- "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
- "database_type=postgres&"
- "project_id={project_id}&"
- "location={location}&"
- "instance={instance}&"
- "use_proxy=True&"
- "sql_proxy_use_tcp=True".format(**postgres_kwargs)
-)
-
-# Postgres: connect via proxy over UNIX socket (specific proxy version)
-os.environ["AIRFLOW_CONN_PROXY_POSTGRES_SOCKET"] = (
- "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
- "database_type=postgres&"
- "project_id={project_id}&"
- "location={location}&"
- "instance={instance}&"
- "use_proxy=True&"
- "sql_proxy_version=v1.13&"
- "sql_proxy_use_tcp=False".format(**postgres_kwargs)
-)
-
-# Postgres: connect directly via TCP (non-SSL)
-os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP"] = (
- "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
- "database_type=postgres&"
- "project_id={project_id}&"
- "location={location}&"
- "instance={instance}&"
- "use_proxy=False&"
- "use_ssl=False".format(**postgres_kwargs)
-)
-
-# Postgres: connect directly via TCP (SSL)
-os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL"] = (
- "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
- "database_type=postgres&"
- "project_id={project_id}&"
- "location={location}&"
- "instance={instance}&"
- "use_proxy=False&"
- "use_ssl=True&"
- "sslcert={client_cert_file}&"
- "sslkey={client_key_file}&"
- "sslrootcert={server_ca_file}".format(**postgres_kwargs)
-)
-
-mysql_kwargs = {
- "user": quote_plus(GCSQL_MYSQL_USER),
- "password": quote_plus(GCSQL_MYSQL_PASSWORD),
- "public_port": GCSQL_MYSQL_PUBLIC_PORT,
- "public_ip": quote_plus(GCSQL_MYSQL_PUBLIC_IP),
- "project_id": quote_plus(GCP_PROJECT_ID),
- "location": quote_plus(GCP_REGION),
- "instance": quote_plus(GCSQL_MYSQL_INSTANCE_NAME_QUERY),
- "database": quote_plus(GCSQL_MYSQL_DATABASE_NAME),
- "client_cert_file":
quote_plus(get_absolute_path(GCSQL_MYSQL_CLIENT_CERT_FILE)),
- "client_key_file":
quote_plus(get_absolute_path(GCSQL_MYSQL_CLIENT_KEY_FILE)),
- "server_ca_file":
quote_plus(get_absolute_path(GCSQL_MYSQL_SERVER_CA_FILE)),
-}
-
-# MySQL: connect via proxy over TCP (specific proxy version)
-os.environ["AIRFLOW_CONN_PROXY_MYSQL_TCP"] = (
- "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
- "database_type=mysql&"
- "project_id={project_id}&"
- "location={location}&"
- "instance={instance}&"
- "use_proxy=True&"
- "sql_proxy_version=v1.13&"
- "sql_proxy_use_tcp=True".format(**mysql_kwargs)
-)
-
-# MySQL: connect via proxy over UNIX socket using pre-downloaded Cloud Sql
Proxy binary
-try:
- sql_proxy_binary_path = subprocess.check_output(["which",
"cloud_sql_proxy"]).decode("utf-8").rstrip()
-except subprocess.CalledProcessError:
- sql_proxy_binary_path = "/tmp/anyhow_download_cloud_sql_proxy"
-
-os.environ["AIRFLOW_CONN_PROXY_MYSQL_SOCKET"] = (
- "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
- "database_type=mysql&"
- "project_id={project_id}&"
- "location={location}&"
- "instance={instance}&"
- "use_proxy=True&"
- "sql_proxy_use_tcp=False".format(**mysql_kwargs)
-)
-
-# MySQL: connect directly via TCP (non-SSL)
-os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP"] = (
- "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
- "database_type=mysql&"
- "project_id={project_id}&"
- "location={location}&"
- "instance={instance}&"
- "use_proxy=False&"
- "use_ssl=False".format(**mysql_kwargs)
-)
-
-# MySQL: connect directly via TCP (SSL) and with fixed Cloud Sql Proxy binary
path
-os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL"] = (
- "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
- "database_type=mysql&"
- "project_id={project_id}&"
- "location={location}&"
- "instance={instance}&"
- "use_proxy=False&"
- "use_ssl=True&"
- "sslcert={client_cert_file}&"
- "sslkey={client_key_file}&"
- "sslrootcert={server_ca_file}".format(**mysql_kwargs)
-)
-
-# Special case: MySQL: connect directly via TCP (SSL) and with fixed Cloud Sql
-# Proxy binary path AND with missing project_id
-
-os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL_NO_PROJECT_ID"] = (
- "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
- "database_type=mysql&"
- "location={location}&"
- "instance={instance}&"
- "use_proxy=False&"
- "use_ssl=True&"
- "sslcert={client_cert_file}&"
- "sslkey={client_key_file}&"
- "sslrootcert={server_ca_file}".format(**mysql_kwargs)
-)
-
-
-# [END howto_operator_cloudsql_query_connections]
-
-# [START howto_operator_cloudsql_query_operators]
-
-connection_names = [
- "proxy_postgres_tcp",
- "proxy_postgres_socket",
- "public_postgres_tcp",
- "public_postgres_tcp_ssl",
- "proxy_mysql_tcp",
- "proxy_mysql_socket",
- "public_mysql_tcp",
- "public_mysql_tcp_ssl",
- "public_mysql_tcp_ssl_no_project_id",
-]
-
-tasks = []
-
-
-with DAG(
- dag_id="example_gcp_sql_query",
- start_date=datetime(2021, 1, 1),
- catchup=False,
- tags=["example"],
-) as dag:
- prev_task = None
-
- for connection_name in connection_names:
- task = CloudSQLExecuteQueryOperator(
- gcp_cloudsql_conn_id=connection_name,
- task_id="example_gcp_sql_task_" + connection_name,
- sql=SQL,
- sql_proxy_binary_path=sql_proxy_binary_path,
- )
- tasks.append(task)
- if prev_task:
- prev_task >> task
- prev_task = task
-
-# [END howto_operator_cloudsql_query_operators]
diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py
b/airflow/providers/google/cloud/hooks/cloud_sql.py
index 91d15811cf..615afde834 100644
--- a/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+import base64
import errno
import json
import os
@@ -34,7 +35,7 @@ import uuid
from inspect import signature
from pathlib import Path
from subprocess import PIPE, Popen
-from tempfile import gettempdir
+from tempfile import NamedTemporaryFile, _TemporaryFileWrapper, gettempdir
from typing import TYPE_CHECKING, Any, Sequence
from urllib.parse import quote_plus
@@ -49,12 +50,16 @@ from googleapiclient.errors import HttpError
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
+from airflow.providers.google.cloud.hooks.secret_manager import (
+ GoogleCloudSecretManagerHook,
+)
from airflow.providers.google.common.hooks.base_google import
GoogleBaseAsyncHook, GoogleBaseHook, get_field
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.log.logging_mixin import LoggingMixin
if TYPE_CHECKING:
+ from google.cloud.secretmanager_v1 import AccessSecretVersionResponse
from requests import Session
UNIX_PATH_MAX = 108
@@ -377,6 +382,29 @@ class CloudSQLHook(GoogleBaseHook):
except HttpError as ex:
raise AirflowException(f"Cloning of instance {instance} failed:
{ex.content}")
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_ssl_certificate(self, instance: str, body: dict, project_id:
str):
+ """
+ Create SSL certificate for a Cloud SQL instance.
+
+ :param instance: Cloud SQL instance ID. This does not include the
project ID.
+ :param body: The request body, as described in
+
https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/sslCerts/insert#SslCertsInsertRequest
+ :param project_id: Project ID of the project that contains the
instance. If set
+ to None or missing, the default project_id from the Google Cloud
connection is used.
+ :return: SslCert insert response. For more details see:
+
https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/sslCerts/insert#response-body
+ """
+ response = (
+ self.get_conn()
+ .sslCerts()
+ .insert(project=project_id, instance=instance, body=body)
+ .execute(num_retries=self.num_retries)
+ )
+ operation_name = response.get("operation", {}).get("name", {})
+ self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)
+ return response
+
@GoogleBaseHook.fallback_to_default_project_id
def _wait_for_operation_to_complete(
self, project_id: str, operation_name: str, time_to_sleep: int =
TIME_TO_SLEEP_IN_SECONDS
@@ -758,7 +786,24 @@ class CloudSQLDatabaseHook(BaseHook):
:param gcp_conn_id: The connection ID used to connect to Google Cloud for
cloud-sql-proxy authentication.
:param default_gcp_project_id: Default project id used if project_id not
specified
- in the connection URL
+ in the connection URL
+ :param ssl_cert: Optional. Path to client certificate to authenticate when
SSL is used. Overrides the
+ connection field ``sslcert``.
+ :param ssl_key: Optional. Path to client private key to authenticate when
SSL is used. Overrides the
+ connection field ``sslkey``.
+ :param ssl_root_cert: Optional. Path to server's certificate to
authenticate when SSL is used. Overrides
+ the connection field ``sslrootcert``.
+ :param ssl_secret_id: Optional. ID of the secret in Google Cloud Secret
Manager that stores SSL
+ certificate in the format below:
+
+ {'sslcert': '',
+ 'sslkey': '',
+ 'sslrootcert': ''}
+
+ Overrides the connection fields ``sslcert``, ``sslkey``,
``sslrootcert``.
+ Note that according to the Secret Manager requirements, the mentioned
dict should be saved as a
+ string, and encoded with base64.
+ Note that this parameter is incompatible with parameters ``ssl_cert``,
``ssl_key``, ``ssl_root_cert``.
"""
conn_name_attr = "gcp_cloudsql_conn_id"
@@ -770,12 +815,18 @@ class CloudSQLDatabaseHook(BaseHook):
self,
gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
default_gcp_project_id: str | None = None,
sql_proxy_binary_path: str | None = None,
+ ssl_cert: str | None = None,
+ ssl_key: str | None = None,
+ ssl_root_cert: str | None = None,
+ ssl_secret_id: str | None = None,
) -> None:
super().__init__()
self.gcp_conn_id = gcp_conn_id
self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id
+ self.impersonation_chain = impersonation_chain
self.cloudsql_connection =
self.get_connection(self.gcp_cloudsql_conn_id)
self.extras = self.cloudsql_connection.extra_dejson
self.project_id = self.extras.get("project_id", default_gcp_project_id)
@@ -792,9 +843,11 @@ class CloudSQLDatabaseHook(BaseHook):
self.password = self.cloudsql_connection.password
self.public_ip = self.cloudsql_connection.host
self.public_port = self.cloudsql_connection.port
- self.sslcert = self.extras.get("sslcert")
- self.sslkey = self.extras.get("sslkey")
- self.sslrootcert = self.extras.get("sslrootcert")
+ self.ssl_cert = ssl_cert
+ self.ssl_key = ssl_key
+ self.ssl_root_cert = ssl_root_cert
+ self.ssl_secret_id = ssl_secret_id
+ self._ssl_cert_temp_files: dict[str, _TemporaryFileWrapper] = {}
# Port and socket path and db_hook are automatically generated
self.sql_proxy_tcp_port = None
self.sql_proxy_unique_path: str | None = None
@@ -805,6 +858,84 @@ class CloudSQLDatabaseHook(BaseHook):
self.db_conn_id = str(uuid.uuid1())
self._validate_inputs()
+ @property
+ def sslcert(self) -> str | None:
+ return self._get_ssl_temporary_file_path(cert_name="sslcert",
cert_path=self.ssl_cert)
+
+ @property
+ def sslkey(self) -> str | None:
+ return self._get_ssl_temporary_file_path(cert_name="sslkey",
cert_path=self.ssl_key)
+
+ @property
+ def sslrootcert(self) -> str | None:
+ return self._get_ssl_temporary_file_path(cert_name="sslrootcert",
cert_path=self.ssl_root_cert)
+
+ def _get_ssl_temporary_file_path(self, cert_name: str, cert_path: str |
None) -> str | None:
+ cert_value = self._get_cert_from_secret(cert_name)
+ original_cert_path = cert_path or self.extras.get(cert_name)
+ if cert_value or original_cert_path:
+ if cert_name not in self._ssl_cert_temp_files:
+ return self._set_temporary_ssl_file(
+ cert_name=cert_name, cert_path=original_cert_path,
cert_value=cert_value
+ )
+ return self._ssl_cert_temp_files[cert_name].name
+ return None
+
+ def _get_cert_from_secret(self, cert_name: str) -> str | None:
+ if not self.ssl_secret_id:
+ return None
+
+ secret_hook = GoogleCloudSecretManagerHook(
+ gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
+ )
+ secret: AccessSecretVersionResponse = secret_hook.access_secret(
+ project_id=self.project_id,
+ secret_id=self.ssl_secret_id,
+ )
+ secret_data = json.loads(base64.b64decode(secret.payload.data))
+ if cert_name in secret_data:
+ return secret_data[cert_name]
+ else:
+ raise AirflowException(
+ "Invalid secret format. Expected dictionary with keys:
`sslcert`, `sslkey`, `sslrootcert`"
+ )
+
+ def _set_temporary_ssl_file(
+ self, cert_name: str, cert_path: str | None = None, cert_value: str |
None = None
+ ) -> str | None:
+ """Save the certificate as a temporary file.
+
+ This method was implemented in order to overcome psql connection error
caused by excessive file
+ permissions: "private key file "..." has group or world access; file
must have permissions
+ u=rw (0600) or less if owned by the current user, or permissions
u=rw,g=r (0640) or less if owned
+ by root". NamedTemporaryFile enforces using exactly one of
create/read/write/append mode so the
+ created file obtains least required permissions "-rw-------" that
satisfies the rules.
+
+ :param cert_name: Required. Name of the certificate (one of sslcert,
sslkey, sslrootcert).
+ :param cert_path: Optional. Path to the certificate.
+ :param cert_value: Optional. The certificate content.
+
+ :returns: The path to the temporary certificate file.
+ """
+ if all([cert_path, cert_value]):
+ raise AirflowException(
+ "Both parameters were specified: `cert_path`, `cert_value`.
Please use only one of them."
+ )
+ if not any([cert_path, cert_value]):
+ self.log.info("Neither cert path and cert value provided. Nothing
to save.")
+ return None
+
+ _temp_file = NamedTemporaryFile(mode="w+b", prefix="/tmp/certs/")
+ if cert_path:
+ with open(cert_path, "rb") as cert_file:
+ _temp_file.write(cert_file.read())
+ elif cert_value:
+ _temp_file.write(cert_value.encode("ascii"))
+ _temp_file.flush()
+ self._ssl_cert_temp_files[cert_name] = _temp_file
+ self.log.info("Copied the certificate '%s' into a temporary file
'%s'", cert_name, _temp_file.name)
+ return _temp_file.name
+
@staticmethod
def _get_bool(val: Any) -> bool:
if val == "False" or val is False:
@@ -836,6 +967,17 @@ class CloudSQLDatabaseHook(BaseHook):
" SSL is not needed as Cloud SQL Proxy "
"provides encryption on its own"
)
+ if any([self.ssl_key, self.ssl_cert, self.ssl_root_cert]) and
self.ssl_secret_id:
+ raise AirflowException(
+ "Invalid SSL settings. Please use either all of parameters
['ssl_cert', 'ssl_cert', "
+ "'ssl_root_cert'] or a single parameter 'ssl_secret_id'."
+ )
+ if any([self.ssl_key, self.ssl_cert, self.ssl_root_cert]):
+ field_names = ["ssl_key", "ssl_cert", "ssl_root_cert"]
+ if missed_values := [field for field in field_names if not
getattr(self, field)]:
+ s = "s are" if len(missed_values) > 1 else "is"
+ missed_values_str = ", ".join(f for f in missed_values)
+ raise AirflowException(f"Invalid SSL settings. Parameter{s}
missing: {missed_values_str}")
def validate_ssl_certs(self) -> None:
"""
diff --git a/airflow/providers/google/cloud/hooks/secret_manager.py
b/airflow/providers/google/cloud/hooks/secret_manager.py
index 5bfea5ac5e..5fd303445c 100644
--- a/airflow/providers/google/cloud/hooks/secret_manager.py
+++ b/airflow/providers/google/cloud/hooks/secret_manager.py
@@ -15,16 +15,38 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Hook for Secrets Manager service."""
+"""This module contains a Secret Manager hook."""
from __future__ import annotations
-from typing import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Sequence
+from deprecated import deprecated
+from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
+from google.cloud.secretmanager_v1 import (
+ AccessSecretVersionResponse,
+ Secret,
+ SecretManagerServiceClient,
+ SecretPayload,
+ SecretVersion,
+)
+
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud._internal_client.secret_manager_client
import _SecretManagerClient
+from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+if TYPE_CHECKING:
+ from google.api_core.retry import Retry
+ from google.cloud.secretmanager_v1.services.secret_manager_service.pagers
import ListSecretsPager
+
+@deprecated(
+ reason="The SecretsManagerHook is deprecated and will be removed after
01.11.2024. "
+ "Please use GoogleCloudSecretManagerHook instead.",
+ category=AirflowProviderDeprecationWarning,
+)
class SecretsManagerHook(GoogleBaseHook):
"""
Hook for the Google Secret Manager API.
@@ -86,3 +108,229 @@ class SecretsManagerHook(GoogleBaseHook):
secret_version=secret_version,
project_id=project_id, # type: ignore
)
+
+
+class GoogleCloudSecretManagerHook(GoogleBaseHook):
+ """Hook for the Google Cloud Secret Manager API.
+
+ See https://cloud.google.com/secret-manager
+ """
+
+ @cached_property
+ def client(self):
+ """Create a Secret Manager Client.
+
+ :return: Secret Manager client.
+ """
+ return SecretManagerServiceClient(credentials=self.get_credentials(),
client_info=CLIENT_INFO)
+
+ def get_conn(self) -> SecretManagerServiceClient:
+ """Retrieve the connection to Secret Manager.
+
+ :return: Secret Manager client.
+ """
+ return self.client
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_secret(
+ self,
+ project_id: str,
+ secret_id: str,
+ secret: dict | Secret | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Secret:
+ """Create a secret.
+
+ .. seealso::
+ For more details see API documentation:
+
https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1_services_secret_manager_service_SecretManagerServiceClient_create_secret
+
+ :param project_id: Required. ID of the GCP project that owns the job.
+ If set to ``None`` or missing, the default project_id from the GCP
connection is used.
+ :param secret_id: Required. ID of the secret to create.
+ :param secret: Optional. Secret to create.
+ :param retry: Optional. Designation of what errors, if any, should be
retried.
+ :param timeout: Optional. The timeout for this request.
+ :param metadata: Optional. Strings which should be sent along with the
request as metadata.
+ :return: Secret object.
+ """
+ _secret = secret or {"replication": {"automatic": {}}}
+ response = self.client.create_secret(
+ request={
+ "parent": f"projects/{project_id}",
+ "secret_id": secret_id,
+ "secret": _secret,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ self.log.info("Secret Created: %s", response.name)
+ return response
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def add_secret_version(
+ self,
+ project_id: str,
+ secret_id: str,
+ secret_payload: dict | SecretPayload | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> SecretVersion:
+ """Add a version to the secret.
+
+ .. seealso::
+ For more details see API documentation:
+
https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1_services_secret_manager_service_SecretManagerServiceClient_add_secret_version
+
+ :param project_id: Required. ID of the GCP project that owns the job.
+ If set to ``None`` or missing, the default project_id from the GCP
connection is used.
+ :param secret_id: Required. ID of the secret to create.
+ :param secret_payload: Optional. A secret payload.
+ :param retry: Optional. Designation of what errors, if any, should be
retried.
+ :param timeout: Optional. The timeout for this request.
+ :param metadata: Optional. Strings which should be sent along with the
request as metadata.
+ :return: Secret version object.
+ """
+ response = self.client.add_secret_version(
+ request={
+ "parent": f"projects/{project_id}/secrets/{secret_id}",
+ "payload": secret_payload,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ self.log.info("Secret version added: %s", response.name)
+ return response
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def list_secrets(
+ self,
+ project_id: str,
+ page_size: int = 0,
+ page_token: str | None = None,
+ secret_filter: str | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> ListSecretsPager:
+ """List secrets.
+
+ .. seealso::
+ For more details see API documentation:
+
https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1_services_secret_manager_service_SecretManagerServiceClient_list_secrets
+
+ :param project_id: Required. ID of the GCP project that owns the job.
+ If set to ``None`` or missing, the default project_id from the GCP
connection is used.
+ :param page_size: Optional, number of results to return in the list.
+ :param page_token: Optional, token to provide to skip to a particular
spot in the list.
+ :param secret_filter: Optional. Filter string.
+ :param retry: Optional. Designation of what errors, if any, should be
retried.
+ :param timeout: Optional. The timeout for this request.
+ :param metadata: Optional. Strings which should be sent along with the
request as metadata.
+ :return: Secret List object.
+ """
+ response = self.client.list_secrets(
+ request={
+ "parent": f"projects/{project_id}",
+ "page_size": page_size,
+ "page_token": page_token,
+ "filter": secret_filter,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ self.log.info("Secrets list obtained")
+ return response
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def secret_exists(self, project_id: str, secret_id: str) -> bool:
+ """Check whether secret exists.
+
+ :param project_id: Required. ID of the GCP project that owns the job.
+ If set to ``None`` or missing, the default project_id from the GCP
connection is used.
+ :param secret_id: Required. ID of the secret to find.
+ :return: True if the secret exists, False otherwise.
+ """
+ secret_filter = f"name:{secret_id}"
+ secret_name = self.client.secret_path(project_id, secret_id)
+ for secret in self.list_secrets(project_id=project_id, page_size=100,
secret_filter=secret_filter):
+ if secret.name.split("/")[-1] == secret_id:
+ self.log.info("Secret %s exists.", secret_name)
+ return True
+ self.log.info("Secret %s doesn't exists.", secret_name)
+ return False
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def access_secret(
+ self,
+ project_id: str,
+ secret_id: str,
+ secret_version: str = "latest",
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> AccessSecretVersionResponse:
+ """Access a secret version.
+
+ .. seealso::
+ For more details see API documentation:
+
https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1_services_secret_manager_service_SecretManagerServiceClient_access_secret_version
+
+ :param project_id: Required. ID of the GCP project that owns the job.
+ If set to ``None`` or missing, the default project_id from the GCP
connection is used.
+ :param secret_id: Required. ID of the secret to access.
+ :param secret_version: Optional. Version of the secret to access.
Default: latest.
+ :param retry: Optional. Designation of what errors, if any, should be
retried.
+ :param timeout: Optional. The timeout for this request.
+ :param metadata: Optional. Strings which should be sent along with the
request as metadata.
+ :return: Access secret version response object.
+ """
+ response = self.client.access_secret_version(
+ request={
+ "name": self.client.secret_version_path(project_id, secret_id,
secret_version),
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ self.log.info("Secret version accessed: %s", response.name)
+ return response
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_secret(
+ self,
+ project_id: str,
+ secret_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> None:
+ """Delete a secret.
+
+ .. seealso::
+ For more details see API documentation:
+
https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1_services_secret_manager_service_SecretManagerServiceClient_delete_secret
+
+ :param project_id: Required. ID of the GCP project that owns the job.
+ If set to ``None`` or missing, the default project_id from the GCP
connection is used.
+ :param secret_id: Required. ID of the secret to delete.
+ :param retry: Optional. Designation of what errors, if any, should be
retried.
+ :param timeout: Optional. The timeout for this request.
+ :param metadata: Optional. Strings which should be sent along with the
request as metadata.
+ :return: Access secret version response object.
+ """
+ name = self.client.secret_path(project_id, secret_id)
+ self.client.delete_secret(
+ request={"name": name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ self.log.info("Secret deleted: %s", name)
+ return None
diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py
b/airflow/providers/google/cloud/operators/cloud_sql.py
index 5595773589..c8519e1e26 100644
--- a/airflow/providers/google/cloud/operators/cloud_sql.py
+++ b/airflow/providers/google/cloud/operators/cloud_sql.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+from functools import cached_property
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence
from googleapiclient.errors import HttpError
@@ -1181,10 +1182,35 @@ class
CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
details on how to define ``gcpcloudsql://`` connection.
:param sql_proxy_binary_path: (optional) Path to the cloud-sql-proxy
binary.
is not specified or the binary is not present, it is automatically
downloaded.
+ :param ssl_cert: (optional) Path to client certificate to authenticate
when SSL is used. Overrides the
+ connection field ``sslcert``.
+ :param ssl_key: (optional) Path to client private key to authenticate when
SSL is used. Overrides the
+ connection field ``sslkey``.
+ :param ssl_root_cert: (optional) Path to server's certificate to
authenticate when SSL is used. Overrides
+ the connection field ``sslrootcert``.
+ :param ssl_secret_id: (optional) ID of the secret in Google Cloud Secret
Manager that stores SSL
+ certificate in the format below:
+
+ {'sslcert': '',
+ 'sslkey': '',
+ 'sslrootcert': ''}
+
+ Overrides the connection fields ``sslcert``, ``sslkey``,
``sslrootcert``.
+ Note that according to the Secret Manager requirements, the mentioned
dict should be saved as a
+ string, and encoded with base64.
+ Note that this parameter is incompatible with parameters ``ssl_cert``,
``ssl_key``, ``ssl_root_cert``.
"""
# [START gcp_sql_query_template_fields]
- template_fields: Sequence[str] = ("sql", "gcp_cloudsql_conn_id",
"gcp_conn_id")
+ template_fields: Sequence[str] = (
+ "sql",
+ "gcp_cloudsql_conn_id",
+ "gcp_conn_id",
+ "ssl_server_cert",
+ "ssl_client_cert",
+ "ssl_client_key",
+ "ssl_secret_id",
+ )
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"sql": "sql"}
# [END gcp_sql_query_template_fields]
@@ -1199,6 +1225,10 @@ class
CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
gcp_conn_id: str = "google_cloud_default",
gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
sql_proxy_binary_path: str | None = None,
+ ssl_server_cert: str | None = None,
+ ssl_client_cert: str | None = None,
+ ssl_client_key: str | None = None,
+ ssl_secret_id: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -1209,6 +1239,10 @@ class
CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
self.parameters = parameters
self.gcp_connection: Connection | None = None
self.sql_proxy_binary_path = sql_proxy_binary_path
+ self.ssl_server_cert = ssl_server_cert
+ self.ssl_client_cert = ssl_client_cert
+ self.ssl_client_key = ssl_client_key
+ self.ssl_secret_id = ssl_secret_id
def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook:
PostgresHook | MySqlHook) -> None:
cloud_sql_proxy_runner = None
@@ -1228,12 +1262,8 @@ class
CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
def execute(self, context: Context):
self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
- hook = CloudSQLDatabaseHook(
- gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
- gcp_conn_id=self.gcp_conn_id,
- default_gcp_project_id=get_field(self.gcp_connection.extra_dejson,
"project"),
- sql_proxy_binary_path=self.sql_proxy_binary_path,
- )
+
+ hook = self.hook
hook.validate_ssl_certs()
connection = hook.create_connection()
hook.validate_socket_path_length()
@@ -1242,3 +1272,16 @@ class
CloudSQLExecuteQueryOperator(GoogleCloudBaseOperator):
self._execute_query(hook, database_hook)
finally:
hook.cleanup_database_hook()
+
+ @cached_property
+ def hook(self):
+ return CloudSQLDatabaseHook(
+ gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
+ gcp_conn_id=self.gcp_conn_id,
+ default_gcp_project_id=get_field(self.gcp_connection.extra_dejson,
"project"),
+ sql_proxy_binary_path=self.sql_proxy_binary_path,
+ ssl_root_cert=self.ssl_server_cert,
+ ssl_cert=self.ssl_client_cert,
+ ssl_key=self.ssl_client_key,
+ ssl_secret_id=self.ssl_secret_id,
+ )
diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
b/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
index 9b859c63db..ec334c0895 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
@@ -574,10 +574,17 @@ certificate/key files available in predefined locations
for all the workers on
which the operator can run. This can be provided for example by mounting
NFS-like volumes in the same path for all the workers.
-Example connection definitions for all non-SSL connectivity cases for
Postgres. For connecting to MySQL database
-please use ``mysql`` as a ``database_type``. Note that all the components of
the connection URI should be URL-encoded:
+Example connection definitions for all non-SSL connectivity. Note that all the
components of the connection URI should
+be URL-encoded:
-.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_postgres.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
+ :language: python
+ :start-after: [START howto_operator_cloudsql_query_connections]
+ :end-before: [END howto_operator_cloudsql_query_connections]
+
+Similar connection definition for all SSL-enabled connectivity:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
:language: python
:start-after: [START howto_operator_cloudsql_query_connections]
:end-before: [END howto_operator_cloudsql_query_connections]
@@ -586,33 +593,50 @@ It is also possible to configure a connection via
environment variable (note tha
matches the :envvar:`AIRFLOW_CONN_{CONN_ID}` postfix uppercase if you are
using a standard AIRFLOW notation for
defining connection via environment variables):
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
:language: python
- :start-after: [START howto_operator_cloudsql_query_connections]
- :end-before: [END howto_operator_cloudsql_query_connections]
+ :start-after: [START howto_operator_cloudsql_query_connections_env]
+ :end-before: [END howto_operator_cloudsql_query_connections_env]
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
+ :language: python
+ :start-after: [START howto_operator_cloudsql_query_connections_env]
+ :end-before: [END howto_operator_cloudsql_query_connections_env]
+
+
+Using the operator
+""""""""""""""""""
Example operator below is using prepared earlier connection. It might be a
connection_id from the Airflow database
or the connection configured via environment variable (note that the
connection id from the operator matches the
:envvar:`AIRFLOW_CONN_{CONN_ID}` postfix uppercase if you are using a standard
AIRFLOW notation for defining connection
via environment variables):
-.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_postgres.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
:language: python
:start-after: [START howto_operator_cloudsql_query_operators]
:end-before: [END howto_operator_cloudsql_query_operators]
+SSL settings can be also specified on an operator's level. In this case SSL
settings configured in the connection
+will be overridden. One of the ways to do so is specifying paths to each
certificate file as shown below.
+Note that these files will be copied into a temporary location with minimal
required permissions for security
+reasons.
-Using the operator
-""""""""""""""""""
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
+ :language: python
+ :start-after: [START howto_operator_cloudsql_query_operators_ssl]
+ :end-before: [END howto_operator_cloudsql_query_operators_ssl]
+
+You can also save your SSL certificated into a Google Cloud Secret Manager and
provide a secret id. The secret
+format is:
+.. code-block:: python
-Example operators below are using all connectivity options. Note connection id
-from the operator matches the :envvar:`AIRFLOW_CONN_{CONN_ID}` postfix
uppercase. This is
-standard AIRFLOW notation for defining connection via environment variables):
+ {"sslcert": "", "sslkey": "", "sslrootcert": ""}
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
:language: python
- :start-after: [START howto_operator_cloudsql_query_operators]
- :end-before: [END howto_operator_cloudsql_query_operators]
+ :start-after: [START howto_operator_cloudsql_query_operators_ssl_secret_id]
+ :end-before: [END howto_operator_cloudsql_query_operators_ssl_secret_id]
Templating
""""""""""
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index ebff77ea37..48b4189e1a 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1,6 +1,7 @@
aarch
abc
accessor
+AccessSecretVersionResponse
accountmaking
aci
Ack
@@ -946,6 +947,7 @@ linter
linux
ListGenerator
ListInfoTypesResponse
+ListSecretsPager
Liveness
liveness
livy
@@ -1407,10 +1409,13 @@ sdk
sdks
searchpath
SearchResultGenerator
+secretmanager
SecretManagerClient
+SecretManagerServiceClient
secretRef
secretRefs
SecretsManagerBackend
+SecretVersion
securable
securecookie
securestring
@@ -1497,7 +1502,10 @@ sshHook
sshtunnel
SSHTunnelForwarder
ssl
+sslcert
+sslkey
sslmode
+sslrootcert
ssm
Stackdriver
stackdriver
diff --git a/tests/providers/google/cloud/hooks/test_cloud_sql.py
b/tests/providers/google/cloud/hooks/test_cloud_sql.py
index 70ac3cbab7..a5d6e16664 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_sql.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_sql.py
@@ -17,12 +17,13 @@
# under the License.
from __future__ import annotations
+import base64
import json
import os
import platform
import tempfile
from unittest import mock
-from unittest.mock import PropertyMock
+from unittest.mock import PropertyMock, call, mock_open
import aiohttp
import httplib2
@@ -50,6 +51,12 @@ OPERATION_NAME = "test_operation_name"
OPERATION_URL = (
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{PROJECT_ID}/operations/{OPERATION_NAME}"
)
+SSL_CERT = "sslcert.pem"
+SSL_KEY = "sslkey.pem"
+SSL_ROOT_CERT = "sslrootcert.pem"
+CONNECTION_ID = "test-conn-id"
+IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
+SECRET_ID = "test-secret-id"
@pytest.fixture
@@ -783,11 +790,13 @@ class TestCloudSqlDatabaseHook:
],
)
@mock.patch("os.path.isfile")
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook._set_temporary_ssl_file")
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_cloudsql_database_hook_validate_ssl_certs_missing_cert_params(
- self, get_connection, mock_is_file, cert_dict
+ self, get_connection, mock_set_temporary_ssl_file, mock_is_file,
cert_dict
):
mock_is_file.side_effects = True
+ mock_set_temporary_ssl_file.side_effect = cert_dict.values()
connection = Connection()
extras = {"location": "test", "instance": "instance", "database_type":
"postgres", "use_ssl": "True"}
extras.update(cert_dict)
@@ -803,10 +812,18 @@ class TestCloudSqlDatabaseHook:
assert "SSL connections requires" in str(err)
@mock.patch("os.path.isfile")
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook._set_temporary_ssl_file")
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
- def test_cloudsql_database_hook_validate_ssl_certs_with_ssl(self,
get_connection, mock_is_file):
+ def test_cloudsql_database_hook_validate_ssl_certs_with_ssl(
+ self, get_connection, mock_set_temporary_ssl_file, mock_is_file
+ ):
connection = Connection()
mock_is_file.return_value = True
+ mock_set_temporary_ssl_file.side_effect = [
+ "/tmp/cert_file.pem",
+ "/tmp/rootcert_file.pem",
+ "/tmp/key_file.pem",
+ ]
connection.set_extra(
json.dumps(
{
@@ -827,12 +844,18 @@ class TestCloudSqlDatabaseHook:
hook.validate_ssl_certs()
@mock.patch("os.path.isfile")
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook._set_temporary_ssl_file")
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def
test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable(
- self, get_connection, mock_is_file
+ self, get_connection, mock_set_temporary_ssl_file, mock_is_file
):
connection = Connection()
mock_is_file.return_value = False
+ mock_set_temporary_ssl_file.side_effect = [
+ "/tmp/cert_file.pem",
+ "/tmp/rootcert_file.pem",
+ "/tmp/key_file.pem",
+ ]
connection.set_extra(
json.dumps(
{
@@ -1001,6 +1024,366 @@ class TestCloudSqlDatabaseHook:
db_hook = hook.get_database_hook(connection=connection)
assert db_hook is not None
+
@mock.patch(HOOK_STR.format("CloudSQLDatabaseHook._get_ssl_temporary_file_path"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook.get_connection"))
+ def test_ssl_cert_properties(self, mock_get_connection,
mock_get_ssl_temporary_file_path):
+ def side_effect_func(cert_name, cert_path):
+ return f"/tmp/certs/{cert_name}"
+
+ mock_get_ssl_temporary_file_path.side_effect = side_effect_func
+ mock_get_connection.return_value = mock.MagicMock(
+ extra_dejson=dict(database_type="postgres", location="test",
instance="instance")
+ )
+
+ hook = CloudSQLDatabaseHook(
+ gcp_cloudsql_conn_id="cloudsql_connection",
+ default_gcp_project_id="google_connection",
+ ssl_cert=SSL_CERT,
+ ssl_key=SSL_KEY,
+ ssl_root_cert=SSL_ROOT_CERT,
+ )
+ sslcert = hook.sslcert
+ sslkey = hook.sslkey
+ sslrootcert = hook.sslrootcert
+
+ assert hook.ssl_cert == SSL_CERT
+ assert hook.ssl_key == SSL_KEY
+ assert hook.ssl_root_cert == SSL_ROOT_CERT
+ assert sslcert == "/tmp/certs/sslcert"
+ assert sslkey == "/tmp/certs/sslkey"
+ assert sslrootcert == "/tmp/certs/sslrootcert"
+ mock_get_ssl_temporary_file_path.assert_has_calls(
+ [
+ call(cert_name="sslcert", cert_path=SSL_CERT),
+ call(cert_name="sslkey", cert_path=SSL_KEY),
+ call(cert_name="sslrootcert", cert_path=SSL_ROOT_CERT),
+ ]
+ )
+
+ @pytest.mark.parametrize("ssl_name", ["sslcert", "sslkey", "sslrootcert"])
+ @pytest.mark.parametrize(
+ "cert_value, cert_path, extra_cert_path",
+ [
+ (None, None, "/connection/path/to/cert.pem"),
+ (None, "/path/to/cert.pem", None),
+ (None, "/path/to/cert.pem", "/connection/path/to/cert.pem"),
+ (mock.MagicMock(), None, None),
+ (mock.MagicMock(), None, "/connection/path/to/cert.pem"),
+ (mock.MagicMock(), "/path/to/cert.pem", None),
+ (mock.MagicMock(), "/path/to/cert.pem",
"/connection/path/to/cert.pem"),
+ ],
+ )
+
@mock.patch(HOOK_STR.format("CloudSQLDatabaseHook._set_temporary_ssl_file"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook._get_cert_from_secret"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook.get_connection"))
+ def test_get_ssl_temporary_file_path(
+ self,
+ mock_get_connection,
+ mock_get_cert_from_secret,
+ mock_set_temporary_ssl_file,
+ cert_value,
+ cert_path,
+ extra_cert_path,
+ ssl_name,
+ ):
+ expected_cert_file_path = cert_path
+ mock_get_connection.return_value = mock.MagicMock(
+ extra_dejson={
+ "database_type": "postgres",
+ "location": "test",
+ "instance": "instance",
+ ssl_name: extra_cert_path,
+ }
+ )
+ mock_get_cert_from_secret.return_value = cert_value
+ mock_set_temporary_ssl_file.return_value = expected_cert_file_path
+
+ hook = CloudSQLDatabaseHook(
+ gcp_cloudsql_conn_id="cloudsql_connection",
+ default_gcp_project_id="google_connection",
+ ssl_cert=SSL_CERT,
+ ssl_key=SSL_KEY,
+ ssl_root_cert=SSL_ROOT_CERT,
+ )
+ actual_cert_file_path =
hook._get_ssl_temporary_file_path(cert_name=ssl_name, cert_path=cert_path)
+
+ assert actual_cert_file_path == expected_cert_file_path
+ assert hook.extras.get(ssl_name) == extra_cert_path
+ mock_get_cert_from_secret.assert_called_once_with(ssl_name)
+ mock_set_temporary_ssl_file.assert_called_once_with(
+ cert_name=ssl_name, cert_path=cert_path or extra_cert_path,
cert_value=cert_value
+ )
+
+ @pytest.mark.parametrize("ssl_name", ["sslcert", "sslkey", "sslrootcert"])
+
@mock.patch(HOOK_STR.format("CloudSQLDatabaseHook._set_temporary_ssl_file"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook._get_cert_from_secret"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook.get_connection"))
+ def test_get_ssl_temporary_file_path_none(
+ self,
+ mock_get_connection,
+ mock_get_cert_from_secret,
+ mock_set_temporary_ssl_file,
+ ssl_name,
+ ):
+ expected_cert_file_path = None
+ mock_get_connection.return_value = mock.MagicMock(
+ extra_dejson={
+ "database_type": "postgres",
+ "location": "test",
+ "instance": "instance",
+ }
+ )
+ mock_get_cert_from_secret.return_value = None
+ mock_set_temporary_ssl_file.return_value = expected_cert_file_path
+
+ hook = CloudSQLDatabaseHook(
+ gcp_cloudsql_conn_id="cloudsql_connection",
+ default_gcp_project_id="google_connection",
+ ssl_cert=SSL_CERT,
+ ssl_key=SSL_KEY,
+ ssl_root_cert=SSL_ROOT_CERT,
+ )
+ actual_cert_file_path =
hook._get_ssl_temporary_file_path(cert_name=ssl_name, cert_path=None)
+
+ assert actual_cert_file_path == expected_cert_file_path
+ assert hook.extras.get(ssl_name) is None
+ mock_get_cert_from_secret.assert_called_once_with(ssl_name)
+ assert not mock_set_temporary_ssl_file.called
+
+ @pytest.mark.parametrize(
+ "cert_name, cert_value",
+ [
+ ["sslcert", SSL_CERT],
+ ["sslkey", SSL_KEY],
+ ["sslrootcert", SSL_ROOT_CERT],
+ ],
+ )
+ @mock.patch(HOOK_STR.format("GoogleCloudSecretManagerHook"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook.get_connection"))
+ def test_get_cert_from_secret(
+ self,
+ mock_get_connection,
+ mock_secret_hook,
+ cert_name,
+ cert_value,
+ ):
+ mock_get_connection.return_value = mock.MagicMock(
+ extra_dejson={
+ "database_type": "postgres",
+ "location": "test",
+ "instance": "instance",
+ }
+ )
+ mock_secret = mock_secret_hook.return_value.access_secret.return_value
+ mock_secret.payload.data = base64.b64encode(json.dumps({cert_name:
cert_value}).encode("ascii"))
+
+ hook = CloudSQLDatabaseHook(
+ gcp_conn_id=CONNECTION_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ default_gcp_project_id=PROJECT_ID,
+ ssl_secret_id=SECRET_ID,
+ )
+ actual_cert_value = hook._get_cert_from_secret(cert_name=cert_name)
+
+ assert actual_cert_value == cert_value
+ mock_secret_hook.assert_called_once_with(
+ gcp_conn_id=CONNECTION_ID, impersonation_chain=IMPERSONATION_CHAIN
+ )
+ mock_secret_hook.return_value.access_secret.assert_called_once_with(
+ project_id=PROJECT_ID, secret_id=SECRET_ID
+ )
+
+ @pytest.mark.parametrize(
+ "cert_name, cert_value",
+ [
+ ["sslcert", SSL_CERT],
+ ["sslkey", SSL_KEY],
+ ["sslrootcert", SSL_ROOT_CERT],
+ ],
+ )
+ @mock.patch(HOOK_STR.format("GoogleCloudSecretManagerHook"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook.get_connection"))
+ def test_get_cert_from_secret_exception(
+ self,
+ mock_get_connection,
+ mock_secret_hook,
+ cert_name,
+ cert_value,
+ ):
+ mock_get_connection.return_value = mock.MagicMock(
+ extra_dejson={
+ "database_type": "postgres",
+ "location": "test",
+ "instance": "instance",
+ }
+ )
+ mock_secret = mock_secret_hook.return_value.access_secret.return_value
+ mock_secret.payload.data = base64.b64encode(json.dumps({"wrong_key":
cert_value}).encode("ascii"))
+
+ hook = CloudSQLDatabaseHook(
+ gcp_conn_id=CONNECTION_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ default_gcp_project_id=PROJECT_ID,
+ ssl_secret_id=SECRET_ID,
+ )
+
+ with pytest.raises(AirflowException):
+ hook._get_cert_from_secret(cert_name=cert_name)
+
+ mock_secret_hook.assert_called_once_with(
+ gcp_conn_id=CONNECTION_ID, impersonation_chain=IMPERSONATION_CHAIN
+ )
+ mock_secret_hook.return_value.access_secret.assert_called_once_with(
+ project_id=PROJECT_ID, secret_id=SECRET_ID
+ )
+
+ @pytest.mark.parametrize("cert_name", ["sslcert", "sslkey", "sslrootcert"])
+ @mock.patch(HOOK_STR.format("GoogleCloudSecretManagerHook"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook.get_connection"))
+ def test_get_cert_from_secret_none(
+ self,
+ mock_get_connection,
+ mock_secret_hook,
+ cert_name,
+ ):
+ mock_get_connection.return_value = mock.MagicMock(
+ extra_dejson={
+ "database_type": "postgres",
+ "location": "test",
+ "instance": "instance",
+ }
+ )
+
+ hook = CloudSQLDatabaseHook(
+ gcp_conn_id=CONNECTION_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ default_gcp_project_id=PROJECT_ID,
+ )
+ actual_cert_value = hook._get_cert_from_secret(cert_name=cert_name)
+
+ assert actual_cert_value is None
+ assert not mock_secret_hook.called
+ assert not mock_secret_hook.return_value.access_secret.called
+
+ @pytest.mark.parametrize("cert_name", ["sslcert", "sslkey", "sslrootcert"])
+ @mock.patch("builtins.open", new_callable=mock_open, read_data="test-data")
+ @mock.patch(HOOK_STR.format("NamedTemporaryFile"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook.get_connection"))
+ def test_set_temporary_ssl_file_cert_path(
+ self,
+ mock_get_connection,
+ mock_named_temporary_file,
+ mock_open_file,
+ cert_name,
+ ):
+ mock_get_connection.return_value = mock.MagicMock(
+ extra_dejson={
+ "database_type": "postgres",
+ "location": "test",
+ "instance": "instance",
+ }
+ )
+ expected_file_name = "/test/path/to/file"
+ mock_named_temporary_file.return_value.name = expected_file_name
+ source_cert_path = "/source/cert/path/to/file"
+
+ hook = CloudSQLDatabaseHook()
+ actual_path = hook._set_temporary_ssl_file(cert_name=cert_name,
cert_path=source_cert_path)
+
+ assert actual_path == expected_file_name
+ mock_named_temporary_file.assert_called_once_with(mode="w+b",
prefix="/tmp/certs/")
+ mock_open_file.assert_has_calls([call(source_cert_path, "rb")])
+
mock_named_temporary_file.return_value.write.assert_called_once_with("test-data")
+ mock_named_temporary_file.return_value.flush.assert_called_once()
+
+ @pytest.mark.parametrize("cert_name", ["sslcert", "sslkey", "sslrootcert"])
+ @mock.patch(HOOK_STR.format("NamedTemporaryFile"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook.get_connection"))
+ def test_set_temporary_ssl_file_cert_value(
+ self,
+ mock_get_connection,
+ mock_named_temporary_file,
+ cert_name,
+ ):
+ mock_get_connection.return_value = mock.MagicMock(
+ extra_dejson={
+ "database_type": "postgres",
+ "location": "test",
+ "instance": "instance",
+ }
+ )
+ expected_file_name = "/test/path/to/file"
+ mock_named_temporary_file.return_value.name = expected_file_name
+ cert_value = "test-cert-value"
+
+ hook = CloudSQLDatabaseHook()
+ actual_path = hook._set_temporary_ssl_file(cert_name=cert_name,
cert_value=cert_value)
+
+ assert actual_path == expected_file_name
+ mock_named_temporary_file.assert_called_once_with(mode="w+b",
prefix="/tmp/certs/")
+
mock_named_temporary_file.return_value.write.assert_called_once_with(cert_value.encode("ascii"))
+ mock_named_temporary_file.return_value.flush.assert_called_once()
+
+ @pytest.mark.parametrize("cert_name", ["sslcert", "sslkey", "sslrootcert"])
+ @mock.patch(HOOK_STR.format("NamedTemporaryFile"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook.get_connection"))
+ def test_set_temporary_ssl_file_exception(
+ self,
+ mock_get_connection,
+ mock_named_temporary_file,
+ cert_name,
+ ):
+ mock_get_connection.return_value = mock.MagicMock(
+ extra_dejson={
+ "database_type": "postgres",
+ "location": "test",
+ "instance": "instance",
+ }
+ )
+ expected_file_name = "/test/path/to/file"
+ mock_named_temporary_file.return_value.name = expected_file_name
+ cert_value = "test-cert-value"
+ source_cert_path = "/source/cert/path/to/file"
+
+ hook = CloudSQLDatabaseHook()
+
+ with pytest.raises(AirflowException):
+ hook._set_temporary_ssl_file(
+ cert_name=cert_name, cert_value=cert_value,
cert_path=source_cert_path
+ )
+
+ assert not mock_named_temporary_file.called
+ assert not mock_named_temporary_file.return_value.write.called
+ assert not mock_named_temporary_file.return_value.flush.called
+
+ @pytest.mark.parametrize("cert_name", ["sslcert", "sslkey", "sslrootcert"])
+ @mock.patch(HOOK_STR.format("NamedTemporaryFile"))
+ @mock.patch(HOOK_STR.format("CloudSQLDatabaseHook.get_connection"))
+ def test_set_temporary_ssl_file_none(
+ self,
+ mock_get_connection,
+ mock_named_temporary_file,
+ cert_name,
+ ):
+ mock_get_connection.return_value = mock.MagicMock(
+ extra_dejson={
+ "database_type": "postgres",
+ "location": "test",
+ "instance": "instance",
+ }
+ )
+ expected_file_name = "/test/path/to/file"
+ mock_named_temporary_file.return_value.name = expected_file_name
+
+ hook = CloudSQLDatabaseHook()
+
+ actual_path = hook._set_temporary_ssl_file(cert_name=cert_name)
+
+ assert actual_path is None
+ assert not mock_named_temporary_file.called
+ assert not mock_named_temporary_file.return_value.write.called
+ assert not mock_named_temporary_file.return_value.flush.called
+
class TestCloudSqlDatabaseQueryHook:
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
@@ -1085,8 +1468,13 @@ class TestCloudSqlDatabaseQueryHook:
)
self._verify_postgres_connection(get_connection, uri)
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook._set_temporary_ssl_file")
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
- def test_hook_with_correct_parameters_postgres_ssl(self, get_connection):
+ def test_hook_with_correct_parameters_postgres_ssl(self, get_connection,
mock_set_temporary_ssl_file):
+ def side_effect_func(cert_name, cert_path, cert_value):
+ return f"/tmp/{cert_name}"
+
+ mock_set_temporary_ssl_file.side_effect = side_effect_func
uri = (
"gcpcloudsql://user:[email protected]:3200/testdb?database_type=postgres&"
"project_id=example-project&location=europe-west1&instance=testdb&"
@@ -1094,9 +1482,9 @@ class TestCloudSqlDatabaseQueryHook:
"sslkey=/bin/bash&sslrootcert=/bin/bash"
)
connection = self._verify_postgres_connection(get_connection, uri)
- assert "/bin/bash" == connection.extra_dejson["sslkey"]
- assert "/bin/bash" == connection.extra_dejson["sslcert"]
- assert "/bin/bash" == connection.extra_dejson["sslrootcert"]
+ assert "/tmp/sslkey" == connection.extra_dejson["sslkey"]
+ assert "/tmp/sslcert" == connection.extra_dejson["sslcert"]
+ assert "/tmp/sslrootcert" == connection.extra_dejson["sslrootcert"]
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_hook_with_correct_parameters_postgres_proxy_socket(self,
get_connection):
@@ -1157,8 +1545,13 @@ class TestCloudSqlDatabaseQueryHook:
)
self.verify_mysql_connection(get_connection, uri)
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook._set_temporary_ssl_file")
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
- def test_hook_with_correct_parameters_mysql_ssl(self, get_connection):
+ def test_hook_with_correct_parameters_mysql_ssl(self, get_connection,
mock_set_temporary_ssl_file):
+ def side_effect_func(cert_name, cert_path, cert_value):
+ return f"/tmp/{cert_name}"
+
+ mock_set_temporary_ssl_file.side_effect = side_effect_func
uri = (
"gcpcloudsql://user:[email protected]:3200/testdb?database_type=mysql&"
"project_id=example-project&location=europe-west1&instance=testdb&"
@@ -1166,9 +1559,9 @@ class TestCloudSqlDatabaseQueryHook:
"sslkey=/bin/bash&sslrootcert=/bin/bash"
)
connection = self.verify_mysql_connection(get_connection, uri)
- assert "/bin/bash" ==
json.loads(connection.extra_dejson["ssl"])["cert"]
- assert "/bin/bash" == json.loads(connection.extra_dejson["ssl"])["key"]
- assert "/bin/bash" == json.loads(connection.extra_dejson["ssl"])["ca"]
+ assert "/tmp/sslcert" ==
json.loads(connection.extra_dejson["ssl"])["cert"]
+ assert "/tmp/sslkey" ==
json.loads(connection.extra_dejson["ssl"])["key"]
+ assert "/tmp/sslrootcert" ==
json.loads(connection.extra_dejson["ssl"])["ca"]
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_hook_with_correct_parameters_mysql_proxy_socket(self,
get_connection):
diff --git a/tests/providers/google/cloud/hooks/test_secret_manager.py
b/tests/providers/google/cloud/hooks/test_secret_manager.py
index 6f0f1a5339..8be41e5795 100644
--- a/tests/providers/google/cloud/hooks/test_secret_manager.py
+++ b/tests/providers/google/cloud/hooks/test_secret_manager.py
@@ -17,13 +17,18 @@
# under the License.
from __future__ import annotations
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from google.api_core.exceptions import NotFound
from google.cloud.secretmanager_v1.types.service import
AccessSecretVersionResponse
-from airflow.providers.google.cloud.hooks.secret_manager import
SecretsManagerHook
+from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.providers.google.cloud.hooks.secret_manager import (
+ GoogleCloudSecretManagerHook,
+ SecretsManagerHook,
+)
+from airflow.providers.google.common.consts import CLIENT_INFO
from tests.providers.google.cloud.utils.base_gcp_mock import (
GCP_PROJECT_ID_HOOK_UNIT_TEST,
mock_base_gcp_hook_default_project_id,
@@ -32,12 +37,18 @@ from tests.providers.google.cloud.utils.base_gcp_mock
import (
BASE_PACKAGE = "airflow.providers.google.common.hooks.base_google."
SECRETS_HOOK_PACKAGE = "airflow.providers.google.cloud.hooks.secret_manager."
INTERNAL_CLIENT_PACKAGE =
"airflow.providers.google.cloud._internal_client.secret_manager_client"
+SECRET_ID = "test-secret-id"
class TestSecretsManagerHook:
def test_delegate_to_runtime_error(self):
with pytest.raises(RuntimeError):
- SecretsManagerHook(gcp_conn_id="GCP_CONN_ID",
delegate_to="delegate_to")
+ with pytest.warns(
+ AirflowProviderDeprecationWarning,
+ match="The SecretsManagerHook is deprecated and will be
removed after 01.11.2024. "
+ "Please use GoogleCloudSecretManagerHook instead.",
+ ):
+ SecretsManagerHook(gcp_conn_id="GCP_CONN_ID",
delegate_to="delegate_to")
@patch(INTERNAL_CLIENT_PACKAGE + "._SecretManagerClient.client",
return_value=MagicMock())
@patch(
@@ -48,7 +59,12 @@ class TestSecretsManagerHook:
def test_get_missing_key(self, mock_get_credentials, mock_client):
mock_client.secret_version_path.return_value = "full-path"
mock_client.access_secret_version.side_effect = NotFound("test-msg")
- secrets_manager_hook = SecretsManagerHook(gcp_conn_id="test")
+ with pytest.warns(
+ AirflowProviderDeprecationWarning,
+ match="The SecretsManagerHook is deprecated and will be removed
after 01.11.2024. "
+ "Please use GoogleCloudSecretManagerHook instead.",
+ ):
+ secrets_manager_hook = SecretsManagerHook(gcp_conn_id="test")
mock_get_credentials.assert_called_once_with()
secret = secrets_manager_hook.get_secret(secret_id="secret")
mock_client.secret_version_path.assert_called_once_with("example-project",
"secret", "latest")
@@ -66,9 +82,222 @@ class TestSecretsManagerHook:
test_response = AccessSecretVersionResponse()
test_response.payload.data = b"result"
mock_client.access_secret_version.return_value = test_response
- secrets_manager_hook = SecretsManagerHook(gcp_conn_id="test")
+ with pytest.warns(
+ AirflowProviderDeprecationWarning,
+ match="The SecretsManagerHook is deprecated and will be removed
after 01.11.2024. "
+ "Please use GoogleCloudSecretManagerHook instead.",
+ ):
+ secrets_manager_hook = SecretsManagerHook(gcp_conn_id="test")
mock_get_credentials.assert_called_once_with()
secret = secrets_manager_hook.get_secret(secret_id="secret")
mock_client.secret_version_path.assert_called_once_with("example-project",
"secret", "latest")
mock_client.access_secret_version.assert_called_once_with(request={"name":
"full-path"})
assert "result" == secret
+
+
+class TestGoogleCloudSecretManagerHook:
+ def setup_method(self, method):
+ with patch(f"{BASE_PACKAGE}GoogleBaseHook.get_connection",
return_value=MagicMock()):
+ self.hook = GoogleCloudSecretManagerHook()
+
+
@patch(f"{SECRETS_HOOK_PACKAGE}GoogleCloudSecretManagerHook.get_credentials")
+ @patch(f"{SECRETS_HOOK_PACKAGE}SecretManagerServiceClient")
+ def test_client(self, mock_client, mock_get_credentials):
+ mock_client_result = mock_client.return_value
+ mock_credentials = self.hook.get_credentials.return_value
+
+ client_1 = self.hook.client
+ client_2 = self.hook.client
+
+ assert client_1 == mock_client_result
+ assert client_1 == client_2
+ mock_client.assert_called_once_with(credentials=mock_credentials,
client_info=CLIENT_INFO)
+ mock_get_credentials.assert_called_once()
+
+ @patch(f"{SECRETS_HOOK_PACKAGE}GoogleCloudSecretManagerHook.client",
new_callable=PropertyMock)
+ def test_get_conn(self, mock_client):
+ mock_client_result = mock_client.return_value
+
+ client_1 = self.hook.get_conn()
+
+ assert client_1 == mock_client_result
+ mock_client.assert_called_once()
+
+ @pytest.mark.parametrize(
+ "input_secret, expected_secret",
+ [
+ (None, {"replication": {"automatic": {}}}),
+ (mock_secret := MagicMock(), mock_secret), # type:
ignore[name-defined]
+ ],
+ )
+ @patch(f"{SECRETS_HOOK_PACKAGE}GoogleCloudSecretManagerHook.client",
new_callable=PropertyMock)
+ def test_create_secret(self, mock_client, input_secret, expected_secret):
+ expected_parent = f"projects/{GCP_PROJECT_ID_HOOK_UNIT_TEST}"
+ expected_response = mock_client.return_value.create_secret.return_value
+ mock_retry, mock_timeout, mock_metadata = MagicMock(), MagicMock(),
MagicMock()
+
+ actual_response = self.hook.create_secret(
+ project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
+ secret_id=SECRET_ID,
+ secret=input_secret,
+ retry=mock_retry,
+ timeout=mock_timeout,
+ metadata=mock_metadata,
+ )
+
+ assert actual_response == expected_response
+ mock_client.assert_called_once()
+ mock_client.return_value.create_secret.assert_called_once_with(
+ request={
+ "parent": expected_parent,
+ "secret_id": SECRET_ID,
+ "secret": expected_secret,
+ },
+ retry=mock_retry,
+ timeout=mock_timeout,
+ metadata=mock_metadata,
+ )
+
+ @patch(f"{SECRETS_HOOK_PACKAGE}GoogleCloudSecretManagerHook.client",
new_callable=PropertyMock)
+ def test_add_secret_version(self, mock_client):
+ expected_parent =
f"projects/{GCP_PROJECT_ID_HOOK_UNIT_TEST}/secrets/{SECRET_ID}"
+ expected_response =
mock_client.return_value.add_secret_version.return_value
+ mock_payload, mock_retry, mock_timeout, mock_metadata = (MagicMock()
for _ in range(4))
+
+ actual_response = self.hook.add_secret_version(
+ project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
+ secret_id=SECRET_ID,
+ secret_payload=mock_payload,
+ retry=mock_retry,
+ timeout=mock_timeout,
+ metadata=mock_metadata,
+ )
+
+ assert actual_response == expected_response
+ mock_client.assert_called_once()
+ mock_client.return_value.add_secret_version.assert_called_once_with(
+ request={
+ "parent": expected_parent,
+ "payload": mock_payload,
+ },
+ retry=mock_retry,
+ timeout=mock_timeout,
+ metadata=mock_metadata,
+ )
+
+ @patch(f"{SECRETS_HOOK_PACKAGE}GoogleCloudSecretManagerHook.client",
new_callable=PropertyMock)
+ def test_list_secrets(self, mock_client):
+ expected_parent = f"projects/{GCP_PROJECT_ID_HOOK_UNIT_TEST}"
+ expected_response = mock_client.return_value.list_secrets.return_value
+ mock_filter, mock_retry, mock_timeout, mock_metadata = (MagicMock()
for _ in range(4))
+ page_size, page_token = 20, "test-page-token"
+
+ actual_response = self.hook.list_secrets(
+ project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
+ secret_filter=mock_filter,
+ page_size=page_size,
+ page_token=page_token,
+ retry=mock_retry,
+ timeout=mock_timeout,
+ metadata=mock_metadata,
+ )
+
+ assert actual_response == expected_response
+ mock_client.assert_called_once()
+ mock_client.return_value.list_secrets.assert_called_once_with(
+ request={
+ "parent": expected_parent,
+ "page_size": page_size,
+ "page_token": page_token,
+ "filter": mock_filter,
+ },
+ retry=mock_retry,
+ timeout=mock_timeout,
+ metadata=mock_metadata,
+ )
+
+ @pytest.mark.parametrize(
+ "secret_names, secret_id, secret_exists_expected",
+ [
+ ([], SECRET_ID, False),
+ (["secret/name"], SECRET_ID, False),
+ (["secret/name1", "secret/name1"], SECRET_ID, False),
+ ([f"secret/{SECRET_ID}"], SECRET_ID, True),
+ ([f"secret/{SECRET_ID}", "secret/name"], SECRET_ID, True),
+ (["secret/name", f"secret/{SECRET_ID}"], SECRET_ID, True),
+ (["name1", SECRET_ID], SECRET_ID, True),
+ ],
+ )
+ @patch(f"{SECRETS_HOOK_PACKAGE}GoogleCloudSecretManagerHook.client",
new_callable=PropertyMock)
+ @patch(f"{SECRETS_HOOK_PACKAGE}GoogleCloudSecretManagerHook.list_secrets")
+ def test_secret_exists(
+ self, mock_list_secrets, mock_client, secret_names, secret_id,
secret_exists_expected
+ ):
+ list_secrets = []
+ for secret_name in secret_names:
+ mock_secret = MagicMock()
+ mock_secret.name = secret_name
+ list_secrets.append(mock_secret)
+ mock_list_secrets.return_value = list_secrets
+ secret_filter = f"name:{secret_id}"
+
+ secret_exists_actual = self.hook.secret_exists(
+ project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, secret_id=secret_id
+ )
+
+ assert secret_exists_actual == secret_exists_expected
+
mock_client.return_value.secret_path.assert_called_once_with(GCP_PROJECT_ID_HOOK_UNIT_TEST,
secret_id)
+ mock_list_secrets.assert_called_once_with(
+ project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST, page_size=100,
secret_filter=secret_filter
+ )
+
+ @patch(f"{SECRETS_HOOK_PACKAGE}GoogleCloudSecretManagerHook.client",
new_callable=PropertyMock)
+ def test_access_secret(self, mock_client):
+ expected_response =
mock_client.return_value.access_secret_version.return_value
+ mock_retry, mock_timeout, mock_metadata = (MagicMock() for _ in
range(3))
+ secret_version = "test-secret-version"
+ mock_name = mock_client.return_value.secret_version_path.return_value
+
+ actual_response = self.hook.access_secret(
+ project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
+ secret_id=SECRET_ID,
+ secret_version=secret_version,
+ retry=mock_retry,
+ timeout=mock_timeout,
+ metadata=mock_metadata,
+ )
+
+ assert actual_response == expected_response
+ assert mock_client.call_count == 2
+ mock_client.return_value.secret_version_path.assert_called_once_with(
+ GCP_PROJECT_ID_HOOK_UNIT_TEST, SECRET_ID, secret_version
+ )
+ mock_client.return_value.access_secret_version.assert_called_once_with(
+ request={"name": mock_name},
+ retry=mock_retry,
+ timeout=mock_timeout,
+ metadata=mock_metadata,
+ )
+
+ @patch(f"{SECRETS_HOOK_PACKAGE}GoogleCloudSecretManagerHook.client",
new_callable=PropertyMock)
+ def test_delete_secret(self, mock_client):
+ mock_retry, mock_timeout, mock_metadata = (MagicMock() for _ in
range(3))
+ mock_name = mock_client.return_value.secret_path.return_value
+
+ actual_response = self.hook.delete_secret(
+ project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
+ secret_id=SECRET_ID,
+ retry=mock_retry,
+ timeout=mock_timeout,
+ metadata=mock_metadata,
+ )
+
+ assert actual_response is None
+ assert mock_client.call_count == 2
+
mock_client.return_value.secret_path.assert_called_once_with(GCP_PROJECT_ID_HOOK_UNIT_TEST,
SECRET_ID)
+ mock_client.return_value.delete_secret.assert_called_once_with(
+ request={"name": mock_name},
+ retry=mock_retry,
+ timeout=mock_timeout,
+ metadata=mock_metadata,
+ )
diff --git
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
new file mode 100644
index 0000000000..9a44a64a6b
--- /dev/null
+++ b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
@@ -0,0 +1,572 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that performs query in a Cloud SQL instance.
+"""
+
+from __future__ import annotations
+
+import logging
+import os
+from collections import namedtuple
+from copy import deepcopy
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Iterable
+
+from googleapiclient import discovery
+
+from airflow import settings
+from airflow.decorators import task, task_group
+from airflow.models.connection import Connection
+from airflow.models.dag import DAG
+from airflow.operators.bash import BashOperator
+from airflow.providers.google.cloud.operators.cloud_sql import (
+ CloudSQLCreateInstanceDatabaseOperator,
+ CloudSQLCreateInstanceOperator,
+ CloudSQLDeleteInstanceOperator,
+ CloudSQLExecuteQueryOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "cloudsql-query"
+REGION = "us-central1"
+HOME_DIR = Path.home()
+
+COMPOSER_ENVIRONMENT = os.environ.get("COMPOSER_ENVIRONMENT", "")
+if COMPOSER_ENVIRONMENT:
+ # We assume that the test is launched in Cloud Composer environment
because the reserved environment
+ # variable is assigned
(https://cloud.google.com/composer/docs/composer-2/set-environment-variables)
+ GET_COMPOSER_NETWORK_COMMAND = """
+ gcloud composer environments describe $COMPOSER_ENVIRONMENT \
+ --location=$COMPOSER_LOCATION \
+ --project=$GCP_PROJECT \
+ --format="value(config.nodeConfig.network)"
+ """
+else:
+ # The test is launched locally
+ GET_COMPOSER_NETWORK_COMMAND = "echo"
+
+
+def run_in_composer():
+ return bool(COMPOSER_ENVIRONMENT)
+
+
+CLOUD_SQL_INSTANCE_NAME_TEMPLATE = f"{ENV_ID}-{DAG_ID}".replace("_", "-")
+CLOUD_SQL_INSTANCE_CREATE_BODY_TEMPLATE: dict[str, Any] = {
+ "name": CLOUD_SQL_INSTANCE_NAME_TEMPLATE,
+ "settings": {
+ "tier": "db-custom-1-3840",
+ "dataDiskSizeGb": 30,
+ "pricingPlan": "PER_USE",
+ "ipConfiguration": {},
+ },
+ # For using a different database version please check the link below.
+ #
https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/SqlDatabaseVersion
+ "databaseVersion": "1.2.3",
+ "region": REGION,
+ "ipConfiguration": {
+ "ipv4Enabled": True,
+ "requireSsl": False,
+ "authorizedNetworks": [
+ {"value": "0.0.0.0/0"},
+ ],
+ },
+}
+DB_PROVIDERS: Iterable[dict[str, str]] = (
+ {
+ "database_type": "postgres",
+ "port": "5432",
+ "database_version": "POSTGRES_15",
+ "cloud_sql_instance_name":
f"{CLOUD_SQL_INSTANCE_NAME_TEMPLATE}-postgres",
+ },
+ {
+ "database_type": "mysql",
+ "port": "3306",
+ "database_version": "MYSQL_8_0",
+ "cloud_sql_instance_name": f"{CLOUD_SQL_INSTANCE_NAME_TEMPLATE}-mysql",
+ },
+)
+
+
+def ip_configuration() -> dict[str, Any]:
+ """Generates an ip configuration for a CloudSQL instance creation body"""
+ if run_in_composer():
+ # Use connection to Cloud SQL instance via Private IP within the Cloud
Composer's network.
+ return {
+ "ipv4Enabled": True,
+ "requireSsl": False,
+ "enablePrivatePathForGoogleCloudServices": True,
+ "privateNetwork": """{{
task_instance.xcom_pull('get_composer_network')}}""",
+ }
+ else:
+ # Use connection to Cloud SQL instance via Public IP from anywhere
(mask 0.0.0.0/0).
+ # Consider specifying your network mask
+ # for allowing requests only from the trusted sources, not from
anywhere.
+ return {
+ "ipv4Enabled": True,
+ "requireSsl": False,
+ "authorizedNetworks": [
+ {"value": "0.0.0.0/0"},
+ ],
+ }
+
+
+def cloud_sql_instance_create_body(database_provider: dict[str, Any]) ->
dict[str, Any]:
+ """Generates a CloudSQL instance creation body"""
+ create_body: dict[str, Any] =
deepcopy(CLOUD_SQL_INSTANCE_CREATE_BODY_TEMPLATE)
+ create_body["name"] = database_provider["cloud_sql_instance_name"]
+ create_body["databaseVersion"] = database_provider["database_version"]
+ create_body["settings"]["ipConfiguration"] = ip_configuration()
+ return create_body
+
+
+CLOUD_SQL_DATABASE_NAME = "test_db"
+CLOUD_SQL_USER = "test_user"
+CLOUD_SQL_PASSWORD = "JoxHlwrPzwch0gz9"
+CLOUD_SQL_IP_ADDRESS = "127.0.0.1"
+CLOUD_SQL_PUBLIC_PORT = 5432
+
+
+def cloud_sql_database_create_body(instance: str) -> dict[str, Any]:
+ """Generates a CloudSQL database creation body"""
+ return {
+ "instance": instance,
+ "name": CLOUD_SQL_DATABASE_NAME,
+ "project": PROJECT_ID,
+ }
+
+
+CLOUD_SQL_INSTANCE_NAME = ""
+DATABASE_TYPE = "" # "postgres|mysql|mssql"
+
+
+# [START howto_operator_cloudsql_query_connections]
+# Connect via proxy over TCP
+CONNECTION_PROXY_TCP_KWARGS = {
+ "conn_type": "gcpcloudsql",
+ "login": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ "host": CLOUD_SQL_IP_ADDRESS,
+ "port": CLOUD_SQL_PUBLIC_PORT,
+ "schema": CLOUD_SQL_DATABASE_NAME,
+ "extra": {
+ "database_type": DATABASE_TYPE,
+ "project_id": PROJECT_ID,
+ "location": REGION,
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "use_proxy": "True",
+ "sql_proxy_use_tcp": "True",
+ },
+}
+
+# Connect via proxy over UNIX socket (specific proxy version)
+CONNECTION_PROXY_SOCKET_KWARGS = {
+ "conn_type": "gcpcloudsql",
+ "login": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ "host": CLOUD_SQL_IP_ADDRESS,
+ "port": CLOUD_SQL_PUBLIC_PORT,
+ "schema": CLOUD_SQL_DATABASE_NAME,
+ "extra": {
+ "database_type": DATABASE_TYPE,
+ "project_id": PROJECT_ID,
+ "location": REGION,
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "use_proxy": "True",
+ "sql_proxy_version": "v1.33.9",
+ "sql_proxy_use_tcp": "False",
+ },
+}
+
+# Connect directly via TCP (non-SSL)
+CONNECTION_PUBLIC_TCP_KWARGS = {
+ "conn_type": "gcpcloudsql",
+ "login": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ "host": CLOUD_SQL_IP_ADDRESS,
+ "port": CLOUD_SQL_PUBLIC_PORT,
+ "schema": CLOUD_SQL_DATABASE_NAME,
+ "extra": {
+ "database_type": DATABASE_TYPE,
+ "project_id": PROJECT_ID,
+ "location": REGION,
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "use_proxy": "False",
+ "use_ssl": "False",
+ },
+}
+
+# [END howto_operator_cloudsql_query_connections]
+
+
+CONNECTION_PROXY_SOCKET_ID = f"{DAG_ID}_{ENV_ID}_proxy_socket"
+CONNECTION_PROXY_TCP_ID = f"{DAG_ID}_{ENV_ID}_proxy_tcp"
+CONNECTION_PUBLIC_TCP_ID = f"{DAG_ID}_{ENV_ID}_public_tcp"
+
+ConnectionConfig = namedtuple("ConnectionConfig", "id kwargs")
+CONNECTIONS = [
+ ConnectionConfig(id=CONNECTION_PROXY_SOCKET_ID,
kwargs=CONNECTION_PROXY_SOCKET_KWARGS),
+ ConnectionConfig(id=CONNECTION_PROXY_TCP_ID,
kwargs=CONNECTION_PROXY_TCP_KWARGS),
+ ConnectionConfig(id=CONNECTION_PUBLIC_TCP_ID,
kwargs=CONNECTION_PUBLIC_TCP_KWARGS),
+]
+
+SQL = [
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
+ "INSERT INTO TABLE_TEST VALUES (0)",
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)",
+ "DROP TABLE TABLE_TEST",
+ "DROP TABLE TABLE_TEST2",
+]
+
+DELETE_CONNECTION_COMMAND = "airflow connections delete {}"
+
+# [START howto_operator_cloudsql_query_connections_env]
+
+# The connections below are created using one of the standard approaches - via
environment
+# variables named AIRFLOW_CONN_* . The connections can also be created in the
database
+# of AIRFLOW (using command line or UI).
+
+postgres_kwargs = {
+ "user": "user",
+ "password": "password",
+ "public_ip": "public_ip",
+ "public_port": "public_port",
+ "database": "database",
+ "project_id": "project_id",
+ "location": "location",
+ "instance": "instance",
+ "client_cert_file": "client_cert_file",
+ "client_key_file": "client_key_file",
+ "server_ca_file": "server_ca_file",
+}
+
+# Postgres: connect via proxy over TCP
+os.environ["AIRFLOW_CONN_PROXY_POSTGRES_TCP"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=postgres&"
+ "project_id={project_id}&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=True&"
+ "sql_proxy_use_tcp=True".format(**postgres_kwargs)
+)
+
+# Postgres: connect via proxy over UNIX socket (specific proxy version)
+os.environ["AIRFLOW_CONN_PROXY_POSTGRES_SOCKET"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=postgres&"
+ "project_id={project_id}&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=True&"
+ "sql_proxy_version=v1.13&"
+ "sql_proxy_use_tcp=False".format(**postgres_kwargs)
+)
+
+# Postgres: connect directly via TCP (non-SSL)
+os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=postgres&"
+ "project_id={project_id}&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=False&"
+ "use_ssl=False".format(**postgres_kwargs)
+)
+
+# Postgres: connect directly via TCP (SSL)
+os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=postgres&"
+ "project_id={project_id}&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=False&"
+ "use_ssl=True&"
+ "sslcert={client_cert_file}&"
+ "sslkey={client_key_file}&"
+ "sslrootcert={server_ca_file}".format(**postgres_kwargs)
+)
+
+mysql_kwargs = {
+ "user": "user",
+ "password": "password",
+ "public_ip": "public_ip",
+ "public_port": "public_port",
+ "database": "database",
+ "project_id": "project_id",
+ "location": "location",
+ "instance": "instance",
+ "client_cert_file": "client_cert_file",
+ "client_key_file": "client_key_file",
+ "server_ca_file": "server_ca_file",
+}
+
+# MySQL: connect via proxy over TCP (specific proxy version)
+os.environ["AIRFLOW_CONN_PROXY_MYSQL_TCP"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=mysql&"
+ "project_id={project_id}&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=True&"
+ "sql_proxy_version=v1.13&"
+ "sql_proxy_use_tcp=True".format(**mysql_kwargs)
+)
+
+# MySQL: connect via proxy over UNIX socket using pre-downloaded Cloud Sql
Proxy binary
+os.environ["AIRFLOW_CONN_PROXY_MYSQL_SOCKET"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=mysql&"
+ "project_id={project_id}&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=True&"
+ "sql_proxy_use_tcp=False".format(**mysql_kwargs)
+)
+
+# MySQL: connect directly via TCP (non-SSL)
+os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=mysql&"
+ "project_id={project_id}&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=False&"
+ "use_ssl=False".format(**mysql_kwargs)
+)
+
+# MySQL: connect directly via TCP (SSL) and with fixed Cloud Sql Proxy binary
path
+os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=mysql&"
+ "project_id={project_id}&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=False&"
+ "use_ssl=True&"
+ "sslcert={client_cert_file}&"
+ "sslkey={client_key_file}&"
+ "sslrootcert={server_ca_file}".format(**mysql_kwargs)
+)
+
+# Special case: MySQL: connect directly via TCP (SSL) and with fixed Cloud Sql
+# Proxy binary path AND with missing project_id
+os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL_NO_PROJECT_ID"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=mysql&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=False&"
+ "use_ssl=True&"
+ "sslcert={client_cert_file}&"
+ "sslkey={client_key_file}&"
+ "sslrootcert={server_ca_file}".format(**mysql_kwargs)
+)
+# [END howto_operator_cloudsql_query_connections_env]
+
+
+log = logging.getLogger(__name__)
+
+
+with DAG(
+ dag_id=DAG_ID,
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "cloudsql", "postgres"],
+) as dag:
+ get_composer_network = BashOperator(
+ task_id="get_composer_network",
+ bash_command=GET_COMPOSER_NETWORK_COMMAND,
+ do_xcom_push=True,
+ )
+
+ for db_provider in DB_PROVIDERS:
+ database_type: str = db_provider["database_type"]
+ cloud_sql_instance_name: str = db_provider["cloud_sql_instance_name"]
+
+ create_cloud_sql_instance = CloudSQLCreateInstanceOperator(
+ task_id=f"create_cloud_sql_instance_{database_type}",
+ project_id=PROJECT_ID,
+ instance=cloud_sql_instance_name,
+ body=cloud_sql_instance_create_body(database_provider=db_provider),
+ )
+
+ create_database = CloudSQLCreateInstanceDatabaseOperator(
+ task_id=f"create_database_{database_type}",
+
body=cloud_sql_database_create_body(instance=cloud_sql_instance_name),
+ instance=cloud_sql_instance_name,
+ )
+
+ @task(task_id=f"create_user_{database_type}")
+ def create_user(instance: str) -> None:
+ with discovery.build("sqladmin", "v1beta4") as service:
+ request = service.users().insert(
+ project=PROJECT_ID,
+ instance=instance,
+ body={
+ "name": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ },
+ )
+ request.execute()
+ return None
+
+ create_user_task = create_user(instance=cloud_sql_instance_name)
+
+ @task(task_id=f"get_ip_address_{database_type}")
+ def get_ip_address(instance: str) -> str | None:
+ """Returns a Cloud SQL instance IP address.
+
+ If the test is running in Cloud Composer, the Private IP address
is used, otherwise Public IP."""
+ with discovery.build("sqladmin", "v1beta4") as service:
+ request = service.connect().get(
+ project=PROJECT_ID,
+ instance=instance,
+ fields="ipAddresses",
+ )
+ response = request.execute()
+ for ip_item in response.get("ipAddresses", []):
+ if run_in_composer():
+ if ip_item["type"] == "PRIVATE":
+ return ip_item["ipAddress"]
+ else:
+ if ip_item["type"] == "PRIMARY":
+ return ip_item["ipAddress"]
+ return None
+
+ get_ip_address_task = get_ip_address(instance=cloud_sql_instance_name)
+
+ @task(task_id=f"create_connection_{database_type}")
+ def create_connection(
+ connection_id: str,
+ instance: str,
+ db_type: str,
+ ip_address: str,
+ port: str,
+ kwargs: dict[str, Any],
+ ) -> str | None:
+ session = settings.Session()
+ if session.query(Connection).filter(Connection.conn_id ==
connection_id).first():
+ log.warning("Connection '%s' already exists", connection_id)
+ return connection_id
+
+ connection: dict[str, Any] = deepcopy(kwargs)
+ connection["extra"]["instance"] = instance
+ connection["host"] = ip_address
+ connection["extra"]["database_type"] = db_type
+ connection["port"] = port
+ _conn = Connection(conn_id=connection_id, **connection)
+ session.add(_conn)
+ session.commit()
+ log.info("Connection created: '%s'", connection_id)
+ return connection_id
+
+ @task_group(group_id=f"create_connections_{database_type}")
+ def create_connections(instance: str, db_type: str, ip_address: str,
port: str):
+ for conn in CONNECTIONS:
+ conn_id = f"{conn.id}_{database_type}"
+ create_connection(
+ connection_id=conn_id,
+ instance=instance,
+ db_type=db_type,
+ ip_address=ip_address,
+ port=port,
+ kwargs=conn.kwargs,
+ )
+
+ create_connections_task = create_connections(
+ instance=cloud_sql_instance_name,
+ db_type=database_type,
+ ip_address=get_ip_address_task,
+ port=db_provider["port"],
+ )
+
+ @task_group(group_id=f"execute_queries_{database_type}")
+ def execute_queries(db_type: str):
+ prev_task = None
+ for conn in CONNECTIONS:
+ connection_id = f"{conn.id}_{db_type}"
+ task_id = f"execute_query_{connection_id}"
+
+ # [START howto_operator_cloudsql_query_operators]
+ query_task = CloudSQLExecuteQueryOperator(
+ gcp_cloudsql_conn_id=connection_id,
+ task_id=task_id,
+ sql=SQL,
+ )
+ # [END howto_operator_cloudsql_query_operators]
+
+ if prev_task:
+ prev_task >> query_task
+ prev_task = query_task
+
+ execute_queries_task = execute_queries(db_type=database_type)
+
+ @task_group(group_id=f"teardown_{database_type}")
+ def teardown(instance: str, db_type: str):
+ task_id = f"delete_cloud_sql_instance_{db_type}"
+ CloudSQLDeleteInstanceOperator(
+ task_id=task_id,
+ project_id=PROJECT_ID,
+ instance=instance,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ for conn in CONNECTIONS:
+ connection_id = f"{conn.id}_{db_type}"
+ BashOperator(
+ task_id=f"delete_connection_{connection_id}",
+
bash_command=DELETE_CONNECTION_COMMAND.format(connection_id),
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ teardown_task = teardown(instance=cloud_sql_instance_name,
db_type=database_type)
+
+ (
+ # TEST SETUP
+ get_composer_network
+ >> create_cloud_sql_instance
+ >> [
+ create_database,
+ create_user_task,
+ get_ip_address_task,
+ ]
+ >> create_connections_task
+ # TEST BODY
+ >> execute_queries_task
+ # TEST TEARDOWN
+ >> teardown_task
+ )
+
+ # ### Everything below this line is not part of example ###
+ # ### Just for system tests purpose ###
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_mysql.py
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_mysql.py
deleted file mode 100644
index f869ac6f37..0000000000
---
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_mysql.py
+++ /dev/null
@@ -1,285 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Example Airflow DAG that performs query in a Cloud SQL instance for MySQL.
-"""
-
-from __future__ import annotations
-
-import logging
-import os
-from collections import namedtuple
-from copy import deepcopy
-from datetime import datetime
-
-from googleapiclient import discovery
-
-from airflow import settings
-from airflow.decorators import task, task_group
-from airflow.models.connection import Connection
-from airflow.models.dag import DAG
-from airflow.operators.bash import BashOperator
-from airflow.providers.google.cloud.operators.cloud_sql import (
- CloudSQLCreateInstanceDatabaseOperator,
- CloudSQLCreateInstanceOperator,
- CloudSQLDeleteInstanceOperator,
- CloudSQLExecuteQueryOperator,
-)
-from airflow.utils.trigger_rule import TriggerRule
-
-# mypy: disable-error-code="call-overload"
-
-
-ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
-PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
-DAG_ID = "cloudsql-query-mysql"
-REGION = "us-central1"
-
-CLOUD_SQL_INSTANCE_NAME = f"{ENV_ID}-{DAG_ID}".replace("_", "-")
-CLOUD_SQL_DATABASE_NAME = "test_db"
-CLOUD_SQL_USER = "test_user"
-CLOUD_SQL_PASSWORD = "JoxHlwrPzwch0gz9"
-CLOUD_SQL_PUBLIC_IP = "127.0.0.1"
-CLOUD_SQL_PUBLIC_PORT = 3306
-CLOUD_SQL_DATABASE_CREATE_BODY = {
- "instance": CLOUD_SQL_INSTANCE_NAME,
- "name": CLOUD_SQL_DATABASE_NAME,
- "project": PROJECT_ID,
-}
-
-CLOUD_SQL_INSTANCE_CREATION_BODY = {
- "name": CLOUD_SQL_INSTANCE_NAME,
- "settings": {
- "tier": "db-custom-1-3840",
- "dataDiskSizeGb": 30,
- "ipConfiguration": {
- "ipv4Enabled": True,
- "requireSsl": False,
- # Consider specifying your network mask
- # for allowing requests only from the trusted sources, not from
anywhere
- "authorizedNetworks": [
- {"value": "0.0.0.0/0"},
- ],
- },
- "pricingPlan": "PER_USE",
- },
- # For using a different database version please check the link below
- #
https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/SqlDatabaseVersion
- "databaseVersion": "MYSQL_8_0",
- "region": REGION,
-}
-
-SQL = [
- "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
- "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
- "INSERT INTO TABLE_TEST VALUES (0)",
- "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)",
- "DROP TABLE TABLE_TEST",
- "DROP TABLE TABLE_TEST2",
-]
-
-# Postgres: connect via proxy over TCP
-CONNECTION_PROXY_TCP_ID = f"connection_{DAG_ID}_{ENV_ID}_proxy_tcp"
-CONNECTION_PROXY_TCP_KWARGS = {
- "conn_type": "gcpcloudsql",
- "login": CLOUD_SQL_USER,
- "password": CLOUD_SQL_PASSWORD,
- "host": CLOUD_SQL_PUBLIC_IP,
- "port": CLOUD_SQL_PUBLIC_PORT,
- "schema": CLOUD_SQL_DATABASE_NAME,
- "extra": {
- "database_type": "mysql",
- "project_id": PROJECT_ID,
- "location": REGION,
- "instance": CLOUD_SQL_INSTANCE_NAME,
- "use_proxy": "True",
- "sql_proxy_use_tcp": "True",
- },
-}
-
-# Postgres: connect via proxy over UNIX socket (specific proxy version)
-CONNECTION_PROXY_SOCKET_ID = f"connection_{DAG_ID}_{ENV_ID}_proxy_socket"
-CONNECTION_PROXY_SOCKET_KWARGS = {
- "conn_type": "gcpcloudsql",
- "login": CLOUD_SQL_USER,
- "password": CLOUD_SQL_PASSWORD,
- "host": CLOUD_SQL_PUBLIC_IP,
- "port": CLOUD_SQL_PUBLIC_PORT,
- "schema": CLOUD_SQL_DATABASE_NAME,
- "extra": {
- "database_type": "mysql",
- "project_id": PROJECT_ID,
- "location": REGION,
- "instance": CLOUD_SQL_INSTANCE_NAME,
- "use_proxy": "True",
- "sql_proxy_version": "v1.33.9",
- "sql_proxy_use_tcp": "False",
- },
-}
-
-# Postgres: connect directly via TCP (non-SSL)
-CONNECTION_PUBLIC_TCP_ID = f"connection_{DAG_ID}_{ENV_ID}_public_tcp"
-CONNECTION_PUBLIC_TCP_KWARGS = {
- "conn_type": "gcpcloudsql",
- "login": CLOUD_SQL_USER,
- "password": CLOUD_SQL_PASSWORD,
- "host": CLOUD_SQL_PUBLIC_IP,
- "port": CLOUD_SQL_PUBLIC_PORT,
- "schema": CLOUD_SQL_DATABASE_NAME,
- "extra": {
- "database_type": "mysql",
- "project_id": PROJECT_ID,
- "location": REGION,
- "instance": CLOUD_SQL_INSTANCE_NAME,
- "use_proxy": "False",
- "use_ssl": "False",
- },
-}
-
-ConnectionConfig = namedtuple("ConnectionConfig", "id kwargs use_public_ip")
-CONNECTIONS = [
- ConnectionConfig(id=CONNECTION_PROXY_TCP_ID,
kwargs=CONNECTION_PROXY_TCP_KWARGS, use_public_ip=False),
- ConnectionConfig(
- id=CONNECTION_PROXY_SOCKET_ID, kwargs=CONNECTION_PROXY_SOCKET_KWARGS,
use_public_ip=False
- ),
- ConnectionConfig(id=CONNECTION_PUBLIC_TCP_ID,
kwargs=CONNECTION_PUBLIC_TCP_KWARGS, use_public_ip=True),
-]
-
-log = logging.getLogger(__name__)
-
-
-with DAG(
- dag_id=DAG_ID,
- start_date=datetime(2021, 1, 1),
- catchup=False,
- tags=["example", "cloudsql", "mysql"],
-) as dag:
- create_cloud_sql_instance = CloudSQLCreateInstanceOperator(
- task_id="create_cloud_sql_instance",
- project_id=PROJECT_ID,
- instance=CLOUD_SQL_INSTANCE_NAME,
- body=CLOUD_SQL_INSTANCE_CREATION_BODY,
- )
-
- create_database = CloudSQLCreateInstanceDatabaseOperator(
- task_id="create_database", body=CLOUD_SQL_DATABASE_CREATE_BODY,
instance=CLOUD_SQL_INSTANCE_NAME
- )
-
- @task
- def create_user() -> None:
- with discovery.build("sqladmin", "v1beta4") as service:
- request = service.users().insert(
- project=PROJECT_ID,
- instance=CLOUD_SQL_INSTANCE_NAME,
- body={
- "name": CLOUD_SQL_USER,
- "password": CLOUD_SQL_PASSWORD,
- },
- )
- request.execute()
-
- @task
- def get_public_ip() -> str | None:
- with discovery.build("sqladmin", "v1beta4") as service:
- request = service.connect().get(
- project=PROJECT_ID, instance=CLOUD_SQL_INSTANCE_NAME,
fields="ipAddresses"
- )
- response = request.execute()
- for ip_item in response.get("ipAddresses", []):
- if ip_item["type"] == "PRIMARY":
- return ip_item["ipAddress"]
- return None
-
- @task
- def create_connection(connection_id: str, connection_kwargs: dict,
use_public_ip: bool, **kwargs) -> None:
- session = settings.Session()
- if session.query(Connection).filter(Connection.conn_id ==
connection_id).first():
- log.warning("Connection '%s' already exists", connection_id)
- return None
- _connection_kwargs = deepcopy(connection_kwargs)
- if use_public_ip:
- public_ip = kwargs["ti"].xcom_pull(task_ids="get_public_ip")
- _connection_kwargs["host"] = public_ip
- connection = Connection(conn_id=connection_id, **_connection_kwargs)
- session.add(connection)
- session.commit()
- log.info("Connection created: '%s'", connection_id)
-
- @task_group(group_id="create_connections")
- def create_connections():
- for con in CONNECTIONS:
- create_connection(
- connection_id=con.id,
- connection_kwargs=con.kwargs,
- use_public_ip=con.use_public_ip,
- )
-
- @task_group(group_id="execute_queries")
- def execute_queries():
- prev_task = None
- for conn in CONNECTIONS:
- connection_id = conn.id
- task_id = "execute_query_" + conn.id
- query_task = CloudSQLExecuteQueryOperator(
- gcp_cloudsql_conn_id=connection_id,
- task_id=task_id,
- sql=SQL,
- )
-
- if prev_task:
- prev_task >> query_task
- prev_task = query_task
-
- @task_group(group_id="teardown")
- def teardown():
- CloudSQLDeleteInstanceOperator(
- task_id="delete_cloud_sql_instance",
- project_id=PROJECT_ID,
- instance=CLOUD_SQL_INSTANCE_NAME,
- trigger_rule=TriggerRule.ALL_DONE,
- )
-
- for con in CONNECTIONS:
- BashOperator(
- task_id=f"delete_connection_{con.id}",
- bash_command=f"airflow connections delete {con.id}",
- trigger_rule=TriggerRule.ALL_DONE,
- )
-
- (
- # TEST SETUP
- create_cloud_sql_instance
- >> [create_database, create_user(), get_public_ip()]
- >> create_connections()
- # TEST BODY
- >> execute_queries()
- # TEST TEARDOWN
- >> teardown()
- )
-
- from tests.system.utils.watcher import watcher
-
- # This test needs watcher in order to properly mark success/failure
- # when "tearDown" task with trigger rule is part of the DAG
- list(dag.tasks) >> watcher()
-
-
-from tests.system.utils import get_test_run # noqa: E402
-
-# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
-test_run = get_test_run(dag)
diff --git
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_postgres.py
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_postgres.py
deleted file mode 100644
index e1f2d72e58..0000000000
---
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_postgres.py
+++ /dev/null
@@ -1,290 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Example Airflow DAG that performs query in a Cloud SQL instance for Postgres.
-"""
-
-from __future__ import annotations
-
-import logging
-import os
-from collections import namedtuple
-from copy import deepcopy
-from datetime import datetime
-
-from googleapiclient import discovery
-
-from airflow import settings
-from airflow.decorators import task, task_group
-from airflow.models.connection import Connection
-from airflow.models.dag import DAG
-from airflow.operators.bash import BashOperator
-from airflow.providers.google.cloud.operators.cloud_sql import (
- CloudSQLCreateInstanceDatabaseOperator,
- CloudSQLCreateInstanceOperator,
- CloudSQLDeleteInstanceOperator,
- CloudSQLExecuteQueryOperator,
-)
-from airflow.utils.trigger_rule import TriggerRule
-
-ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
-PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
-DAG_ID = "cloudsql-query-pg"
-REGION = "us-central1"
-
-CLOUD_SQL_INSTANCE_NAME = f"instance-{ENV_ID}-{DAG_ID}".replace("_", "-")
-CLOUD_SQL_DATABASE_NAME = "test_db"
-CLOUD_SQL_USER = "test_user"
-CLOUD_SQL_PASSWORD = "JoxHlwrPzwch0gz9"
-CLOUD_SQL_PUBLIC_IP = "127.0.0.1"
-CLOUD_SQL_PUBLIC_PORT = 5432
-CLOUD_SQL_DATABASE_CREATE_BODY = {
- "instance": CLOUD_SQL_INSTANCE_NAME,
- "name": CLOUD_SQL_DATABASE_NAME,
- "project": PROJECT_ID,
-}
-
-CLOUD_SQL_INSTANCE_CREATION_BODY = {
- "name": CLOUD_SQL_INSTANCE_NAME,
- "settings": {
- "tier": "db-custom-1-3840",
- "dataDiskSizeGb": 30,
- "ipConfiguration": {
- "ipv4Enabled": True,
- "requireSsl": False,
- # Consider specifying your network mask
- # for allowing requests only from the trusted sources, not from
anywhere
- "authorizedNetworks": [
- {"value": "0.0.0.0/0"},
- ],
- },
- "pricingPlan": "PER_USE",
- },
- # For using a different database version please check the link below
- #
https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/SqlDatabaseVersion
- "databaseVersion": "POSTGRES_15",
- "region": REGION,
-}
-
-SQL = [
- "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
- "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
- "INSERT INTO TABLE_TEST VALUES (0)",
- "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)",
- "DROP TABLE TABLE_TEST",
- "DROP TABLE TABLE_TEST2",
-]
-
-# [START howto_operator_cloudsql_query_connections]
-
-# Postgres: connect via proxy over TCP
-CONNECTION_PROXY_TCP_KWARGS = {
- "conn_type": "gcpcloudsql",
- "login": CLOUD_SQL_USER,
- "password": CLOUD_SQL_PASSWORD,
- "host": CLOUD_SQL_PUBLIC_IP,
- "port": CLOUD_SQL_PUBLIC_PORT,
- "schema": CLOUD_SQL_DATABASE_NAME,
- "extra": {
- "database_type": "postgres",
- "project_id": PROJECT_ID,
- "location": REGION,
- "instance": CLOUD_SQL_INSTANCE_NAME,
- "use_proxy": "True",
- "sql_proxy_use_tcp": "True",
- },
-}
-
-# Postgres: connect via proxy over UNIX socket (specific proxy version)
-CONNECTION_PROXY_SOCKET_KWARGS = {
- "conn_type": "gcpcloudsql",
- "login": CLOUD_SQL_USER,
- "password": CLOUD_SQL_PASSWORD,
- "host": CLOUD_SQL_PUBLIC_IP,
- "port": CLOUD_SQL_PUBLIC_PORT,
- "schema": CLOUD_SQL_DATABASE_NAME,
- "extra": {
- "database_type": "postgres",
- "project_id": PROJECT_ID,
- "location": REGION,
- "instance": CLOUD_SQL_INSTANCE_NAME,
- "use_proxy": "True",
- "sql_proxy_version": "v1.33.9",
- "sql_proxy_use_tcp": "False",
- },
-}
-
-# Postgres: connect directly via TCP (non-SSL)
-CONNECTION_PUBLIC_TCP_KWARGS = {
- "conn_type": "gcpcloudsql",
- "login": CLOUD_SQL_USER,
- "password": CLOUD_SQL_PASSWORD,
- "host": CLOUD_SQL_PUBLIC_IP,
- "port": CLOUD_SQL_PUBLIC_PORT,
- "schema": CLOUD_SQL_DATABASE_NAME,
- "extra": {
- "database_type": "postgres",
- "project_id": PROJECT_ID,
- "location": REGION,
- "instance": CLOUD_SQL_INSTANCE_NAME,
- "use_proxy": "False",
- "use_ssl": "False",
- },
-}
-
-# [END howto_operator_cloudsql_query_connections]
-
-CONNECTION_PROXY_TCP_ID = f"connection_{DAG_ID}_{ENV_ID}_proxy_tcp"
-CONNECTION_PUBLIC_TCP_ID = f"connection_{DAG_ID}_{ENV_ID}_public_tcp"
-CONNECTION_PROXY_SOCKET_ID = f"connection_{DAG_ID}_{ENV_ID}_proxy_socket"
-
-ConnectionConfig = namedtuple("ConnectionConfig", "id kwargs use_public_ip")
-CONNECTIONS = [
- ConnectionConfig(id=CONNECTION_PROXY_TCP_ID,
kwargs=CONNECTION_PROXY_TCP_KWARGS, use_public_ip=False),
- ConnectionConfig(
- id=CONNECTION_PROXY_SOCKET_ID, kwargs=CONNECTION_PROXY_SOCKET_KWARGS,
use_public_ip=False
- ),
- ConnectionConfig(id=CONNECTION_PUBLIC_TCP_ID,
kwargs=CONNECTION_PUBLIC_TCP_KWARGS, use_public_ip=True),
-]
-
-log = logging.getLogger(__name__)
-
-
-with DAG(
- dag_id=DAG_ID,
- start_date=datetime(2021, 1, 1),
- catchup=False,
- tags=["example", "cloudsql", "postgres"],
-) as dag:
- create_cloud_sql_instance = CloudSQLCreateInstanceOperator(
- task_id="create_cloud_sql_instance",
- project_id=PROJECT_ID,
- instance=CLOUD_SQL_INSTANCE_NAME,
- body=CLOUD_SQL_INSTANCE_CREATION_BODY,
- )
-
- create_database = CloudSQLCreateInstanceDatabaseOperator(
- task_id="create_database", body=CLOUD_SQL_DATABASE_CREATE_BODY,
instance=CLOUD_SQL_INSTANCE_NAME
- )
-
- @task
- def create_user() -> None:
- with discovery.build("sqladmin", "v1beta4") as service:
- request = service.users().insert(
- project=PROJECT_ID,
- instance=CLOUD_SQL_INSTANCE_NAME,
- body={
- "name": CLOUD_SQL_USER,
- "password": CLOUD_SQL_PASSWORD,
- },
- )
- request.execute()
-
- @task
- def get_public_ip() -> str | None:
- with discovery.build("sqladmin", "v1beta4") as service:
- request = service.connect().get(
- project=PROJECT_ID, instance=CLOUD_SQL_INSTANCE_NAME,
fields="ipAddresses"
- )
- response = request.execute()
- for ip_item in response.get("ipAddresses", []):
- if ip_item["type"] == "PRIMARY":
- return ip_item["ipAddress"]
- return None
-
- @task
- def create_connection(connection_id: str, connection_kwargs: dict,
use_public_ip: bool, **kwargs) -> None:
- session = settings.Session()
- if session.query(Connection).filter(Connection.conn_id ==
connection_id).first():
- log.warning("Connection '%s' already exists", connection_id)
- return None
- _connection_kwargs = deepcopy(connection_kwargs)
- if use_public_ip:
- public_ip = kwargs["ti"].xcom_pull(task_ids="get_public_ip")
- _connection_kwargs["host"] = public_ip
- connection = Connection(conn_id=connection_id, **_connection_kwargs)
- session.add(connection)
- session.commit()
- log.info("Connection created: '%s'", connection_id)
-
- @task_group(group_id="create_connections")
- def create_connections():
- for con in CONNECTIONS:
- create_connection(
- connection_id=con.id,
- connection_kwargs=con.kwargs,
- use_public_ip=con.use_public_ip,
- )
-
- @task_group(group_id="execute_queries")
- def execute_queries():
- prev_task = None
- for conn in CONNECTIONS:
- connection_id = conn.id
- task_id = "execute_query_" + conn.id
-
- # [START howto_operator_cloudsql_query_operators]
- query_task = CloudSQLExecuteQueryOperator(
- gcp_cloudsql_conn_id=connection_id,
- task_id=task_id,
- sql=SQL,
- )
- # [END howto_operator_cloudsql_query_operators]
-
- if prev_task:
- prev_task >> query_task
- prev_task = query_task
-
- @task_group(group_id="teardown")
- def teardown():
- CloudSQLDeleteInstanceOperator(
- task_id="delete_cloud_sql_instance",
- project_id=PROJECT_ID,
- instance=CLOUD_SQL_INSTANCE_NAME,
- trigger_rule=TriggerRule.ALL_DONE,
- )
-
- for con in CONNECTIONS:
- BashOperator(
- task_id=f"delete_connection_{con.id}",
- bash_command=f"airflow connections delete {con.id}",
- trigger_rule=TriggerRule.ALL_DONE,
- )
-
- (
- # TEST SETUP
- create_cloud_sql_instance
- >> [create_database, create_user(), get_public_ip()]
- >> create_connections()
- # TEST BODY
- >> execute_queries()
- # TEST TEARDOWN
- >> teardown()
- )
-
- from tests.system.utils.watcher import watcher
-
- # This test needs watcher in order to properly mark success/failure
- # when "tearDown" task with trigger rule is part of the DAG
- list(dag.tasks) >> watcher()
-
-
-from tests.system.utils import get_test_run # noqa: E402
-
-# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
-test_run = get_test_run(dag)
diff --git
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
new file mode 100644
index 0000000000..3ae74a30d8
--- /dev/null
+++
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
@@ -0,0 +1,518 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that performs query in a Cloud SQL instance with SSL
support.
+"""
+
+from __future__ import annotations
+
+import base64
+import json
+import logging
+import os
+import random
+import string
+from copy import deepcopy
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Iterable
+
+from googleapiclient import discovery
+
+from airflow import settings
+from airflow.decorators import task
+from airflow.models.connection import Connection
+from airflow.models.dag import DAG
+from airflow.operators.bash import BashOperator
+from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLHook
+from airflow.providers.google.cloud.hooks.secret_manager import
GoogleCloudSecretManagerHook
+from airflow.providers.google.cloud.operators.cloud_sql import (
+ CloudSQLCreateInstanceDatabaseOperator,
+ CloudSQLCreateInstanceOperator,
+ CloudSQLDeleteInstanceOperator,
+ CloudSQLExecuteQueryOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "Not found")
+DAG_ID = "cloudsql-query-ssl"
+REGION = "us-central1"
+HOME_DIR = Path.home()
+
+COMPOSER_ENVIRONMENT = os.environ.get("COMPOSER_ENVIRONMENT", "")
+if COMPOSER_ENVIRONMENT:
+ # We assume that the test is launched in Cloud Composer environment
because the reserved environment
+ # variable is assigned
(https://cloud.google.com/composer/docs/composer-2/set-environment-variables)
+ GET_COMPOSER_NETWORK_COMMAND = """
+ gcloud composer environments describe $COMPOSER_ENVIRONMENT \
+ --location=$COMPOSER_LOCATION \
+ --project=$GCP_PROJECT \
+ --format="value(config.nodeConfig.network)"
+ """
+else:
+ # The test is launched locally
+ GET_COMPOSER_NETWORK_COMMAND = "echo"
+
+
+def run_in_composer():
+ return bool(COMPOSER_ENVIRONMENT)
+
+
+CLOUD_SQL_INSTANCE_NAME_TEMPLATE = f"{ENV_ID}-{DAG_ID}".replace("_", "-")
+CLOUD_SQL_INSTANCE_CREATE_BODY_TEMPLATE: dict[str, Any] = {
+ "name": CLOUD_SQL_INSTANCE_NAME_TEMPLATE,
+ "settings": {
+ "tier": "db-custom-1-3840",
+ "dataDiskSizeGb": 30,
+ "pricingPlan": "PER_USE",
+ "ipConfiguration": {},
+ },
+ # For using a different database version please check the link below.
+ #
https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/SqlDatabaseVersion
+ "databaseVersion": "1.2.3",
+ "region": REGION,
+ "ipConfiguration": {
+ "ipv4Enabled": True,
+ "requireSsl": True,
+ "authorizedNetworks": [
+ {"value": "0.0.0.0/0"},
+ ],
+ },
+}
+
+DB_PROVIDERS: Iterable[dict[str, str]] = (
+ {
+ "database_type": "postgres",
+ "port": "5432",
+ "database_version": "POSTGRES_15",
+ "cloud_sql_instance_name":
f"{CLOUD_SQL_INSTANCE_NAME_TEMPLATE}-postgres",
+ },
+ {
+ "database_type": "mysql",
+ "port": "3306",
+ "database_version": "MYSQL_8_0",
+ "cloud_sql_instance_name": f"{CLOUD_SQL_INSTANCE_NAME_TEMPLATE}-mysql",
+ },
+)
+
+
+def ip_configuration() -> dict[str, Any]:
+ """Generates an ip configuration for a CloudSQL instance creation body"""
+ if run_in_composer():
+ # Use connection to Cloud SQL instance via Private IP within the Cloud
Composer's network.
+ return {
+ "ipv4Enabled": True,
+ "requireSsl": False,
+ "sslMode": "ENCRYPTED_ONLY",
+ "enablePrivatePathForGoogleCloudServices": True,
+ "privateNetwork": """{{
task_instance.xcom_pull('get_composer_network')}}""",
+ }
+ else:
+ # Use connection to Cloud SQL instance via Public IP from anywhere
(mask 0.0.0.0/0).
+ # Consider specifying your network mask
+ # for allowing requests only from the trusted sources, not from
anywhere.
+ return {
+ "ipv4Enabled": True,
+ "requireSsl": False,
+ "sslMode": "ENCRYPTED_ONLY",
+ "authorizedNetworks": [
+ {"value": "0.0.0.0/0"},
+ ],
+ }
+
+
+def cloud_sql_instance_create_body(database_provider: dict[str, Any]) ->
dict[str, Any]:
+ """Generates a CloudSQL instance creation body"""
+ create_body: dict[str, Any] =
deepcopy(CLOUD_SQL_INSTANCE_CREATE_BODY_TEMPLATE)
+ create_body["name"] = database_provider["cloud_sql_instance_name"]
+ create_body["databaseVersion"] = database_provider["database_version"]
+ create_body["settings"]["ipConfiguration"] = ip_configuration()
+ return create_body
+
+
+CLOUD_SQL_DATABASE_NAME = "test_db"
+CLOUD_SQL_USER = "test_user"
+CLOUD_SQL_PASSWORD = "JoxHlwrPzwch0gz9"
+CLOUD_SQL_IP_ADDRESS = "127.0.0.1"
+CLOUD_SQL_PUBLIC_PORT = 5432
+
+
+def cloud_sql_database_create_body(instance: str) -> dict[str, Any]:
+ """Generates a CloudSQL database creation body"""
+ return {
+ "instance": instance,
+ "name": CLOUD_SQL_DATABASE_NAME,
+ "project": PROJECT_ID,
+ }
+
+
+CLOUD_SQL_INSTANCE_NAME = ""
+DATABASE_TYPE = "" # "postgres|mysql|mssql"
+
+# [START howto_operator_cloudsql_query_connections]
+# Connect directly via TCP (SSL)
+CONNECTION_PUBLIC_TCP_SSL_KWARGS = {
+ "conn_type": "gcpcloudsql",
+ "login": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ "host": CLOUD_SQL_IP_ADDRESS,
+ "port": CLOUD_SQL_PUBLIC_PORT,
+ "schema": CLOUD_SQL_DATABASE_NAME,
+ "extra": {
+ "database_type": DATABASE_TYPE,
+ "project_id": PROJECT_ID,
+ "location": REGION,
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "use_proxy": "False",
+ "use_ssl": "True",
+ },
+}
+# [END howto_operator_cloudsql_query_connections]
+
+CONNECTION_PUBLIC_TCP_SSL_ID = f"{DAG_ID}_{ENV_ID}_tcp_ssl"
+
+SQL = [
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
+ "INSERT INTO TABLE_TEST VALUES (0)",
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)",
+ "DROP TABLE TABLE_TEST",
+ "DROP TABLE TABLE_TEST2",
+]
+
+DELETE_CONNECTION_COMMAND = "airflow connections delete {}"
+
+SSL_PATH = f"/{DAG_ID}/{ENV_ID}"
+SSL_LOCAL_PATH_PREFIX = "/tmp"
+SSL_COMPOSER_PATH_PREFIX = "/home/airflow/gcs/data"
+# [START howto_operator_cloudsql_query_connections_env]
+
+# The connections below are created using one of the standard approaches - via
environment
+# variables named AIRFLOW_CONN_* . The connections can also be created in the
database
+# of AIRFLOW (using command line or UI).
+
+postgres_kwargs = {
+ "user": "user",
+ "password": "password",
+ "public_ip": "public_ip",
+ "public_port": "public_port",
+ "database": "database",
+ "project_id": "project_id",
+ "location": "location",
+ "instance": "instance",
+ "client_cert_file": "client_cert_file",
+ "client_key_file": "client_key_file",
+ "server_ca_file": "server_ca_file",
+}
+
+# Postgres: connect directly via TCP (SSL)
+os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=postgres&"
+ "project_id={project_id}&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=False&"
+ "use_ssl=True&"
+ "sslcert={client_cert_file}&"
+ "sslkey={client_key_file}&"
+ "sslrootcert={server_ca_file}".format(**postgres_kwargs)
+)
+
+mysql_kwargs = {
+ "user": "user",
+ "password": "password",
+ "public_ip": "public_ip",
+ "public_port": "public_port",
+ "database": "database",
+ "project_id": "project_id",
+ "location": "location",
+ "instance": "instance",
+ "client_cert_file": "client_cert_file",
+ "client_key_file": "client_key_file",
+ "server_ca_file": "server_ca_file",
+}
+
+# MySQL: connect directly via TCP (SSL)
+os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL"] = (
+ "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?"
+ "database_type=mysql&"
+ "project_id={project_id}&"
+ "location={location}&"
+ "instance={instance}&"
+ "use_proxy=False&"
+ "use_ssl=True&"
+ "sslcert={client_cert_file}&"
+ "sslkey={client_key_file}&"
+ "sslrootcert={server_ca_file}".format(**mysql_kwargs)
+)
+# [END howto_operator_cloudsql_query_connections_env]
+
+
+log = logging.getLogger(__name__)
+
+with DAG(
+ dag_id=DAG_ID,
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "cloudsql", "postgres"],
+) as dag:
+ get_composer_network = BashOperator(
+ task_id="get_composer_network",
+ bash_command=GET_COMPOSER_NETWORK_COMMAND,
+ do_xcom_push=True,
+ )
+
+ for db_provider in DB_PROVIDERS:
+ database_type: str = db_provider["database_type"]
+ cloud_sql_instance_name: str = db_provider["cloud_sql_instance_name"]
+
+ create_cloud_sql_instance = CloudSQLCreateInstanceOperator(
+ task_id=f"create_cloud_sql_instance_{database_type}",
+ project_id=PROJECT_ID,
+ instance=cloud_sql_instance_name,
+ body=cloud_sql_instance_create_body(database_provider=db_provider),
+ )
+
+ create_database = CloudSQLCreateInstanceDatabaseOperator(
+ task_id=f"create_database_{database_type}",
+
body=cloud_sql_database_create_body(instance=cloud_sql_instance_name),
+ instance=cloud_sql_instance_name,
+ )
+
+ @task(task_id=f"create_user_{database_type}")
+ def create_user(instance: str) -> None:
+ with discovery.build("sqladmin", "v1beta4") as service:
+ request = service.users().insert(
+ project=PROJECT_ID,
+ instance=instance,
+ body={
+ "name": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ },
+ )
+ request.execute()
+ return None
+
+ create_user_task = create_user(instance=cloud_sql_instance_name)
+
+ @task(task_id=f"get_ip_address_{database_type}")
+ def get_ip_address(instance: str) -> str | None:
+ """Returns a Cloud SQL instance IP address.
+
+ If the test is running in Cloud Composer, the Private IP address
is used, otherwise Public IP."""
+ with discovery.build("sqladmin", "v1beta4") as service:
+ request = service.connect().get(
+ project=PROJECT_ID,
+ instance=instance,
+ fields="ipAddresses",
+ )
+ response = request.execute()
+ for ip_item in response.get("ipAddresses", []):
+ if run_in_composer():
+ if ip_item["type"] == "PRIVATE":
+ return ip_item["ipAddress"]
+ else:
+ if ip_item["type"] == "PRIMARY":
+ return ip_item["ipAddress"]
+ return None
+
+ get_ip_address_task = get_ip_address(instance=cloud_sql_instance_name)
+
+ conn_id = f"{CONNECTION_PUBLIC_TCP_SSL_ID}_{database_type}"
+
+ @task(task_id=f"create_connection_{database_type}")
+ def create_connection(
+ connection_id: str, instance: str, db_type: str, ip_address: str,
port: str
+ ) -> str | None:
+ session = settings.Session()
+ if session.query(Connection).filter(Connection.conn_id ==
connection_id).first():
+ log.warning("Connection '%s' already exists", connection_id)
+ return connection_id
+
+ connection: dict[str, Any] =
deepcopy(CONNECTION_PUBLIC_TCP_SSL_KWARGS)
+ connection["extra"]["instance"] = instance
+ connection["host"] = ip_address
+ connection["extra"]["database_type"] = db_type
+ connection["port"] = port
+ conn = Connection(conn_id=connection_id, **connection)
+ session.add(conn)
+ session.commit()
+ log.info("Connection created: '%s'", connection_id)
+ return connection_id
+
+ create_connection_task = create_connection(
+ connection_id=conn_id,
+ instance=cloud_sql_instance_name,
+ db_type=database_type,
+ ip_address=get_ip_address_task,
+ port=db_provider["port"],
+ )
+
+ @task(task_id=f"create_ssl_certificates_{database_type}")
+ def create_ssl_certificate(instance: str, connection_id: str) ->
dict[str, Any]:
+ hook = CloudSQLHook(api_version="v1", gcp_conn_id=connection_id)
+ certificate_name =
f"test_cert_{''.join(random.choice(string.ascii_letters) for _ in range(8))}"
+ response = hook.create_ssl_certificate(
+ instance=instance,
+ body={"common_name": certificate_name},
+ project_id=PROJECT_ID,
+ )
+ return response
+
+ create_ssl_certificate_task = create_ssl_certificate(
+ instance=cloud_sql_instance_name,
connection_id=create_connection_task
+ )
+
+ @task(task_id=f"save_ssl_cert_locally_{database_type}")
+ def save_ssl_cert_locally(ssl_cert: dict[str, Any], db_type: str) ->
dict[str, str]:
+ folder = SSL_COMPOSER_PATH_PREFIX if run_in_composer() else
SSL_LOCAL_PATH_PREFIX
+ folder += f"/certs/{db_type}/{ssl_cert['operation']['name']}"
+ if not os.path.exists(folder):
+ os.makedirs(folder)
+ _ssl_root_cert_path = f"{folder}/sslrootcert.pem"
+ _ssl_cert_path = f"{folder}/sslcert.pem"
+ _ssl_key_path = f"{folder}/sslkey.pem"
+ with open(_ssl_root_cert_path, "w") as ssl_root_cert_file:
+ ssl_root_cert_file.write(ssl_cert["serverCaCert"]["cert"])
+ with open(_ssl_cert_path, "w") as ssl_cert_file:
+ ssl_cert_file.write(ssl_cert["clientCert"]["certInfo"]["cert"])
+ with open(_ssl_key_path, "w") as ssl_key_file:
+ ssl_key_file.write(ssl_cert["clientCert"]["certPrivateKey"])
+ return {
+ "sslrootcert": _ssl_root_cert_path,
+ "sslcert": _ssl_cert_path,
+ "sslkey": _ssl_key_path,
+ }
+
+ save_ssl_cert_locally_task = save_ssl_cert_locally(
+ ssl_cert=create_ssl_certificate_task, db_type=database_type
+ )
+
+ @task(task_id=f"save_ssl_cert_to_secret_manager_{database_type}")
+ def save_ssl_cert_to_secret_manager(ssl_cert: dict[str, Any], db_type:
str) -> str:
+ hook = GoogleCloudSecretManagerHook()
+ payload = {
+ "sslrootcert": ssl_cert["serverCaCert"]["cert"],
+ "sslcert": ssl_cert["clientCert"]["certInfo"]["cert"],
+ "sslkey": ssl_cert["clientCert"]["certPrivateKey"],
+ }
+ _secret_id = f"secret_{DAG_ID}_{ENV_ID}_{db_type}"
+
+ if not hook.secret_exists(project_id=PROJECT_ID,
secret_id=_secret_id):
+ hook.create_secret(
+ secret_id=_secret_id,
+ project_id=PROJECT_ID,
+ )
+
+ hook.add_secret_version(
+ project_id=PROJECT_ID,
+ secret_id=_secret_id,
+
secret_payload=dict(data=base64.b64encode(json.dumps(payload).encode("ascii"))),
+ )
+
+ return _secret_id
+
+ save_ssl_cert_to_secret_manager_task = save_ssl_cert_to_secret_manager(
+ ssl_cert=create_ssl_certificate_task, db_type=database_type
+ )
+
+ task_id = f"example_cloud_sql_query_ssl_{database_type}"
+ ssl_server_cert_path = (
+ f"{{{{
task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslrootcert']
}}}}"
+ )
+ ssl_cert_path = (
+ f"{{{{
task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslcert']
}}}}"
+ )
+ ssl_key_path = f"{{{{
task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslkey'] }}}}"
+
+ # [START howto_operator_cloudsql_query_operators_ssl]
+ query_task = CloudSQLExecuteQueryOperator(
+ gcp_cloudsql_conn_id=conn_id,
+ task_id=task_id,
+ sql=SQL,
+ ssl_client_cert=ssl_cert_path,
+ ssl_server_cert=ssl_server_cert_path,
+ ssl_client_key=ssl_key_path,
+ )
+ # [END howto_operator_cloudsql_query_operators_ssl]
+
+ task_id = f"example_cloud_sql_query_ssl_secret_{database_type}"
+ secret_id = f"{{{{
task_instance.xcom_pull('save_ssl_cert_to_secret_manager_{database_type}') }}}}"
+
+ # [START howto_operator_cloudsql_query_operators_ssl_secret_id]
+ query_task_secret = CloudSQLExecuteQueryOperator(
+ gcp_cloudsql_conn_id=conn_id,
+ task_id=task_id,
+ sql=SQL,
+ ssl_secret_id=secret_id,
+ )
+ # [END howto_operator_cloudsql_query_operators_ssl_secret_id]
+
+ delete_instance = CloudSQLDeleteInstanceOperator(
+ task_id=f"delete_cloud_sql_instance_{database_type}",
+ project_id=PROJECT_ID,
+ instance=cloud_sql_instance_name,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ delete_connection = BashOperator(
+ task_id=f"delete_connection_{conn_id}",
+ bash_command=DELETE_CONNECTION_COMMAND.format(conn_id),
+ trigger_rule=TriggerRule.ALL_DONE,
+ skip_on_exit_code=1,
+ )
+
+ @task(task_id=f"delete_secret_{database_type}")
+ def delete_secret(ssl_secret_id, db_type: str) -> None:
+ hook = GoogleCloudSecretManagerHook()
+ if hook.secret_exists(project_id=PROJECT_ID,
secret_id=ssl_secret_id):
+ hook.delete_secret(secret_id=ssl_secret_id,
project_id=PROJECT_ID)
+
+ delete_secret_task = delete_secret(
+ ssl_secret_id=save_ssl_cert_to_secret_manager_task,
db_type=database_type
+ )
+
+ (
+ # TEST SETUP
+ get_composer_network
+ >> create_cloud_sql_instance
+ >> [create_database, create_user_task, get_ip_address_task]
+ >> create_connection_task
+ >> create_ssl_certificate_task
+ >> [save_ssl_cert_locally_task,
save_ssl_cert_to_secret_manager_task]
+ # TEST BODY
+ >> query_task
+ >> query_task_secret
+ # TEST TEARDOWN
+ >> [delete_instance, delete_connection, delete_secret_task]
+ )
+
+ # ### Everything below this line is not part of example ###
+ # ### Just for system tests purpose ###
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)