This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7d2c2ee879 add description method in BigQueryCursor class (#25366)
7d2c2ee879 is described below
commit 7d2c2ee879656faf47829d1ad89fc4441e19a66e
Author: sophie-ly <[email protected]>
AuthorDate: Thu Aug 4 16:48:35 2022 +0200
add description method in BigQueryCursor class (#25366)
]
---
airflow/providers/google/cloud/hooks/bigquery.py | 62 +++++++++++++++++-----
.../providers/google/cloud/hooks/test_bigquery.py | 29 ++++++++--
2 files changed, 74 insertions(+), 17 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index 1e9848d767..1a29863467 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -2663,11 +2663,16 @@ class BigQueryCursor(BigQueryBaseCursor):
self.job_id = None # type: Optional[str]
self.buffer = [] # type: list
self.all_pages_loaded = False # type: bool
+ self._description = [] # type: List
@property
- def description(self) -> None:
- """The schema description method is not currently implemented"""
- raise NotImplementedError
+ def description(self) -> List:
+ """Return the cursor description"""
+ return self._description
+
+ @description.setter
+ def description(self, value):
+ self._description = value
def close(self) -> None:
"""By default, do nothing"""
@@ -2688,6 +2693,10 @@ class BigQueryCursor(BigQueryBaseCursor):
self.flush_results()
self.job_id = self.hook.run_query(sql)
+ query_results = self._get_query_result()
+ description = _format_schema_for_description(query_results["schema"])
+ self.description = description
+
def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
Execute a BigQuery query multiple times with different parameters.
@@ -2723,17 +2732,7 @@ class BigQueryCursor(BigQueryBaseCursor):
if self.all_pages_loaded:
return None
- query_results = (
- self.service.jobs()
- .getQueryResults(
- projectId=self.project_id,
- jobId=self.job_id,
- location=self.location,
- pageToken=self.page_token,
- )
- .execute(num_retries=self.num_retries)
- )
-
+ query_results = self._get_query_result()
if 'rows' in query_results and query_results['rows']:
self.page_token = query_results.get('pageToken')
fields = query_results['schema']['fields']
@@ -2805,6 +2804,21 @@ class BigQueryCursor(BigQueryBaseCursor):
def setoutputsize(self, size: Any, column: Any = None) -> None:
"""Does nothing by default"""
+ def _get_query_result(self) -> Dict:
+ """Get job query results like data, schema, job type..."""
+ query_results = (
+ self.service.jobs()
+ .getQueryResults(
+ projectId=self.project_id,
+ jobId=self.job_id,
+ location=self.location,
+ pageToken=self.page_token,
+ )
+ .execute(num_retries=self.num_retries)
+ )
+
+ return query_results
+
def _bind_parameters(operation: str, parameters: dict) -> str:
"""Helper method that binds parameters to a SQL query"""
@@ -2973,3 +2987,23 @@ def _validate_src_fmt_configs(
raise ValueError(f"{k} is not a valid src_fmt_configs for type
{source_format}.")
return src_fmt_configs
+
+
+def _format_schema_for_description(schema: Dict) -> List:
+ """
+ Reformat the schema to match cursor description standard which is a tuple
+ of 7 elemenbts (name, type, display_size, internal_size, precision, scale,
null_ok)
+ """
+ description = []
+ for field in schema["fields"]:
+ field_description = (
+ field["name"],
+ field["type"],
+ None,
+ None,
+ None,
+ None,
+ field["mode"] == "NULLABLE",
+ )
+ description.append(field_description)
+ return description
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 2b378b1ec6..b2afe99087 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -34,6 +34,7 @@ from airflow.providers.google.cloud.hooks.bigquery import (
BigQueryHook,
_api_resource_configs_duplication_check,
_cleanse_time_partitioning,
+ _format_schema_for_description,
_validate_src_fmt_configs,
_validate_value,
split_tablename,
@@ -1239,11 +1240,33 @@ class TestBigQueryCursor(_BigQueryBaseTestClass):
]
)
+ def test_format_schema_for_description(self):
+ test_query_result = {
+ "schema": {
+ "fields": [
+ {"name": "field_1", "type": "STRING", "mode": "NULLABLE"},
+ ]
+ },
+ }
+ description =
_format_schema_for_description(test_query_result["schema"])
+ assert description == [('field_1', 'STRING', None, None, None, None,
True)]
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
- def test_description(self, mock_get_service):
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
+ def test_description(self, mock_insert, mock_get_service):
+ mock_get_query_results =
mock_get_service.return_value.jobs.return_value.getQueryResults
+ mock_execute = mock_get_query_results.return_value.execute
+ mock_execute.return_value = {
+ "schema": {
+ "fields": [
+ {"name": "ts", "type": "TIMESTAMP", "mode": "NULLABLE"},
+ ]
+ },
+ }
+
bq_cursor = self.hook.get_cursor()
- with pytest.raises(NotImplementedError):
- bq_cursor.description
+ bq_cursor.execute("SELECT CURRENT_TIMESTAMP() as ts")
+ assert bq_cursor.description == [("ts", "TIMESTAMP", None, None, None,
None, True)]
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_close(self, mock_get_service):