ashb closed pull request #4126: [AIRFLOW-2524] More AWS SageMaker operators, 
sensors for model, endpoint-config and endpoint
URL: https://github.com/apache/incubator-airflow/pull/4126
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/aws_hook.py 
b/airflow/contrib/hooks/aws_hook.py
index 265d4e56af..9d4a73e1c0 100644
--- a/airflow/contrib/hooks/aws_hook.py
+++ b/airflow/contrib/hooks/aws_hook.py
@@ -183,7 +183,7 @@ def get_session(self, region_name=None):
     def get_credentials(self, region_name=None):
         """Get the underlying `botocore.Credentials` object.
 
-        This contains the attributes: access_key, secret_key and token.
+        This contains the following authentication attributes: access_key, 
secret_key and token.
         """
         session, _ = self._get_credentials(region_name)
         # Credentials are refreshable, so accessing your access key and
@@ -193,8 +193,8 @@ def get_credentials(self, region_name=None):
 
     def expand_role(self, role):
         """
-        Expand an IAM role name to an IAM role ARN. If role is already an IAM 
ARN,
-        no change is made.
+        If the IAM role is a role name, get the Amazon Resource Name (ARN) for 
the role.
+        If IAM role is already an IAM role ARN, no change is made.
 
         :param role: IAM role name or ARN
         :return: IAM role ARN
diff --git a/airflow/contrib/operators/sagemaker_base_operator.py 
b/airflow/contrib/operators/sagemaker_base_operator.py
index cf1e59387a..08d6d0eb6a 100644
--- a/airflow/contrib/operators/sagemaker_base_operator.py
+++ b/airflow/contrib/operators/sagemaker_base_operator.py
@@ -79,7 +79,7 @@ def parse_config_integers(self):
             self.parse_integer(self.config, field)
 
     def expand_role(self):
-        raise NotImplementedError('Please implement expand_role() in sub 
class!')
+        pass
 
     def preprocess_config(self):
         self.log.info(
diff --git a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py 
b/airflow/contrib/operators/sagemaker_endpoint_config_operator.py
new file mode 100644
index 0000000000..a94cf30229
--- /dev/null
+++ b/airflow/contrib/operators/sagemaker_endpoint_config_operator.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.contrib.operators.sagemaker_base_operator import 
SageMakerBaseOperator
+from airflow.utils.decorators import apply_defaults
+from airflow.exceptions import AirflowException
+
+
+class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
+
+    """
+    Create a SageMaker endpoint config.
+
+    This operator returns The ARN of the endpoint config created in Amazon 
SageMaker
+
+    :param config: The configuration necessary to create an endpoint config.
+
+        For details of the configuration parameter, See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config
+    :type config: dict
+    :param aws_conn_id: The AWS connection ID to use.
+    :type aws_conn_id: str
+    """  # noqa: E501
+
+    integer_fields = [
+        ['ProductionVariants', 'InitialInstanceCount']
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 config,
+                 *args, **kwargs):
+        super(SageMakerEndpointConfigOperator, self).__init__(config=config,
+                                                              *args, **kwargs)
+
+        self.config = config
+
+    def execute(self, context):
+        self.preprocess_config()
+
+        self.log.info('Creating SageMaker Endpoint Config %s.', 
self.config['EndpointConfigName'])
+        response = self.hook.create_endpoint_config(self.config)
+        if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+            raise AirflowException(
+                'Sagemaker endpoint config creation failed: %s' % response)
+        else:
+            return {
+                'EndpointConfig': self.hook.describe_endpoint_config(
+                    self.config['EndpointConfigName']
+                )
+            }
diff --git a/airflow/contrib/operators/sagemaker_endpoint_operator.py 
b/airflow/contrib/operators/sagemaker_endpoint_operator.py
new file mode 100644
index 0000000000..4094fbe59e
--- /dev/null
+++ b/airflow/contrib/operators/sagemaker_endpoint_operator.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.contrib.hooks.aws_hook import AwsHook
+from airflow.contrib.operators.sagemaker_base_operator import 
SageMakerBaseOperator
+from airflow.utils.decorators import apply_defaults
+from airflow.exceptions import AirflowException
+
+
+class SageMakerEndpointOperator(SageMakerBaseOperator):
+
+    """
+    Create a SageMaker endpoint.
+
+    This operator returns The ARN of the endpoint created in Amazon SageMaker
+
+    :param config:
+        The configuration necessary to create an endpoint.
+
+        If you need to create a SageMaker endpoint based on an existed 
SageMaker model and an existed SageMaker
+        endpoint config,
+
+            config = endpoint_configuration;
+
+        If you need to create all of SageMaker model, SageMaker 
endpoint-config and SageMaker endpoint,
+
+            config = {
+                'Model': model_configuration,
+
+                'EndpointConfig': endpoint_config_configuration,
+
+                'Endpoint': endpoint_configuration
+            }
+
+        For details of the configuration parameter of model_configuration, See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model
+
+        For details of the configuration parameter of 
endpoint_config_configuration, See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config
+
+        For details of the configuration parameter of endpoint_configuration, 
See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint
+    :type config: dict
+    :param aws_conn_id: The AWS connection ID to use.
+    :type aws_conn_id: str
+    :param wait_for_completion: Whether the operator should wait until the 
endpoint creation finishes.
+    :type wait_for_completion: bool
+    :param check_interval: If wait is set to True, this is the time interval, 
in seconds, that this operation waits
+        before polling the status of the endpoint creation.
+    :type check_interval: int
+    :param max_ingestion_time: If wait is set to True, this operation fails if 
the endpoint creation doesn't finish
+        within max_ingestion_time seconds. If you set this parameter to None 
it never times out.
+    :type max_ingestion_time: int
+    :param operation: Whether to create an endpoint or update an endpoint. 
Must be either 'create or 'update'.
+    :type operation: str
+    """  # noqa: E501
+
+    @apply_defaults
+    def __init__(self,
+                 config,
+                 wait_for_completion=True,
+                 check_interval=30,
+                 max_ingestion_time=None,
+                 operation='create',
+                 *args, **kwargs):
+        super(SageMakerEndpointOperator, self).__init__(config=config,
+                                                        *args, **kwargs)
+
+        self.config = config
+        self.wait_for_completion = wait_for_completion
+        self.check_interval = check_interval
+        self.max_ingestion_time = max_ingestion_time
+        self.operation = operation.lower()
+        if self.operation not in ['create', 'update']:
+            raise ValueError('Invalid value! Argument operation has to be one 
of "create" and "update"')
+        self.create_integer_fields()
+
+    def create_integer_fields(self):
+        if 'EndpointConfig' in self.config:
+            self.integer_fields = [
+                ['EndpointConfig', 'ProductionVariants', 
'InitialInstanceCount']
+            ]
+
+    def expand_role(self):
+        if 'Model' not in self.config:
+            return
+        hook = AwsHook(self.aws_conn_id)
+        config = self.config['Model']
+        if 'ExecutionRoleArn' in config:
+            config['ExecutionRoleArn'] = 
hook.expand_role(config['ExecutionRoleArn'])
+
+    def execute(self, context):
+        self.preprocess_config()
+
+        model_info = self.config.get('Model')
+        endpoint_config_info = self.config.get('EndpointConfig')
+        endpoint_info = self.config.get('Endpoint', self.config)
+
+        if model_info:
+            self.log.info('Creating SageMaker model %s.', 
model_info['ModelName'])
+            self.hook.create_model(model_info)
+
+        if endpoint_config_info:
+            self.log.info('Creating endpoint config %s.', 
endpoint_config_info['EndpointConfigName'])
+            self.hook.create_endpoint_config(endpoint_config_info)
+
+        if self.operation == 'create':
+            sagemaker_operation = self.hook.create_endpoint
+            log_str = 'Creating'
+        elif self.operation == 'update':
+            sagemaker_operation = self.hook.update_endpoint
+            log_str = 'Updating'
+        else:
+            raise ValueError('Invalid value! Argument operation has to be one 
of "create" and "update"')
+
+        self.log.info('{} SageMaker endpoint {}.'.format(log_str, 
endpoint_info['EndpointName']))
+
+        response = sagemaker_operation(
+            endpoint_info,
+            wait_for_completion=self.wait_for_completion,
+            check_interval=self.check_interval,
+            max_ingestion_time=self.max_ingestion_time
+        )
+        if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+            raise AirflowException(
+                'Sagemaker endpoint creation failed: %s' % response)
+        else:
+            return {
+                'EndpointConfig': self.hook.describe_endpoint_config(
+                    endpoint_info['EndpointConfigName']
+                ),
+                'Endpoint': self.hook.describe_endpoint(
+                    endpoint_info['EndpointName']
+                )
+            }
diff --git a/airflow/contrib/operators/sagemaker_model_operator.py 
b/airflow/contrib/operators/sagemaker_model_operator.py
new file mode 100644
index 0000000000..4332daa9db
--- /dev/null
+++ b/airflow/contrib/operators/sagemaker_model_operator.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.contrib.hooks.aws_hook import AwsHook
+from airflow.contrib.operators.sagemaker_base_operator import 
SageMakerBaseOperator
+from airflow.utils.decorators import apply_defaults
+from airflow.exceptions import AirflowException
+
+
+class SageMakerModelOperator(SageMakerBaseOperator):
+
+    """
+    Create a SageMaker model.
+
+    This operator returns The ARN of the model created in Amazon SageMaker
+
+    :param config: The configuration necessary to create a model.
+
+        For details of the configuration parameter, See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model
+    :type config: dict
+    :param aws_conn_id: The AWS connection ID to use.
+    :type aws_conn_id: str
+    """  # noqa: E501
+
+    @apply_defaults
+    def __init__(self,
+                 config,
+                 *args, **kwargs):
+        super(SageMakerModelOperator, self).__init__(config=config,
+                                                     *args, **kwargs)
+
+        self.config = config
+
+    def expand_role(self):
+        if 'ExecutionRoleArn' in self.config:
+            hook = AwsHook(self.aws_conn_id)
+            self.config['ExecutionRoleArn'] = 
hook.expand_role(self.config['ExecutionRoleArn'])
+
+    def execute(self, context):
+        self.preprocess_config()
+
+        self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
+        response = self.hook.create_model(self.config)
+        if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+            raise AirflowException('Sagemaker model creation failed: %s' % 
response)
+        else:
+            return {
+                'Model': self.hook.describe_model(
+                    self.config['ModelName']
+                )
+            }
diff --git a/airflow/contrib/operators/sagemaker_training_operator.py 
b/airflow/contrib/operators/sagemaker_training_operator.py
index 69036925f3..d90f7e6555 100644
--- a/airflow/contrib/operators/sagemaker_training_operator.py
+++ b/airflow/contrib/operators/sagemaker_training_operator.py
@@ -29,23 +29,26 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
 
     This operator returns The ARN of the training job created in Amazon 
SageMaker.
 
-    :param config: The configuration necessary to start a training job 
(templated)
+    :param config: The configuration necessary to start a training job 
(templated).
+
+        For details of the configuration parameter, See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
     :type config: dict
     :param aws_conn_id: The AWS connection ID to use.
     :type aws_conn_id: str
-    :param wait_for_completion: if the operator should block until training 
job finishes
+    :param wait_for_completion: If wait is set to True, the time interval, in 
seconds,
+        that the operation waits to check the status of the training job.
     :type wait_for_completion: bool
     :param print_log: if the operator should print the cloudwatch log during 
training
     :type print_log: bool
     :param check_interval: if wait is set to be true, this is the time interval
         in seconds which the operator will check the status of the training job
     :type check_interval: int
-    :param max_ingestion_time: if wait is set to be true, the operator will 
fail
-        if the training job hasn't finish within the max_ingestion_time in 
seconds
-        (Caution: be careful to set this parameters because training can take 
very long)
-        Setting it to None implies no timeout.
+    :param max_ingestion_time: If wait is set to True, the operation fails if 
the training job
+        doesn't finish within max_ingestion_time seconds. If you set this 
parameter to None,
+        the operation does not timeout.
     :type max_ingestion_time: int
-    """
+    """  # noqa: E501
 
     integer_fields = [
         ['ResourceConfig', 'InstanceCount'],
@@ -87,8 +90,7 @@ def execute(self, context):
             max_ingestion_time=self.max_ingestion_time
         )
         if response['ResponseMetadata']['HTTPStatusCode'] != 200:
-            raise AirflowException(
-                'Sagemaker Training Job creation failed: %s' % response)
+            raise AirflowException('Sagemaker Training Job creation failed: 
%s' % response)
         else:
             return {
                 'Training': self.hook.describe_training_job(
diff --git a/airflow/contrib/operators/sagemaker_transform_operator.py 
b/airflow/contrib/operators/sagemaker_transform_operator.py
index 7be570cdac..9d1c665f9e 100644
--- a/airflow/contrib/operators/sagemaker_transform_operator.py
+++ b/airflow/contrib/operators/sagemaker_transform_operator.py
@@ -29,26 +29,39 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
 
     This operator returns The ARN of the model created in Amazon SageMaker.
 
-    :param config: The configuration necessary to start a transform job 
(templated)
+    :param config: The configuration necessary to start a transform job 
(templated).
+
+        If you need to create a SageMaker transform job based on an existed 
SageMaker model,
+
+            config = transform_config;
+
+        If you need to create both SageMaker model and SageMaker Transform job,
+
+            config = {
+                'Model': model_config,
+
+                'Transform': transform_config
+            }
+
+        For details of the configuration parameter of transform_config, See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_transform_job
+
+        For details of the configuration parameter of model_config, See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model
+
     :type config: dict
-    :param model_config:
-        The configuration necessary to create a SageMaker model, the default 
is none
-        which means the SageMaker model used for the SageMaker transform job 
already exists.
-        If given, it will be used to create a SageMaker model before creating
-        the SageMaker transform job
-    :type model_config: dict
     :param aws_conn_id: The AWS connection ID to use.
     :type aws_conn_id: string
-    :param wait_for_completion: if the program should keep running until job 
finishes
+    :param wait_for_completion: Set to True to wait until the transform job 
finishes.
     :type wait_for_completion: bool
-    :param check_interval: if wait is set to be true, this is the time interval
-        in seconds which the operator will check the status of the transform 
job
+    :param check_interval: If wait is set to True, the time interval, in 
seconds,
+        that this operation waits to check the status of the transform job.
     :type check_interval: int
-    :param max_ingestion_time: if wait is set to be true, the operator will 
fail
-        if the transform job hasn't finish within the max_ingestion_time in 
seconds
-        (Caution: be careful to set this parameters because transform can take 
very long)
+    :param max_ingestion_time: If wait is set to True, the operation fails
+        if the transform job doesn't finish within max_ingestion_time seconds. 
If you
+        set this parameter to None, the operation does not timeout.
     :type max_ingestion_time: int
-    """
+    """  # noqa: E501
 
     @apply_defaults
     def __init__(self,
diff --git a/airflow/contrib/operators/sagemaker_tuning_operator.py 
b/airflow/contrib/operators/sagemaker_tuning_operator.py
index 94c995072a..dc1282b5b8 100644
--- a/airflow/contrib/operators/sagemaker_tuning_operator.py
+++ b/airflow/contrib/operators/sagemaker_tuning_operator.py
@@ -25,24 +25,27 @@
 
 class SageMakerTuningOperator(SageMakerBaseOperator):
     """
