kaxil closed pull request #4324: [AIRFLOW-3327] Add support for location in
BigQueryHook
URL: https://github.com/apache/incubator-airflow/pull/4324
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/airflow/contrib/hooks/bigquery_hook.py
b/airflow/contrib/hooks/bigquery_hook.py
index 5cab013b28..aee8125797 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -53,10 +53,12 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook,
LoggingMixin):
def __init__(self,
bigquery_conn_id='bigquery_default',
delegate_to=None,
- use_legacy_sql=True):
+ use_legacy_sql=True,
+ location=None):
super(BigQueryHook, self).__init__(
gcp_conn_id=bigquery_conn_id, delegate_to=delegate_to)
self.use_legacy_sql = use_legacy_sql
+ self.location = location
def get_conn(self):
"""
@@ -67,7 +69,9 @@ def get_conn(self):
return BigQueryConnection(
service=service,
project_id=project,
- use_legacy_sql=self.use_legacy_sql)
+ use_legacy_sql=self.use_legacy_sql,
+ location=self.location,
+ )
def get_service(self):
"""
@@ -201,7 +205,8 @@ def __init__(self,
service,
project_id,
use_legacy_sql=True,
- api_resource_configs=None):
+ api_resource_configs=None,
+ location=None):
self.service = service
self.project_id = project_id
@@ -211,6 +216,7 @@ def __init__(self,
self.api_resource_configs = api_resource_configs \
if api_resource_configs else {}
self.running_job_id = None
+ self.location = location
def create_empty_table(self,
project_id,
@@ -512,7 +518,8 @@ def run_query(self,
priority='INTERACTIVE',
time_partitioning=None,
api_resource_configs=None,
- cluster_fields=None):
+ cluster_fields=None,
+ location=None):
"""
Executes a BigQuery SQL query. Optionally persists results in a
BigQuery
table. See here:
@@ -580,11 +587,18 @@ def run_query(self,
by one or more columns. This is only available in combination with
time_partitioning. The order of columns given determines the sort
order.
:type cluster_fields: list of str
+ :param location: The geographic location of the job. Required except
for
+ US and EU. See details at
+
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
+ :type location: str
"""
if time_partitioning is None:
time_partitioning = {}
+ if location:
+ self.location = location
+
if not api_resource_configs:
api_resource_configs = self.api_resource_configs
else:
@@ -1089,9 +1103,15 @@ def run_with_configuration(self, configuration):
keep_polling_job = True
while keep_polling_job:
try:
- job = jobs.get(
- projectId=self.project_id,
- jobId=self.running_job_id).execute()
+ if self.location:
+ job = jobs.get(
+ projectId=self.project_id,
+ jobId=self.running_job_id,
+ location=self.location).execute()
+ else:
+ job = jobs.get(
+ projectId=self.project_id,
+ jobId=self.running_job_id).execute()
if job['status']['state'] == 'DONE':
keep_polling_job = False
# Check if job had errors.
@@ -1120,7 +1140,13 @@ def run_with_configuration(self, configuration):
def poll_job_complete(self, job_id):
jobs = self.service.jobs()
try:
- job = jobs.get(projectId=self.project_id, jobId=job_id).execute()
+ if self.location:
+ job = jobs.get(projectId=self.project_id,
+ jobId=job_id,
+ location=self.location).execute()
+ else:
+ job = jobs.get(projectId=self.project_id,
+ jobId=job_id).execute()
if job['status']['state'] == 'DONE':
return True
except HttpError as err:
@@ -1143,9 +1169,15 @@ def cancel_query(self):
not self.poll_job_complete(self.running_job_id)):
self.log.info('Attempting to cancel job : %s, %s', self.project_id,
self.running_job_id)
- jobs.cancel(
- projectId=self.project_id,
- jobId=self.running_job_id).execute()
+ if self.location:
+ jobs.cancel(
+ projectId=self.project_id,
+ jobId=self.running_job_id,
+ location=self.location).execute()
+ else:
+ jobs.cancel(
+ projectId=self.project_id,
+ jobId=self.running_job_id).execute()
else:
self.log.info('No running BigQuery jobs to cancel.')
return
@@ -1617,11 +1649,13 @@ class BigQueryCursor(BigQueryBaseCursor):
https://github.com/dropbox/PyHive/blob/master/pyhive/common.py
"""
- def __init__(self, service, project_id, use_legacy_sql=True):
+ def __init__(self, service, project_id, use_legacy_sql=True,
location=None):
super(BigQueryCursor, self).__init__(
service=service,
project_id=project_id,
- use_legacy_sql=use_legacy_sql)
+ use_legacy_sql=use_legacy_sql,
+ location=location,
+ )
self.buffersize = None
self.page_token = None
self.job_id = None
diff --git a/airflow/contrib/operators/bigquery_operator.py
b/airflow/contrib/operators/bigquery_operator.py
index 106bee8b69..f597db93a5 100644
--- a/airflow/contrib/operators/bigquery_operator.py
+++ b/airflow/contrib/operators/bigquery_operator.py
@@ -97,6 +97,10 @@ class BigQueryOperator(BaseOperator):
by one or more columns. This is only available in conjunction with
time_partitioning. The order of columns given determines the sort
order.
:type cluster_fields: list of str
+ :param location: The geographic location of the job. Required except for
+ US and EU. See details at
+
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
+ :type location: str
"""
template_fields = ('sql', 'destination_dataset_table', 'labels')
@@ -124,6 +128,7 @@ def __init__(self,
time_partitioning=None,
api_resource_configs=None,
cluster_fields=None,
+ location=None,
*args,
**kwargs):
super(BigQueryOperator, self).__init__(*args, **kwargs)
@@ -147,6 +152,7 @@ def __init__(self,
self.time_partitioning = time_partitioning
self.api_resource_configs = api_resource_configs
self.cluster_fields = cluster_fields
+ self.location = location
def execute(self, context):
if self.bq_cursor is None:
@@ -154,7 +160,9 @@ def execute(self, context):
hook = BigQueryHook(
bigquery_conn_id=self.bigquery_conn_id,
use_legacy_sql=self.use_legacy_sql,
- delegate_to=self.delegate_to)
+ delegate_to=self.delegate_to,
+ location=self.location,
+ )
conn = hook.get_conn()
self.bq_cursor = conn.cursor()
self.bq_cursor.run_query(
diff --git a/tests/contrib/hooks/test_bigquery_hook.py
b/tests/contrib/hooks/test_bigquery_hook.py
index 8c59116c85..244574e5e3 100644
--- a/tests/contrib/hooks/test_bigquery_hook.py
+++ b/tests/contrib/hooks/test_bigquery_hook.py
@@ -228,9 +228,7 @@ def test_invalid_schema_update_and_write_disposition(self):
)
self.assertIn("schema_update_options is only", str(context.exception))
- @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
- @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
- def test_cancel_queries(self, mocked_time, mocked_logging):
+ def test_cancel_queries(self):
project_id = 12345
running_job_id = 3
@@ -273,8 +271,7 @@ def test_api_resource_configs(self, run_with_config):
self.assertIs(args[0]['query']['useQueryCache'], bool_val)
self.assertIs(args[0]['query']['useLegacySql'], True)
- @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_api_resource_configs_duplication_warning(self, run_with_config):
+ def test_api_resource_configs_duplication_warning(self):
with self.assertRaises(ValueError):
cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id")
cursor.run_query('query',
@@ -295,8 +292,7 @@ def test_duplication_check(self):
self.assertIsNone(_api_resource_configs_duplication_check(
"key_one", key_one, {"key_one": True}))
- @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_insert_all_succeed(self, run_with_config):
+ def test_insert_all_succeed(self):
project_id = 'bq-project'
dataset_id = 'bq_dataset'
table_id = 'bq_table'
@@ -311,7 +307,7 @@ def test_insert_all_succeed(self, run_with_config):
}
mock_service = mock.Mock()
- method = (mock_service.tabledata.return_value.insertAll)
+ method = mock_service.tabledata.return_value.insertAll
method.return_value.execute.return_value = {
"kind": "bigquery#tableDataInsertAllResponse"
}
@@ -320,8 +316,7 @@ def test_insert_all_succeed(self, run_with_config):
method.assert_called_with(projectId=project_id, datasetId=dataset_id,
tableId=table_id, body=body)
- @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_insert_all_fail(self, run_with_config):
+ def test_insert_all_fail(self):
project_id = 'bq-project'
dataset_id = 'bq_dataset'
table_id = 'bq_table'
@@ -330,7 +325,7 @@ def test_insert_all_fail(self, run_with_config):
]
mock_service = mock.Mock()
- method = (mock_service.tabledata.return_value.insertAll)
+ method = mock_service.tabledata.return_value.insertAll
method.return_value.execute.return_value = {
"kind": "bigquery#tableDataInsertAllResponse",
"insertErrors": [
@@ -345,8 +340,7 @@ def test_insert_all_fail(self, run_with_config):
cursor.insert_all(project_id, dataset_id, table_id,
rows, fail_on_error=True)
- @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_create_view_fails_on_exception(self, run_with_config):
+ def test_create_view_fails_on_exception(self):
project_id = 'bq-project'
dataset_id = 'bq_dataset'
table_id = 'bq_table_view'
@@ -356,7 +350,7 @@ def test_create_view_fails_on_exception(self,
run_with_config):
}
mock_service = mock.Mock()
- method = (mock_service.tables.return_value.insert)
+ method = mock_service.tables.return_value.insert
method.return_value.execute.side_effect = HttpError(
resp={'status': '400'}, content=b'Query is required for views')
cursor = hook.BigQueryBaseCursor(mock_service, project_id)
@@ -364,8 +358,7 @@ def test_create_view_fails_on_exception(self,
run_with_config):
cursor.create_empty_table(project_id, dataset_id, table_id,
view=view)
- @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_create_view(self, run_with_config):
+ def test_create_view(self):
project_id = 'bq-project'
dataset_id = 'bq_dataset'
table_id = 'bq_table_view'
@@ -375,7 +368,7 @@ def test_create_view(self, run_with_config):
}
mock_service = mock.Mock()
- method = (mock_service.tables.return_value.insert)
+ method = mock_service.tables.return_value.insert
cursor = hook.BigQueryBaseCursor(mock_service, project_id)
cursor.create_empty_table(project_id, dataset_id, table_id,
view=view)
@@ -419,18 +412,14 @@ def run_with_config(config):
class TestDatasetsOperations(unittest.TestCase):
- @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_create_empty_dataset_no_dataset_id_err(self,
- run_with_configuration):
+ def test_create_empty_dataset_no_dataset_id_err(self):
with self.assertRaises(ValueError):
hook.BigQueryBaseCursor(
mock.Mock(), "test_create_empty_dataset").create_empty_dataset(
dataset_id="", project_id="")
- @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_create_empty_dataset_duplicates_call_err(self,
- run_with_configuration):
+ def test_create_empty_dataset_duplicates_call_err(self):
with self.assertRaises(ValueError):
hook.BigQueryBaseCursor(
mock.Mock(), "test_create_empty_dataset").create_empty_dataset(
@@ -504,10 +493,8 @@ def test_get_datasets_list(self):
class TestTimePartitioningInRunJob(unittest.TestCase):
- @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
- @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_run_load_default(self, mocked_rwc, mocked_time, mocked_logging):
+ def test_run_load_default(self, mocked_rwc):
project_id = 12345
def run_with_config(config):
@@ -531,10 +518,8 @@ def test_run_with_auto_detect(self, run_with_config):
args, kwargs = run_with_config.call_args
self.assertIs(args[0]['load']['autodetect'], True)
- @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
- @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_run_load_with_arg(self, mocked_rwc, mocked_time, mocked_logging):
+ def test_run_load_with_arg(self, mocked_rwc):
project_id = 12345
def run_with_config(config):
@@ -558,10 +543,8 @@ def run_with_config(config):
mocked_rwc.assert_called_once()
- @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
- @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_run_query_default(self, mocked_rwc, mocked_time, mocked_logging):
+ def test_run_query_default(self, mocked_rwc):
project_id = 12345
def run_with_config(config):
@@ -573,10 +556,8 @@ def run_with_config(config):
mocked_rwc.assert_called_once()
- @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
- @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_run_query_with_arg(self, mocked_rwc, mocked_time, mocked_logging):
+ def test_run_query_with_arg(self, mocked_rwc):
project_id = 12345
def run_with_config(config):
@@ -623,10 +604,8 @@ def test_extra_time_partitioning_options(self):
class TestClusteringInRunJob(unittest.TestCase):
- @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
- @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_run_load_default(self, mocked_rwc, mocked_time, mocked_logging):
+ def test_run_load_default(self, mocked_rwc):
project_id = 12345
def run_with_config(config):
@@ -642,10 +621,8 @@ def run_with_config(config):
mocked_rwc.assert_called_once()
- @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
- @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_run_load_with_arg(self, mocked_rwc, mocked_time, mocked_logging):
+ def test_run_load_with_arg(self, mocked_rwc):
project_id = 12345
def run_with_config(config):
@@ -668,10 +645,8 @@ def run_with_config(config):
mocked_rwc.assert_called_once()
- @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
- @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_run_query_default(self, mocked_rwc, mocked_time, mocked_logging):
+ def test_run_query_default(self, mocked_rwc):
project_id = 12345
def run_with_config(config):
@@ -683,10 +658,8 @@ def run_with_config(config):
mocked_rwc.assert_called_once()
- @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin")
- @mock.patch("airflow.contrib.hooks.bigquery_hook.time")
@mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
- def test_run_query_with_arg(self, mocked_rwc, mocked_time, mocked_logging):
+ def test_run_query_with_arg(self, mocked_rwc):
project_id = 12345
def run_with_config(config):
@@ -729,5 +702,21 @@ def test_legacy_sql_override_propagates_properly(self,
run_with_config):
self.assertIs(args[0]['query']['useLegacySql'], False)
+class TestBigQueryHookLocation(unittest.TestCase):
+ @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration')
+ def test_location_propagates_properly(self, run_with_config):
+ with mock.patch.object(hook.BigQueryHook, 'get_service'):
+ bq_hook = hook.BigQueryHook(location=None)
+ self.assertIsNone(bq_hook.location)
+
+ bq_cursor = hook.BigQueryBaseCursor(mock.Mock(),
+ 'test-project',
+ location=None)
+ self.assertIsNone(bq_cursor.location)
+ bq_cursor.run_query(sql='select 1', location='US')
+ run_with_config.assert_called_once()
+ self.assertEquals(bq_cursor.location, 'US')
+
+
if __name__ == '__main__':
unittest.main()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services