This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0c6fd5be86 Remove usage of deprecated methods from BigQueryCursor
(#35606)
0c6fd5be86 is described below
commit 0c6fd5be864b26031d388c921ed48058a610983e
Author: Maksim <[email protected]>
AuthorDate: Fri Nov 17 12:04:25 2023 +0100
Remove usage of deprecated methods from BigQueryCursor (#35606)
---
airflow/providers/google/cloud/hooks/bigquery.py | 170 ++++++++++++++++++++-
.../providers/google/cloud/hooks/test_bigquery.py | 38 ++---
2 files changed, 187 insertions(+), 21 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index 802a134765..3ad29c66f0 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -129,7 +129,8 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
def get_conn(self) -> BigQueryConnection:
"""Get a BigQuery PEP 249 connection object."""
- service = self.get_service()
+ http_authorized = self._authorize()
+ service = build("bigquery", "v2", http=http_authorized,
cache_discovery=False)
return BigQueryConnection(
service=service,
project_id=self.project_id,
@@ -2775,7 +2776,7 @@ class BigQueryCursor(BigQueryBaseCursor):
"""
sql = _bind_parameters(operation, parameters) if parameters else
operation
self.flush_results()
- self.job_id = self.hook.run_query(sql)
+ self.job_id = self._run_query(sql)
query_results = self._get_query_result()
if "schema" in query_results:
@@ -2913,6 +2914,171 @@ class BigQueryCursor(BigQueryBaseCursor):
return query_results
+ def _run_query(
+ self,
+ sql,
+ location: str | None = None,
+ ) -> str:
+ """Run job query."""
+ if not self.project_id:
+ raise ValueError("The project_id should be set")
+
+ configuration = self._prepare_query_configuration(sql)
+ job = self.hook.insert_job(configuration=configuration,
project_id=self.project_id, location=location)
+
+ return job.job_id
+
+ def _prepare_query_configuration(
+ self,
+ sql,
+ destination_dataset_table: str | None = None,
+ write_disposition: str = "WRITE_EMPTY",
+ allow_large_results: bool = False,
+ flatten_results: bool | None = None,
+ udf_config: list | None = None,
+ use_legacy_sql: bool | None = None,
+ maximum_billing_tier: int | None = None,
+ maximum_bytes_billed: float | None = None,
+ create_disposition: str = "CREATE_IF_NEEDED",
+ query_params: list | None = None,
+ labels: dict | None = None,
+ schema_update_options: Iterable | None = None,
+ priority: str | None = None,
+ time_partitioning: dict | None = None,
+ api_resource_configs: dict | None = None,
+ cluster_fields: list[str] | None = None,
+ encryption_configuration: dict | None = None,
+ ):
+ """Helper method that prepare configuration for query."""
+ labels = labels or self.hook.labels
+ schema_update_options = list(schema_update_options or [])
+
+ priority = priority or self.hook.priority
+
+ if time_partitioning is None:
+ time_partitioning = {}
+
+ if not api_resource_configs:
+ api_resource_configs = self.hook.api_resource_configs
+ else:
+ _validate_value("api_resource_configs", api_resource_configs, dict)
+
+ configuration = deepcopy(api_resource_configs)
+
+ if "query" not in configuration:
+ configuration["query"] = {}
+ else:
+ _validate_value("api_resource_configs['query']",
configuration["query"], dict)
+
+ if sql is None and not configuration["query"].get("query", None):
+ raise TypeError("`BigQueryBaseCursor.run_query` missing 1 required
positional argument: `sql`")
+
+ # BigQuery also allows you to define how you want a table's schema to
change
+ # as a side effect of a query job
+ # for more details:
+ #
https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions
+
+ allowed_schema_update_options = ["ALLOW_FIELD_ADDITION",
"ALLOW_FIELD_RELAXATION"]
+
+ if not
set(allowed_schema_update_options).issuperset(set(schema_update_options)):
+ raise ValueError(
+ f"{schema_update_options} contains invalid schema update
options."
+ f" Please only use one or more of the following options:
{allowed_schema_update_options}"
+ )
+
+ if schema_update_options:
+ if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
+ raise ValueError(
+ "schema_update_options is only "
+ "allowed if write_disposition is "
+ "'WRITE_APPEND' or 'WRITE_TRUNCATE'."
+ )
+
+ if destination_dataset_table:
+ destination_project, destination_dataset, destination_table =
self.hook.split_tablename(
+ table_input=destination_dataset_table,
default_project_id=self.project_id
+ )
+
+ destination_dataset_table = { # type: ignore
+ "projectId": destination_project,
+ "datasetId": destination_dataset,
+ "tableId": destination_table,
+ }
+
+ if cluster_fields:
+ cluster_fields = {"fields": cluster_fields} # type: ignore
+
+ query_param_list: list[tuple[Any, str, str | bool | None | dict, type
| tuple[type]]] = [
+ (sql, "query", None, (str,)),
+ (priority, "priority", priority, (str,)),
+ (use_legacy_sql, "useLegacySql", self.use_legacy_sql, bool),
+ (query_params, "queryParameters", None, list),
+ (udf_config, "userDefinedFunctionResources", None, list),
+ (maximum_billing_tier, "maximumBillingTier", None, int),
+ (maximum_bytes_billed, "maximumBytesBilled", None, float),
+ (time_partitioning, "timePartitioning", {}, dict),
+ (schema_update_options, "schemaUpdateOptions", None, list),
+ (destination_dataset_table, "destinationTable", None, dict),
+ (cluster_fields, "clustering", None, dict),
+ ]
+
+ for param, param_name, param_default, param_type in query_param_list:
+ if param_name not in configuration["query"] and param in [None,
{}, ()]:
+ if param_name == "timePartitioning":
+ param_default =
_cleanse_time_partitioning(destination_dataset_table, time_partitioning)
+ param = param_default
+
+ if param in [None, {}, ()]:
+ continue
+
+ _api_resource_configs_duplication_check(param_name, param,
configuration["query"])
+
+ configuration["query"][param_name] = param
+
+ # check valid type of provided param,
+ # it last step because we can get param from 2 sources,
+ # and first of all need to find it
+
+ _validate_value(param_name, configuration["query"][param_name],
param_type)
+
+ if param_name == "schemaUpdateOptions" and param:
+ self.log.info("Adding experimental 'schemaUpdateOptions': %s",
schema_update_options)
+
+ if param_name == "destinationTable":
+ for key in ["projectId", "datasetId", "tableId"]:
+ if key not in configuration["query"]["destinationTable"]:
+ raise ValueError(
+ "Not correct 'destinationTable' in "
+ "api_resource_configs. 'destinationTable' "
+ "must be a dict with {'projectId':'', "
+ "'datasetId':'', 'tableId':''}"
+ )
+ else:
+ configuration["query"].update(
+ {
+ "allowLargeResults": allow_large_results,
+ "flattenResults": flatten_results,
+ "writeDisposition": write_disposition,
+ "createDisposition": create_disposition,
+ }
+ )
+
+ if (
+ "useLegacySql" in configuration["query"]
+ and configuration["query"]["useLegacySql"]
+ and "queryParameters" in configuration["query"]
+ ):
+ raise ValueError("Query parameters are not allowed when using
legacy SQL")
+
+ if labels:
+ _api_resource_configs_duplication_check("labels", labels,
configuration)
+ configuration["labels"] = labels
+
+ if encryption_configuration:
+ configuration["query"]["destinationEncryptionConfiguration"] =
encryption_configuration
+
+ return configuration
+
def _bind_parameters(operation: str, parameters: dict) -> str:
"""Helper method that binds parameters to a SQL query."""
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 238be83bad..47fc204647 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -1208,7 +1208,7 @@ class TestTableOperations(_BigQueryBaseTestClass):
@pytest.mark.db_test
class TestBigQueryCursor(_BigQueryBaseTestClass):
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_execute_with_parameters(self, mock_insert, _):
bq_cursor = self.hook.get_cursor()
@@ -1223,7 +1223,7 @@ class TestBigQueryCursor(_BigQueryBaseTestClass):
}
mock_insert.assert_called_once_with(configuration=conf,
project_id=PROJECT_ID, location=None)
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_execute_many(self, mock_insert, _):
bq_cursor = self.hook.get_cursor()
@@ -1275,10 +1275,10 @@ class TestBigQueryCursor(_BigQueryBaseTestClass):
("field_3", "STRING", None, None, None, None, False),
]
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
- def test_description(self, mock_insert, mock_get_service):
- mock_get_query_results =
mock_get_service.return_value.jobs.return_value.getQueryResults
+ def test_description(self, mock_insert, mock_build):
+ mock_get_query_results =
mock_build.return_value.jobs.return_value.getQueryResults
mock_execute = mock_get_query_results.return_value.execute
mock_execute.return_value = {
"schema": {
@@ -1292,10 +1292,10 @@ class TestBigQueryCursor(_BigQueryBaseTestClass):
bq_cursor.execute("SELECT CURRENT_TIMESTAMP() as ts")
assert bq_cursor.description == [("ts", "TIMESTAMP", None, None, None,
None, True)]
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
- def test_description_no_schema(self, mock_insert, mock_get_service):
- mock_get_query_results =
mock_get_service.return_value.jobs.return_value.getQueryResults
+ def test_description_no_schema(self, mock_insert, mock_build):
+ mock_get_query_results =
mock_build.return_value.jobs.return_value.getQueryResults
mock_execute = mock_get_query_results.return_value.execute
mock_execute.return_value = {}
@@ -1369,9 +1369,9 @@ class TestBigQueryCursor(_BigQueryBaseTestClass):
result = bq_cursor.next()
assert result is None
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
- def test_next(self, mock_get_service):
- mock_get_query_results =
mock_get_service.return_value.jobs.return_value.getQueryResults
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
+ def test_next(self, mock_build):
+ mock_get_query_results =
mock_build.return_value.jobs.return_value.getQueryResults
mock_execute = mock_get_query_results.return_value.execute
mock_execute.return_value = {
"rows": [
@@ -1402,10 +1402,10 @@ class TestBigQueryCursor(_BigQueryBaseTestClass):
)
mock_execute.assert_called_once_with(num_retries=bq_cursor.num_retries)
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor.flush_results")
- def test_next_no_rows(self, mock_flush_results, mock_get_service):
- mock_get_query_results =
mock_get_service.return_value.jobs.return_value.getQueryResults
+ def test_next_no_rows(self, mock_flush_results, mock_build):
+ mock_get_query_results =
mock_build.return_value.jobs.return_value.getQueryResults
mock_execute = mock_get_query_results.return_value.execute
mock_execute.return_value = {}
@@ -1421,10 +1421,10 @@ class TestBigQueryCursor(_BigQueryBaseTestClass):
mock_execute.assert_called_once_with(num_retries=bq_cursor.num_retries)
assert mock_flush_results.call_count == 1
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor.flush_results")
- def test_flush_cursor_in_execute(self, _, mock_insert, mock_get_service):
+ def test_flush_cursor_in_execute(self, _, mock_insert, mock_build):
bq_cursor = self.hook.get_cursor()
bq_cursor.execute("SELECT %(foo)s", {"foo": "bar"})
assert mock_insert.call_count == 1
@@ -1786,7 +1786,7 @@ class TestClusteringInRunJob(_BigQueryBaseTestClass):
class TestBigQueryHookLegacySql(_BigQueryBaseTestClass):
"""Ensure `use_legacy_sql` param in `BigQueryHook` propagates properly."""
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_hook_uses_legacy_sql_by_default(self, mock_insert, _):
self.hook.get_first("query")
@@ -1797,10 +1797,10 @@ class TestBigQueryHookLegacySql(_BigQueryBaseTestClass):
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=(CREDENTIALS, PROJECT_ID),
)
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_legacy_sql_override_propagates_properly(
- self, mock_insert, mock_get_service, mock_get_creds_and_proj_id
+ self, mock_insert, mock_build, mock_get_creds_and_proj_id
):
bq_hook = BigQueryHook(use_legacy_sql=False)
bq_hook.get_first("query")