[
https://issues.apache.org/jira/browse/AIRFLOW-3078?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16644750#comment-16644750
]
ASF GitHub Bot commented on AIRFLOW-3078:
-----------------------------------------
kaxil closed pull request #4022: [AIRFLOW-3078] Basic operators for Google
Compute Engine
URL: https://github.com/apache/incubator-airflow/pull/4022
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/airflow/contrib/example_dags/example_gcp_compute.py
b/airflow/contrib/example_dags/example_gcp_compute.py
new file mode 100644
index 0000000000..e4abe2e152
--- /dev/null
+++ b/airflow/contrib/example_dags/example_gcp_compute.py
@@ -0,0 +1,108 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG that starts, stops and sets the machine type of a Google
Compute
+Engine instance.
+
+This DAG relies on the following Airflow variables
+https://airflow.apache.org/concepts.html#variables
+* PROJECT_ID - Google Cloud Platform project where the Compute Engine instance
exists.
+* LOCATION - Google Cloud Platform zone where the instance exists.
+* INSTANCE - Name of the Compute Engine instance.
+* SHORT_MACHINE_TYPE_NAME - Machine type resource name to set, e.g.
'n1-standard-1'.
+ See https://cloud.google.com/compute/docs/machine-types
+"""
+
+import datetime
+
+import airflow
+from airflow import models
+from airflow.contrib.operators.gcp_compute_operator import
GceInstanceStartOperator, \
+ GceInstanceStopOperator, GceSetMachineTypeOperator
+
+# [START howto_operator_gce_args]
+PROJECT_ID = models.Variable.get('PROJECT_ID', '')
+LOCATION = models.Variable.get('LOCATION', '')
+INSTANCE = models.Variable.get('INSTANCE', '')
+SHORT_MACHINE_TYPE_NAME = models.Variable.get('SHORT_MACHINE_TYPE_NAME', '')
+SET_MACHINE_TYPE_BODY = {
+ 'machineType': 'zones/{}/machineTypes/{}'.format(LOCATION,
SHORT_MACHINE_TYPE_NAME)
+}
+
+default_args = {
+ 'start_date': airflow.utils.dates.days_ago(1)
+}
+# [END howto_operator_gce_args]
+
+with models.DAG(
+ 'example_gcp_compute',
+ default_args=default_args,
+ schedule_interval=datetime.timedelta(days=1)
+) as dag:
+ # [START howto_operator_gce_start]
+ gce_instance_start = GceInstanceStartOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=INSTANCE,
+ task_id='gcp_compute_start_task'
+ )
+ # [END howto_operator_gce_start]
+ # Duplicate start for idempotence testing
+ gce_instance_start2 = GceInstanceStartOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=INSTANCE,
+ task_id='gcp_compute_start_task2'
+ )
+ # [START howto_operator_gce_stop]
+ gce_instance_stop = GceInstanceStopOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=INSTANCE,
+ task_id='gcp_compute_stop_task'
+ )
+ # [END howto_operator_gce_stop]
+ # Duplicate stop for idempotence testing
+ gce_instance_stop2 = GceInstanceStopOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=INSTANCE,
+ task_id='gcp_compute_stop_task2'
+ )
+ # [START howto_operator_gce_set_machine_type]
+ gce_set_machine_type = GceSetMachineTypeOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=INSTANCE,
+ body=SET_MACHINE_TYPE_BODY,
+ task_id='gcp_compute_set_machine_type'
+ )
+ # [END howto_operator_gce_set_machine_type]
+ # Duplicate set machine type for idempotence testing
+ gce_set_machine_type2 = GceSetMachineTypeOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=INSTANCE,
+ body=SET_MACHINE_TYPE_BODY,
+ task_id='gcp_compute_set_machine_type2'
+ )
+
+ gce_instance_start >> gce_instance_start2 >> gce_instance_stop >> \
+ gce_instance_stop2 >> gce_set_machine_type >> gce_set_machine_type2
diff --git a/airflow/contrib/hooks/gcp_compute_hook.py
b/airflow/contrib/hooks/gcp_compute_hook.py
new file mode 100644
index 0000000000..5fa088942b
--- /dev/null
+++ b/airflow/contrib/hooks/gcp_compute_hook.py
@@ -0,0 +1,167 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import time
+from googleapiclient.discovery import build
+
+from airflow import AirflowException
+from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
+
+# Number of retries - used by googleapiclient method calls to perform retries
+# For requests that are "retriable"
+NUM_RETRIES = 5
+
+# Time to sleep between active checks of the operation results
+TIME_TO_SLEEP_IN_SECONDS = 1
+
+
+class GceOperationStatus:
+ PENDING = "PENDING"
+ RUNNING = "RUNNING"
+ DONE = "DONE"
+
+
+# noinspection PyAbstractClass
+class GceHook(GoogleCloudBaseHook):
+ """
+ Hook for Google Compute Engine APIs.
+ """
+ _conn = None
+
+ def __init__(self,
+ api_version,
+ gcp_conn_id='google_cloud_default',
+ delegate_to=None):
+ super(GceHook, self).__init__(gcp_conn_id, delegate_to)
+ self.api_version = api_version
+
+ def get_conn(self):
+ """
+ Retrieves connection to Google Compute Engine.
+
+ :return: Google Compute Engine services object
+ :rtype: dict
+ """
+ if not self._conn:
+ http_authorized = self._authorize()
+ self._conn = build('compute', self.api_version,
+ http=http_authorized, cache_discovery=False)
+ return self._conn
+
+ def start_instance(self, project_id, zone, resource_id):
+ """
+ Starts an existing instance defined by project_id, zone and
resource_id.
+
+ :param project_id: Google Cloud Platform project where the Compute
Engine
+ instance exists.
+ :type project_id: str
+ :param zone: Google Cloud Platform zone where the instance exists.
+ :type zone: str
+ :param resource_id: Name of the Compute Engine instance resource.
+ :type resource_id: str
+ :return: True if the operation succeeded, raises an error otherwise
+ :rtype: bool
+ """
+ response = self.get_conn().instances().start(
+ project=project_id,
+ zone=zone,
+ instance=resource_id
+ ).execute(num_retries=NUM_RETRIES)
+ operation_name = response["name"]
+ return self._wait_for_operation_to_complete(project_id, zone,
operation_name)
+
+ def stop_instance(self, project_id, zone, resource_id):
+ """
+ Stops an instance defined by project_id, zone and resource_id.
+
+ :param project_id: Google Cloud Platform project where the Compute
Engine
+ instance exists.
+ :type project_id: str
+ :param zone: Google Cloud Platform zone where the instance exists.
+ :type zone: str
+ :param resource_id: Name of the Compute Engine instance resource.
+ :type resource_id: str
+ :return: True if the operation succeeded, raises an error otherwise
+ :rtype: bool
+ """
+ response = self.get_conn().instances().stop(
+ project=project_id,
+ zone=zone,
+ instance=resource_id
+ ).execute(num_retries=NUM_RETRIES)
+ operation_name = response["name"]
+ return self._wait_for_operation_to_complete(project_id, zone,
operation_name)
+
+ def set_machine_type(self, project_id, zone, resource_id, body):
+ """
+ Sets machine type of an instance defined by project_id, zone and
resource_id.
+
+ :param project_id: Google Cloud Platform project where the Compute
Engine
+ instance exists.
+ :type project_id: str
+ :param zone: Google Cloud Platform zone where the instance exists.
+ :type zone: str
+ :param resource_id: Name of the Compute Engine instance resource.
+ :type resource_id: str
+ :param body: Body required by the Compute Engine setMachineType API,
+ as described in
+
https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMachineType
+ :type body: dict
+ :return: True if the operation succeeded, raises an error otherwise
+ :rtype: bool
+ """
+ response = self._execute_set_machine_type(project_id, zone,
resource_id, body)
+ operation_name = response["name"]
+ return self._wait_for_operation_to_complete(project_id, zone,
operation_name)
+
+ def _execute_set_machine_type(self, project_id, zone, resource_id, body):
+ return self.get_conn().instances().setMachineType(
+ project=project_id, zone=zone, instance=resource_id, body=body)\
+ .execute(num_retries=NUM_RETRIES)
+
+ def _wait_for_operation_to_complete(self, project_id, zone,
operation_name):
+ """
+ Waits for the named operation to complete - checks status of the
+ asynchronous call.
+
+ :param operation_name: name of the operation
+ :type operation_name: str
+ :return: True if the operation succeeded, raises an error otherwise
+ :rtype: bool
+ """
+ service = self.get_conn()
+ while True:
+ operation_response = self._check_operation_status(
+ service, operation_name, project_id, zone)
+ if operation_response.get("status") == GceOperationStatus.DONE:
+ error = operation_response.get("error")
+ if error:
+ code = operation_response.get("httpErrorStatusCode")
+ msg = operation_response.get("httpErrorMessage")
+ # Extracting the errors list as string and trimming square
braces
+ error_msg = str(error.get("errors"))[1:-1]
+ raise AirflowException("{} {}: ".format(code, msg) +
error_msg)
+ # No meaningful info to return from the response in case of
success
+ return True
+ time.sleep(TIME_TO_SLEEP_IN_SECONDS)
+
+ def _check_operation_status(self, service, operation_name, project_id,
zone):
+ return service.zoneOperations().get(
+ project=project_id, zone=zone, operation=operation_name).execute(
+ num_retries=NUM_RETRIES)
diff --git a/airflow/contrib/operators/gcp_compute_operator.py
b/airflow/contrib/operators/gcp_compute_operator.py
new file mode 100644
index 0000000000..a2fd545294
--- /dev/null
+++ b/airflow/contrib/operators/gcp_compute_operator.py
@@ -0,0 +1,183 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow import AirflowException
+from airflow.contrib.hooks.gcp_compute_hook import GceHook
+from airflow.contrib.utils.gcp_field_validator import GcpBodyFieldValidator
+from airflow.models import BaseOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class GceBaseOperator(BaseOperator):
+ """
+ Abstract base operator for Google Compute Engine operators to inherit from.
+ """
+ @apply_defaults
+ def __init__(self,
+ project_id,
+ zone,
+ resource_id,
+ gcp_conn_id='google_cloud_default',
+ api_version='v1',
+ *args, **kwargs):
+ self.project_id = project_id
+ self.zone = zone
+ self.full_location = 'projects/{}/zones/{}'.format(self.project_id,
+ self.zone)
+ self.resource_id = resource_id
+ self.gcp_conn_id = gcp_conn_id
+ self.api_version = api_version
+ self._validate_inputs()
+ self._hook = GceHook(gcp_conn_id=self.gcp_conn_id,
api_version=self.api_version)
+ super(GceBaseOperator, self).__init__(*args, **kwargs)
+
+ def _validate_inputs(self):
+ if not self.project_id:
+ raise AirflowException("The required parameter 'project_id' is
missing")
+ if not self.zone:
+ raise AirflowException("The required parameter 'zone' is missing")
+ if not self.resource_id:
+ raise AirflowException("The required parameter 'resource_id' is
missing")
+
+ def execute(self, context):
+ pass
+
+
+class GceInstanceStartOperator(GceBaseOperator):
+ """
+ Start an instance in Google Compute Engine.
+
+ :param project_id: Google Cloud Platform project where the Compute Engine
+ instance exists.
+ :type project_id: str
+ :param zone: Google Cloud Platform zone where the instance exists.
+ :type zone: str
+ :param resource_id: Name of the Compute Engine instance resource.
+ :type resource_id: str
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud
Platform.
+ :type gcp_conn_id: str
+ :param api_version: API version used (e.g. v1).
+ :type api_version: str
+ """
+ template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id',
'api_version')
+
+ @apply_defaults
+ def __init__(self,
+ project_id,
+ zone,
+ resource_id,
+ gcp_conn_id='google_cloud_default',
+ api_version='v1',
+ *args, **kwargs):
+ super(GceInstanceStartOperator, self).__init__(
+ project_id=project_id, zone=zone, resource_id=resource_id,
+ gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs)
+
+ def execute(self, context):
+ return self._hook.start_instance(self.project_id, self.zone,
self.resource_id)
+
+
+class GceInstanceStopOperator(GceBaseOperator):
+ """
+ Stop an instance in Google Compute Engine.
+
+ :param project_id: Google Cloud Platform project where the Compute Engine
+ instance exists.
+ :type project_id: str
+ :param zone: Google Cloud Platform zone where the instance exists.
+ :type zone: str
+ :param resource_id: Name of the Compute Engine instance resource.
+ :type resource_id: str
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud
Platform.
+ :type gcp_conn_id: str
+ :param api_version: API version used (e.g. v1).
+ :type api_version: str
+ """
+ template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id',
'api_version')
+
+ @apply_defaults
+ def __init__(self,
+ project_id,
+ zone,
+ resource_id,
+ gcp_conn_id='google_cloud_default',
+ api_version='v1',
+ *args, **kwargs):
+ super(GceInstanceStopOperator, self).__init__(
+ project_id=project_id, zone=zone, resource_id=resource_id,
+ gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs)
+
+ def execute(self, context):
+ return self._hook.stop_instance(self.project_id, self.zone,
self.resource_id)
+
+
+SET_MACHINE_TYPE_VALIDATION_SPECIFICATION = [
+ dict(name="machineType", regexp="^.+$"),
+]
+
+
+class GceSetMachineTypeOperator(GceBaseOperator):
+ """
+ Changes the machine type for a stopped instance to the machine type
specified in
+ the request.
+
+ :param project_id: Google Cloud Platform project where the Compute Engine
+ instance exists.
+ :type project_id: str
+ :param zone: Google Cloud Platform zone where the instance exists.
+ :type zone: str
+ :param resource_id: Name of the Compute Engine instance resource.
+ :type resource_id: str
+ :param body: Body required by the Compute Engine setMachineType API, as
described in
+
https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMachineType#request-body
+ :type body: dict
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud
Platform.
+ :type gcp_conn_id: str
+ :param api_version: API version used (e.g. v1).
+ :type api_version: str
+ """
+ template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id',
'api_version')
+
+ @apply_defaults
+ def __init__(self,
+ project_id,
+ zone,
+ resource_id,
+ body,
+ gcp_conn_id='google_cloud_default',
+ api_version='v1',
+ validate_body=True,
+ *args, **kwargs):
+ self.body = body
+ self._field_validator = None
+ if validate_body:
+ self._field_validator = GcpBodyFieldValidator(
+ SET_MACHINE_TYPE_VALIDATION_SPECIFICATION,
api_version=api_version)
+ super(GceSetMachineTypeOperator, self).__init__(
+ project_id=project_id, zone=zone, resource_id=resource_id,
+ gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs)
+
+ def _validate_all_body_fields(self):
+ if self._field_validator:
+ self._field_validator.validate(self.body)
+
+ def execute(self, context):
+ self._validate_all_body_fields()
+ return self._hook.set_machine_type(self.project_id, self.zone,
+ self.resource_id, self.body)
diff --git a/airflow/contrib/operators/gcp_function_operator.py
b/airflow/contrib/operators/gcp_function_operator.py
index 4455307c93..8207b9d084 100644
--- a/airflow/contrib/operators/gcp_function_operator.py
+++ b/airflow/contrib/operators/gcp_function_operator.py
@@ -20,277 +20,23 @@
from googleapiclient.errors import HttpError
-from airflow import AirflowException, LoggingMixin
+from airflow import AirflowException
+from airflow.contrib.utils.gcp_field_validator import GcpBodyFieldValidator, \
+ GcpFieldValidationException
from airflow.version import version
from airflow.models import BaseOperator
from airflow.contrib.hooks.gcp_function_hook import GcfHook
from airflow.utils.decorators import apply_defaults
-# TODO: This whole section should be extracted later to
contrib/tools/field_validator.py
-
-COMPOSITE_FIELD_TYPES = ['union', 'dict']
-
-
-class FieldValidationException(AirflowException):
- """
- Thrown when validation finds dictionary field not valid according to
specification.
- """
-
- def __init__(self, message):
- super(FieldValidationException, self).__init__(message)
-
-
-class ValidationSpecificationException(AirflowException):
- """
- Thrown when validation specification is wrong
- (rather than dictionary being validated).
- This should only happen during development as ideally
- specification itself should not be invalid ;) .
- """
-
- def __init__(self, message):
- super(ValidationSpecificationException, self).__init__(message)
-
-
-# TODO: make better description, add some examples
-# TODO: move to contrib/utils folder when we reuse it.
-class BodyFieldValidator(LoggingMixin):
- """
- Validates correctness of request body according to specification.
- The specification can describe various type of
- fields including custom validation, and union of fields. This validator is
meant
- to be reusable by various operators
- in the near future, but for now it is left as part of the Google Cloud
Function,
- so documentation about the
- validator is not yet complete. To see what kind of specification can be
used,
- please take a look at
- gcp_function_operator.CLOUD_FUNCTION_VALIDATION which specifies validation
- for GCF deploy operator.
-
- :param validation_specs: dictionary describing validation specification
- :type validation_specs: [dict]
- :param api_version: Version of the api used (for example v1)
- :type api_version: str
-
- """
- def __init__(self, validation_specs, api_version):
- # type: ([dict], str) -> None
- super(BodyFieldValidator, self).__init__()
- self._validation_specs = validation_specs
- self._api_version = api_version
-
- @staticmethod
- def _get_field_name_with_parent(field_name, parent):
- if parent:
- return parent + '.' + field_name
- return field_name
-
- @staticmethod
- def _sanity_checks(children_validation_specs, field_type, full_field_path,
- regexp, custom_validation, value):
- # type: (dict, str, str, str, function, object) -> None
- if value is None and field_type != 'union':
- raise FieldValidationException(
- "The required body field '{}' is missing. Please add it.".
- format(full_field_path))
- if regexp and field_type:
- raise ValidationSpecificationException(
- "The validation specification entry '{}' has both type and
regexp. "
- "The regexp is only allowed without type (i.e. assume type is
'str' "
- "that can be validated with regexp)".format(full_field_path))
- if children_validation_specs and field_type not in
COMPOSITE_FIELD_TYPES:
- raise ValidationSpecificationException(
- "Nested fields are specified in field '{}' of type '{}'. "
- "Nested fields are only allowed for fields of those types:
('{}').".
- format(full_field_path, field_type, COMPOSITE_FIELD_TYPES))
- if custom_validation and field_type:
- raise ValidationSpecificationException(
- "The validation specification field '{}' has both type and "
- "custom_validation. Custom validation is only allowed without
type.".
- format(full_field_path))
-
- @staticmethod
- def _validate_regexp(full_field_path, regexp, value):
- # type: (str, str, str) -> None
- if not re.match(regexp, value):
- # Note matching of only the beginning as we assume the regexps
all-or-nothing
- raise FieldValidationException(
- "The body field '{}' of value '{}' does not match the field "
- "specification regexp: '{}'.".
- format(full_field_path, value, regexp))
-
- def _validate_dict(self, children_validation_specs, full_field_path,
value):
- # type: (dict, str, dict) -> None
- for child_validation_spec in children_validation_specs:
- self._validate_field(validation_spec=child_validation_spec,
- dictionary_to_validate=value,
- parent=full_field_path)
- for field_name in value.keys():
- if field_name not in [spec['name'] for spec in
children_validation_specs]:
- self.log.warning(
- "The field '{}' is in the body, but is not specified in
the "
- "validation specification '{}'. "
- "This might be because you are using newer API version and
"
- "new field names defined for that version. Then the
warning "
- "can be safely ignored, or you might want to upgrade the
operator"
- "to the version that supports the new API version.".format(
- self._get_field_name_with_parent(field_name,
full_field_path),
- children_validation_specs))
-
- def _validate_union(self, children_validation_specs, full_field_path,
- dictionary_to_validate):
- # type: (dict, str, dict) -> None
- field_found = False
- found_field_name = None
- for child_validation_spec in children_validation_specs:
- # Forcing optional so that we do not have to type optional = True
- # in specification for all union fields
- new_field_found = self._validate_field(
- validation_spec=child_validation_spec,
- dictionary_to_validate=dictionary_to_validate,
- parent=full_field_path,
- force_optional=True)
- field_name = child_validation_spec['name']
- if new_field_found and field_found:
- raise FieldValidationException(
- "The mutually exclusive fields '{}' and '{}' belonging to
the "
- "union '{}' are both present. Please remove one".
- format(field_name, found_field_name, full_field_path))
- if new_field_found:
- field_found = True
- found_field_name = field_name
- if not field_found:
- self.log.warning(
- "There is no '{}' union defined in the body {}. "
- "Validation expected one of '{}' but could not find any. It's
possible "
- "that you are using newer API version and there is another
union variant "
- "defined for that version. Then the warning can be safely
ignored, "
- "or you might want to upgrade the operator to the version that
"
- "supports the new API version.".format(
- full_field_path,
- dictionary_to_validate,
- [field['name'] for field in children_validation_specs]))
-
- def _validate_field(self, validation_spec, dictionary_to_validate,
parent=None,
- force_optional=False):
- """
- Validates if field is OK.
- :param validation_spec: specification of the field
- :type validation_spec: dict
- :param dictionary_to_validate: dictionary where the field should be
present
- :type dictionary_to_validate: dict
- :param parent: full path of parent field
- :type parent: str
- :param force_optional: forces the field to be optional
- (all union fields have force_optional set to True)
- :type force_optional: bool
- :return: True if the field is present
- """
- field_name = validation_spec['name']
- field_type = validation_spec.get('type')
- optional = validation_spec.get('optional')
- regexp = validation_spec.get('regexp')
- children_validation_specs = validation_spec.get('fields')
- required_api_version = validation_spec.get('api_version')
- custom_validation = validation_spec.get('custom_validation')
-
- full_field_path =
self._get_field_name_with_parent(field_name=field_name,
- parent=parent)
- if required_api_version and required_api_version != self._api_version:
- self.log.debug(
- "Skipping validation of the field '{}' for API version '{}' "
- "as it is only valid for API version '{}'".
- format(field_name, self._api_version, required_api_version))
- return False
- value = dictionary_to_validate.get(field_name)
-
- if (optional or force_optional) and value is None:
- self.log.debug("The optional field '{}' is missing. That's
perfectly OK.".
- format(full_field_path))
- return False
-
- # Certainly down from here the field is present (value is not None)
- # so we should only return True from now on
-
-
self._sanity_checks(children_validation_specs=children_validation_specs,
- field_type=field_type,
- full_field_path=full_field_path,
- regexp=regexp,
- custom_validation=custom_validation,
- value=value)
-
- if regexp:
- self._validate_regexp(full_field_path, regexp, value)
- elif field_type == 'dict':
- if not isinstance(value, dict):
- raise FieldValidationException(
- "The field '{}' should be dictionary type according to "
- "specification '{}' but it is '{}'".
- format(full_field_path, validation_spec, value))
- if children_validation_specs is None:
- self.log.debug(
- "The dict field '{}' has no nested fields defined in the "
- "specification '{}'. That's perfectly ok - it's content
will "
- "not be validated."
- .format(full_field_path, validation_spec))
- else:
- self._validate_dict(children_validation_specs,
full_field_path, value)
- elif field_type == 'union':
- if not children_validation_specs:
- raise ValidationSpecificationException(
- "The union field '{}' has no nested fields "
- "defined in specification '{}'. Unions should have at
least one "
- "nested field defined.".format(full_field_path,
validation_spec))
- self._validate_union(children_validation_specs, full_field_path,
- dictionary_to_validate)
- elif custom_validation:
- try:
- custom_validation(value)
- except Exception as e:
- raise FieldValidationException(
- "Error while validating custom field '{}' specified by
'{}': '{}'".
- format(full_field_path, validation_spec, e))
- elif field_type is None:
- self.log.debug("The type of field '{}' is not specified in '{}'. "
- "Not validating its content.".
- format(full_field_path, validation_spec))
- else:
- raise ValidationSpecificationException(
- "The field '{}' is of type '{}' in specification '{}'."
- "This type is unknown to validation!".format(
- full_field_path, field_type, validation_spec))
- return True
-
- def validate(self, body_to_validate):
- """
- Validates if the body (dictionary) follows specification that the
validator was
- instantiated with. Raises ValidationSpecificationException or
- ValidationFieldException in case of problems with specification or the
- body not conforming to the specification respectively.
- :param body_to_validate: body that must follow the specification
- :type body_to_validate: dict
- :return: None
- """
- try:
- for validation_spec in self._validation_specs:
- self._validate_field(validation_spec=validation_spec,
- dictionary_to_validate=body_to_validate)
- except FieldValidationException as e:
- raise FieldValidationException(
- "There was an error when validating: field '{}': '{}'".
- format(body_to_validate, e))
-
-# TODO End of field validator to be extracted
-
def _validate_available_memory_in_mb(value):
if int(value) <= 0:
- raise FieldValidationException("The available memory has to be greater
than 0")
+ raise GcpFieldValidationException("The available memory has to be
greater than 0")
def _validate_max_instances(value):
if int(value) <= 0:
- raise FieldValidationException(
+ raise GcpFieldValidationException(
"The max instances parameter has to be greater than 0")
@@ -378,9 +124,10 @@ def __init__(self,
self.api_version = api_version
self.zip_path = zip_path
self.zip_path_preprocessor = ZipPathPreprocessor(body, zip_path)
- self.validate_body = validate_body
- self._field_validator = BodyFieldValidator(CLOUD_FUNCTION_VALIDATION,
- api_version=api_version)
+ self._field_validator = None
+ if validate_body:
+ self._field_validator =
GcpBodyFieldValidator(CLOUD_FUNCTION_VALIDATION,
+
api_version=api_version)
self._hook = GcfHook(gcp_conn_id=self.gcp_conn_id,
api_version=self.api_version)
self._validate_inputs()
super(GcfFunctionDeployOperator, self).__init__(*args, **kwargs)
@@ -395,7 +142,8 @@ def _validate_inputs(self):
self.zip_path_preprocessor.preprocess_body()
def _validate_all_body_fields(self):
- self._field_validator.validate(self.body)
+ if self._field_validator:
+ self._field_validator.validate(self.body)
def _create_new_function(self):
self._hook.create_new_function(self.full_location, self.body)
@@ -406,8 +154,8 @@ def _update_function(self):
def _check_if_function_exists(self):
name = self.body.get('name')
if not name:
- raise FieldValidationException("The 'name' field should be present
in "
- "body: '{}'.".format(self.body))
+ raise GcpFieldValidationException("The 'name' field should be
present in "
+ "body: '{}'.".format(self.body))
try:
self._hook.get_function(name)
except HttpError as e:
@@ -430,8 +178,7 @@ def _set_airflow_version_label(self):
def execute(self, context):
if self.zip_path_preprocessor.should_upload_function():
self.body[SOURCE_UPLOAD_URL] = self._upload_source_code()
- if self.validate_body:
- self._validate_all_body_fields()
+ self._validate_all_body_fields()
self._set_airflow_version_label()
if not self._check_if_function_exists():
self._create_new_function()
diff --git a/airflow/contrib/utils/gcp_field_validator.py
b/airflow/contrib/utils/gcp_field_validator.py
new file mode 100644
index 0000000000..20f72d94b8
--- /dev/null
+++ b/airflow/contrib/utils/gcp_field_validator.py
@@ -0,0 +1,417 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Validator for body fields sent via GCP API.
+
+The validator performs validation of the body (being dictionary of fields) that
+is sent in the API request to Google Cloud (via googleclient API usually).
+
+Context
+-------
+The specification mostly focuses on helping Airflow DAG developers in the
development
+phase. You can build your own GCP operator (such as GcfDeployOperator for
example) which
+can have built-in validation specification for the particular API. It's super
helpful
+when developer plays with different fields and their values at the initial
phase of
+DAG development. Most of the Google Cloud APIs perform their own validation on
the
+server side, but most of the requests are asynchronous and you need to wait
for result
+of the operation. This takes precious times and slows
+down iteration over the API. BodyFieldValidator is meant to be used on the
client side
+and it should therefore provide an instant feedback to the developer on
misspelled or
+wrong type of parameters.
+
+The validation should be performed in "execute()" method call in order to allow
+template parameters to be expanded before validation is performed.
+
+Types of fields
+---------------
+
+Specification is an array of dictionaries - each dictionary describes field,
its type,
+validation, optionality, api_version supported and nested fields (for unions
and dicts).
+
+Typically (for clarity and in order to aid syntax highlighting) the array of
+dicts should be defined as series of dict() executions. Fragment of example
+specification might look as follows:
+
+```
+SPECIFICATION =[
+ dict(name="an_union", type="union", optional=True, fields=[
+ dict(name="variant_1", type="dict"),
+ dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'),
+ ),
+ dict(name="an_union", type="dict", fields=[
+ dict(name="field_1", type="dict"),
+ dict(name="field_2", regexp=r'^.+$'),
+ ),
+ ...
+]
+```
+
+Each field should have key = "name" indicating field name. The field can be of
one of the
+following types:
+
+* Dict fields: (key = "type", value="dict"):
+ Field of this type should contain nested fields in form of an array of dicts.
+ Each of the fields in the array is then expected (unless marked as optional)
+ and validated recursively. If an extra field is present in the dictionary,
warning is
+ printed in log file (but the validation succeeds - see the
Forward-compatibility notes)
+* Union fields (key = "type", value="union"): field of this type should
contain nested
+ fields in form of an array of dicts. One of the fields (and only one) should
be
+ present (unless the union is marked as optional). If more than one union
field is
+ present, FieldValidationException is raised. If none of the union fields is
+ present - warning is printed in the log (see below Forward-compatibility
notes).
+* Regexp-validated fields: (key = "regexp") - fields of this type are assumed
to be
+ strings and they are validated with the regexp specified. Remember that the
regexps
+ should ideally contain ^ at the beginning and $ at the end to make sure that
+ the whole field content is validated. Typically such regexp
+ validations should be used carefully and sparingly (see Forward-compatibility
+ notes below). Most of regexp validation should be at most r'^.+$'.
+* Custom-validated fields: (key = "custom_validation") - fields of this type
are validated
+ using method specified via custom_validation field. Any exception thrown in
the custom
+ validation will be turned into FieldValidationException and will cause
validation to
+ fail. Such custom validations might be used to check numeric fields
(including
+ ranges of values), booleans or any other types of fields.
+* API version: (key="api_version") if API version is specified, then the field
will only
+ be validated when api_version used at field validator initialization matches
exactly the
+ the version specified. If you want to declare fields that are available in
several
+ versions of the APIs, you should specify the field as many times as many API
versions
+ should be supported (each time with different API version).
+* if none of the keys ("type", "regexp", "custom_validation" - the field is
not validated
+
+You can see some of the field examples in EXAMPLE_VALIDATION_SPECIFICATION.
+
+
+Forward-compatibility notes
+---------------------------
+Certain decisions are crucial to allow the client APIs to work also with
future API
+versions. Since body attached is passed to the API’s call, this is entirely
+possible to pass-through any new fields in the body (for future API versions) -
+albeit without validation on the client side - they can and will still be
validated
+on the server side usually.
+
+Here are the guidelines that you should follow to make validation
forward-compatible:
+
+* most of the fields are not validated for their content. It's possible to use
regexp
+ in some specific cases that are guaranteed not to change in the future, but
for most
+ fields regexp validation should be r'^.+$' indicating check for non-emptiness
+* api_version is not validated - user can pass any future version of the api
here. The API
+ version is only used to filter parameters that are marked as present in this
api version
+ any new (not present in the specification) fields in the body are allowed
(not verified)
+ For dictionaries, new fields can be added to dictionaries by future calls.
However if an
+ unknown field in dictionary is added, a warning is logged by the client (but
validation
+ remains successful). This is very nice feature to protect against typos in
names.
+* For unions, newly added union variants can be added by future calls and they
will
+ pass validation, however the content or presence of those fields will not be
validated.
+ This means that it’s possible to send a new non-validated union field
together with an
+ old validated field and this problem will not be detected by the client. In
such case
+ warning will be printed.
+* When you add validator to an operator, you should also add ``validate_body``
parameter
+ (default = True) to __init__ of such operators - when it is set to False,
+ no validation should be performed. This is a safeguard for totally
unpredicted and
+ backwards-incompatible changes that might sometimes occur in the APIs.
+
+"""
+
+import re
+
+from airflow import LoggingMixin, AirflowException
+
+COMPOSITE_FIELD_TYPES = ['union', 'dict']
+
+
+class GcpFieldValidationException(AirflowException):
+ """Thrown when validation finds dictionary field not valid according to
specification.
+ """
+
+ def __init__(self, message):
+ super(GcpFieldValidationException, self).__init__(message)
+
+
+class GcpValidationSpecificationException(AirflowException):
+ """Thrown when validation specification is wrong.
+
+ This should only happen during development as ideally
+ specification itself should not be invalid ;) .
+ """
+
+ def __init__(self, message):
+ super(GcpValidationSpecificationException, self).__init__(message)
+
+
+def _int_greater_than_zero(value):
+ if int(value) <= 0:
+ raise GcpFieldValidationException("The available memory has to be
greater than 0")
+
+
+EXAMPLE_VALIDATION_SPECIFICATION = [
+ dict(name="name", regexp="^.+$"),
+ dict(name="description", regexp="^.+$", optional=True),
+ dict(name="availableMemoryMb", custom_validation=_int_greater_than_zero,
+ optional=True),
+ dict(name="labels", optional=True, type="dict"),
+ dict(name="an_union", type="union", fields=[
+ dict(name="variant_1", regexp=r'^.+$'),
+ dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'),
+ dict(name="variant_3", type="dict", fields=[
+ dict(name="url", regexp=r'^.+$')
+ ]),
+ dict(name="variant_4")
+ ]),
+]
+
+
+class GcpBodyFieldValidator(LoggingMixin):
+ """Validates correctness of request body according to specification.
+
+ The specification can describe various type of
+ fields including custom validation, and union of fields. This validator is
+ to be reusable by various operators. See the
EXAMPLE_VALIDATION_SPECIFICATION
+ for some examples and explanations of how to create specification.
+
+ :param validation_specs: dictionary describing validation specification
+ :type validation_specs: [dict]
+ :param api_version: Version of the api used (for example v1)
+ :type api_version: str
+
+ """
+ def __init__(self, validation_specs, api_version):
+ # type: ([dict], str) -> None
+ super(GcpBodyFieldValidator, self).__init__()
+ self._validation_specs = validation_specs
+ self._api_version = api_version
+
+ @staticmethod
+ def _get_field_name_with_parent(field_name, parent):
+ if parent:
+ return parent + '.' + field_name
+ return field_name
+
+ @staticmethod
+ def _sanity_checks(children_validation_specs, field_type, full_field_path,
+ regexp, custom_validation, value):
+ # type: (dict, str, str, str, function, object) -> None
+ if value is None and field_type != 'union':
+ raise GcpFieldValidationException(
+ "The required body field '{}' is missing. Please add it.".
+ format(full_field_path))
+ if regexp and field_type:
+ raise GcpValidationSpecificationException(
+ "The validation specification entry '{}' has both type and
regexp. "
+ "The regexp is only allowed without type (i.e. assume type is
'str' "
+ "that can be validated with regexp)".format(full_field_path))
+ if children_validation_specs and field_type not in
COMPOSITE_FIELD_TYPES:
+ raise GcpValidationSpecificationException(
+ "Nested fields are specified in field '{}' of type '{}'. "
+ "Nested fields are only allowed for fields of those types:
('{}').".
+ format(full_field_path, field_type, COMPOSITE_FIELD_TYPES))
+ if custom_validation and field_type:
+ raise GcpValidationSpecificationException(
+ "The validation specification field '{}' has both type and "
+ "custom_validation. Custom validation is only allowed without
type.".
+ format(full_field_path))
+
+ @staticmethod
+ def _validate_regexp(full_field_path, regexp, value):
+ # type: (str, str, str) -> None
+ if not re.match(regexp, value):
+ # Note matching of only the beginning as we assume the regexps
all-or-nothing
+ raise GcpFieldValidationException(
+ "The body field '{}' of value '{}' does not match the field "
+ "specification regexp: '{}'.".
+ format(full_field_path, value, regexp))
+
+ def _validate_dict(self, children_validation_specs, full_field_path,
value):
+ # type: (dict, str, dict) -> None
+ for child_validation_spec in children_validation_specs:
+ self._validate_field(validation_spec=child_validation_spec,
+ dictionary_to_validate=value,
+ parent=full_field_path)
+ all_dict_keys = [spec['name'] for spec in children_validation_specs]
+ for field_name in value.keys():
+ if field_name not in all_dict_keys:
+ self.log.warning(
+ "The field '{}' is in the body, but is not specified in
the "
+ "validation specification '{}'. "
+ "This might be because you are using newer API version and
"
+ "new field names defined for that version. Then the
warning "
+ "can be safely ignored, or you might want to upgrade the
operator"
+ "to the version that supports the new API version.".format(
+ self._get_field_name_with_parent(field_name,
full_field_path),
+ children_validation_specs))
+
+ def _validate_union(self, children_validation_specs, full_field_path,
+ dictionary_to_validate):
+ # type: (dict, str, dict) -> None
+ field_found = False
+ found_field_name = None
+ for child_validation_spec in children_validation_specs:
+ # Forcing optional so that we do not have to type optional = True
+ # in specification for all union fields
+ new_field_found = self._validate_field(
+ validation_spec=child_validation_spec,
+ dictionary_to_validate=dictionary_to_validate,
+ parent=full_field_path,
+ force_optional=True)
+ field_name = child_validation_spec['name']
+ if new_field_found and field_found:
+ raise GcpFieldValidationException(
+ "The mutually exclusive fields '{}' and '{}' belonging to
the "
+ "union '{}' are both present. Please remove one".
+ format(field_name, found_field_name, full_field_path))
+ if new_field_found:
+ field_found = True
+ found_field_name = field_name
+ if not field_found:
+ self.log.warning(
+ "There is no '{}' union defined in the body {}. "
+ "Validation expected one of '{}' but could not find any. It's
possible "
+ "that you are using newer API version and there is another
union variant "
+ "defined for that version. Then the warning can be safely
ignored, "
+ "or you might want to upgrade the operator to the version that
"
+ "supports the new API version.".format(
+ full_field_path,
+ dictionary_to_validate,
+ [field['name'] for field in children_validation_specs]))
+
+ def _validate_field(self, validation_spec, dictionary_to_validate,
parent=None,
+ force_optional=False):
+ """
+ Validates if field is OK.
+ :param validation_spec: specification of the field
+ :type validation_spec: dict
+ :param dictionary_to_validate: dictionary where the field should be
present
+ :type dictionary_to_validate: dict
+ :param parent: full path of parent field
+ :type parent: str
+ :param force_optional: forces the field to be optional
+ (all union fields have force_optional set to True)
+ :type force_optional: bool
+ :return: True if the field is present
+ """
+ field_name = validation_spec['name']
+ field_type = validation_spec.get('type')
+ optional = validation_spec.get('optional')
+ regexp = validation_spec.get('regexp')
+ children_validation_specs = validation_spec.get('fields')
+ required_api_version = validation_spec.get('api_version')
+ custom_validation = validation_spec.get('custom_validation')
+
+ full_field_path =
self._get_field_name_with_parent(field_name=field_name,
+ parent=parent)
+ if required_api_version and required_api_version != self._api_version:
+ self.log.debug(
+ "Skipping validation of the field '{}' for API version '{}' "
+ "as it is only valid for API version '{}'".
+ format(field_name, self._api_version, required_api_version))
+ return False
+ value = dictionary_to_validate.get(field_name)
+
+ if (optional or force_optional) and value is None:
+ self.log.debug("The optional field '{}' is missing. That's
perfectly OK.".
+ format(full_field_path))
+ return False
+
+ # Certainly down from here the field is present (value is not None)
+ # so we should only return True from now on
+
+
self._sanity_checks(children_validation_specs=children_validation_specs,
+ field_type=field_type,
+ full_field_path=full_field_path,
+ regexp=regexp,
+ custom_validation=custom_validation,
+ value=value)
+
+ if regexp:
+ self._validate_regexp(full_field_path, regexp, value)
+ elif field_type == 'dict':
+ if not isinstance(value, dict):
+ raise GcpFieldValidationException(
+ "The field '{}' should be dictionary type according to "
+ "specification '{}' but it is '{}'".
+ format(full_field_path, validation_spec, value))
+ if children_validation_specs is None:
+ self.log.debug(
+ "The dict field '{}' has no nested fields defined in the "
+ "specification '{}'. That's perfectly ok - it's content
will "
+ "not be validated."
+ .format(full_field_path, validation_spec))
+ else:
+ self._validate_dict(children_validation_specs,
full_field_path, value)
+ elif field_type == 'union':
+ if not children_validation_specs:
+ raise GcpValidationSpecificationException(
+ "The union field '{}' has no nested fields "
+ "defined in specification '{}'. Unions should have at
least one "
+ "nested field defined.".format(full_field_path,
validation_spec))
+ self._validate_union(children_validation_specs, full_field_path,
+ dictionary_to_validate)
+ elif custom_validation:
+ try:
+ custom_validation(value)
+ except Exception as e:
+ raise GcpFieldValidationException(
+ "Error while validating custom field '{}' specified by
'{}': '{}'".
+ format(full_field_path, validation_spec, e))
+ elif field_type is None:
+ self.log.debug("The type of field '{}' is not specified in '{}'. "
+ "Not validating its content.".
+ format(full_field_path, validation_spec))
+ else:
+ raise GcpValidationSpecificationException(
+ "The field '{}' is of type '{}' in specification '{}'."
+ "This type is unknown to validation!".format(
+ full_field_path, field_type, validation_spec))
+ return True
+
+ def validate(self, body_to_validate):
+ """
+ Validates if the body (dictionary) follows specification that the
validator was
+ instantiated with. Raises ValidationSpecificationException or
+ ValidationFieldException in case of problems with specification or the
+ body not conforming to the specification respectively.
+ :param body_to_validate: body that must follow the specification
+ :type body_to_validate: dict
+ :return: None
+ """
+ try:
+ for validation_spec in self._validation_specs:
+ self._validate_field(validation_spec=validation_spec,
+ dictionary_to_validate=body_to_validate)
+ except GcpFieldValidationException as e:
+ raise GcpFieldValidationException(
+ "There was an error when validating: body '{}': '{}'".
+ format(body_to_validate, e))
+ all_field_names = [spec['name'] for spec in self._validation_specs
+ if spec.get('type') != 'union' and
+ spec.get('api_version') != self._api_version]
+ all_union_fields = [spec for spec in self._validation_specs
+ if spec.get('type') == 'union']
+ for union_field in all_union_fields:
+ all_field_names.extend(
+ [nested_union_spec['name'] for nested_union_spec in
union_field['fields']
+ if nested_union_spec.get('type') != 'union' and
+ nested_union_spec.get('api_version') != self._api_version])
+ for field_name in body_to_validate.keys():
+ if field_name not in all_field_names:
+ self.log.warning(
+ "The field '{}' is in the body, but is not specified in
the "
+ "validation specification '{}'. "
+ "This might be because you are using newer API version and
"
+ "new field names defined for that version. Then the
warning "
+ "can be safely ignored, or you might want to upgrade the
operator"
+ "to the version that supports the new API version.".format(
+ field_name, self._validation_specs))
diff --git a/docs/howto/operator.rst b/docs/howto/operator.rst
index 0d973a391c..549d677570 100644
--- a/docs/howto/operator.rst
+++ b/docs/howto/operator.rst
@@ -102,6 +102,62 @@ to execute a BigQuery load job.
:start-after: [START howto_operator_gcs_to_bq]
:end-before: [END howto_operator_gcs_to_bq]
+GceInstanceStartOperator
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+Allows to start an existing Google Compute Engine instance.
+
+In this example parameter values are extracted from Airflow variables.
+Moreover, the ``default_args`` dict is used to pass common arguments to all
operators in a single DAG.
+
+.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py
+ :language: python
+ :start-after: [START howto_operator_gce_args]
+ :end-before: [END howto_operator_gce_args]
+
+
+Define the :class:`~airflow.contrib.operators.gcp_compute_operator
+.GceInstanceStartOperator` by passing the required arguments to the
constructor.
+
+.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_gce_start]
+ :end-before: [END howto_operator_gce_start]
+
+GceInstanceStopOperator
+^^^^^^^^^^^^^^^^^^^^^^^
+
+Allows to stop an existing Google Compute Engine instance.
+
+For parameter definition take a look at
:class:`~airflow.contrib.operators.gcp_compute_operator.GceInstanceStartOperator`
above.
+
+Define the :class:`~airflow.contrib.operators.gcp_compute_operator
+.GceInstanceStopOperator` by passing the required arguments to the constructor.
+
+.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_gce_stop]
+ :end-before: [END howto_operator_gce_stop]
+
+GceSetMachineTypeOperator
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Allows to change the machine type for a stopped instance to the specified
machine type.
+
+For parameter definition take a look at
:class:`~airflow.contrib.operators.gcp_compute_operator.GceInstanceStartOperator`
above.
+
+Define the :class:`~airflow.contrib.operators.gcp_compute_operator
+.GceSetMachineTypeOperator` by passing the required arguments to the
constructor.
+
+.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_gce_set_machine_type]
+ :end-before: [END howto_operator_gce_set_machine_type]
+
+
GcfFunctionDeleteOperator
^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/docs/integration.rst b/docs/integration.rst
index 6ef7bd8398..3a5a3c3e05 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -457,6 +457,37 @@ BigQueryHook
.. autoclass:: airflow.contrib.hooks.bigquery_hook.BigQueryHook
:members:
+Compute Engine
+''''''''''''''
+
+Compute Engine Operators
+""""""""""""""""""""""""
+
+- :ref:`GceInstanceStartOperator` : start an existing Google Compute Engine
instance.
+- :ref:`GceInstanceStopOperator` : stop an existing Google Compute Engine
instance.
+- :ref:`GceSetMachineTypeOperator` : change the machine type for a stopped
instance.
+
+.. _GceInstanceStartOperator:
+
+GceInstanceStartOperator
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. autoclass::
airflow.contrib.operators.gcp_compute_operator.GceInstanceStartOperator
+
+.. _GceInstanceStopOperator:
+
+GceInstanceStopOperator
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. autoclass::
airflow.contrib.operators.gcp_compute_operator.GceInstanceStopOperator
+
+.. _GceSetMachineTypeOperator:
+
+GceSetMachineTypeOperator
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. autoclass::
airflow.contrib.operators.gcp_compute_operator.GceSetMachineTypeOperator
+
Cloud Functions
'''''''''''''''
diff --git a/tests/contrib/operators/test_gcp_compute_operator.py
b/tests/contrib/operators/test_gcp_compute_operator.py
new file mode 100644
index 0000000000..449c4e015f
--- /dev/null
+++ b/tests/contrib/operators/test_gcp_compute_operator.py
@@ -0,0 +1,377 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import ast
+import unittest
+
+from airflow import AirflowException, configuration
+from airflow.contrib.operators.gcp_compute_operator import
GceInstanceStartOperator, \
+ GceInstanceStopOperator, GceSetMachineTypeOperator
+from airflow.models import TaskInstance, DAG
+from airflow.utils import timezone
+
+try:
+ # noinspection PyProtectedMember
+ from unittest import mock
+except ImportError:
+ try:
+ import mock
+ except ImportError:
+ mock = None
+
+PROJECT_ID = 'project-id'
+LOCATION = 'zone'
+RESOURCE_ID = 'resource-id'
+SHORT_MACHINE_TYPE_NAME = 'n1-machine-type'
+SET_MACHINE_TYPE_BODY = {
+ 'machineType': 'zones/{}/machineTypes/{}'.format(LOCATION,
SHORT_MACHINE_TYPE_NAME)
+}
+
+DEFAULT_DATE = timezone.datetime(2017, 1, 1)
+
+
+class GceInstanceStartTest(unittest.TestCase):
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_instance_start(self, mock_hook):
+ mock_hook.return_value.start_instance.return_value = True
+ op = GceInstanceStartOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=RESOURCE_ID,
+ task_id='id'
+ )
+ result = op.execute(None)
+ mock_hook.assert_called_once_with(api_version='v1',
+ gcp_conn_id='google_cloud_default')
+ mock_hook.return_value.start_instance.assert_called_once_with(
+ PROJECT_ID, LOCATION, RESOURCE_ID
+ )
+ self.assertTrue(result)
+
+ # Setting all of the operator's input parameters as templated dag_ids
+ # (could be anything else) just to test if the templating works for all
fields
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_instance_start_with_templates(self, mock_hook):
+ dag_id = 'test_dag_id'
+ configuration.load_test_config()
+ args = {
+ 'start_date': DEFAULT_DATE
+ }
+ self.dag = DAG(dag_id, default_args=args)
+ op = GceInstanceStartOperator(
+ project_id='{{ dag.dag_id }}',
+ zone='{{ dag.dag_id }}',
+ resource_id='{{ dag.dag_id }}',
+ gcp_conn_id='{{ dag.dag_id }}',
+ api_version='{{ dag.dag_id }}',
+ task_id='id',
+ dag=self.dag
+ )
+ ti = TaskInstance(op, DEFAULT_DATE)
+ ti.render_templates()
+ self.assertEqual(dag_id, getattr(op, 'project_id'))
+ self.assertEqual(dag_id, getattr(op, 'zone'))
+ self.assertEqual(dag_id, getattr(op, 'resource_id'))
+ self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
+ self.assertEqual(dag_id, getattr(op, 'api_version'))
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_start_should_throw_ex_when_missing_project_id(self, mock_hook):
+ with self.assertRaises(AirflowException) as cm:
+ op = GceInstanceStartOperator(
+ project_id="",
+ zone=LOCATION,
+ resource_id=RESOURCE_ID,
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ self.assertIn("The required parameter 'project_id' is missing",
str(err))
+ mock_hook.assert_not_called()
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_start_should_throw_ex_when_missing_zone(self, mock_hook):
+ with self.assertRaises(AirflowException) as cm:
+ op = GceInstanceStartOperator(
+ project_id=PROJECT_ID,
+ zone="",
+ resource_id=RESOURCE_ID,
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ self.assertIn("The required parameter 'zone' is missing", str(err))
+ mock_hook.assert_not_called()
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_start_should_throw_ex_when_missing_resource_id(self, mock_hook):
+ with self.assertRaises(AirflowException) as cm:
+ op = GceInstanceStartOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id="",
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ self.assertIn("The required parameter 'resource_id' is missing",
str(err))
+ mock_hook.assert_not_called()
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_instance_stop(self, mock_hook):
+ mock_hook.return_value.stop_instance.return_value = True
+ op = GceInstanceStopOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=RESOURCE_ID,
+ task_id='id'
+ )
+ result = op.execute(None)
+ mock_hook.assert_called_once_with(api_version='v1',
+ gcp_conn_id='google_cloud_default')
+ mock_hook.return_value.stop_instance.assert_called_once_with(
+ PROJECT_ID, LOCATION, RESOURCE_ID
+ )
+ self.assertTrue(result)
+
+ # Setting all of the operator's input parameters as templated dag_ids
+ # (could be anything else) just to test if the templating works for all
fields
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_instance_stop_with_templates(self, mock_hook):
+ dag_id = 'test_dag_id'
+ configuration.load_test_config()
+ args = {
+ 'start_date': DEFAULT_DATE
+ }
+ self.dag = DAG(dag_id, default_args=args)
+ op = GceInstanceStopOperator(
+ project_id='{{ dag.dag_id }}',
+ zone='{{ dag.dag_id }}',
+ resource_id='{{ dag.dag_id }}',
+ gcp_conn_id='{{ dag.dag_id }}',
+ api_version='{{ dag.dag_id }}',
+ task_id='id',
+ dag=self.dag
+ )
+ ti = TaskInstance(op, DEFAULT_DATE)
+ ti.render_templates()
+ self.assertEqual(dag_id, getattr(op, 'project_id'))
+ self.assertEqual(dag_id, getattr(op, 'zone'))
+ self.assertEqual(dag_id, getattr(op, 'resource_id'))
+ self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
+ self.assertEqual(dag_id, getattr(op, 'api_version'))
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_stop_should_throw_ex_when_missing_project_id(self, mock_hook):
+ with self.assertRaises(AirflowException) as cm:
+ op = GceInstanceStopOperator(
+ project_id="",
+ zone=LOCATION,
+ resource_id=RESOURCE_ID,
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ self.assertIn("The required parameter 'project_id' is missing",
str(err))
+ mock_hook.assert_not_called()
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_stop_should_throw_ex_when_missing_zone(self, mock_hook):
+ with self.assertRaises(AirflowException) as cm:
+ op = GceInstanceStopOperator(
+ project_id=PROJECT_ID,
+ zone="",
+ resource_id=RESOURCE_ID,
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ self.assertIn("The required parameter 'zone' is missing", str(err))
+ mock_hook.assert_not_called()
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_stop_should_throw_ex_when_missing_resource_id(self, mock_hook):
+ with self.assertRaises(AirflowException) as cm:
+ op = GceInstanceStopOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id="",
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ self.assertIn("The required parameter 'resource_id' is missing",
str(err))
+ mock_hook.assert_not_called()
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_set_machine_type(self, mock_hook):
+ mock_hook.return_value.set_machine_type.return_value = True
+ op = GceSetMachineTypeOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=RESOURCE_ID,
+ body=SET_MACHINE_TYPE_BODY,
+ task_id='id'
+ )
+ result = op.execute(None)
+ mock_hook.assert_called_once_with(api_version='v1',
+ gcp_conn_id='google_cloud_default')
+ mock_hook.return_value.set_machine_type.assert_called_once_with(
+ PROJECT_ID, LOCATION, RESOURCE_ID, SET_MACHINE_TYPE_BODY
+ )
+ self.assertTrue(result)
+
+ # Setting all of the operator's input parameters as templated dag_ids
+ # (could be anything else) just to test if the templating works for all
fields
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_set_machine_type_with_templates(self, mock_hook):
+ dag_id = 'test_dag_id'
+ configuration.load_test_config()
+ args = {
+ 'start_date': DEFAULT_DATE
+ }
+ self.dag = DAG(dag_id, default_args=args)
+ op = GceSetMachineTypeOperator(
+ project_id='{{ dag.dag_id }}',
+ zone='{{ dag.dag_id }}',
+ resource_id='{{ dag.dag_id }}',
+ body={},
+ gcp_conn_id='{{ dag.dag_id }}',
+ api_version='{{ dag.dag_id }}',
+ task_id='id',
+ dag=self.dag
+ )
+ ti = TaskInstance(op, DEFAULT_DATE)
+ ti.render_templates()
+ self.assertEqual(dag_id, getattr(op, 'project_id'))
+ self.assertEqual(dag_id, getattr(op, 'zone'))
+ self.assertEqual(dag_id, getattr(op, 'resource_id'))
+ self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
+ self.assertEqual(dag_id, getattr(op, 'api_version'))
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_set_machine_type_should_throw_ex_when_missing_project_id(self,
mock_hook):
+ with self.assertRaises(AirflowException) as cm:
+ op = GceSetMachineTypeOperator(
+ project_id="",
+ zone=LOCATION,
+ resource_id=RESOURCE_ID,
+ body=SET_MACHINE_TYPE_BODY,
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ self.assertIn("The required parameter 'project_id' is missing",
str(err))
+ mock_hook.assert_not_called()
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_set_machine_type_should_throw_ex_when_missing_zone(self,
mock_hook):
+ with self.assertRaises(AirflowException) as cm:
+ op = GceSetMachineTypeOperator(
+ project_id=PROJECT_ID,
+ zone="",
+ resource_id=RESOURCE_ID,
+ body=SET_MACHINE_TYPE_BODY,
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ self.assertIn("The required parameter 'zone' is missing", str(err))
+ mock_hook.assert_not_called()
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_set_machine_type_should_throw_ex_when_missing_resource_id(self,
mock_hook):
+ with self.assertRaises(AirflowException) as cm:
+ op = GceSetMachineTypeOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id="",
+ body=SET_MACHINE_TYPE_BODY,
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ self.assertIn("The required parameter 'resource_id' is missing",
str(err))
+ mock_hook.assert_not_called()
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
+ def test_set_machine_type_should_throw_ex_when_missing_machine_type(self,
mock_hook):
+ with self.assertRaises(AirflowException) as cm:
+ op = GceSetMachineTypeOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=RESOURCE_ID,
+ body={},
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ self.assertIn(
+ "The required body field 'machineType' is missing. Please add
it.", str(err))
+ mock_hook.assert_called_once_with(api_version='v1',
+ gcp_conn_id='google_cloud_default')
+
+ MOCK_OP_RESPONSE = "{'kind': 'compute#operation', 'id':
'8529919847974922736', " \
+ "'name': " \
+
"'operation-1538578207537-577542784f769-7999ab71-94f9ec1d', " \
+ "'zone':
'https://www.googleapis.com/compute/v1/projects/polidea" \
+ "-airflow/zones/europe-west3-b', 'operationType': " \
+ "'setMachineType', 'targetLink': " \
+
"'https://www.googleapis.com/compute/v1/projects/polidea-airflow" \
+ "/zones/europe-west3-b/instances/pa-1', 'targetId': " \
+ "'2480086944131075860', 'status': 'DONE', 'user': " \
+ "'[email protected]', "
\
+ "'progress': 100, 'insertTime':
'2018-10-03T07:50:07.951-07:00', "\
+ "'startTime': '2018-10-03T07:50:08.324-07:00',
'endTime': " \
+ "'2018-10-03T07:50:08.484-07:00', 'error': {'errors':
[{'code': " \
+ "'UNSUPPORTED_OPERATION', 'message': \"Machine type
with name " \
+ "'machine-type-1' does not exist in zone
'europe-west3-b'.\"}]}, "\
+ "'httpErrorStatusCode': 400, 'httpErrorMessage': 'BAD
REQUEST', " \
+ "'selfLink': " \
+
"'https://www.googleapis.com/compute/v1/projects/polidea-airflow" \
+
"/zones/europe-west3-b/operations/operation-1538578207537" \
+ "-577542784f769-7999ab71-94f9ec1d'} "
+
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook'
+ '._check_operation_status')
+ @mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook'
+ '._execute_set_machine_type')
+
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook.get_conn')
+ def test_set_machine_type_should_handle_and_trim_gce_error(
+ self, get_conn, _execute_set_machine_type,
_check_operation_status):
+ get_conn.return_value = {}
+ _execute_set_machine_type.return_value = {"name": "test-operation"}
+ _check_operation_status.return_value =
ast.literal_eval(self.MOCK_OP_RESPONSE)
+ with self.assertRaises(AirflowException) as cm:
+ op = GceSetMachineTypeOperator(
+ project_id=PROJECT_ID,
+ zone=LOCATION,
+ resource_id=RESOURCE_ID,
+ body=SET_MACHINE_TYPE_BODY,
+ task_id='id'
+ )
+ op.execute(None)
+ err = cm.exception
+ _check_operation_status.assert_called_once_with(
+ {}, "test-operation", PROJECT_ID, LOCATION)
+ _execute_set_machine_type.assert_called_once_with(
+ PROJECT_ID, LOCATION, RESOURCE_ID, SET_MACHINE_TYPE_BODY)
+ # Checking the full message was sometimes failing due to different
order
+ # of keys in the serialized JSON
+ self.assertIn("400 BAD REQUEST: {", str(err)) # checking the square
bracket trim
+ self.assertIn("UNSUPPORTED_OPERATION", str(err))
diff --git a/tests/contrib/operators/test_gcp_function_operator.py
b/tests/contrib/operators/test_gcp_function_operator.py
index d7585ae66f..4192560dd9 100644
--- a/tests/contrib/operators/test_gcp_function_operator.py
+++ b/tests/contrib/operators/test_gcp_function_operator.py
@@ -519,6 +519,23 @@ def test_valid_trigger_union_field(self, trigger,
mock_hook):
)
mock_hook.reset_mock()
+ @mock.patch('airflow.contrib.operators.gcp_function_operator.GcfHook')
+ def test_extra_parameter(self, mock_hook):
+ mock_hook.return_value.list_functions.return_value = []
+ mock_hook.return_value.create_new_function.return_value = True
+ body = deepcopy(VALID_BODY)
+ body['extra_parameter'] = 'extra'
+ op = GcfFunctionDeployOperator(
+ project_id="test_project_id",
+ location="test_region",
+ body=body,
+ task_id="id"
+ )
+ op.execute(None)
+ mock_hook.assert_called_once_with(api_version='v1',
+ gcp_conn_id='google_cloud_default')
+ mock_hook.reset_mock()
+
class GcfFunctionDeleteTest(unittest.TestCase):
_FUNCTION_NAME =
'projects/project_name/locations/project_location/functions' \
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
> Basic operators for Google Compute Engine
> -----------------------------------------
>
> Key: AIRFLOW-3078
> URL: https://issues.apache.org/jira/browse/AIRFLOW-3078
> Project: Apache Airflow
> Issue Type: New Feature
> Components: contrib, gcp
> Reporter: Jarek Potiuk
> Assignee: Jarek Potiuk
> Priority: Trivial
>
> In order to be able to interact with raw Google Compute Engine, we need an
> operator that should be able to:
> For managing individual machines:
> * Start Instance:
> ([https://cloud.google.com/compute/docs/reference/rest/v1/instances/start])
> * Set Machine Type
> ([https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMachineType])
>
> * Stop Instance:
> ([https://cloud.google.com/compute/docs/reference/rest/v1/instances/stop])
>
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)