This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ae85d754e Bugfix BigQueryToMsSqlOperator (#39171)
4ae85d754e is described below

commit 4ae85d754e9f8a65d461e86eb6111d3b9974a065
Author: max <[email protected]>
AuthorDate: Tue Apr 23 10:47:42 2024 +0000

    Bugfix BigQueryToMsSqlOperator (#39171)
---
 .../google/cloud/transfers/bigquery_to_mssql.py    |   2 +-
 tests/always/test_project_structure.py             |   1 -
 .../cloud/transfers/test_bigquery_to_mssql.py      |  88 +++++++
 .../cloud/bigquery/example_bigquery_to_mssql.py    | 276 +++++++++++++++++++--
 4 files changed, 342 insertions(+), 25 deletions(-)

diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py 
b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
index c251ec5615..8a5749dc9e 100644
--- a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
+++ b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
@@ -91,7 +91,7 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
         self.source_project_dataset_table = source_project_dataset_table
 
     def get_sql_hook(self) -> MsSqlHook:
-        return MsSqlHook(schema=self.database, 
mysql_conn_id=self.mssql_conn_id)
+        return MsSqlHook(schema=self.database, 
mssql_conn_id=self.mssql_conn_id)
 
     def persist_links(self, context: Context) -> None:
         project_id, dataset_id, table_id = 
self.source_project_dataset_table.split(".")
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index 4341dd1aec..3437092e65 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -143,7 +143,6 @@ class TestProjectStructure:
             
"tests/providers/google/cloud/operators/vertex_ai/test_model_service.py",
             
"tests/providers/google/cloud/operators/vertex_ai/test_pipeline_job.py",
             "tests/providers/google/cloud/sensors/test_dataform.py",
-            "tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py",
             "tests/providers/google/cloud/transfers/test_bigquery_to_sql.py",
             "tests/providers/google/cloud/transfers/test_mssql_to_gcs.py",
             "tests/providers/google/cloud/transfers/test_presto_to_gcs.py",
diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py 
b/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py
new file mode 100644
index 0000000000..e4fd897324
--- /dev/null
+++ b/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py
@@ -0,0 +1,88 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+from airflow.providers.google.cloud.transfers.bigquery_to_mssql import 
BigQueryToMsSqlOperator
+
+TASK_ID = "test-bq-create-table-operator"
+TEST_DATASET = "test-dataset"
+TEST_TABLE_ID = "test-table-id"
+TEST_DAG_ID = "test-bigquery-operators"
+TEST_PROJECT = "test-project"
+
+
+class TestBigQueryToMsSqlOperator:
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryTableLink")
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
+    def test_execute_good_request_to_bq(self, mock_hook, mock_link):
+        destination_table = "table"
+        operator = BigQueryToMsSqlOperator(
+            task_id=TASK_ID,
+            
source_project_dataset_table=f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}",
+            target_table_name=destination_table,
+            replace=False,
+        )
+
+        operator.execute(None)
+        mock_hook.return_value.list_rows.assert_called_once_with(
+            dataset_id=TEST_DATASET,
+            table_id=TEST_TABLE_ID,
+            max_results=1000,
+            selected_fields=None,
+            start_index=0,
+        )
+
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook")
+    def test_get_sql_hook(self, mock_hook):
+        hook_expected = mock_hook.return_value
+
+        destination_table = "table"
+        operator = BigQueryToMsSqlOperator(
+            task_id=TASK_ID,
+            
source_project_dataset_table=f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}",
+            target_table_name=destination_table,
+            replace=False,
+        )
+
+        hook_actual = operator.get_sql_hook()
+
+        assert hook_actual == hook_expected
+        mock_hook.assert_called_once_with(schema=operator.database, 
mssql_conn_id=operator.mssql_conn_id)
+
+    
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryTableLink")
+    def test_persist_links(self, mock_link):
+        mock_context = mock.MagicMock()
+
+        destination_table = "table"
+        operator = BigQueryToMsSqlOperator(
+            task_id=TASK_ID,
+            
source_project_dataset_table=f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}",
+            target_table_name=destination_table,
+            replace=False,
+        )
+        operator.persist_links(context=mock_context)
+
+        mock_link.persist.assert_called_once_with(
+            context=mock_context,
+            task_instance=operator,
+            dataset_id=TEST_DATASET,
+            project_id=TEST_PROJECT,
+            table_id=TEST_TABLE_ID,
+        )
diff --git 
a/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py 
b/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py
index 822020df28..2ad7671ec3 100644
--- a/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py
+++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py
@@ -17,21 +17,42 @@
 # under the License.
 """
 Example Airflow DAG for Google BigQuery service.
+
+This DAG relies on the following OS environment variables
+
+* AIRFLOW__API__GOOGLE_KEY_PATH - Path to service account key file. Note, you 
can skip this variable if you
+  run this DAG in a Composer environment.
 """
 
 from __future__ import annotations
 