-    Initiate a SageMaker hyper-parameter tuning job.
+    Initiate a SageMaker hyperparameter tuning job.
 
     This operator returns The ARN of the tuning job created in Amazon 
SageMaker.
 
-    :param config: The configuration necessary to start a tuning job 
(templated)
+    :param config: The configuration necessary to start a tuning job 
(templated).
+
+        For details of the configuration parameter, See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_hyper_parameter_tuning_job
     :type config: dict
     :param aws_conn_id: The AWS connection ID to use.
     :type aws_conn_id: str
-    :param wait_for_completion: if the operator should block until tuning job 
finishes
+    :param wait_for_completion: Set to True to wait until the tuning job 
finishes.
     :type wait_for_completion: bool
-    :param check_interval: if wait is set to be true, this is the time interval
-        in seconds which the operator will check the status of the tuning job
+    :param check_interval: If wait is set to True, the time interval, in 
seconds,
+        that this operation waits to check the status of the tuning job.
     :type check_interval: int
-    :param max_ingestion_time: if wait is set to be true, the operator will 
fail
-        if the tuning job hasn't finish within the max_ingestion_time in 
seconds
-        (Caution: be careful to set this parameters because tuning can take 
very long)
+    :param max_ingestion_time: If wait is set to True, the operation fails
+        if the tuning job doesn't finish within max_ingestion_time seconds. If 
you
+        set this parameter to None, the operation does not timeout.
     :type max_ingestion_time: int
