Repository: incubator-airflow Updated Branches: refs/heads/master b75367bb5 -> 804710fda
[AIRFLOW-1688] Support load.time_partitioning in bigquery_hook Closes #2820 from albertocalderari/master Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/804710fd Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/804710fd Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/804710fd Branch: refs/heads/master Commit: 804710fda54d0eb8dfef0385e518b04a35c8fed4 Parents: b75367b Author: alberto.calderari <[email protected]> Authored: Thu Jan 11 09:24:21 2018 -0800 Committer: Chris Riccomini <[email protected]> Committed: Thu Jan 11 09:24:31 2018 -0800 ---------------------------------------------------------------------- airflow/contrib/hooks/bigquery_hook.py | 70 ++++++++++++----- airflow/contrib/operators/gcs_to_bq.py | 10 ++- tests/contrib/hooks/test_bigquery_hook.py | 104 +++++++++++++++++++++++-- 3 files changed, 158 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/804710fd/airflow/contrib/hooks/bigquery_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index d64c2a1..fe51d50 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -22,6 +22,7 @@ from builtins import range from past.builtins import basestring +from airflow import AirflowException from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook from airflow.hooks.dbapi_hook import DbApiHook from airflow.utils.log.logging_mixin import LoggingMixin @@ -450,7 +451,8 @@ class BigQueryBaseCursor(LoggingMixin): allow_quoted_newlines=False, allow_jagged_rows=False, schema_update_options=(), - src_fmt_configs={}): + src_fmt_configs={}, + time_partitioning={}): """ Executes a BigQuery load command to load data from Google Cloud Storage to BigQuery. See here: @@ -460,9 +462,11 @@ class BigQueryBaseCursor(LoggingMixin): For more details about these parameters. :param destination_project_dataset_table: - The dotted (<project>.|<project>:)<dataset>.<table> BigQuery table to load - data into. If <project> is not included, project will be the project defined - in the connection json. + The dotted (<project>.|<project>:)<dataset>.<table>($<partition>) BigQuery + table to load data into. If <project> is not included, project will be the + project defined in the connection json. If a partition is specified the + operator will automatically append the data, create a new partition or create + a new DAY partitioned table. :type destination_project_dataset_table: string :param schema_fields: The schema field list as defined here: https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.load @@ -484,20 +488,28 @@ class BigQueryBaseCursor(LoggingMixin): :param max_bad_records: The maximum number of bad records that BigQuery can ignore when running the job. :type max_bad_records: int - :param quote_character: The value that is used to quote data sections in a CSV file. + :param quote_character: The value that is used to quote data sections in a CSV + file. :type quote_character: string - :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not (false). + :param allow_quoted_newlines: Whether to allow quoted newlines (true) or not + (false). :type allow_quoted_newlines: boolean :param allow_jagged_rows: Accept rows that are missing trailing optional columns. - The missing values are treated as nulls. If false, records with missing trailing columns - are treated as bad records, and if there are too many bad records, an invalid error is - returned in the job result. Only applicable when soure_format is CSV. + The missing values are treated as nulls. If false, records with missing + trailing columns are treated as bad records, and if there are too many bad + records, an invalid error is returned in the job result. Only applicable when + soure_format is CSV. :type allow_jagged_rows: bool :param schema_update_options: Allows the schema of the desitination table to be updated as a side effect of the load job. :type schema_update_options: tuple :param src_fmt_configs: configure optional fields specific to the source format :type src_fmt_configs: dict + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and + expiration as per API specifications. Note that 'field' is not available in + concurrency with dataset.table$partition. + :type time_partitioning: dict """ # bigquery only allows certain source formats @@ -518,7 +530,7 @@ class BigQueryBaseCursor(LoggingMixin): # bigquery also allows you to define how you want a table's schema to change # as a side effect of a load # for more details: - # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schemaUpdateOptions + # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.load.schemaUpdateOptions allowed_schema_update_options = [ 'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION" ] @@ -547,6 +559,23 @@ class BigQueryBaseCursor(LoggingMixin): 'writeDisposition': write_disposition, } } + + # if it is a partitioned table ($ is in the table name) add partition load option + if '$' in destination_project_dataset_table: + if time_partitioning.get('field'): + raise AirflowException( + "Cannot specify field partition and partition name " + "(dataset.table$partition) at the same time" + ) + configuration['load']['timePartitioning'] = dict(type='DAY') + + # can specify custom time partitioning options based on a field, or adding + # expiration + if time_partitioning: + if not configuration.get('load', {}).get('timePartitioning'): + configuration['load']['timePartitioning'] = {} + configuration['load']['timePartitioning'].update(time_partitioning) + if schema_fields: configuration['load']['schema'] = {'fields': schema_fields} @@ -777,7 +806,7 @@ class BigQueryBaseCursor(LoggingMixin): default_project_id=self.project_id) try: - tables_resource = self.service.tables() \ + self.service.tables() \ .delete(projectId=deletion_project, datasetId=deletion_dataset, tableId=deletion_table) \ @@ -1011,13 +1040,14 @@ class BigQueryCursor(BigQueryBaseCursor): def fetchmany(self, size=None): """ - Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a - list of tuples). An empty sequence is returned when no more rows are available. - The number of rows to fetch per call is specified by the parameter. If it is not given, the - cursor's arraysize determines the number of rows to be fetched. The method should try to - fetch as many rows as indicated by the size parameter. If this is not possible due to the - specified number of rows not being available, fewer rows may be returned. - An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to + Fetch the next set of rows of a query result, returning a sequence of sequences + (e.g. a list of tuples). An empty sequence is returned when no more rows are + available. The number of rows to fetch per call is specified by the parameter. + If it is not given, the cursor's arraysize determines the number of rows to be + fetched. The method should try to fetch as many rows as indicated by the size + parameter. If this is not possible due to the specified number of rows not being + available, fewer rows may be returned. An :py:class:`~pyhive.exc.Error` + (or subclass) exception is raised if the previous call to :py:meth:`execute` did not produce any result set or no call was issued yet. """ if size is None: @@ -1033,8 +1063,8 @@ class BigQueryCursor(BigQueryBaseCursor): def fetchall(self): """ - Fetch all (remaining) rows of a query result, returning them as a sequence of sequences - (e.g. a list of tuples). + Fetch all (remaining) rows of a query result, returning them as a sequence of + sequences (e.g. a list of tuples). """ result = [] while True: http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/804710fd/airflow/contrib/operators/gcs_to_bq.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py index 730a3bc..75302b6 100644 --- a/airflow/contrib/operators/gcs_to_bq.py +++ b/airflow/contrib/operators/gcs_to_bq.py @@ -52,6 +52,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): delegate_to=None, schema_update_options=(), src_fmt_configs={}, + time_partitioning={}, *args, **kwargs): """ @@ -119,6 +120,11 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): :type schema_update_options: list :param src_fmt_configs: configure optional fields specific to the source format :type src_fmt_configs: dict + :param time_partitioning: configure optional time partitioning fields i.e. + partition by field, type and expiration as per API specifications. + Note that 'field' is not available in concurrency with + dataset.table$partition. + :type time_partitioning: dict """ super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, **kwargs) @@ -147,6 +153,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): self.schema_update_options = schema_update_options self.src_fmt_configs = src_fmt_configs + self.time_partitioning = time_partitioning def execute(self, context): bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, @@ -181,7 +188,8 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): allow_quoted_newlines=self.allow_quoted_newlines, allow_jagged_rows=self.allow_jagged_rows, schema_update_options=self.schema_update_options, - src_fmt_configs=self.src_fmt_configs) + src_fmt_configs=self.src_fmt_configs, + time_partitioning=self.time_partitioning) if self.max_id_key: cursor.execute('SELECT MAX({}) FROM {}'.format( http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/804710fd/tests/contrib/hooks/test_bigquery_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py index 0365bba..86268c4 100644 --- a/tests/contrib/hooks/test_bigquery_hook.py +++ b/tests/contrib/hooks/test_bigquery_hook.py @@ -16,6 +16,7 @@ import unittest import mock +from airflow import AirflowException from airflow.contrib.hooks import bigquery_hook as hook from oauth2client.contrib.gce import HttpAccessTokenRefreshError @@ -173,6 +174,7 @@ def mock_job_cancel(projectId, jobId): mock_canceled_jobs.append(jobId) return mock.Mock() + class TestBigQueryBaseCursor(unittest.TestCase): def test_invalid_schema_update_options(self): with self.assertRaises(Exception) as context: @@ -194,26 +196,118 @@ class TestBigQueryBaseCursor(unittest.TestCase): write_disposition='WRITE_EMPTY' ) self.assertIn("schema_update_options is only", str(context.exception)) - + @mock.patch("airflow.contrib.hooks.bigquery_hook.LoggingMixin") @mock.patch("airflow.contrib.hooks.bigquery_hook.time") def test_cancel_queries(self, mocked_logging, mocked_time): project_id = 12345 running_job_id = 3 - + mock_jobs = mock.Mock() mock_jobs.cancel = mock.Mock(side_effect=mock_job_cancel) mock_service = mock.Mock() mock_service.jobs = mock.Mock(return_value=mock_jobs) - + bq_hook = hook.BigQueryBaseCursor(mock_service, project_id) bq_hook.running_job_id = running_job_id bq_hook.poll_job_complete = mock.Mock(side_effect=mock_poll_job_complete) - + bq_hook.cancel_query() - + mock_jobs.cancel.assert_called_with(projectId=project_id, jobId=running_job_id) + +class TestTimePartitioningInRunJob(unittest.TestCase): + + class BigQueryBaseCursorTest(hook.BigQueryBaseCursor): + """Use this class to verify the load configuration""" + def run_with_configuration(self, configuration): + return configuration + + class Serv(object): + """mocks the behaviour of a succezsfull Job""" + + class Job(object): + """mocks the behaviour of a succezsfull Job""" + def __getitem__(self, item=None): + return self + + def get(self, projectId, jobId=None): + return self.__getitem__(projectId) + + def insert(self, projectId, body=None): + return self.get(projectId, body) + + def execute(self): + return { + 'status': {'state': 'DONE'}, + 'jobReference': {'jobId': 0} + } + + def __int__(self, job='mock_load'): + self.job = job + + def jobs(self): + return self.Job() + + def test_the_job_execution_wont_break(self): + s = self.Serv() + bqc = hook.BigQueryBaseCursor(s, 'str') + job = bqc.run_load( + destination_project_dataset_table='test.teast', + schema_fields=[], + source_uris=[], + time_partitioning={'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000} + ) + + self.assertEquals(job, 0) + + def test_dollar_makes_partition(self): + s = self.Serv() + bqc = self.BigQueryBaseCursorTest(s, 'str') + cnfg = bqc.run_load( + destination_project_dataset_table='test.teast$20170101', + schema_fields=[], + source_uris=[], + src_fmt_configs={} + ) + expect = { + 'type': 'DAY' + } + self.assertEqual(cnfg['load'].get('timePartitioning'), expect) + + def test_extra_time_partitioning_options(self): + s = self.Serv() + bqc = self.BigQueryBaseCursorTest(s, 'str') + cnfg = bqc.run_load( + destination_project_dataset_table='test.teast', + schema_fields=[], + source_uris=[], + time_partitioning={'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000} + ) + + expect = { + 'type': 'DAY', + 'field': 'test_field', + 'expirationMs': 1000 + } + + self.assertEqual(cnfg['load'].get('timePartitioning'), expect) + + def test_cant_add_dollar_and_field_name(self): + s = self.Serv() + bqc = self.BigQueryBaseCursorTest(s, 'str') + + with self.assertRaises(AirflowException): + tp_dict = {'type': 'DAY', 'field': 'test_field', 'expirationMs': 1000} + bqc.run_load( + destination_project_dataset_table='test.teast$20170101', + schema_fields=[], + source_uris=[], + time_partitioning=tp_dict + ) + + if __name__ == '__main__': unittest.main()
