This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 50bd126  [BEAM-9804] Allow user configuration of BigQuery temporary 
dataset
     new 7038af2  Merge pull request #12960 from [BEAM-9804] Allow user 
configuration of BigQuery temporary dataset
50bd126 is described below

commit 50bd1260475191321af56182787435a9c617066f
Author: Frank Zhao <[email protected]>
AuthorDate: Mon Sep 28 00:00:36 2020 +1000

    [BEAM-9804] Allow user configuration of BigQuery temporary dataset
    
    Allow ReadFromBigQuery to use a user pre-configured dataset for the 
temporary dataset.
    Using a DatasetReference will also allow for cross project temporary 
dataset configuration.
---
 sdks/python/apache_beam/io/gcp/bigquery.py       | 18 ++++++++--
 sdks/python/apache_beam/io/gcp/bigquery_test.py  | 45 ++++++++++++++++++++++++
 sdks/python/apache_beam/io/gcp/bigquery_tools.py | 23 +++++++-----
 3 files changed, 75 insertions(+), 11 deletions(-)

diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py 
b/sdks/python/apache_beam/io/gcp/bigquery.py
index 694114e..e0866cf 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery.py
@@ -474,7 +474,8 @@ class _BigQuerySource(dataflow_io.NativeSource):
       coder=None,
       use_standard_sql=False,
       flatten_results=True,