-    """
+    """  # noqa: E501
 
     integer_fields = [
         ['HyperParameterTuningJobConfig', 'ResourceLimits', 
'MaxNumberOfTrainingJobs'],
@@ -87,8 +90,7 @@ def execute(self, context):
             max_ingestion_time=self.max_ingestion_time
         )
         if response['ResponseMetadata']['HTTPStatusCode'] != 200:
-            raise AirflowException(
-                'Sagemaker Tuning Job creation failed: %s' % response)
+            raise AirflowException('Sagemaker Tuning Job creation failed: %s' 
% response)
         else:
             return {
                 'Tuning': self.hook.describe_tuning_job(
diff --git a/airflow/contrib/sensors/sagemaker_endpoint_sensor.py 
b/airflow/contrib/sensors/sagemaker_endpoint_sensor.py
new file mode 100644
index 0000000000..ceed9c1009
--- /dev/null
+++ b/airflow/contrib/sensors/sagemaker_endpoint_sensor.py
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
+from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
+from airflow.utils.decorators import apply_defaults
+
+
+class SageMakerEndpointSensor(SageMakerBaseSensor):
+    """
+    Asks for the state of the endpoint state until it reaches a terminal state.
+    If it fails the sensor errors, the task fails.
+
+    :param job_name: job_name of the endpoint instance to check the state of
+    :type job_name: str
+    """
+
+    template_fields = ['endpoint_name']
+    template_ext = ()
+
+    @apply_defaults
+    def __init__(self,
+                 endpoint_name,
+                 *args,
+                 **kwargs):
+        super(SageMakerEndpointSensor, self).__init__(*args, **kwargs)
+        self.endpoint_name = endpoint_name
+
+    def non_terminal_states(self):
+        return SageMakerHook.endpoint_non_terminal_states
+
+    def failed_states(self):
+        return SageMakerHook.failed_states
+
+    def get_sagemaker_response(self):
+        sagemaker = SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+        self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name)
+        return sagemaker.describe_endpoint(self.endpoint_name)
+
+    def get_failed_reason_from_response(self, response):
+        return response['FailureReason']
+
+    def state_from_response(self, response):
+        return response['EndpointStatus']
diff --git a/airflow/hooks/S3_hook.py b/airflow/hooks/S3_hook.py
index c9e81e9d73..f8ccb370f8 100644
--- a/airflow/hooks/S3_hook.py
+++ b/airflow/hooks/S3_hook.py
@@ -71,11 +71,11 @@ def get_bucket(self, bucket_name):
 
     def create_bucket(self, bucket_name, region_name=None):
         """
-        Creates a boto3.S3.Bucket object
+        Creates an Amazon S3 bucket.
 
-        :param bucket_name: the name of the bucket
+        :param bucket_name: The name of the bucket
         :type bucket_name: str
-        :param region__name: the name of the aws region
+        :param region_name: The name of the aws region in which to create the 
bucket.
         :type region_name: str
         """
         s3_conn = self.get_conn()
@@ -428,19 +428,19 @@ def load_file_obj(self,
                       replace=False,
                       encrypt=False):
         """
-        Loads file object to S3
+        Loads a file object to S3
 
-        :param file_obj: file-like object to set as content for the key.
+        :param file_obj: The file-like object to set as the content for the S3 
key.
         :type file_obj: file-like object
         :param key: S3 key that will point to the file
         :type key: str
         :param bucket_name: Name of the bucket in which to store the file
         :type bucket_name: str
-        :param replace: A flag to decide whether or not to overwrite the key
-            if it already exists
+        :param replace: A flag that indicates whether to overwrite the key
+            if it already exists.
         :type replace: bool
-        :param encrypt: If True, the file will be encrypted on the server-side
-            by S3 and will be stored in an encrypted form while at rest in S3.
+        :param encrypt: If True, S3 encrypts the file on the server,
+            and the file is stored in encrypted form at rest in S3.
         :type encrypt: bool
         """
         if not bucket_name:
diff --git a/docs/code.rst b/docs/code.rst
index cca3b5ff77..9e7ac7d272 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -190,9 +190,12 @@ Operators
 .. autoclass:: airflow.contrib.operators.s3_list_operator.S3ListOperator
 .. autoclass:: 
airflow.contrib.operators.s3_to_gcs_operator.S3ToGoogleCloudStorageOperator
 .. autoclass:: 
airflow.contrib.operators.sagemaker_base_operator.SageMakerBaseOperator
+.. autoclass:: 
airflow.contrib.operators.sagemaker_endpoint_operator.SageMakerEndpointOperator
+.. autoclass:: 
airflow.contrib.operators.sagemaker_endpoint_config_operator.SageMakerEndpointConfigOperator
+.. autoclass:: 
airflow.contrib.operators.sagemaker_model_operator.SageMakerModelOperator
 .. autoclass:: 
airflow.contrib.operators.sagemaker_training_operator.SageMakerTrainingOperator
 .. autoclass:: 
airflow.contrib.operators.sagemaker_transform_operator.SageMakerTransformOperator
-.. autoclass:: 
airflow.contrib.operators.sagemaker_tuning_operator.SagemakerTuningOperator
+.. autoclass:: 
airflow.contrib.operators.sagemaker_tuning_operator.SageMakerTuningOperator
 .. autoclass:: 
airflow.contrib.operators.segment_track_event_operator.SegmentTrackEventOperator
 .. autoclass:: airflow.contrib.operators.sftp_operator.SFTPOperator
 .. autoclass:: 
airflow.contrib.operators.slack_webhook_operator.SlackWebhookOperator
@@ -231,6 +234,7 @@ Sensors
 .. autoclass:: airflow.contrib.sensors.qubole_sensor.QuboleSensor
 .. autoclass:: airflow.contrib.sensors.redis_key_sensor.RedisKeySensor
 .. autoclass:: 
airflow.contrib.sensors.sagemaker_base_sensor.SageMakerBaseSensor
+.. autoclass:: 
airflow.contrib.sensors.sagemaker_endpoint_sensor.SageMakerEndpointSensor
 .. autoclass:: 
airflow.contrib.sensors.sagemaker_training_sensor.SageMakerTrainingSensor
 .. autoclass:: 
airflow.contrib.sensors.sagemaker_transform_sensor.SageMakerTransformSensor
 .. autoclass:: 
airflow.contrib.sensors.sagemaker_tuning_sensor.SageMakerTuningSensor
diff --git a/tests/contrib/operators/test_sagemaker_endpoint_config_operator.py 
b/tests/contrib/operators/test_sagemaker_endpoint_config_operator.py
new file mode 100644
index 0000000000..658fd813df
--- /dev/null
+++ b/tests/contrib/operators/test_sagemaker_endpoint_config_operator.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+from airflow import configuration
+from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
+from airflow.contrib.operators.sagemaker_endpoint_config_operator \
+    import SageMakerEndpointConfigOperator
+from airflow.exceptions import AirflowException
+
+model_name = 'test-model-name'
+config_name = 'test-config-name'
+
+create_endpoint_config_params = {
+    'EndpointConfigName': config_name,
+    'ProductionVariants': [
+        {
+            'VariantName': 'AllTraffic',
+            'ModelName': model_name,
+            'InitialInstanceCount': '1',
+            'InstanceType': 'ml.c4.xlarge'
+        }
+    ]
+}
+
+
+class TestSageMakerEndpointConfigOperator(unittest.TestCase):
+
+    def setUp(self):
+        configuration.load_test_config()
+        self.sagemaker = SageMakerEndpointConfigOperator(
+            task_id='test_sagemaker_operator',
+            aws_conn_id='sagemaker_test_id',
+            config=create_endpoint_config_params
+        )
+
+    def test_parse_config_integers(self):
+        self.sagemaker.parse_config_integers()
+        for variant in self.sagemaker.config['ProductionVariants']:
+            self.assertEqual(variant['InitialInstanceCount'],
+                             int(variant['InitialInstanceCount']))
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
+    def test_execute(self, mock_model, mock_client):
+        mock_model.return_value = {
+            'EndpointConfigArn': 'testarn',
+            'ResponseMetadata': {
+                'HTTPStatusCode': 200
+            }
+        }
+        self.sagemaker.execute(None)
+        mock_model.assert_called_once_with(create_endpoint_config_params)
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_model')
+    def test_execute_with_failure(self, mock_model, mock_client):
+        mock_model.return_value = {
+            'EndpointConfigArn': 'testarn',
+            'ResponseMetadata': {
+                'HTTPStatusCode': 200
+            }
+        }
+        self.assertRaises(AirflowException, self.sagemaker.execute, None)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/contrib/operators/test_sagemaker_endpoint_operator.py 
b/tests/contrib/operators/test_sagemaker_endpoint_operator.py
new file mode 100644
index 0000000000..710daceff5
--- /dev/null
+++ b/tests/contrib/operators/test_sagemaker_endpoint_operator.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+from airflow import configuration
+from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
+from airflow.contrib.operators.sagemaker_endpoint_operator \
+    import SageMakerEndpointOperator
+from airflow.exceptions import AirflowException
+
+role = 'arn:aws:iam:role/test-role'
+bucket = 'test-bucket'
+image = 'test-image'
+output_url = 's3://{}/test/output'.format(bucket)
+model_name = 'test-model-name'
+config_name = 'test-endpoint-config-name'
+endpoint_name = 'test-endpoint-name'
+
+create_model_params = {
+    'ModelName': model_name,
+    'PrimaryContainer': {
+        'Image': image,
+        'ModelDataUrl': output_url,
+    },
+    'ExecutionRoleArn': role
+}
+
+create_endpoint_config_params = {
+    'EndpointConfigName': config_name,
+    'ProductionVariants': [
+        {
+            'VariantName': 'AllTraffic',
+            'ModelName': model_name,
+            'InitialInstanceCount': '1',
+            'InstanceType': 'ml.c4.xlarge'
+        }
+    ]
+}
+
+create_endpoint_params = {
+    'EndpointName': endpoint_name,
+    'EndpointConfigName': config_name
+}
+
+config = {
+    'Model': create_model_params,
+    'EndpointConfig': create_endpoint_config_params,
+    'Endpoint': create_endpoint_params
+}
+
+
+class TestSageMakerEndpointOperator(unittest.TestCase):
+
+    def setUp(self):
+        configuration.load_test_config()
+        self.sagemaker = SageMakerEndpointOperator(
+            task_id='test_sagemaker_operator',
+            aws_conn_id='sagemaker_test_id',
+            config=config,
+            wait_for_completion=False,
+            check_interval=5,
+            operation='create'
+        )
+
+    def test_parse_config_integers(self):
+        self.sagemaker.parse_config_integers()
+        for variant in 
self.sagemaker.config['EndpointConfig']['ProductionVariants']:
+            self.assertEqual(variant['InitialInstanceCount'],
+                             int(variant['InitialInstanceCount']))
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_model')
+    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
+    @mock.patch.object(SageMakerHook, 'create_endpoint')
+    def test_execute(self, mock_endpoint, mock_endpoint_config,
+                     mock_model, mock_client):
+        mock_endpoint.return_value = {'EndpointArn': 'testarn',
+                                      'ResponseMetadata':
+                                      {'HTTPStatusCode': 200}}
+        self.sagemaker.execute(None)
+        mock_model.assert_called_once_with(create_model_params)
+        
mock_endpoint_config.assert_called_once_with(create_endpoint_config_params)
+        mock_endpoint.assert_called_once_with(create_endpoint_params,
+                                              wait_for_completion=False,
+                                              check_interval=5,
+                                              max_ingestion_time=None
+                                              )
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_model')
+    @mock.patch.object(SageMakerHook, 'create_endpoint_config')
+    @mock.patch.object(SageMakerHook, 'create_endpoint')
+    def test_execute_with_failure(self, mock_endpoint, mock_endpoint_config,
+                                  mock_model, mock_client):
+        mock_endpoint.return_value = {'EndpointArn': 'testarn',
+                                      'ResponseMetadata':
+                                      {'HTTPStatusCode': 404}}
+        self.assertRaises(AirflowException, self.sagemaker.execute, None)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/contrib/operators/test_sagemaker_model_operator.py 
b/tests/contrib/operators/test_sagemaker_model_operator.py
new file mode 100644
index 0000000000..6b6d12de10
--- /dev/null
+++ b/tests/contrib/operators/test_sagemaker_model_operator.py
@@ -0,0 +1,83 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+from airflow import configuration
+from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
+from airflow.contrib.operators.sagemaker_model_operator \
+    import SageMakerModelOperator
+from airflow.exceptions import AirflowException
+
+role = 'arn:aws:iam:role/test-role'
+
+bucket = 'test-bucket'
+
+model_name = 'test-model-name'
+
+image = 'test-image'
+
+output_url = 's3://{}/test/output'.format(bucket)
+create_model_params = {
+    'ModelName': model_name,
+    'PrimaryContainer': {
+        'Image': image,
+        'ModelDataUrl': output_url,
+    },
+    'ExecutionRoleArn': role
+}
+
+
+class TestSageMakerModelOperator(unittest.TestCase):
+
+    def setUp(self):
+        configuration.load_test_config()
+        self.sagemaker = SageMakerModelOperator(
+            task_id='test_sagemaker_operator',
+            aws_conn_id='sagemaker_test_id',
+            config=create_model_params
+        )
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_model')
+    def test_execute(self, mock_model, mock_client):
+        mock_model.return_value = {'ModelArn': 'testarn',
+                                   'ResponseMetadata':
+                                       {'HTTPStatusCode': 200}}
+        self.sagemaker.execute(None)
+        mock_model.assert_called_once_with(create_model_params)
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_model')
+    def test_execute_with_failure(self, mock_model, mock_client):
+        mock_model.return_value = {'ModelArn': 'testarn',
+                                   'ResponseMetadata':
+                                       {'HTTPStatusCode': 404}}
+        self.assertRaises(AirflowException, self.sagemaker.execute, None)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/contrib/sensors/test_sagemaker_endpoint_sensor.py 
b/tests/contrib/sensors/test_sagemaker_endpoint_sensor.py
new file mode 100644
index 0000000000..2e438e4e19
--- /dev/null
+++ b/tests/contrib/sensors/test_sagemaker_endpoint_sensor.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+from airflow import configuration
+from airflow.contrib.sensors.sagemaker_endpoint_sensor \
+    import SageMakerEndpointSensor
+from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
+from airflow.exceptions import AirflowException
+
+DESCRIBE_ENDPOINT_CREATING_RESPONSE = {
+    'EndpointStatus': 'Creating',
+    'ResponseMetadata': {
+        'HTTPStatusCode': 200,
+    }
+}
+DESCRIBE_ENDPOINT_INSERVICE_RESPONSE = {
+    'EndpointStatus': 'InService',
+    'ResponseMetadata': {
+        'HTTPStatusCode': 200,
+    }
+}
+
+DESCRIBE_ENDPOINT_FAILED_RESPONSE = {
+    'EndpointStatus': 'Failed',
+    'ResponseMetadata': {
+        'HTTPStatusCode': 200,
+    },
+    'FailureReason': 'Unknown'
+}
+
+DESCRIBE_ENDPOINT_UPDATING_RESPONSE = {
+    'EndpointStatus': 'Updating',
+    'ResponseMetadata': {
+        'HTTPStatusCode': 200,
+    }
+}
+
+
+class TestSageMakerEndpointSensor(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'describe_endpoint')
+    def test_sensor_with_failure(self, mock_describe, mock_client):
+        mock_describe.side_effect = [DESCRIBE_ENDPOINT_FAILED_RESPONSE]
+        sensor = SageMakerEndpointSensor(
+            task_id='test_task',
+            poke_interval=1,
+            aws_conn_id='aws_test',
+            endpoint_name='test_job_name'
+        )
+        self.assertRaises(AirflowException, sensor.execute, None)
+        mock_describe.assert_called_once_with('test_job_name')
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, '__init__')
+    @mock.patch.object(SageMakerHook, 'describe_endpoint')
+    def test_sensor(self, mock_describe, hook_init, mock_client):
+        hook_init.return_value = None
+
+        mock_describe.side_effect = [
+            DESCRIBE_ENDPOINT_CREATING_RESPONSE,
+            DESCRIBE_ENDPOINT_UPDATING_RESPONSE,
+            DESCRIBE_ENDPOINT_INSERVICE_RESPONSE
+        ]
+        sensor = SageMakerEndpointSensor(
+            task_id='test_task',
+            poke_interval=1,
+            aws_conn_id='aws_test',
+            endpoint_name='test_job_name'
+        )
+
+        sensor.execute(None)
+
+        # make sure we called 3 times(terminated when its completed)
+        self.assertEqual(mock_describe.call_count, 3)
+
+        # make sure the hook was initialized with the specific params
+        hook_init.assert_called_with(aws_conn_id='aws_test')
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/contrib/sensors/test_sagemaker_training_sensor.py 
b/tests/contrib/sensors/test_sagemaker_training_sensor.py
index 5861d7a6fd..6642449c6a 100644
--- a/tests/contrib/sensors/test_sagemaker_training_sensor.py
+++ b/tests/contrib/sensors/test_sagemaker_training_sensor.py
@@ -101,7 +101,7 @@ def test_sensor(self, mock_describe_job, hook_init, 
mock_client):
 
         sensor.execute(None)
 
