This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f616551ca Fix cloudsql-query system tests (#41092)
8f616551ca is described below

commit 8f616551cadbaee53b1bb5952936c163093b0b40
Author: max <[email protected]>
AuthorDate: Tue Jul 30 05:28:01 2024 +0200

    Fix cloudsql-query system tests (#41092)
---
 airflow/providers/google/cloud/hooks/cloud_sql.py  |  4 +-
 .../cloud/cloud_sql/example_cloud_sql_query.py     | 66 ++++++++--------------
 .../cloud/cloud_sql/example_cloud_sql_query_ssl.py | 51 ++++++-----------
 3 files changed, 44 insertions(+), 77 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py 
b/airflow/providers/google/cloud/hooks/cloud_sql.py
index 0baa30ec7a..b1a1d1883c 100644
--- a/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -929,7 +929,9 @@ class CloudSQLDatabaseHook(BaseHook):
             self.log.info("Neither cert path and cert value provided. Nothing 
to save.")
             return None
 
-        _temp_file = NamedTemporaryFile(mode="w+b", prefix="/tmp/certs/")
+        certs_folder = "/tmp/certs/"
+        Path(certs_folder).mkdir(parents=True, exist_ok=True)
+        _temp_file = NamedTemporaryFile(mode="w+b", prefix=certs_folder)
         if cert_path:
             with open(cert_path, "rb") as cert_file:
                 _temp_file.write(cert_file.read())
diff --git 
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py 
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
index e5394c84bc..84af2a0aff 100644
--- a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
+++ b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query.py
@@ -35,13 +35,13 @@ from airflow import settings
 from airflow.decorators import task, task_group
 from airflow.models.connection import Connection
 from airflow.models.dag import DAG
-from airflow.operators.bash import BashOperator
 from airflow.providers.google.cloud.operators.cloud_sql import (
     CloudSQLCreateInstanceDatabaseOperator,
     CloudSQLCreateInstanceOperator,
     CloudSQLDeleteInstanceOperator,
     CloudSQLExecuteQueryOperator,
 )
+from airflow.settings import Session
 from airflow.utils.trigger_rule import TriggerRule
 from tests.system.providers.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
 
@@ -52,18 +52,6 @@ REGION = "us-central1"
 HOME_DIR = Path.home()
 
 COMPOSER_ENVIRONMENT = os.environ.get("COMPOSER_ENVIRONMENT", "")
-if COMPOSER_ENVIRONMENT:
-    # We assume that the test is launched in Cloud Composer environment 
because the reserved environment
-    # variable is assigned 
(https://cloud.google.com/composer/docs/composer-2/set-environment-variables)
-    GET_COMPOSER_NETWORK_COMMAND = """
-    gcloud composer environments describe $COMPOSER_ENVIRONMENT \
-    --location=$COMPOSER_LOCATION \
-    --project=$GCP_PROJECT \
-    --format="value(config.nodeConfig.network)"
-    """
-else:
-    # The test is launched locally
-    GET_COMPOSER_NETWORK_COMMAND = "echo"
 
 
 def run_in_composer():
@@ -115,7 +103,7 @@ def ip_configuration() -> dict[str, Any]:
             "ipv4Enabled": True,
             "requireSsl": False,
             "enablePrivatePathForGoogleCloudServices": True,
-            "privateNetwork": """{{ 
task_instance.xcom_pull('get_composer_network')}}""",
+            "privateNetwork": f"projects/{PROJECT_ID}/global/networks/default",
         }
     else:
         # Use connection to Cloud SQL instance via Public IP from anywhere 
(mask 0.0.0.0/0).
@@ -395,12 +383,6 @@ with DAG(
     catchup=False,
     tags=["example", "cloudsql", "postgres"],
 ) as dag:
-    get_composer_network = BashOperator(
-        task_id="get_composer_network",
-        bash_command=GET_COMPOSER_NETWORK_COMMAND,
-        do_xcom_push=True,
-    )
-
     for db_provider in DB_PROVIDERS:
         database_type: str = db_provider["database_type"]
         cloud_sql_instance_name: str = db_provider["cloud_sql_instance_name"]
@@ -467,9 +449,9 @@ with DAG(
             kwargs: dict[str, Any],
         ) -> str | None:
             session = settings.Session()
-            if session.query(Connection).filter(Connection.conn_id == 
connection_id).first():
-                log.warning("Connection '%s' already exists", connection_id)
-                return connection_id
+            log.info("Removing connection %s if it exists", connection_id)
+            query = session.query(Connection).filter(Connection.conn_id == 
connection_id)
+            query.delete()
 
             connection: dict[str, Any] = deepcopy(kwargs)
             connection["extra"]["instance"] = instance
@@ -523,30 +505,28 @@ with DAG(
 
         execute_queries_task = execute_queries(db_type=database_type)
 
-        @task_group(group_id=f"teardown_{database_type}")
-        def teardown(instance: str, db_type: str):
-            task_id = f"delete_cloud_sql_instance_{db_type}"
-            CloudSQLDeleteInstanceOperator(
-                task_id=task_id,
-                project_id=PROJECT_ID,
-                instance=instance,
-                trigger_rule=TriggerRule.ALL_DONE,
-            )
+        @task()
+        def delete_connection(connection_id: str) -> None:
+            session = Session()
+            log.info("Removing connection %s", connection_id)
+            query = session.query(Connection).filter(Connection.conn_id == 
connection_id)
+            query.delete()
+            session.commit()
 
-            for conn in CONNECTIONS:
-                connection_id = f"{conn.id}_{db_type}"
-                BashOperator(
-                    task_id=f"delete_connection_{connection_id}",
-                    
bash_command=DELETE_CONNECTION_COMMAND.format(connection_id),
-                    trigger_rule=TriggerRule.ALL_DONE,
-                )
+        delete_connections_task = delete_connection.expand(
+            connection_id=[f"{conn.id}_{database_type}" for conn in 
CONNECTIONS]
+        )
 
-        teardown_task = teardown(instance=cloud_sql_instance_name, 
db_type=database_type)
+        delete_instance = CloudSQLDeleteInstanceOperator(
+            task_id=f"delete_cloud_sql_instance_{database_type}",
+            project_id=PROJECT_ID,
+            instance=cloud_sql_instance_name,
+            trigger_rule=TriggerRule.ALL_DONE,
+        )
 
         (
             # TEST SETUP
-            get_composer_network
-            >> create_cloud_sql_instance
+            create_cloud_sql_instance
             >> [
                 create_database,
                 create_user_task,
@@ -556,7 +536,7 @@ with DAG(
             # TEST BODY
             >> execute_queries_task
             # TEST TEARDOWN
-            >> teardown_task
+            >> [delete_instance, delete_connections_task]
         )
 
     # ### Everything below this line is not part of example ###
diff --git 
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py 
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
index 3ae74a30d8..ef8617f926 100644
--- 
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
+++ 
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py
@@ -38,7 +38,6 @@ from airflow import settings
 from airflow.decorators import task
 from airflow.models.connection import Connection
 from airflow.models.dag import DAG
-from airflow.operators.bash import BashOperator
 from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLHook
 from airflow.providers.google.cloud.hooks.secret_manager import 
GoogleCloudSecretManagerHook
 from airflow.providers.google.cloud.operators.cloud_sql import (
@@ -47,27 +46,17 @@ from airflow.providers.google.cloud.operators.cloud_sql 
import (
     CloudSQLDeleteInstanceOperator,
     CloudSQLExecuteQueryOperator,
 )
+from airflow.settings import Session
 from airflow.utils.trigger_rule import TriggerRule
+from tests.system.providers.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
 
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
-PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "Not found")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or 
DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
 DAG_ID = "cloudsql-query-ssl"
 REGION = "us-central1"
 HOME_DIR = Path.home()
 
 COMPOSER_ENVIRONMENT = os.environ.get("COMPOSER_ENVIRONMENT", "")
-if COMPOSER_ENVIRONMENT:
-    # We assume that the test is launched in Cloud Composer environment 
because the reserved environment
-    # variable is assigned 
(https://cloud.google.com/composer/docs/composer-2/set-environment-variables)
-    GET_COMPOSER_NETWORK_COMMAND = """
-    gcloud composer environments describe $COMPOSER_ENVIRONMENT \
-    --location=$COMPOSER_LOCATION \
-    --project=$GCP_PROJECT \
-    --format="value(config.nodeConfig.network)"
-    """
-else:
-    # The test is launched locally
-    GET_COMPOSER_NETWORK_COMMAND = "echo"
 
 
 def run_in_composer():
@@ -121,7 +110,7 @@ def ip_configuration() -> dict[str, Any]:
             "requireSsl": False,
             "sslMode": "ENCRYPTED_ONLY",
             "enablePrivatePathForGoogleCloudServices": True,
-            "privateNetwork": """{{ 
task_instance.xcom_pull('get_composer_network')}}""",
+            "privateNetwork": f"projects/{PROJECT_ID}/global/networks/default",
         }
     else:
         # Use connection to Cloud SQL instance via Public IP from anywhere 
(mask 0.0.0.0/0).
@@ -273,12 +262,6 @@ with DAG(
     catchup=False,
     tags=["example", "cloudsql", "postgres"],
 ) as dag:
-    get_composer_network = BashOperator(
-        task_id="get_composer_network",
-        bash_command=GET_COMPOSER_NETWORK_COMMAND,
-        do_xcom_push=True,
-    )
-
     for db_provider in DB_PROVIDERS:
         database_type: str = db_provider["database_type"]
         cloud_sql_instance_name: str = db_provider["cloud_sql_instance_name"]
@@ -342,9 +325,9 @@ with DAG(
             connection_id: str, instance: str, db_type: str, ip_address: str, 
port: str
         ) -> str | None:
             session = settings.Session()
-            if session.query(Connection).filter(Connection.conn_id == 
connection_id).first():
-                log.warning("Connection '%s' already exists", connection_id)
-                return connection_id
+            log.info("Removing connection %s if it exists", connection_id)
+            query = session.query(Connection).filter(Connection.conn_id == 
connection_id)
+            query.delete()
 
             connection: dict[str, Any] = 
deepcopy(CONNECTION_PUBLIC_TCP_SSL_KWARGS)
             connection["extra"]["instance"] = instance
@@ -472,12 +455,15 @@ with DAG(
             trigger_rule=TriggerRule.ALL_DONE,
         )
 
-        delete_connection = BashOperator(
-            task_id=f"delete_connection_{conn_id}",
-            bash_command=DELETE_CONNECTION_COMMAND.format(conn_id),
-            trigger_rule=TriggerRule.ALL_DONE,
-            skip_on_exit_code=1,
-        )
+        @task(task_id=f"delete_connection_{database_type}")
+        def delete_connection(connection_id: str) -> None:
+            session = Session()
+            log.info("Removing connection %s", connection_id)
+            query = session.query(Connection).filter(Connection.conn_id == 
connection_id)
+            query.delete()
+            session.commit()
+
+        delete_connection_task = delete_connection(connection_id=conn_id)
 
         @task(task_id=f"delete_secret_{database_type}")
         def delete_secret(ssl_secret_id, db_type: str) -> None:
@@ -491,8 +477,7 @@ with DAG(
 
         (
             # TEST SETUP
-            get_composer_network
-            >> create_cloud_sql_instance
+            create_cloud_sql_instance
             >> [create_database, create_user_task, get_ip_address_task]
             >> create_connection_task
             >> create_ssl_certificate_task
@@ -501,7 +486,7 @@ with DAG(
             >> query_task
             >> query_task_secret
             # TEST TEARDOWN
-            >> [delete_instance, delete_connection, delete_secret_task]
+            >> [delete_instance, delete_connection_task, delete_secret_task]
         )
 
     # ### Everything below this line is not part of example ###

Reply via email to