This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4ae85d754e Bugfix BigQueryToMsSqlOperator (#39171)
4ae85d754e is described below
commit 4ae85d754e9f8a65d461e86eb6111d3b9974a065
Author: max <[email protected]>
AuthorDate: Tue Apr 23 10:47:42 2024 +0000
Bugfix BigQueryToMsSqlOperator (#39171)
---
.../google/cloud/transfers/bigquery_to_mssql.py | 2 +-
tests/always/test_project_structure.py | 1 -
.../cloud/transfers/test_bigquery_to_mssql.py | 88 +++++++
.../cloud/bigquery/example_bigquery_to_mssql.py | 276 +++++++++++++++++++--
4 files changed, 342 insertions(+), 25 deletions(-)
diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
index c251ec5615..8a5749dc9e 100644
--- a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
+++ b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py
@@ -91,7 +91,7 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
self.source_project_dataset_table = source_project_dataset_table
def get_sql_hook(self) -> MsSqlHook:
- return MsSqlHook(schema=self.database,
mysql_conn_id=self.mssql_conn_id)
+ return MsSqlHook(schema=self.database,
mssql_conn_id=self.mssql_conn_id)
def persist_links(self, context: Context) -> None:
project_id, dataset_id, table_id =
self.source_project_dataset_table.split(".")
diff --git a/tests/always/test_project_structure.py
b/tests/always/test_project_structure.py
index 4341dd1aec..3437092e65 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -143,7 +143,6 @@ class TestProjectStructure:
"tests/providers/google/cloud/operators/vertex_ai/test_model_service.py",
"tests/providers/google/cloud/operators/vertex_ai/test_pipeline_job.py",
"tests/providers/google/cloud/sensors/test_dataform.py",
- "tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py",
"tests/providers/google/cloud/transfers/test_bigquery_to_sql.py",
"tests/providers/google/cloud/transfers/test_mssql_to_gcs.py",
"tests/providers/google/cloud/transfers/test_presto_to_gcs.py",
diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py
b/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py
new file mode 100644
index 0000000000..e4fd897324
--- /dev/null
+++ b/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py
@@ -0,0 +1,88 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+from airflow.providers.google.cloud.transfers.bigquery_to_mssql import
BigQueryToMsSqlOperator
+
+TASK_ID = "test-bq-create-table-operator"
+TEST_DATASET = "test-dataset"
+TEST_TABLE_ID = "test-table-id"
+TEST_DAG_ID = "test-bigquery-operators"
+TEST_PROJECT = "test-project"
+
+
+class TestBigQueryToMsSqlOperator:
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryTableLink")
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
+ def test_execute_good_request_to_bq(self, mock_hook, mock_link):
+ destination_table = "table"
+ operator = BigQueryToMsSqlOperator(
+ task_id=TASK_ID,
+
source_project_dataset_table=f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}",
+ target_table_name=destination_table,
+ replace=False,
+ )
+
+ operator.execute(None)
+ mock_hook.return_value.list_rows.assert_called_once_with(
+ dataset_id=TEST_DATASET,
+ table_id=TEST_TABLE_ID,
+ max_results=1000,
+ selected_fields=None,
+ start_index=0,
+ )
+
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook")
+ def test_get_sql_hook(self, mock_hook):
+ hook_expected = mock_hook.return_value
+
+ destination_table = "table"
+ operator = BigQueryToMsSqlOperator(
+ task_id=TASK_ID,
+
source_project_dataset_table=f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}",
+ target_table_name=destination_table,
+ replace=False,
+ )
+
+ hook_actual = operator.get_sql_hook()
+
+ assert hook_actual == hook_expected
+ mock_hook.assert_called_once_with(schema=operator.database,
mssql_conn_id=operator.mssql_conn_id)
+
+
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryTableLink")
+ def test_persist_links(self, mock_link):
+ mock_context = mock.MagicMock()
+
+ destination_table = "table"
+ operator = BigQueryToMsSqlOperator(
+ task_id=TASK_ID,
+
source_project_dataset_table=f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}",
+ target_table_name=destination_table,
+ replace=False,
+ )
+ operator.persist_links(context=mock_context)
+
+ mock_link.persist.assert_called_once_with(
+ context=mock_context,
+ task_instance=operator,
+ dataset_id=TEST_DATASET,
+ project_id=TEST_PROJECT,
+ table_id=TEST_TABLE_ID,
+ )
diff --git
a/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py
b/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py
index 822020df28..2ad7671ec3 100644
--- a/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py
+++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py
@@ -17,21 +17,42 @@
# under the License.
"""
Example Airflow DAG for Google BigQuery service.
+
+This DAG relies on the following OS environment variables
+
+* AIRFLOW__API__GOOGLE_KEY_PATH - Path to service account key file. Note, you
can skip this variable if you
+ run this DAG in a Composer environment.
"""
from __future__ import annotations
+import logging
import os
from datetime import datetime
import pytest
+from pendulum import duration
+from airflow.decorators import task
+from airflow.models import Connection
from airflow.models.dag import DAG
+from airflow.operators.bash import BashOperator
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
+from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook
+from airflow.providers.google.cloud.hooks.compute_ssh import
ComputeEngineSSHHook
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryCreateEmptyDatasetOperator,
BigQueryCreateEmptyTableOperator,
BigQueryDeleteDatasetOperator,
+ BigQueryInsertJobOperator,
)
+from airflow.providers.google.cloud.operators.compute import (
+ ComputeEngineDeleteInstanceOperator,
+ ComputeEngineInsertInstanceOperator,
+)
+from airflow.providers.ssh.operators.ssh import SSHOperator
+from airflow.settings import Session
+from airflow.utils.trigger_rule import TriggerRule
try:
from airflow.providers.google.cloud.transfers.bigquery_to_mssql import
BigQueryToMsSqlOperator
@@ -39,13 +60,102 @@ except ImportError:
pytest.skip("MsSQL not available", allow_module_level=True)
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
-PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "example-project")
DAG_ID = "example_bigquery_to_mssql"
-DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}"
-DATA_EXPORT_BUCKET_NAME = os.environ.get("GCP_BIGQUERY_EXPORT_BUCKET_NAME",
"INVALID BUCKET NAME")
-TABLE = "table_42"
-destination_table = "mssql_table_test"
+
+REGION = "europe-west2"
+ZONE = REGION + "-a"
+NETWORK = "default"
+CONNECTION_ID = f"connection_{DAG_ID}_{ENV_ID}".replace("-", "_")
+CONNECTION_TYPE = "mssql"
+
+BIGQUERY_DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}"
+BIGQUERY_TABLE = "table_42"
+INSERT_ROWS_QUERY = (
+ f"INSERT INTO {BIGQUERY_DATASET_NAME}.{BIGQUERY_TABLE} (emp_name, salary) "
+ "VALUES ('emp 1', 10000), ('emp 2', 15000);"
+)
+
+DB_PORT = 1433
+DB_USER_NAME = "sa"
+DB_USER_PASSWORD = "5FHq4fSZ85kK6g0n"
+SETUP_MSSQL_COMMAND = f"""
+sudo apt update &&
+sudo apt install -y docker.io &&
+sudo docker run -e ACCEPT_EULA=Y -e MSSQL_SA_PASSWORD={DB_USER_PASSWORD} -p
{DB_PORT}:{DB_PORT} \
+ -d mcr.microsoft.com/mssql/server:2022-latest
+"""
+SQL_TABLE = "test_table"
+SQL_CREATE_TABLE = f"""if not exists (select * from sys.tables where
sys.tables.name='{SQL_TABLE}' and sys.tables.type='U')
+ create table {SQL_TABLE} (
+ emp_name VARCHAR(8),
+ salary INT
+ )
+"""
+
+GCE_MACHINE_TYPE = "n1-standard-1"
+GCE_INSTANCE_NAME = f"instance-{DAG_ID}-{ENV_ID}".replace("_", "-")
+GCE_INSTANCE_BODY = {
+ "name": GCE_INSTANCE_NAME,
+ "machine_type": f"zones/{ZONE}/machineTypes/{GCE_MACHINE_TYPE}",
+ "disks": [
+ {
+ "boot": True,
+ "device_name": GCE_INSTANCE_NAME,
+ "initialize_params": {
+ "disk_size_gb": "10",
+ "disk_type": f"zones/{ZONE}/diskTypes/pd-balanced",
+ "source_image":
"projects/debian-cloud/global/images/debian-11-bullseye-v20220621",
+ },
+ }
+ ],
+ "network_interfaces": [
+ {
+ "access_configs": [{"name": "External NAT", "network_tier":
"PREMIUM"}],
+ "stack_type": "IPV4_ONLY",
+ "subnetwork": f"regions/{REGION}/subnetworks/default",
+ }
+ ],
+}
+FIREWALL_RULE_NAME = f"allow-http-{DB_PORT}"
+CREATE_FIREWALL_RULE_COMMAND = f"""
+if [ $AIRFLOW__API__GOOGLE_KEY_PATH ]; then \
+ gcloud auth activate-service-account
--key-file=$AIRFLOW__API__GOOGLE_KEY_PATH; \
+fi;
+
+if [ -z gcloud compute firewall-rules list --filter=name:{FIREWALL_RULE_NAME}
--format="value(name)" ]; then \
+ gcloud compute firewall-rules create {FIREWALL_RULE_NAME} \
+ --project={PROJECT_ID} \
+ --direction=INGRESS \
+ --priority=100 \
+ --network={NETWORK} \
+ --action=ALLOW \
+ --rules=tcp:{DB_PORT} \
+ --source-ranges=0.0.0.0/0
+else
+ echo "Firewall rule {FIREWALL_RULE_NAME} already exists."
+fi
+"""
+DELETE_FIREWALL_RULE_COMMAND = f"""
+if [ $AIRFLOW__API__GOOGLE_KEY_PATH ]; then \
+ gcloud auth activate-service-account
--key-file=$AIRFLOW__API__GOOGLE_KEY_PATH; \
+fi; \
+if [ gcloud compute firewall-rules list --filter=name:{FIREWALL_RULE_NAME}
--format="value(name)" ]; then \
+ gcloud compute firewall-rules delete {FIREWALL_RULE_NAME}
--project={PROJECT_ID} --quiet; \
+fi;
+"""
+DELETE_PERSISTENT_DISK_COMMAND = f"""
+if [ $AIRFLOW__API__GOOGLE_KEY_PATH ]; then \
+ gcloud auth activate-service-account
--key-file=$AIRFLOW__API__GOOGLE_KEY_PATH; \
+fi;
+
+gcloud compute disks delete {GCE_INSTANCE_NAME} --project={PROJECT_ID}
--zone={ZONE} --quiet
+"""
+
+
+log = logging.getLogger(__name__)
+
with DAG(
DAG_ID,
@@ -54,41 +164,161 @@ with DAG(
catchup=False,
tags=["example", "bigquery"],
) as dag:
+ create_bigquery_dataset = BigQueryCreateEmptyDatasetOperator(
+ task_id="create_bigquery_dataset", dataset_id=BIGQUERY_DATASET_NAME
+ )
+
+ create_bigquery_table = BigQueryCreateEmptyTableOperator(
+ task_id="create_bigquery_table",
+ dataset_id=BIGQUERY_DATASET_NAME,
+ table_id=BIGQUERY_TABLE,
+ schema_fields=[
+ {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"},
+ {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"},
+ ],
+ )
+
+ insert_bigquery_data = BigQueryInsertJobOperator(
+ task_id="insert_bigquery_data",
+ configuration={
+ "query": {
+ "query": INSERT_ROWS_QUERY,
+ "useLegacySql": False,
+ "priority": "BATCH",
+ }
+ },
+ )
+
+ create_gce_instance = ComputeEngineInsertInstanceOperator(
+ task_id="create_gce_instance",
+ project_id=PROJECT_ID,
+ zone=ZONE,
+ body=GCE_INSTANCE_BODY,
+ )
+
+ create_firewall_rule = BashOperator(
+ task_id="create_firewall_rule",
+ bash_command=CREATE_FIREWALL_RULE_COMMAND,
+ )
+
+ setup_mssql = SSHOperator(
+ task_id="setup_mssql",
+ ssh_hook=ComputeEngineSSHHook(
+ user="username",
+ instance_name=GCE_INSTANCE_NAME,
+ zone=ZONE,
+ project_id=PROJECT_ID,
+ use_oslogin=False,
+ use_iap_tunnel=False,
+ cmd_timeout=180,
+ ),
+ command=SETUP_MSSQL_COMMAND,
+ retries=4,
+ )
+
+ @task
+ def get_public_ip() -> str:
+ hook = ComputeEngineHook()
+ address = hook.get_instance_address(resource_id=GCE_INSTANCE_NAME,
zone=ZONE, project_id=PROJECT_ID)
+ return address
+
+ get_public_ip_task = get_public_ip()
+
+ @task
+ def setup_connection(ip_address: str) -> None:
+ connection = Connection(
+ conn_id=CONNECTION_ID,
+ description="Example connection",
+ conn_type=CONNECTION_TYPE,
+ host=ip_address,
+ login=DB_USER_NAME,
+ password=DB_USER_PASSWORD,
+ port=DB_PORT,
+ )
+ session = Session()
+ log.info("Removing connection %s if it exists", CONNECTION_ID)
+ query = session.query(Connection).filter(Connection.conn_id ==
CONNECTION_ID)
+ query.delete()
+
+ session.add(connection)
+ session.commit()
+ log.info("Connection %s created", CONNECTION_ID)
+
+ setup_connection_task = setup_connection(get_public_ip_task)
+
+ create_sql_table = SQLExecuteQueryOperator(
+ task_id="create_sql_table",
+ conn_id=CONNECTION_ID,
+ sql=SQL_CREATE_TABLE,
+ retries=4,
+ retry_delay=duration(seconds=20),
+ retry_exponential_backoff=False,
+ )
+
# [START howto_operator_bigquery_to_mssql]
bigquery_to_mssql = BigQueryToMsSqlOperator(
task_id="bigquery_to_mssql",
- source_project_dataset_table=f"{PROJECT_ID}.{DATASET_NAME}.{TABLE}",
- target_table_name=destination_table,
+ mssql_conn_id=CONNECTION_ID,
+
source_project_dataset_table=f"{PROJECT_ID}.{BIGQUERY_DATASET_NAME}.{BIGQUERY_TABLE}",
+ target_table_name=SQL_TABLE,
replace=False,
)
# [END howto_operator_bigquery_to_mssql]
- create_dataset =
BigQueryCreateEmptyDatasetOperator(task_id="create_dataset",
dataset_id=DATASET_NAME)
+ delete_bigquery_dataset = BigQueryDeleteDatasetOperator(
+ task_id="delete_bigquery_dataset",
+ dataset_id=BIGQUERY_DATASET_NAME,
+ delete_contents=True,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
- create_table = BigQueryCreateEmptyTableOperator(
- task_id="create_table",
- dataset_id=DATASET_NAME,
- table_id=TABLE,
- schema_fields=[
- {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"},
- {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"},
- ],
+ delete_firewall_rule = BashOperator(
+ task_id="delete_firewall_rule",
+ bash_command=DELETE_FIREWALL_RULE_COMMAND,
+ trigger_rule=TriggerRule.ALL_DONE,
)
- delete_dataset = BigQueryDeleteDatasetOperator(
- task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True
+ delete_gce_instance = ComputeEngineDeleteInstanceOperator(
+ task_id="delete_gce_instance",
+ resource_id=GCE_INSTANCE_NAME,
+ zone=ZONE,
+ project_id=PROJECT_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
)
+ delete_persistent_disk = BashOperator(
+ task_id="delete_persistent_disk",
+ bash_command=DELETE_PERSISTENT_DISK_COMMAND,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ delete_connection = BashOperator(
+ task_id="delete_connection",
+ bash_command=f"airflow connections delete {CONNECTION_ID}",
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ # TEST SETUP
+ create_bigquery_dataset >> create_bigquery_table >> insert_bigquery_data
+ create_gce_instance >> setup_mssql
+ create_gce_instance >> get_public_ip_task >> setup_connection_task
+ [setup_mssql, setup_connection_task, create_firewall_rule] >>
create_sql_table
+
(
- # TEST SETUP
- create_dataset
- >> create_table
+ [insert_bigquery_data, create_sql_table]
# TEST BODY
>> bigquery_to_mssql
- # TEST TEARDOWN
- >> delete_dataset
)
+ # TEST TEARDOWN
+ bigquery_to_mssql >> [
+ delete_bigquery_dataset,
+ delete_firewall_rule,
+ delete_gce_instance,
+ delete_connection,
+ ]
+ delete_gce_instance >> delete_persistent_disk
+
from tests.system.utils.watcher import watcher
# This test needs watcher in order to properly mark success/failure