-        # make sure we called 4 times(terminated when its compeleted)
+        # make sure we called 3 times(terminated when its compeleted)
         self.assertEqual(mock_describe_job.call_count, 3)
 
         # make sure the hook was initialized with the specific params
diff --git a/tests/contrib/sensors/test_sagemaker_transform_sensor.py 
b/tests/contrib/sensors/test_sagemaker_transform_sensor.py
index 1394920d5d..810680683c 100644
--- a/tests/contrib/sensors/test_sagemaker_transform_sensor.py
+++ b/tests/contrib/sensors/test_sagemaker_transform_sensor.py
@@ -97,7 +97,7 @@ def test_sensor(self, mock_describe_job, hook_init, 
mock_client):
 
         sensor.execute(None)
 
-        # make sure we called 4 times(terminated when its compeleted)
+        # make sure we called 3 times(terminated when its compeleted)
         self.assertEqual(mock_describe_job.call_count, 3)
 
         # make sure the hook was initialized with the specific params
diff --git a/tests/contrib/sensors/test_sagemaker_tuning_sensor.py 
b/tests/contrib/sensors/test_sagemaker_tuning_sensor.py
index 8c0ba11380..beb62dc47b 100644
--- a/tests/contrib/sensors/test_sagemaker_tuning_sensor.py
+++ b/tests/contrib/sensors/test_sagemaker_tuning_sensor.py
@@ -100,7 +100,7 @@ def test_sensor(self, mock_describe_job, hook_init, 
mock_client):
 
         sensor.execute(None)
 
-        # make sure we called 4 times(terminated when its compeleted)
+        # make sure we called 3 times(terminated when its compeleted)
         self.assertEqual(mock_describe_job.call_count, 3)
 
         # make sure the hook was initialized with the specific params


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to