This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8f616551ca Fix cloudsql-query system tests (#41092)
8f616551ca is described below
commit 8f616551cadbaee53b1bb5952936c163093b0b40
Author: max <[email protected]>
AuthorDate: Tue Jul 30 05:28:01 2024 +0200
Fix cloudsql-query system tests (#41092)
---
airflow/providers/google/cloud/hooks/cloud_sql.py | 4 +-
.../cloud/cloud_sql/example_cloud_sql_query.py | 66 ++++++++--------------
.../cloud/cloud_sql/example_cloud_sql_query_ssl.py | 51 ++++++-----------
3 files changed, 44 insertions(+), 77 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py
b/airflow/providers/google/cloud/hooks/cloud_sql.py
index 0baa30ec7a..b1a1d1883c 100644
--- a/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -929,7 +929,9 @@ class CloudSQLDatabaseHook(BaseHook):
self.log.info("Neither cert path and cert value provided. Nothing
to save.")
return None
- _temp_file = NamedTemporaryFile(mode="w+b", prefix="/tmp/certs/")
+ certs_folder = "/tmp/certs/"
+ Path(certs_folder).mkdir(parents=True, exist_ok=True)
+ _temp_file = NamedTemporaryFile(mode="w+b", prefix=certs_folder)
if cert_path:
with open(cert_path, "rb") as cert_file:
_temp_file.write(cert_file.read())
diff --git
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
index e5394c84bc..84af2a0aff 100644
--- a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
+++ b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
@@ -35,13 +35,13 @@ from airflow import settings
from airflow.decorators import task, task_group
from airflow.models.connection import Connection
from airflow.models.dag import DAG
-from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.operators.cloud_sql import (
CloudSQLCreateInstanceDatabaseOperator,
CloudSQLCreateInstanceOperator,
CloudSQLDeleteInstanceOperator,
CloudSQLExecuteQueryOperator,
)
+from airflow.settings import Session
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
@@ -52,18 +52,6 @@ REGION = "us-central1"
HOME_DIR = Path.home()
COMPOSER_ENVIRONMENT = os.environ.get("COMPOSER_ENVIRONMENT", "")
-if COMPOSER_ENVIRONMENT:
- # We assume that the test is launched in Cloud Composer environment
because the reserved environment
- # variable is assigned
(https://cloud.google.com/composer/docs/composer-2/set-environment-variables)
- GET_COMPOSER_NETWORK_COMMAND = """
- gcloud composer environments describe $COMPOSER_ENVIRONMENT \
- --location=$COMPOSER_LOCATION \
- --project=$GCP_PROJECT \
- --format="value(config.nodeConfig.network)"
- """
-else:
- # The test is launched locally
- GET_COMPOSER_NETWORK_COMMAND = "echo"
def run_in_composer():
@@ -115,7 +103,7 @@ def ip_configuration() -> dict[str, Any]:
"ipv4Enabled": True,
"requireSsl": False,
"enablePrivatePathForGoogleCloudServices": True,
- "privateNetwork": """{{
task_instance.xcom_pull('get_composer_network')}}""",
+ "privateNetwork": f"projects/{PROJECT_ID}/global/networks/default",
}
else:
# Use connection to Cloud SQL instance via Public IP from anywhere
(mask 0.0.0.0/0).
@@ -395,12 +383,6 @@ with DAG(
catchup=False,
tags=["example", "cloudsql", "postgres"],
) as dag:
- get_composer_network = BashOperator(
- task_id="get_composer_network",
- bash_command=GET_COMPOSER_NETWORK_COMMAND,
- do_xcom_push=True,
- )
-
for db_provider in DB_PROVIDERS:
database_type: str = db_provider["database_type"]
cloud_sql_instance_name: str = db_provider["cloud_sql_instance_name"]
@@ -467,9 +449,9 @@ with DAG(
kwargs: dict[str, Any],
) -> str | None:
session = settings.Session()
- if session.query(Connection).filter(Connection.conn_id ==
connection_id).first():
- log.warning("Connection '%s' already exists", connection_id)
- return connection_id
+ log.info("Removing connection %s if it exists", connection_id)
+ query = session.query(Connection).filter(Connection.conn_id ==
connection_id)
+ query.delete()
connection: dict[str, Any] = deepcopy(kwargs)
connection["extra"]["instance"] = instance
@@ -523,30 +505,28 @@ with DAG(
execute_queries_task = execute_queries(db_type=database_type)
- @task_group(group_id=f"teardown_{database_type}")
- def teardown(instance: str, db_type: str):
- task_id = f"delete_cloud_sql_instance_{db_type}"
- CloudSQLDeleteInstanceOperator(
- task_id=task_id,
- project_id=PROJECT_ID,
- instance=instance,
- trigger_rule=TriggerRule.ALL_DONE,
- )
+ @task()
+ def delete_connection(connection_id: str) -> None:
+ session = Session()
+ log.info("Removing connection %s", connection_id)
+ query = session.query(Connection).filter(Connection.conn_id ==
connection_id)
+ query.delete()
+ session.commit()
- for conn in CONNECTIONS:
- connection_id = f"{conn.id}_{db_type}"
- BashOperator(
- task_id=f"delete_connection_{connection_id}",
-
bash_command=DELETE_CONNECTION_COMMAND.format(connection_id),
- trigger_rule=TriggerRule.ALL_DONE,
- )
+ delete_connections_task = delete_connection.expand(
+ connection_id=[f"{conn.id}_{database_type}" for conn in
CONNECTIONS]
+ )
- teardown_task = teardown(instance=cloud_sql_instance_name,
db_type=database_type)
+ delete_instance = CloudSQLDeleteInstanceOperator(
+ task_id=f"delete_cloud_sql_instance_{database_type}",
+ project_id=PROJECT_ID,
+ instance=cloud_sql_instance_name,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
(
# TEST SETUP
- get_composer_network
- >> create_cloud_sql_instance
+ create_cloud_sql_instance
>> [
create_database,
create_user_task,
@@ -556,7 +536,7 @@ with DAG(
# TEST BODY
>> execute_queries_task
# TEST TEARDOWN
- >> teardown_task
+ >> [delete_instance, delete_connections_task]
)
# ### Everything below this line is not part of example ###
diff --git
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
index 3ae74a30d8..ef8617f926 100644
---
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
+++
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
@@ -38,7 +38,6 @@ from airflow import settings
from airflow.decorators import task
from airflow.models.connection import Connection
from airflow.models.dag import DAG
-from airflow.operators.bash import BashOperator
from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLHook
from airflow.providers.google.cloud.hooks.secret_manager import
GoogleCloudSecretManagerHook
from airflow.providers.google.cloud.operators.cloud_sql import (
@@ -47,27 +46,17 @@ from airflow.providers.google.cloud.operators.cloud_sql
import (
CloudSQLDeleteInstanceOperator,
CloudSQLExecuteQueryOperator,
)
+from airflow.settings import Session
from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
-PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "Not found")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or
DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
DAG_ID = "cloudsql-query-ssl"
REGION = "us-central1"
HOME_DIR = Path.home()
COMPOSER_ENVIRONMENT = os.environ.get("COMPOSER_ENVIRONMENT", "")
-if COMPOSER_ENVIRONMENT:
- # We assume that the test is launched in Cloud Composer environment
because the reserved environment
- # variable is assigned
(https://cloud.google.com/composer/docs/composer-2/set-environment-variables)
- GET_COMPOSER_NETWORK_COMMAND = """
- gcloud composer environments describe $COMPOSER_ENVIRONMENT \
- --location=$COMPOSER_LOCATION \
- --project=$GCP_PROJECT \
- --format="value(config.nodeConfig.network)"
- """
-else:
- # The test is launched locally
- GET_COMPOSER_NETWORK_COMMAND = "echo"
def run_in_composer():
@@ -121,7 +110,7 @@ def ip_configuration() -> dict[str, Any]:
"requireSsl": False,
"sslMode": "ENCRYPTED_ONLY",
"enablePrivatePathForGoogleCloudServices": True,
- "privateNetwork": """{{
task_instance.xcom_pull('get_composer_network')}}""",
+ "privateNetwork": f"projects/{PROJECT_ID}/global/networks/default",
}
else:
# Use connection to Cloud SQL instance via Public IP from anywhere
(mask 0.0.0.0/0).
@@ -273,12 +262,6 @@ with DAG(
catchup=False,
tags=["example", "cloudsql", "postgres"],
) as dag:
- get_composer_network = BashOperator(
- task_id="get_composer_network",
- bash_command=GET_COMPOSER_NETWORK_COMMAND,
- do_xcom_push=True,
- )
-
for db_provider in DB_PROVIDERS:
database_type: str = db_provider["database_type"]
cloud_sql_instance_name: str = db_provider["cloud_sql_instance_name"]
@@ -342,9 +325,9 @@ with DAG(
connection_id: str, instance: str, db_type: str, ip_address: str,
port: str
) -> str | None:
session = settings.Session()
- if session.query(Connection).filter(Connection.conn_id ==
connection_id).first():
- log.warning("Connection '%s' already exists", connection_id)
- return connection_id
+ log.info("Removing connection %s if it exists", connection_id)
+ query = session.query(Connection).filter(Connection.conn_id ==
connection_id)
+ query.delete()
connection: dict[str, Any] =
deepcopy(CONNECTION_PUBLIC_TCP_SSL_KWARGS)
connection["extra"]["instance"] = instance
@@ -472,12 +455,15 @@ with DAG(
trigger_rule=TriggerRule.ALL_DONE,
)
- delete_connection = BashOperator(
- task_id=f"delete_connection_{conn_id}",
- bash_command=DELETE_CONNECTION_COMMAND.format(conn_id),
- trigger_rule=TriggerRule.ALL_DONE,
- skip_on_exit_code=1,
- )
+ @task(task_id=f"delete_connection_{database_type}")
+ def delete_connection(connection_id: str) -> None:
+ session = Session()
+ log.info("Removing connection %s", connection_id)
+ query = session.query(Connection).filter(Connection.conn_id ==
connection_id)
+ query.delete()
+ session.commit()
+
+ delete_connection_task = delete_connection(connection_id=conn_id)
@task(task_id=f"delete_secret_{database_type}")
def delete_secret(ssl_secret_id, db_type: str) -> None:
@@ -491,8 +477,7 @@ with DAG(
(
# TEST SETUP
- get_composer_network
- >> create_cloud_sql_instance
+ create_cloud_sql_instance
>> [create_database, create_user_task, get_ip_address_task]
>> create_connection_task
>> create_ssl_certificate_task
@@ -501,7 +486,7 @@ with DAG(
>> query_task
>> query_task_secret
# TEST TEARDOWN
- >> [delete_instance, delete_connection, delete_secret_task]
+ >> [delete_instance, delete_connection_task, delete_secret_task]
)
# ### Everything below this line is not part of example ###