+import logging
 import os
 from datetime import datetime
 
 import pytest
+from pendulum import duration
 
+from airflow.decorators import task
+from airflow.models import Connection
 from airflow.models.dag import DAG
+from airflow.operators.bash import BashOperator
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
+from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook
+from airflow.providers.google.cloud.hooks.compute_ssh import 
ComputeEngineSSHHook
 from airflow.providers.google.cloud.operators.bigquery import (
     BigQueryCreateEmptyDatasetOperator,
     BigQueryCreateEmptyTableOperator,
     BigQueryDeleteDatasetOperator,
+    BigQueryInsertJobOperator,
 )
+from airflow.providers.google.cloud.operators.compute import (
+    ComputeEngineDeleteInstanceOperator,
+    ComputeEngineInsertInstanceOperator,
+)
+from airflow.providers.ssh.operators.ssh import SSHOperator
+from airflow.settings import Session
+from airflow.utils.trigger_rule import TriggerRule
 
 try:
     from airflow.providers.google.cloud.transfers.bigquery_to_mssql import 
BigQueryToMsSqlOperator
@@ -39,13 +60,102 @@ except ImportError:
     pytest.skip("MsSQL not available", allow_module_level=True)
 
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
-PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "example-project")
 DAG_ID = "example_bigquery_to_mssql"
 
-DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}"
-DATA_EXPORT_BUCKET_NAME = os.environ.get("GCP_BIGQUERY_EXPORT_BUCKET_NAME", 
"INVALID BUCKET NAME")
-TABLE = "table_42"
-destination_table = "mssql_table_test"
+
+REGION = "europe-west2"
+ZONE = REGION + "-a"
+NETWORK = "default"
+CONNECTION_ID = f"connection_{DAG_ID}_{ENV_ID}".replace("-", "_")
+CONNECTION_TYPE = "mssql"
+
+BIGQUERY_DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}"
+BIGQUERY_TABLE = "table_42"
+INSERT_ROWS_QUERY = (
+    f"INSERT INTO {BIGQUERY_DATASET_NAME}.{BIGQUERY_TABLE} (emp_name, salary) "
+    "VALUES ('emp 1', 10000), ('emp 2', 15000);"
+)
+
+DB_PORT = 1433
+DB_USER_NAME = "sa"
+DB_USER_PASSWORD = "5FHq4fSZ85kK6g0n"
+SETUP_MSSQL_COMMAND = f"""
+sudo apt update &&
+sudo apt install -y docker.io &&
+sudo docker run -e ACCEPT_EULA=Y -e MSSQL_SA_PASSWORD={DB_USER_PASSWORD} -p 
{DB_PORT}:{DB_PORT} \
+    -d mcr.microsoft.com/mssql/server:2022-latest
+"""
+SQL_TABLE = "test_table"
+SQL_CREATE_TABLE = f"""if not exists (select * from sys.tables where 
sys.tables.name='{SQL_TABLE}' and sys.tables.type='U')
+    create table {SQL_TABLE} (
+        emp_name VARCHAR(8),
+        salary INT
+    )
+"""
+
+GCE_MACHINE_TYPE = "n1-standard-1"
+GCE_INSTANCE_NAME = f"instance-{DAG_ID}-{ENV_ID}".replace("_", "-")
+GCE_INSTANCE_BODY = {
+    "name": GCE_INSTANCE_NAME,
+    "machine_type": f"zones/{ZONE}/machineTypes/{GCE_MACHINE_TYPE}",
+    "disks": [
+        {
+            "boot": True,
+            "device_name": GCE_INSTANCE_NAME,
+            "initialize_params": {
+                "disk_size_gb": "10",
+                "disk_type": f"zones/{ZONE}/diskTypes/pd-balanced",
+                "source_image": 
"projects/debian-cloud/global/images/debian-11-bullseye-v20220621",
+            },
+        }
+    ],
+    "network_interfaces": [
+        {
+            "access_configs": [{"name": "External NAT", "network_tier": 
"PREMIUM"}],
+            "stack_type": "IPV4_ONLY",
+            "subnetwork": f"regions/{REGION}/subnetworks/default",
+        }
+    ],
+}
+FIREWALL_RULE_NAME = f"allow-http-{DB_PORT}"
+CREATE_FIREWALL_RULE_COMMAND = f"""
+if [ $AIRFLOW__API__GOOGLE_KEY_PATH ]; then \
+ gcloud auth activate-service-account 
--key-file=$AIRFLOW__API__GOOGLE_KEY_PATH; \
+fi;
+
+if [ -z gcloud compute firewall-rules list --filter=name:{FIREWALL_RULE_NAME} 
--format="value(name)" ]; then \
+    gcloud compute firewall-rules create {FIREWALL_RULE_NAME} \
+      --project={PROJECT_ID} \
+      --direction=INGRESS \
+      --priority=100 \
+      --network={NETWORK} \
+      --action=ALLOW \
+      --rules=tcp:{DB_PORT} \
+      --source-ranges=0.0.0.0/0
+else
+    echo "Firewall rule {FIREWALL_RULE_NAME} already exists."
+fi
+"""
+DELETE_FIREWALL_RULE_COMMAND = f"""
+if [ $AIRFLOW__API__GOOGLE_KEY_PATH ]; then \
+ gcloud auth activate-service-account 
--key-file=$AIRFLOW__API__GOOGLE_KEY_PATH; \
+fi; \
+if [ gcloud compute firewall-rules list --filter=name:{FIREWALL_RULE_NAME} 
--format="value(name)" ]; then \
+    gcloud compute firewall-rules delete {FIREWALL_RULE_NAME} 
--project={PROJECT_ID} --quiet; \
+fi;
+"""
+DELETE_PERSISTENT_DISK_COMMAND = f"""
+if [ $AIRFLOW__API__GOOGLE_KEY_PATH ]; then \
+ gcloud auth activate-service-account 
--key-file=$AIRFLOW__API__GOOGLE_KEY_PATH; \
+fi;
+
+gcloud compute disks delete {GCE_INSTANCE_NAME} --project={PROJECT_ID} 
--zone={ZONE} --quiet
+"""
+
+
+log = logging.getLogger(__name__)
+
 
 with DAG(
     DAG_ID,
@@ -54,41 +164,161 @@ with DAG(
     catchup=False,
     tags=["example", "bigquery"],
 ) as dag:
+    create_bigquery_dataset = BigQueryCreateEmptyDatasetOperator(
+        task_id="create_bigquery_dataset", dataset_id=BIGQUERY_DATASET_NAME
+    )
+
+    create_bigquery_table = BigQueryCreateEmptyTableOperator(
+        task_id="create_bigquery_table",
+        dataset_id=BIGQUERY_DATASET_NAME,
+        table_id=BIGQUERY_TABLE,
+        schema_fields=[
+            {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"},
+            {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"},
+        ],
+    )
+
+    insert_bigquery_data = BigQueryInsertJobOperator(
+        task_id="insert_bigquery_data",
+        configuration={
+            "query": {
+                "query": INSERT_ROWS_QUERY,
+                "useLegacySql": False,
+                "priority": "BATCH",
+            }
+        },
+    )
+
+    create_gce_instance = ComputeEngineInsertInstanceOperator(
+        task_id="create_gce_instance",
+        project_id=PROJECT_ID,
+        zone=ZONE,
+        body=GCE_INSTANCE_BODY,
+    )
+
+    create_firewall_rule = BashOperator(
+        task_id="create_firewall_rule",
+        bash_command=CREATE_FIREWALL_RULE_COMMAND,
+    )
+
+    setup_mssql = SSHOperator(
+        task_id="setup_mssql",
+        ssh_hook=ComputeEngineSSHHook(
+            user="username",
+            instance_name=GCE_INSTANCE_NAME,
+            zone=ZONE,
+            project_id=PROJECT_ID,
+            use_oslogin=False,
+            use_iap_tunnel=False,
+            cmd_timeout=180,
+        ),
+        command=SETUP_MSSQL_COMMAND,
+        retries=4,
+    )
+
+    @task
+    def get_public_ip() -> str:
+        hook = ComputeEngineHook()
+        address = hook.get_instance_address(resource_id=GCE_INSTANCE_NAME, 
zone=ZONE, project_id=PROJECT_ID)
+        return address
+
+    get_public_ip_task = get_public_ip()
+
+    @task
+    def setup_connection(ip_address: str) -> None:
+        connection = Connection(
+            conn_id=CONNECTION_ID,
+            description="Example connection",
+            conn_type=CONNECTION_TYPE,
+            host=ip_address,
+            login=DB_USER_NAME,
+            password=DB_USER_PASSWORD,
+            port=DB_PORT,
+        )
+        session = Session()
+        log.info("Removing connection %s if it exists", CONNECTION_ID)
+        query = session.query(Connection).filter(Connection.conn_id == 
CONNECTION_ID)
+        query.delete()
+
+        session.add(connection)
+        session.commit()
+        log.info("Connection %s created", CONNECTION_ID)
+
+    setup_connection_task = setup_connection(get_public_ip_task)
+
+    create_sql_table = SQLExecuteQueryOperator(
+        task_id="create_sql_table",
+        conn_id=CONNECTION_ID,
+        sql=SQL_CREATE_TABLE,
+        retries=4,
+        retry_delay=duration(seconds=20),
+        retry_exponential_backoff=False,
+    )
+
     # [START howto_operator_bigquery_to_mssql]
     bigquery_to_mssql = BigQueryToMsSqlOperator(
         task_id="bigquery_to_mssql",
-        source_project_dataset_table=f"{PROJECT_ID}.{DATASET_NAME}.{TABLE}",
-        target_table_name=destination_table,
+        mssql_conn_id=CONNECTION_ID,
+        
source_project_dataset_table=f"{PROJECT_ID}.{BIGQUERY_DATASET_NAME}.{BIGQUERY_TABLE}",
+        target_table_name=SQL_TABLE,
         replace=False,
     )
     # [END howto_operator_bigquery_to_mssql]
 
-    create_dataset = 
BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", 
dataset_id=DATASET_NAME)
+    delete_bigquery_dataset = BigQueryDeleteDatasetOperator(
+        task_id="delete_bigquery_dataset",
+        dataset_id=BIGQUERY_DATASET_NAME,
+        delete_contents=True,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
 
-    create_table = BigQueryCreateEmptyTableOperator(
-        task_id="create_table",
-        dataset_id=DATASET_NAME,
-        table_id=TABLE,
-        schema_fields=[
-            {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"},
-            {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"},
-        ],
+    delete_firewall_rule = BashOperator(
+        task_id="delete_firewall_rule",
+        bash_command=DELETE_FIREWALL_RULE_COMMAND,
+        trigger_rule=TriggerRule.ALL_DONE,
     )
 
-    delete_dataset = BigQueryDeleteDatasetOperator(
-        task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True
+    delete_gce_instance = ComputeEngineDeleteInstanceOperator(
+        task_id="delete_gce_instance",
+        resource_id=GCE_INSTANCE_NAME,
+        zone=ZONE,
+        project_id=PROJECT_ID,
+        trigger_rule=TriggerRule.ALL_DONE,
     )
 
+    delete_persistent_disk = BashOperator(
+        task_id="delete_persistent_disk",
+        bash_command=DELETE_PERSISTENT_DISK_COMMAND,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    delete_connection = BashOperator(
+        task_id="delete_connection",
+        bash_command=f"airflow connections delete {CONNECTION_ID}",
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    # TEST SETUP
+    create_bigquery_dataset >> create_bigquery_table >> insert_bigquery_data
+    create_gce_instance >> setup_mssql
+    create_gce_instance >> get_public_ip_task >> setup_connection_task
+    [setup_mssql, setup_connection_task, create_firewall_rule] >> 
create_sql_table
+
     (
-        # TEST SETUP
-        create_dataset
-        >> create_table
+        [insert_bigquery_data, create_sql_table]
         # TEST BODY
         >> bigquery_to_mssql
-        # TEST TEARDOWN
-        >> delete_dataset
     )
 
+    # TEST TEARDOWN
+    bigquery_to_mssql >> [
+        delete_bigquery_dataset,
+        delete_firewall_rule,
+        delete_gce_instance,
+        delete_connection,
+    ]
+    delete_gce_instance >> delete_persistent_disk
+
     from tests.system.utils.watcher import watcher
 
     # This test needs watcher in order to properly mark success/failure

Reply via email to