-      kms_key=None):
+      kms_key=None,
+      temp_dataset=None):
     """Initialize a :class:`BigQuerySource`.
 
     Args:
@@ -513,6 +514,10 @@ class _BigQuerySource(dataflow_io.NativeSource):
         query results. The default value is :data:`True`.
       kms_key (str): Optional Cloud KMS key name for use when creating new
         tables.
+      temp_dataset (``google.cloud.bigquery.dataset.DatasetReference``):
+        The dataset in which to create temporary tables when performing file
+        loads. By default, a new dataset is created in the execution project 
for
+        temporary tables.
 
     Raises:
       ValueError: if any of the following is true:
@@ -552,6 +557,7 @@ class _BigQuerySource(dataflow_io.NativeSource):
     self.flatten_results = flatten_results
     self.coder = coder or bigquery_tools.RowAsDictJsonCoder()
     self.kms_key = kms_key
+    self.temp_dataset = temp_dataset
 
   def display_data(self):
     if self.query is not None:
@@ -681,7 +687,8 @@ class _CustomBigQuerySource(BoundedSource):
       use_json_exports=False,
       job_name=None,
       step_name=None,
-      unique_id=None):
+      unique_id=None,
+      temp_dataset=None):
     if table is not None and query is not None:
       raise ValueError(
           'Both a BigQuery table and a query were specified.'
@@ -712,6 +719,7 @@ class _CustomBigQuerySource(BoundedSource):
     self.bq_io_metadata = None  # Populate in setup, as it may make an RPC
     self.bigquery_job_labels = bigquery_job_labels or {}
     self.use_json_exports = use_json_exports
+    self.temp_dataset = temp_dataset
     self._job_name = job_name or 'AUTOMATIC_JOB_NAME'
     self._step_name = step_name
     self._source_uuid = unique_id
@@ -781,6 +789,8 @@ class _CustomBigQuerySource(BoundedSource):
     project = self.options.view_as(GoogleCloudOptions).project
     if isinstance(project, vp.ValueProvider):
       project = project.get()
+    if self.temp_dataset:
+      return self.temp_dataset.projectId
     if not project:
       project = self.project
     return project
@@ -798,7 +808,9 @@ class _CustomBigQuerySource(BoundedSource):
 
   def split(self, desired_bundle_size, start_position=None, 
stop_position=None):
     if self.split_result is None:
-      bq = bigquery_tools.BigQueryWrapper()
+      bq = bigquery_tools.BigQueryWrapper(
+          temp_dataset_id=(
+              self.temp_dataset.datasetId if self.temp_dataset else None))
 
       if self.query is not None:
         self._setup_temporary_dataset(bq)
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py 
b/sdks/python/apache_beam/io/gcp/bigquery_test.py
index 114f200..da3f34f 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py
@@ -51,6 +51,7 @@ from apache_beam.io.gcp.bigquery import _StreamToBigQuery
 from apache_beam.io.gcp.bigquery_file_loads_test import _ELEMENTS
 from apache_beam.io.gcp.bigquery_read_internal import 
bigquery_export_destination_uri
 from apache_beam.io.gcp.bigquery_tools import JSON_COMPLIANCE_ERROR
+from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
 from apache_beam.io.gcp.bigquery_tools import RetryStrategy
 from apache_beam.io.gcp.internal.clients import bigquery
 from apache_beam.io.gcp.pubsub import ReadFromPubSub
@@ -436,6 +437,50 @@ class TestReadFromBigQuery(unittest.TestCase):
             'empty, using temp_location instead'
         ])
 
+  @mock.patch.object(BigQueryWrapper, '_delete_dataset')
+  @mock.patch('apache_beam.io.gcp.internal.clients.bigquery.BigqueryV2')
+  def test_temp_dataset_location_is_configurable(self, api, delete_dataset):
+    temp_dataset = bigquery.DatasetReference(
+        projectId='temp-project', datasetId='bq_dataset')
+    bq = BigQueryWrapper(client=api, temp_dataset_id=temp_dataset.datasetId)
+    gcs_location = 'gs://gcs_location'
+
+    # bq.get_or_create_dataset.return_value = temp_dataset
+    c = beam.io.gcp.bigquery._CustomBigQuerySource(
+        query='select * from test_table',
+        gcs_location=gcs_location,
+        validate=True,
+        pipeline_options=beam.options.pipeline_options.PipelineOptions(),
+        job_name='job_name',
+        step_name='step_name',
+        project='execution_project',
+        **{'temp_dataset': temp_dataset})
+
+    api.datasets.Get.side_effect = HttpError({
+        'status_code': 404, 'status': 404
+    },
+                                             '',
+                                             '')
+
+    c._setup_temporary_dataset(bq)
+    api.datasets.Insert.assert_called_with(
+        bigquery.BigqueryDatasetsInsertRequest(
+            dataset=bigquery.Dataset(datasetReference=temp_dataset),
+            projectId=temp_dataset.projectId))
+
+    api.datasets.Get.return_value = temp_dataset
+    api.datasets.Get.side_effect = None
+    bq.clean_up_temporary_dataset(temp_dataset.projectId)
+    delete_dataset.assert_called_with(
+        temp_dataset.projectId, temp_dataset.datasetId, True)
+
+    self.assertEqual(
+        bq._get_temp_table(temp_dataset.projectId),
+        bigquery.TableReference(
+            projectId=temp_dataset.projectId,
+            datasetId=temp_dataset.datasetId,
+            tableId=BigQueryWrapper.TEMP_TABLE + bq._temporary_table_suffix))
+
 
 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
 class TestBigQuerySink(unittest.TestCase):
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py 
b/sdks/python/apache_beam/io/gcp/bigquery_tools.py
index 0495453..76d60ec 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py
@@ -266,7 +266,7 @@ class BigQueryWrapper(object):
 
   HISTOGRAM_METRIC_LOGGER = MetricLogger()
 
-  def __init__(self, client=None):
+  def __init__(self, client=None, temp_dataset_id=None):
     self.client = client or bigquery.BigqueryV2(
         http=get_new_http(),
         credentials=auth.get_service_credentials(),
@@ -281,6 +281,7 @@ class BigQueryWrapper(object):
         'latency_histogram_ms',
         LinearBucket(0, 20, 3000),
         BigQueryWrapper.HISTOGRAM_METRIC_LOGGER)
+    self.temp_dataset_id = temp_dataset_id or self._get_temp_dataset()
 
   @property
   def unique_row_id(self):
@@ -300,9 +301,12 @@ class BigQueryWrapper(object):
   def _get_temp_table(self, project_id):
     return parse_table_reference(
         table=BigQueryWrapper.TEMP_TABLE + self._temporary_table_suffix,
-        dataset=BigQueryWrapper.TEMP_DATASET + self._temporary_table_suffix,
+        dataset=self.temp_dataset_id,
         project=project_id)
 
+  def _get_temp_dataset(self):
+    return BigQueryWrapper.TEMP_DATASET + self._temporary_table_suffix
+
   @retry.with_exponential_backoff(
       num_retries=MAX_RETRIES,
       retry_filter=retry.retry_on_server_errors_and_timeout_filter)
@@ -705,26 +709,29 @@ class BigQueryWrapper(object):
       num_retries=MAX_RETRIES,
       retry_filter=retry.retry_on_server_errors_and_timeout_filter)
   def create_temporary_dataset(self, project_id, location):
-    dataset_id = BigQueryWrapper.TEMP_DATASET + self._temporary_table_suffix
+    is_user_configured_dataset = \
+      not self.temp_dataset_id.startswith(self.TEMP_DATASET)
     # Check if dataset exists to make sure that the temporary id is unique
     try:
       self.client.datasets.Get(
           bigquery.BigqueryDatasetsGetRequest(
-              projectId=project_id, datasetId=dataset_id))
-      if project_id is not None:
+              projectId=project_id, datasetId=self.temp_dataset_id))
+      if project_id is not None and not is_user_configured_dataset:
         # Unittests don't pass projectIds so they can be run without error
+        # User configured datasets are allowed to pre-exist.
         raise RuntimeError(
             'Dataset %s:%s already exists so cannot be used as temporary.' %
-            (project_id, dataset_id))
+            (project_id, self.temp_dataset_id))
     except HttpError as exn:
       if exn.status_code == 404:
         _LOGGER.warning(
             'Dataset %s:%s does not exist so we will create it as temporary '
             'with location=%s',
             project_id,
-            dataset_id,
+            self.temp_dataset_id,
             location)
-        self.get_or_create_dataset(project_id, dataset_id, location=location)
+        self.get_or_create_dataset(
+            project_id, self.temp_dataset_id, location=location)
       else:
         raise
 

Reply via email to