This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d2c2ee879 add description method in BigQueryCursor class (#25366)
7d2c2ee879 is described below

commit 7d2c2ee879656faf47829d1ad89fc4441e19a66e
Author: sophie-ly <[email protected]>
AuthorDate: Thu Aug 4 16:48:35 2022 +0200

    add description method in BigQueryCursor class (#25366)
    
    ]
---
 airflow/providers/google/cloud/hooks/bigquery.py   | 62 +++++++++++++++++-----
 .../providers/google/cloud/hooks/test_bigquery.py  | 29 ++++++++--
 2 files changed, 74 insertions(+), 17 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/bigquery.py 
b/airflow/providers/google/cloud/hooks/bigquery.py
index 1e9848d767..1a29863467 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -2663,11 +2663,16 @@ class BigQueryCursor(BigQueryBaseCursor):
         self.job_id = None  # type: Optional[str]
         self.buffer = []  # type: list
         self.all_pages_loaded = False  # type: bool
+        self._description = []  # type: List
 
     @property
-    def description(self) -> None:
-        """The schema description method is not currently implemented"""
-        raise NotImplementedError
+    def description(self) -> List:
+        """Return the cursor description"""
+        return self._description
+
+    @description.setter
+    def description(self, value):
+        self._description = value
 
     def close(self) -> None:
         """By default, do nothing"""
@@ -2688,6 +2693,10 @@ class BigQueryCursor(BigQueryBaseCursor):
         self.flush_results()
         self.job_id = self.hook.run_query(sql)
 
+        query_results = self._get_query_result()
+        description = _format_schema_for_description(query_results["schema"])
+        self.description = description
+
     def executemany(self, operation: str, seq_of_parameters: list) -> None:
         """
         Execute a BigQuery query multiple times with different parameters.
@@ -2723,17 +2732,7 @@ class BigQueryCursor(BigQueryBaseCursor):
             if self.all_pages_loaded:
                 return None
 
-            query_results = (
-                self.service.jobs()
-                .getQueryResults(
-                    projectId=self.project_id,
-                    jobId=self.job_id,
-                    location=self.location,
-                    pageToken=self.page_token,
-                )
-                .execute(num_retries=self.num_retries)
-            )
-
+            query_results = self._get_query_result()
             if 'rows' in query_results and query_results['rows']:
                 self.page_token = query_results.get('pageToken')
                 fields = query_results['schema']['fields']
@@ -2805,6 +2804,21 @@ class BigQueryCursor(BigQueryBaseCursor):
     def setoutputsize(self, size: Any, column: Any = None) -> None:
         """Does nothing by default"""
 
+    def _get_query_result(self) -> Dict:
+        """Get job query results like data, schema, job type..."""
+        query_results = (
+            self.service.jobs()
+            .getQueryResults(
+                projectId=self.project_id,
+                jobId=self.job_id,
+                location=self.location,
+                pageToken=self.page_token,
+            )
+            .execute(num_retries=self.num_retries)
+        )
+
+        return query_results
+
 
 def _bind_parameters(operation: str, parameters: dict) -> str:
     """Helper method that binds parameters to a SQL query"""
@@ -2973,3 +2987,23 @@ def _validate_src_fmt_configs(
             raise ValueError(f"{k} is not a valid src_fmt_configs for type 
{source_format}.")
 
     return src_fmt_configs
+
+
+def _format_schema_for_description(schema: Dict) -> List:
+    """
+    Reformat the schema to match cursor description standard which is a tuple
+    of 7 elemenbts (name, type, display_size, internal_size, precision, scale, 
null_ok)
+    """
+    description = []
+    for field in schema["fields"]:
+        field_description = (
+            field["name"],
+            field["type"],
+            None,
+            None,
+            None,
+            None,
+            field["mode"] == "NULLABLE",
+        )
+        description.append(field_description)
+    return description
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py 
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 2b378b1ec6..b2afe99087 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -34,6 +34,7 @@ from airflow.providers.google.cloud.hooks.bigquery import (
     BigQueryHook,
     _api_resource_configs_duplication_check,
     _cleanse_time_partitioning,
+    _format_schema_for_description,
     _validate_src_fmt_configs,
     _validate_value,
     split_tablename,
@@ -1239,11 +1240,33 @@ class TestBigQueryCursor(_BigQueryBaseTestClass):
             ]
         )
 
+    def test_format_schema_for_description(self):
+        test_query_result = {
+            "schema": {
+                "fields": [
+                    {"name": "field_1", "type": "STRING", "mode": "NULLABLE"},
+                ]
+            },
+        }
+        description = 
_format_schema_for_description(test_query_result["schema"])
+        assert description == [('field_1', 'STRING', None, None, None, None, 
True)]
+
     
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
-    def test_description(self, mock_get_service):
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
+    def test_description(self, mock_insert, mock_get_service):
+        mock_get_query_results = 
mock_get_service.return_value.jobs.return_value.getQueryResults
+        mock_execute = mock_get_query_results.return_value.execute
+        mock_execute.return_value = {
+            "schema": {
+                "fields": [
+                    {"name": "ts", "type": "TIMESTAMP", "mode": "NULLABLE"},
+                ]
+            },
+        }
+
         bq_cursor = self.hook.get_cursor()
-        with pytest.raises(NotImplementedError):
-            bq_cursor.description
+        bq_cursor.execute("SELECT CURRENT_TIMESTAMP() as ts")
+        assert bq_cursor.description == [("ts", "TIMESTAMP", None, None, None, 
None, True)]
 
     
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
     def test_close(self, mock_get_service):

Reply via email to