yangaws commented on a change in pull request #3767: [AIRFLOW-2524]Add
SageMaker Batch Inference
URL: https://github.com/apache/incubator-airflow/pull/3767#discussion_r217487448
##########
File path: airflow/contrib/hooks/sagemaker_hook.py
##########
@@ -219,6 +219,52 @@ def create_tuning_job(self, tuning_job_config):
return self.conn.create_hyper_parameter_tuning_job(
**tuning_job_config)
+ def create_transform_job(self, transform_job_config,
wait_for_completion=True):
+ """
+ Create a transform job
+ :param transform_job_config: the config for transform job
+ :type transform_job_config: dict
+ :param wait_for_completion:
+ if the program should keep running until job finishes
+ :type wait_for_completion: bool
+ :return: A dict that contains ARN of the transform job.
+ """
+ if self.use_db_config:
+ if not self.sagemaker_conn_id:
+ raise AirflowException(
+ "SageMaker connection id must be present to \
+ read SageMaker transform job configuration.")
+
+ sagemaker_conn = self.get_connection(self.sagemaker_conn_id)
+
+ config = sagemaker_conn.extra_dejson.copy()
+ transform_job_config.update(config)
+
+ self.check_for_url(transform_job_config
+ ['TransformInput']['DataSource']
+ ['S3DataSource']['S3Uri'])
+
+ response = self.conn.create_transform_job(
+ **transform_job_config)
+ if wait_for_completion:
+ self.check_status(['InProgress', 'Stopping', 'Stopped'],
Review comment:
Made them into set instead of dict
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services