Repository: incubator-airflow Updated Branches: refs/heads/master 4c674ccff -> 86063ba4e
[AIRFLOW-1568] Add datastore export/import operators Closes #2568 from jgao54/ds-import-export Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/86063ba4 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/86063ba4 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/86063ba4 Branch: refs/heads/master Commit: 86063ba4e9babfa280a4d8374a8b45c0cc72aed0 Parents: 4c674cc Author: Joy Gao <[email protected]> Authored: Wed Sep 6 09:45:56 2017 -0700 Committer: Chris Riccomini <[email protected]> Committed: Wed Sep 6 09:45:56 2017 -0700 ---------------------------------------------------------------------- airflow/contrib/hooks/bigquery_hook.py | 35 +++++-- airflow/contrib/hooks/datastore_hook.py | 100 +++++++++++++++--- airflow/contrib/hooks/gcp_api_base_hook.py | 20 +++- .../operators/datastore_export_operator.py | 104 +++++++++++++++++++ .../operators/datastore_import_operator.py | 95 +++++++++++++++++ airflow/contrib/operators/gcs_to_bq.py | 12 ++- 6 files changed, 335 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/86063ba4/airflow/contrib/hooks/bigquery_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index e60f597..d7c1126 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -390,7 +390,8 @@ class BigQueryBaseCursor(object): max_bad_records=0, quote_character=None, allow_quoted_newlines=False, - schema_update_options=()): + schema_update_options=(), + src_fmt_configs={}): """ Executes a BigQuery load command to load data from Google Cloud Storage to BigQuery. See here: @@ -431,6 +432,8 @@ class BigQueryBaseCursor(object): :param schema_update_options: Allows the schema of the desitination table to be updated as a side effect of the load job. :type schema_update_options: list + :param src_fmt_configs: configure optional fields specific to the source format + :type src_fmt_configs: dict """ # bigquery only allows certain source formats @@ -439,7 +442,7 @@ class BigQueryBaseCursor(object): # Refer to this link for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat source_format = source_format.upper() - allowed_formats = ["CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS"] + allowed_formats = ["CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS", "DATASTORE_BACKUP"] if source_format not in allowed_formats: raise ValueError("{0} is not a valid source format. " "Please use one of the following types: {1}" @@ -497,18 +500,32 @@ class BigQueryBaseCursor(object): ) configuration['load']['schemaUpdateOptions'] = schema_update_options - if source_format == 'CSV': - configuration['load']['skipLeadingRows'] = skip_leading_rows - configuration['load']['fieldDelimiter'] = field_delimiter - if max_bad_records: configuration['load']['maxBadRecords'] = max_bad_records + # if following fields are not specified in src_fmt_configs, + # honor the top-level params for backward-compatibility + if 'skip_leading_rows' not in src_fmt_configs: + src_fmt_configs['skip_leading_rows'] = skip_leading_rows + if 'fieldDelimiter' not in src_fmt_configs: + src_fmt_configs['fieldDelimiter'] = field_delimiter if quote_character: - configuration['load']['quote'] = quote_character - + src_fmt_configs['quote'] = quote_character if allow_quoted_newlines: - configuration['load']['allowQuotedNewlines'] = allow_quoted_newlines + src_fmt_configs['allowQuotedNewlines'] = allow_quoted_newlines + + src_fmt_to_configs_mapping = { + 'CSV': ['allowJaggedRows', 'allowQuotedNewlines', 'autodetect', + 'fieldDelimiter', 'skipLeadingRows', 'ignoreUnknownValues', + 'nullMarker', 'quote'], + 'DATASTORE_BACKUP': ['projectionFields'], + 'NEWLINE_DELIMITED_JSON': ['autodetect', 'ignoreUnknownValues'], + 'AVRO': [], + } + valid_configs = src_fmt_to_configs_mapping[source_format] + src_fmt_configs = {k: v for k, v in src_fmt_configs.items() + if k in valid_configs} + configuration['load'].update(src_fmt_configs) return self.run_with_configuration(configuration) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/86063ba4/airflow/contrib/hooks/datastore_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/datastore_hook.py b/airflow/contrib/hooks/datastore_hook.py index 9b1dee3..7a4386a 100644 --- a/airflow/contrib/hooks/datastore_hook.py +++ b/airflow/contrib/hooks/datastore_hook.py @@ -13,6 +13,9 @@ # limitations under the License. # +import json +import time +import logging from apiclient.discovery import build from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook @@ -30,53 +33,52 @@ class DatastoreHook(GoogleCloudBaseHook): datastore_conn_id='google_cloud_datastore_default', delegate_to=None): super(DatastoreHook, self).__init__(datastore_conn_id, delegate_to) - # datasetId is the same as the project name - self.dataset_id = self._get_field('project') self.connection = self.get_conn() + self.admin_connection = self.get_conn('v1beta1') - def get_conn(self): + def get_conn(self, version='v1'): """ Returns a Google Cloud Storage service object. """ http_authorized = self._authorize() - return build('datastore', 'v1beta2', http=http_authorized) + return build('datastore', version, http=http_authorized) def allocate_ids(self, partialKeys): """ Allocate IDs for incomplete keys. - see https://cloud.google.com/datastore/docs/apis/v1beta2/datasets/allocateIds + see https://cloud.google.com/datastore/docs/reference/rest/v1/projects/allocateIds :param partialKeys: a list of partial keys :return: a list of full keys. """ - resp = self.connection.datasets().allocateIds(datasetId=self.dataset_id, body={'keys': partialKeys}).execute() + resp = self.connection.projects().allocateIds(projectId=self.project_id, body={'keys': partialKeys}).execute() return resp['keys'] def begin_transaction(self): """ Get a new transaction handle - see https://cloud.google.com/datastore/docs/apis/v1beta2/datasets/beginTransaction + see https://cloud.google.com/datastore/docs/reference/rest/v1/projects/beginTransaction :return: a transaction handle """ - resp = self.connection.datasets().beginTransaction(datasetId=self.dataset_id, body={}).execute() + resp = self.connection.projects().beginTransaction(projectId=self.project_id, body={}).execute() return resp['transaction'] def commit(self, body): """ Commit a transaction, optionally creating, deleting or modifying some entities. - see https://cloud.google.com/datastore/docs/apis/v1beta2/datasets/commit + see https://cloud.google.com/datastore/docs/reference/rest/v1/projects/commit :param body: the body of the commit request :return: the response body of the commit request """ - resp = self.connection.datasets().commit(datasetId=self.dataset_id, body=body).execute() + resp = self.connection.projects().commit(projectId=self.project_id, body=body).execute() return resp def lookup(self, keys, read_consistency=None, transaction=None): """ Lookup some entities by key - see https://cloud.google.com/datastore/docs/apis/v1beta2/datasets/lookup + see https://cloud.google.com/datastore/docs/reference/rest/v1/projects/lookup :param keys: the keys to lookup :param read_consistency: the read consistency to use. default, strong or eventual. Cannot be used with a transaction. @@ -88,23 +90,89 @@ class DatastoreHook(GoogleCloudBaseHook): body['readConsistency'] = read_consistency if transaction: body['transaction'] = transaction - return self.connection.datasets().lookup(datasetId=self.dataset_id, body=body).execute() + return self.connection.projects().lookup(projectId=self.project_id, body=body).execute() def rollback(self, transaction): """ Roll back a transaction - see https://cloud.google.com/datastore/docs/apis/v1beta2/datasets/rollback + see https://cloud.google.com/datastore/docs/reference/rest/v1/projects/rollback :param transaction: the transaction to roll back """ - self.connection.datasets().rollback(datasetId=self.dataset_id, body={'transaction': transaction})\ + self.connection.projects().rollback(projectId=self.project_id, body={'transaction': transaction})\ .execute() def run_query(self, body): """ Run a query for entities. - see https://cloud.google.com/datastore/docs/apis/v1beta2/datasets/runQuery + see https://cloud.google.com/datastore/docs/reference/rest/v1/projects/runQuery :param body: the body of the query request :return: the batch of query results. """ - resp = self.connection.datasets().runQuery(datasetId=self.dataset_id, body=body).execute() + resp = self.connection.projects().runQuery(projectId=self.project_id, body=body).execute() return resp['batch'] + + def get_operation(self, name): + """ + Gets the latest state of a long-running operation + + :param name: the name of the operation resource + """ + resp = self.connection.projects().operations().get(name=name).execute() + return resp + + def delete_operation(self, name): + """ + Deletes the long-running operation + + :param name: the name of the operation resource + """ + resp = self.connection.projects().operations().delete(name=name).execute() + return resp + + def poll_operation_until_done(self, name, polling_interval_in_seconds): + """ + Poll backup operation state until it's completed + """ + while True: + result = self.get_operation(name) + state = result['metadata']['common']['state'] + if state == 'PROCESSING': + logging.info('Operation is processing. Re-polling state in {} seconds' + .format(polling_interval_in_seconds)) + time.sleep(polling_interval_in_seconds) + else: + return result + + def export_to_storage_bucket(self, bucket, namespace=None, entity_filter=None, labels=None): + """ + Export entities from Cloud Datastore to Cloud Storage for backup + """ + output_uri_prefix = 'gs://' + ('/').join(filter(None, [bucket, namespace])) + if not entity_filter: + entity_filter = {} + if not labels: + labels = {} + body = { + 'outputUrlPrefix': output_uri_prefix, + 'entityFilter': entity_filter, + 'labels': labels, + } + resp = self.admin_connection.projects().export(projectId=self.project_id, body=body).execute() + return resp + + def import_from_storage_bucket(self, bucket, file, namespace=None, entity_filter=None, labels=None): + """ + Import a backup from Cloud Storage to Cloud Datastore + """ + input_url = 'gs://' + ('/').join(filter(None, [bucket, namespace, file])) + if not entity_filter: + entity_filter = {} + if not labels: + labels = {} + body = { + 'inputUrl': input_url, + 'entityFilter': entity_filter, + 'labels': labels, + } + resp = self.admin_connection.projects().import_(projectId=self.project_id, body=body).execute() + return resp http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/86063ba4/airflow/contrib/hooks/gcp_api_base_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/gcp_api_base_hook.py b/airflow/contrib/hooks/gcp_api_base_hook.py index 2260e7b..48c5979 100644 --- a/airflow/contrib/hooks/gcp_api_base_hook.py +++ b/airflow/contrib/hooks/gcp_api_base_hook.py @@ -14,6 +14,7 @@ # import logging +import json import httplib2 from oauth2client.client import GoogleCredentials @@ -22,7 +23,6 @@ from oauth2client.service_account import ServiceAccountCredentials from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook - class GoogleCloudBaseHook(BaseHook): """ A base hook for Google cloud-related hooks. Google cloud has a shared REST @@ -57,10 +57,9 @@ class GoogleCloudBaseHook(BaseHook): self.delegate_to = delegate_to self.extras = self.get_connection(conn_id).extra_dejson - def _authorize(self): + def _get_credentials(self): """ - Returns an authorized HTTP object to be used to build a Google cloud - service hook connection. + Returns the Credentials object for Google API """ key_path = self._get_field('key_path', False) scope = self._get_field('scope', False) @@ -86,7 +85,20 @@ class GoogleCloudBaseHook(BaseHook): 'use a JSON key file.') else: raise AirflowException('Unrecognised extension for key file.') + return credentials + + def _get_access_token(self): + """ + Returns a valid access token from Google API Credentials + """ + return self._get_credentials().get_access_token().access_token + def _authorize(self): + """ + Returns an authorized HTTP object to be used to build a Google cloud + service hook connection. + """ + credentials = self._get_credentials() http = httplib2.Http() return credentials.authorize(http) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/86063ba4/airflow/contrib/operators/datastore_export_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/datastore_export_operator.py b/airflow/contrib/operators/datastore_export_operator.py new file mode 100644 index 0000000..1980dfe --- /dev/null +++ b/airflow/contrib/operators/datastore_export_operator.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from airflow.contrib.hooks.datastore_hook import DatastoreHook +from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + +class DatastoreExportOperator(BaseOperator): + """ + Export entities from Google Cloud Datastore to Cloud Storage + + :param bucket: name of the cloud storage bucket to backup data + :type bucket: string + :param namespace: optional namespace path in the specified Cloud Storage bucket + to backup data. If this namespace does not exist in GCS, it will be created. + :type namespace: str + :param datastore_conn_id: the name of the Datastore connection id to use + :type datastore_conn_id: string + :param cloud_storage_conn_id: the name of the cloud storage connection id to force-write + backup + :type cloud_storage_conn_id: string + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: string + :param entity_filter: description of what data from the project is included in the export, + refer to https://cloud.google.com/datastore/docs/reference/rest/Shared.Types/EntityFilter + :type entity_filter: dict + :param labels: client-assigned labels for cloud storage + :type labels: dict + :param polling_interval_in_seconds: number of seconds to wait before polling for + execution status again + :type polling_interval_in_seconds: int + :param overwrite_existing: if the storage bucket + namespace is not empty, it will be + emptied prior to exports. This enables overwriting existing backups. + :type overwrite_existing: bool + :param xcom_push: push operation name to xcom for reference + :type xcom_push: bool + """ + + @apply_defaults + def __init__(self, + bucket, + namespace=None, + datastore_conn_id='google_cloud_default', + cloud_storage_conn_id='google_cloud_default', + delegate_to=None, + entity_filter=None, + labels=None, + polling_interval_in_seconds=10, + overwrite_existing=False, + xcom_push=False, + *args, + **kwargs): + super(DatastoreExportOperator, self).__init__(*args, **kwargs) + self.datastore_conn_id = datastore_conn_id + self.cloud_storage_conn_id = cloud_storage_conn_id + self.delegate_to = delegate_to + self.bucket = bucket + self.namespace = namespace + self.entity_filter = entity_filter + self.labels = labels + self.polling_interval_in_seconds = polling_interval_in_seconds + self.overwrite_existing = overwrite_existing + self.xcom_push = xcom_push + + def execute(self, context): + logging.info('Exporting data to Cloud Storage bucket ' + self.bucket) + + if self.overwrite_existing and self.namespace: + gcs_hook = GoogleCloudStorageHook(self.cloud_storage_conn_id) + objects = gcs_hook.list(self.bucket, prefix=self.namespace) + for o in objects: + gcs_hook.delete(self.bucket, o) + + ds_hook = DatastoreHook(self.datastore_conn_id,self.delegate_to) + result = ds_hook.export_to_storage_bucket(bucket=self.bucket, + namespace=self.namespace, + entity_filter=self.entity_filter, + labels=self.labels) + operation_name = result['name'] + result = ds_hook.poll_operation_until_done(operation_name, + self.polling_interval_in_seconds) + + state = result['metadata']['common']['state'] + if state != 'SUCCESSFUL': + raise AirflowException('Operation failed: result={}'.format(result)) + + if self.xcom_push: + return result http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/86063ba4/airflow/contrib/operators/datastore_import_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/datastore_import_operator.py b/airflow/contrib/operators/datastore_import_operator.py new file mode 100644 index 0000000..3427ba5 --- /dev/null +++ b/airflow/contrib/operators/datastore_import_operator.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging + +from airflow.contrib.hooks.datastore_hook import DatastoreHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + +class DatastoreImportOperator(BaseOperator): + """ + Import entities from Cloud Storage to Google Cloud Datastore + + :param bucket: container in Cloud Storage to store data + :type bucket: string + :param file: path of the backup metadata file in the specified Cloud Storage bucket. + It should have the extension .overall_export_metadata + :type file: string + :param namespace: optional namespace of the backup metadata file in + the specified Cloud Storage bucket. + :type namespace: str + :param entity_filter: description of what data from the project is included in the export, + refer to https://cloud.google.com/datastore/docs/reference/rest/Shared.Types/EntityFilter + :type entity_filter: dict + :param labels: client-assigned labels for cloud storage + :type labels: dict + :param datastore_conn_id: the name of the connection id to use + :type datastore_conn_id: string + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have domain-wide + delegation enabled. + :type delegate_to: string + :param polling_interval_in_seconds: number of seconds to wait before polling for + execution status again + :type polling_interval_in_seconds: int + :param xcom_push: push operation name to xcom for reference + :type xcom_push: bool + """ + + @apply_defaults + def __init__(self, + bucket, + file, + namespace=None, + entity_filter=None, + labels=None, + datastore_conn_id='google_cloud_default', + delegate_to=None, + polling_interval_in_seconds=10, + xcom_push=False, + *args, + **kwargs): + super(DatastoreImportOperator, self).__init__(*args, **kwargs) + self.datastore_conn_id = datastore_conn_id + self.delegate_to = delegate_to + self.bucket = bucket + self.file = file + self.namespace = namespace + self.entity_filter = entity_filter + self.labels = labels + self.polling_interval_in_seconds = polling_interval_in_seconds + self.xcom_push = xcom_push + + def execute(self, context): + logging.info('Importing data from Cloud Storage bucket ' + self.bucket) + ds_hook = DatastoreHook(self.datastore_conn_id, self.delegate_to) + result = ds_hook.import_from_storage_bucket(bucket=self.bucket, + file=self.file, + namespace=self.namespace, + entity_filter=self.entity_filter, + labels=self.labels) + operation_name = result['name'] + result = ds_hook.poll_operation_until_done(operation_name, + self.polling_interval_in_seconds) + + + state = result['metadata']['common']['state'] + if state != 'SUCCESSFUL': + raise AirflowException('Operation failed: result={}'.format(result)) + + if self.xcom_push: + return result + http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/86063ba4/airflow/contrib/operators/gcs_to_bq.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py index bab5abe..9981cd4 100644 --- a/airflow/contrib/operators/gcs_to_bq.py +++ b/airflow/contrib/operators/gcs_to_bq.py @@ -51,6 +51,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): google_cloud_storage_conn_id='google_cloud_storage_default', delegate_to=None, schema_update_options=(), + src_fmt_configs={}, *args, **kwargs): """ @@ -62,6 +63,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): :param bucket: The bucket to load from. :type bucket: string :param source_objects: List of Google cloud storage URIs to load from. + If source_format is 'DATASTORE_BACKUP', the list must only contain a single URI. :type object: list :param destination_project_dataset_table: The dotted (<project>.)<dataset>.<table> BigQuery table to load data into. If <project> is not included, project will @@ -69,6 +71,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): :type destination_project_dataset_table: string :param schema_fields: If set, the schema field list as defined here: https://cloud.google.com/bigquery/docs/reference/v2/jobs#configuration.load + Should not be set when source_format is 'DATASTORE_BACKUP'. :type schema_fields: list :param schema_object: If set, a GCS object path pointing to a .json file that contains the schema for the table. @@ -109,6 +112,8 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): :param schema_update_options: Allows the schema of the desitination table to be updated as a side effect of the load job. :type schema_update_options: list + :param src_fmt_configs: configure optional fields specific to the source format + :type src_fmt_configs: dict """ super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, **kwargs) @@ -135,12 +140,14 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): self.delegate_to = delegate_to self.schema_update_options = schema_update_options + self.src_fmt_configs = src_fmt_configs def execute(self, context): bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id, delegate_to=self.delegate_to) - if not self.schema_fields and self.schema_object: + if not self.schema_fields and self.schema_object \ + and self.source_format != 'DATASTORE_BACKUP': gcs_hook = GoogleCloudStorageHook( google_cloud_storage_conn_id=self.google_cloud_storage_conn_id, delegate_to=self.delegate_to) @@ -166,7 +173,8 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): max_bad_records=self.max_bad_records, quote_character=self.quote_character, allow_quoted_newlines=self.allow_quoted_newlines, - schema_update_options=self.schema_update_options) + schema_update_options=self.schema_update_options, + src_fmt_configs=self.src_fmt_configs) if self.max_id_key: cursor.execute('SELECT MAX({}) FROM {}'.format(
