This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 416cae48f5 Migrate system tests for CloudSQLExecuteQueryOperator
(non-SSL) (#33316)
416cae48f5 is described below
commit 416cae48f5a15dd57a242929338a7915e57eefce
Author: max <[email protected]>
AuthorDate: Sun Aug 13 22:41:06 2023 +0000
Migrate system tests for CloudSQLExecuteQueryOperator (non-SSL) (#33316)
---
.../operators/cloud/cloud_sql.rst | 24 +-
.../cloud_sql/example_cloud_sql_query_mysql.py | 280 ++++++++++++++++++++
.../cloud_sql/example_cloud_sql_query_postgres.py | 288 +++++++++++++++++++++
3 files changed, 590 insertions(+), 2 deletions(-)
diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
b/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
index 6bedaef077..afdedcac74 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
@@ -574,14 +574,34 @@ certificate/key files available in predefined locations
for all the workers on
which the operator can run. This can be provided for example by mounting
NFS-like volumes in the same path for all the workers.
-Example connection definitions for all connectivity cases. Note that all the
components
-of the connection URI should be URL-encoded:
+Example connection definitions for all non-SSL connectivity cases for
Postgres. For connecting to MySQL database
+please use ``mysql`` as a ``database_type``. Note that all the components of
the connection URI should be URL-encoded:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_postgres.py
+ :language: python
+ :start-after: [START howto_operator_cloudsql_query_connections]
+ :end-before: [END howto_operator_cloudsql_query_connections]
+
+It is also possible to configure a connection via environment variable (note
that the connection id from the operator
+matches the :envvar:`AIRFLOW_CONN_{CONN_ID}` postfix uppercase if you are
using a standard AIRFLOW notation for
+defining connection via environment variables):
.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py
:language: python
:start-after: [START howto_operator_cloudsql_query_connections]
:end-before: [END howto_operator_cloudsql_query_connections]
+Example operator below is using prepared earlier connection. It might be a
connection_id from the Airflow database
+or the connection configured via environment variable (note that the
connection id from the operator matches the
+:envvar:`AIRFLOW_CONN_{CONN_ID}` postfix uppercase if you are using a standard
AIRFLOW notation for defining connection
+via environment variables):
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_postgres.py
+ :language: python
+ :start-after: [START howto_operator_cloudsql_query_operators]
+ :end-before: [END howto_operator_cloudsql_query_operators]
+
+
Using the operator
""""""""""""""""""
diff --git
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_mysql.py
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_mysql.py
new file mode 100644
index 0000000000..2a3050ead6
--- /dev/null
+++
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_mysql.py
@@ -0,0 +1,280 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that performs query in a Cloud SQL instance for MySQL.
+"""
+from __future__ import annotations
+
+import logging
+import os
+from collections import namedtuple
+from copy import deepcopy
+from datetime import datetime
+
+from googleapiclient import discovery
+
+from airflow import models, settings
+from airflow.decorators import task, task_group
+from airflow.models import Connection
+from airflow.operators.bash import BashOperator
+from airflow.providers.google.cloud.operators.cloud_sql import (
+ CloudSQLCreateInstanceDatabaseOperator,
+ CloudSQLCreateInstanceOperator,
+ CloudSQLDeleteInstanceOperator,
+ CloudSQLExecuteQueryOperator,
+)
+from airflow.settings import Session
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "cloudsql-query-mysql"
+REGION = "us-central1"
+
+CLOUD_SQL_INSTANCE_NAME = f"{ENV_ID}-{DAG_ID}".replace("_", "-")
+CLOUD_SQL_DATABASE_NAME = "test_db"
+CLOUD_SQL_USER = "test_user"
+CLOUD_SQL_PASSWORD = "JoxHlwrPzwch0gz9"
+CLOUD_SQL_PUBLIC_IP = "127.0.0.1"
+CLOUD_SQL_PUBLIC_PORT = 3306
+CLOUD_SQL_DATABASE_CREATE_BODY = {
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "name": CLOUD_SQL_DATABASE_NAME,
+ "project": PROJECT_ID,
+}
+
+CLOUD_SQL_INSTANCE_CREATION_BODY = {
+ "name": CLOUD_SQL_INSTANCE_NAME,
+ "settings": {
+ "tier": "db-custom-1-3840",
+ "dataDiskSizeGb": 30,
+ "ipConfiguration": {
+ "ipv4Enabled": True,
+ "requireSsl": False,
+ # Consider specifying your network mask
+ # for allowing requests only from the trusted sources, not from
anywhere
+ "authorizedNetworks": [
+ {"value": "0.0.0.0/0"},
+ ],
+ },
+ "pricingPlan": "PER_USE",
+ },
+ # For using a different database version please check the link below
+ #
https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/SqlDatabaseVersion
+ "databaseVersion": "MYSQL_8_0",
+ "region": REGION,
+}
+
+SQL = [
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
+ "INSERT INTO TABLE_TEST VALUES (0)",
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)",
+ "DROP TABLE TABLE_TEST",
+ "DROP TABLE TABLE_TEST2",
+]
+
+# Postgres: connect via proxy over TCP
+CONNECTION_PROXY_TCP_ID = f"connection_{DAG_ID}_{ENV_ID}_proxy_tcp"
+CONNECTION_PROXY_TCP_KWARGS = {
+ "conn_type": "gcpcloudsql",
+ "login": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ "host": CLOUD_SQL_PUBLIC_IP,
+ "port": CLOUD_SQL_PUBLIC_PORT,
+ "schema": CLOUD_SQL_DATABASE_NAME,
+ "extra": {
+ "database_type": "mysql",
+ "project_id": PROJECT_ID,
+ "location": REGION,
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "use_proxy": "True",
+ "sql_proxy_use_tcp": "True",
+ },
+}
+
+# Postgres: connect via proxy over UNIX socket (specific proxy version)
+CONNECTION_PROXY_SOCKET_ID = f"connection_{DAG_ID}_{ENV_ID}_proxy_socket"
+CONNECTION_PROXY_SOCKET_KWARGS = {
+ "conn_type": "gcpcloudsql",
+ "login": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ "host": CLOUD_SQL_PUBLIC_IP,
+ "port": CLOUD_SQL_PUBLIC_PORT,
+ "schema": CLOUD_SQL_DATABASE_NAME,
+ "extra": {
+ "database_type": "mysql",
+ "project_id": PROJECT_ID,
+ "location": REGION,
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "use_proxy": "True",
+ "sql_proxy_version": "v1.33.9",
+ "sql_proxy_use_tcp": "False",
+ },
+}
+
+# Postgres: connect directly via TCP (non-SSL)
+CONNECTION_PUBLIC_TCP_ID = f"connection_{DAG_ID}_{ENV_ID}_public_tcp"
+CONNECTION_PUBLIC_TCP_KWARGS = {
+ "conn_type": "gcpcloudsql",
+ "login": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ "host": CLOUD_SQL_PUBLIC_IP,
+ "port": CLOUD_SQL_PUBLIC_PORT,
+ "schema": CLOUD_SQL_DATABASE_NAME,
+ "extra": {
+ "database_type": "mysql",
+ "project_id": PROJECT_ID,
+ "location": REGION,
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "use_proxy": "False",
+ "use_ssl": "False",
+ },
+}
+
+ConnectionConfig = namedtuple("ConnectionConfig", "id kwargs use_public_ip")
+CONNECTIONS = [
+ ConnectionConfig(id=CONNECTION_PROXY_TCP_ID,
kwargs=CONNECTION_PROXY_TCP_KWARGS, use_public_ip=False),
+ ConnectionConfig(
+ id=CONNECTION_PROXY_SOCKET_ID, kwargs=CONNECTION_PROXY_SOCKET_KWARGS,
use_public_ip=False
+ ),
+ ConnectionConfig(id=CONNECTION_PUBLIC_TCP_ID,
kwargs=CONNECTION_PUBLIC_TCP_KWARGS, use_public_ip=True),
+]
+
+log = logging.getLogger(__name__)
+
+
+with models.DAG(
+ dag_id=DAG_ID,
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "cloudsql", "mysql"],
+) as dag:
+ create_cloud_sql_instance = CloudSQLCreateInstanceOperator(
+ task_id="create_cloud_sql_instance",
+ project_id=PROJECT_ID,
+ instance=CLOUD_SQL_INSTANCE_NAME,
+ body=CLOUD_SQL_INSTANCE_CREATION_BODY,
+ )
+
+ create_database = CloudSQLCreateInstanceDatabaseOperator(
+ task_id="create_database", body=CLOUD_SQL_DATABASE_CREATE_BODY,
instance=CLOUD_SQL_INSTANCE_NAME
+ )
+
+ @task
+ def create_user() -> None:
+ with discovery.build("sqladmin", "v1beta4") as service:
+ request = service.users().insert(
+ project=PROJECT_ID,
+ instance=CLOUD_SQL_INSTANCE_NAME,
+ body={
+ "name": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ },
+ )
+ request.execute()
+
+ @task
+ def get_public_ip() -> str | None:
+ with discovery.build("sqladmin", "v1beta4") as service:
+ request = service.connect().get(
+ project=PROJECT_ID, instance=CLOUD_SQL_INSTANCE_NAME,
fields="ipAddresses"
+ )
+ response = request.execute()
+ for ip_item in response.get("ipAddresses", []):
+ if ip_item["type"] == "PRIMARY":
+ return ip_item["ipAddress"]
+
+ @task
+ def create_connection(connection_id: str, connection_kwargs: dict,
use_public_ip: bool, **kwargs) -> None:
+ session: Session = settings.Session()
+ if session.query(Connection).filter(Connection.conn_id ==
connection_id).first():
+ log.warning("Connection '%s' already exists", connection_id)
+ return None
+ _connection_kwargs = deepcopy(connection_kwargs)
+ if use_public_ip:
+ public_ip = kwargs["ti"].xcom_pull(task_ids="get_public_ip")
+ _connection_kwargs["host"] = public_ip
+ connection = Connection(conn_id=connection_id, **_connection_kwargs)
+ session.add(connection)
+ session.commit()
+ log.info("Connection created: '%s'", connection_id)
+
+ @task_group(group_id="create_connections")
+ def create_connections():
+ for con in CONNECTIONS:
+ create_connection(
+ connection_id=con.id,
+ connection_kwargs=con.kwargs,
+ use_public_ip=con.use_public_ip,
+ )
+
+ @task_group(group_id="execute_queries")
+ def execute_queries():
+ prev_task = None
+ for conn in CONNECTIONS:
+ connection_id = conn.id
+ task_id = "execute_query_" + conn.id
+ query_task = CloudSQLExecuteQueryOperator(
+ gcp_cloudsql_conn_id=connection_id,
+ task_id=task_id,
+ sql=SQL,
+ )
+
+ if prev_task:
+ prev_task >> query_task
+ prev_task = query_task
+
+ @task_group(group_id="teardown")
+ def teardown():
+ CloudSQLDeleteInstanceOperator(
+ task_id="delete_cloud_sql_instance",
+ project_id=PROJECT_ID,
+ instance=CLOUD_SQL_INSTANCE_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ for con in CONNECTIONS:
+ BashOperator(
+ task_id=f"delete_connection_{con.id}",
+ bash_command=f"airflow connections delete {con.id}",
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ create_cloud_sql_instance
+ >> [create_database, create_user(), get_public_ip()]
+ >> create_connections()
+ # TEST BODY
+ >> execute_queries()
+ # TEST TEARDOWN
+ >> teardown()
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_postgres.py
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_postgres.py
new file mode 100644
index 0000000000..e2a21f7b46
--- /dev/null
+++
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_postgres.py
@@ -0,0 +1,288 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that performs query in a Cloud SQL instance for Postgres.
+"""
+from __future__ import annotations
+
+import logging
+import os
+from collections import namedtuple
+from copy import deepcopy
+from datetime import datetime
+
+from googleapiclient import discovery
+
+from airflow import models, settings
+from airflow.decorators import task, task_group
+from airflow.models import Connection
+from airflow.operators.bash import BashOperator
+from airflow.providers.google.cloud.operators.cloud_sql import (
+ CloudSQLCreateInstanceDatabaseOperator,
+ CloudSQLCreateInstanceOperator,
+ CloudSQLDeleteInstanceOperator,
+ CloudSQLExecuteQueryOperator,
+)
+from airflow.settings import Session
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "cloudsql-query-pg"
+REGION = "us-central1"
+
+CLOUD_SQL_INSTANCE_NAME = f"{ENV_ID}-{DAG_ID}".replace("_", "-")
+CLOUD_SQL_DATABASE_NAME = "test_db"
+CLOUD_SQL_USER = "test_user"
+CLOUD_SQL_PASSWORD = "JoxHlwrPzwch0gz9"
+CLOUD_SQL_PUBLIC_IP = "127.0.0.1"
+CLOUD_SQL_PUBLIC_PORT = 5432
+CLOUD_SQL_DATABASE_CREATE_BODY = {
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "name": CLOUD_SQL_DATABASE_NAME,
+ "project": PROJECT_ID,
+}
+
+CLOUD_SQL_INSTANCE_CREATION_BODY = {
+ "name": CLOUD_SQL_INSTANCE_NAME,
+ "settings": {
+ "tier": "db-custom-1-3840",
+ "dataDiskSizeGb": 30,
+ "ipConfiguration": {
+ "ipv4Enabled": True,
+ "requireSsl": False,
+ # Consider specifying your network mask
+ # for allowing requests only from the trusted sources, not from
anywhere
+ "authorizedNetworks": [
+ {"value": "0.0.0.0/0"},
+ ],
+ },
+ "pricingPlan": "PER_USE",
+ },
+ # For using a different database version please check the link below
+ #
https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/SqlDatabaseVersion
+ "databaseVersion": "POSTGRES_15",
+ "region": REGION,
+}
+
+SQL = [
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST (I INTEGER)",
+ "INSERT INTO TABLE_TEST VALUES (0)",
+ "CREATE TABLE IF NOT EXISTS TABLE_TEST2 (I INTEGER)",
+ "DROP TABLE TABLE_TEST",
+ "DROP TABLE TABLE_TEST2",
+]
+
+# [START howto_operator_cloudsql_query_connections]
+
+# Postgres: connect via proxy over TCP
+CONNECTION_PROXY_TCP_KWARGS = {
+ "conn_type": "gcpcloudsql",
+ "login": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ "host": CLOUD_SQL_PUBLIC_IP,
+ "port": CLOUD_SQL_PUBLIC_PORT,
+ "schema": CLOUD_SQL_DATABASE_NAME,
+ "extra": {
+ "database_type": "postgres",
+ "project_id": PROJECT_ID,
+ "location": REGION,
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "use_proxy": "True",
+ "sql_proxy_use_tcp": "True",
+ },
+}
+
+# Postgres: connect via proxy over UNIX socket (specific proxy version)
+CONNECTION_PROXY_SOCKET_KWARGS = {
+ "conn_type": "gcpcloudsql",
+ "login": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ "host": CLOUD_SQL_PUBLIC_IP,
+ "port": CLOUD_SQL_PUBLIC_PORT,
+ "schema": CLOUD_SQL_DATABASE_NAME,
+ "extra": {
+ "database_type": "postgres",
+ "project_id": PROJECT_ID,
+ "location": REGION,
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "use_proxy": "True",
+ "sql_proxy_version": "v1.33.9",
+ "sql_proxy_use_tcp": "False",
+ },
+}
+
+# Postgres: connect directly via TCP (non-SSL)
+CONNECTION_PUBLIC_TCP_KWARGS = {
+ "conn_type": "gcpcloudsql",
+ "login": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ "host": CLOUD_SQL_PUBLIC_IP,
+ "port": CLOUD_SQL_PUBLIC_PORT,
+ "schema": CLOUD_SQL_DATABASE_NAME,
+ "extra": {
+ "database_type": "postgres",
+ "project_id": PROJECT_ID,
+ "location": REGION,
+ "instance": CLOUD_SQL_INSTANCE_NAME,
+ "use_proxy": "False",
+ "use_ssl": "False",
+ },
+}
+
+# [END howto_operator_cloudsql_query_connections]
+
+CONNECTION_PROXY_TCP_ID = f"connection_{DAG_ID}_{ENV_ID}_proxy_tcp"
+CONNECTION_PUBLIC_TCP_ID = f"connection_{DAG_ID}_{ENV_ID}_public_tcp"
+CONNECTION_PROXY_SOCKET_ID = f"connection_{DAG_ID}_{ENV_ID}_proxy_socket"
+
+ConnectionConfig = namedtuple("ConnectionConfig", "id kwargs use_public_ip")
+CONNECTIONS = [
+ ConnectionConfig(id=CONNECTION_PROXY_TCP_ID,
kwargs=CONNECTION_PROXY_TCP_KWARGS, use_public_ip=False),
+ ConnectionConfig(
+ id=CONNECTION_PROXY_SOCKET_ID, kwargs=CONNECTION_PROXY_SOCKET_KWARGS,
use_public_ip=False
+ ),
+ ConnectionConfig(id=CONNECTION_PUBLIC_TCP_ID,
kwargs=CONNECTION_PUBLIC_TCP_KWARGS, use_public_ip=True),
+]
+
+log = logging.getLogger(__name__)
+
+
+with models.DAG(
+ dag_id=DAG_ID,
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "cloudsql", "postgres"],
+) as dag:
+ create_cloud_sql_instance = CloudSQLCreateInstanceOperator(
+ task_id="create_cloud_sql_instance",
+ project_id=PROJECT_ID,
+ instance=CLOUD_SQL_INSTANCE_NAME,
+ body=CLOUD_SQL_INSTANCE_CREATION_BODY,
+ )
+
+ create_database = CloudSQLCreateInstanceDatabaseOperator(
+ task_id="create_database", body=CLOUD_SQL_DATABASE_CREATE_BODY,
instance=CLOUD_SQL_INSTANCE_NAME
+ )
+
+ @task
+ def create_user() -> None:
+ with discovery.build("sqladmin", "v1beta4") as service:
+ request = service.users().insert(
+ project=PROJECT_ID,
+ instance=CLOUD_SQL_INSTANCE_NAME,
+ body={
+ "name": CLOUD_SQL_USER,
+ "password": CLOUD_SQL_PASSWORD,
+ },
+ )
+ request.execute()
+
+ @task
+ def get_public_ip() -> str | None:
+ with discovery.build("sqladmin", "v1beta4") as service:
+ request = service.connect().get(
+ project=PROJECT_ID, instance=CLOUD_SQL_INSTANCE_NAME,
fields="ipAddresses"
+ )
+ response = request.execute()
+ for ip_item in response.get("ipAddresses", []):
+ if ip_item["type"] == "PRIMARY":
+ return ip_item["ipAddress"]
+
+ @task
+ def create_connection(connection_id: str, connection_kwargs: dict,
use_public_ip: bool, **kwargs) -> None:
+ session: Session = settings.Session()
+ if session.query(Connection).filter(Connection.conn_id ==
connection_id).first():
+ log.warning("Connection '%s' already exists", connection_id)
+ return None
+ _connection_kwargs = deepcopy(connection_kwargs)
+ if use_public_ip:
+ public_ip = kwargs["ti"].xcom_pull(task_ids="get_public_ip")
+ _connection_kwargs["host"] = public_ip
+ connection = Connection(conn_id=connection_id, **_connection_kwargs)
+ session.add(connection)
+ session.commit()
+ log.info("Connection created: '%s'", connection_id)
+
+ @task_group(group_id="create_connections")
+ def create_connections():
+ for con in CONNECTIONS:
+ create_connection(
+ connection_id=con.id,
+ connection_kwargs=con.kwargs,
+ use_public_ip=con.use_public_ip,
+ )
+
+ @task_group(group_id="execute_queries")
+ def execute_queries():
+ prev_task = None
+ for conn in CONNECTIONS:
+ connection_id = conn.id
+ task_id = "execute_query_" + conn.id
+
+ # [START howto_operator_cloudsql_query_operators]
+ query_task = CloudSQLExecuteQueryOperator(
+ gcp_cloudsql_conn_id=connection_id,
+ task_id=task_id,
+ sql=SQL,
+ )
+ # [END howto_operator_cloudsql_query_operators]
+
+ if prev_task:
+ prev_task >> query_task
+ prev_task = query_task
+
+ @task_group(group_id="teardown")
+ def teardown():
+ CloudSQLDeleteInstanceOperator(
+ task_id="delete_cloud_sql_instance",
+ project_id=PROJECT_ID,
+ instance=CLOUD_SQL_INSTANCE_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ for con in CONNECTIONS:
+ BashOperator(
+ task_id=f"delete_connection_{con.id}",
+ bash_command=f"airflow connections delete {con.id}",
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ create_cloud_sql_instance
+ >> [create_database, create_user(), get_public_ip()]
+ >> create_connections()
+ # TEST BODY
+ >> execute_queries()
+ # TEST TEARDOWN
+ >> teardown()
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)