troychen728 commented on a change in pull request #3658: [AIRFLOW-2524] Add 
Amazon SageMaker Training
URL: https://github.com/apache/incubator-airflow/pull/3658#discussion_r208671964
 
 

 ##########
 File path: airflow/contrib/hooks/sagemaker_hook.py
 ##########
 @@ -0,0 +1,239 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import copy
+import time
+from botocore.exceptions import ClientError
+
+from airflow.exceptions import AirflowException
+from airflow.contrib.hooks.aws_hook import AwsHook
+from airflow.hooks.S3_hook import S3Hook
+
+
+class SageMakerHook(AwsHook):
+    """
+    Interact with Amazon SageMaker.
+    sagemaker_conn_id is required for using
+    the config stored in db for training/tuning
+    """
+
+    def __init__(self,
+                 sagemaker_conn_id=None,
+                 use_db_config=False,
+                 region_name=None,
+                 check_interval=5,
+                 max_ingestion_time=None,
+                 *args, **kwargs):
+        super(SageMakerHook, self).__init__(*args, **kwargs)
+        self.sagemaker_conn_id = sagemaker_conn_id
+        self.use_db_config = use_db_config
+        self.region_name = region_name
+        self.check_interval = check_interval
+        self.max_ingestion_time = max_ingestion_time
+        self.conn = self.get_conn()
+
+    def check_for_url(self, s3url):
+        """
+        check if the s3url exists
+        :param s3url: S3 url
+        :type s3url:str
+        :return: bool
+        """
+        bucket, key = S3Hook.parse_s3_url(s3url)
+        s3hook = S3Hook(aws_conn_id=self.aws_conn_id)
+        if not s3hook.check_for_bucket(bucket_name=bucket):
+            raise AirflowException(
+                "The input S3 Bucket {} does not exist ".format(bucket))
+        if not s3hook.check_for_key(key=key, bucket_name=bucket):
+            raise AirflowException("The input S3 Key {} does not exist in the 
Bucket"
+                                   .format(s3url, bucket))
+        return True
+
+    def check_valid_training_input(self, training_config):
+        """
+        Run checks before a training starts
+        :param config: training_config
+        :type config: dict
+        :return: None
+        """
+        for channel in training_config['InputDataConfig']:
+            self.check_for_url(channel['DataSource']
+                               ['S3DataSource']['S3Uri'])
+
+    def check_valid_tuning_input(self, tuning_config):
+        """
+        Run checks before a tuning job starts
+        :param config: tuning_config
+        :type config: dict
+        :return: None
+        """
+        for channel in 
tuning_config['TrainingJobDefinition']['InputDataConfig']:
+            self.check_for_url(channel['DataSource']
+                               ['S3DataSource']['S3Uri'])
+
+    def check_status(self, non_terminal_states,
+                     failed_state, key,
+                     describe_function, *args):
+        """
+        :param non_terminal_states: the set of non_terminal states
+        :type non_terminal_states: dict
+        :param failed_state: the set of failed states
+        :type failed_state: dict
+        :param key: the key of the response dict
+        that points to the state
+        :type key: string
+        :param describe_function: the function used to retrieve the status
+        :type describe_function: python callable
+        :param args: the arguments for the function
+        :return: None
+        """
+        sec = 0
+        running = True
+
+        while running:
+
+            sec = sec + self.check_interval
+
+            if self.max_ingestion_time and sec > self.max_ingestion_time:
+                # ensure that the job gets killed if the max ingestion time is 
exceeded
+                raise AirflowException("SageMaker job took more than "
+                                       "%s seconds", self.max_ingestion_time)
+
+            time.sleep(self.check_interval)
+            try:
+                status = describe_function(*args)[key]
+                self.log.info("Job still running for %s seconds... "
+                              "current status is %s" % (sec, status))
+            except KeyError:
+                raise AirflowException("Could not get status of the SageMaker 
job")
+            except ClientError:
+                raise AirflowException("AWS request failed, check log for more 
info")
+
+            if status in non_terminal_states:
+                running = True
+            elif status in failed_state:
+                raise AirflowException("SageMaker job failed")
+            else:
+                running = False
+
+        self.log.info('SageMaker Job Compeleted')
+
+    def get_conn(self):
+        """
+        Establish an AWS connection
+        :return: a boto3 SageMaker client
+        """
+        return self.get_client_type('sagemaker', region_name=self.region_name)
+
+    def list_training_job(self, name_contains=None, status_equals=None):
+        """
+        List the training jobs associated with the given input
+        :param name_contains: A string in the training job name
+        :type name_contains: str
+        :param status_equals: 'InProgress'|'Completed'
+        |'Failed'|'Stopping'|'Stopped'
+        :return:dict
+        """
+        return self.conn.list_training_jobs(
+            NameContains=name_contains, StatusEquals=status_equals)
+
+    def list_tuning_job(self, name_contains=None, status_equals=None):
+        """
+        List the tuning jobs associated with the given input
+        :param name_contains: A string in the training job name
+        :type name_contains: str
+        :param status_equals: 'InProgress'|'Completed'
+        |'Failed'|'Stopping'|'Stopped'
+        :return:dict
+        """
+        return self.conn.list_hyper_parameter_tuning_job(
+            NameContains=name_contains, StatusEquals=status_equals)
+
+    def create_training_job(self, training_job_config, wait=True):
 
 Review comment:
   Changed to wait_for_completion

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to