feluelle commented on a change in pull request #9594:
URL: https://github.com/apache/airflow/pull/9594#discussion_r449436572



##########
File path: airflow/providers/amazon/aws/hooks/sagemaker.py
##########
@@ -786,6 +825,48 @@ def list_training_jobs(
         )
         return results
 
+    def list_processing_jobs(
+        self, name_contains: Optional[str] = None, max_results: Optional[int] 
= None, **kwargs
+    ) -> List[Dict]:   # noqa: D402
+        """
+        This method wraps boto3's list_processing_jobs(). The processing job 
name and max results are
+        configurable via arguments. Other arguments are not, and should be 
provided via kwargs.

Review comment:
       What is the reason you added `name_contains` and `max_results` as 
separate arguments?
   
   I think it would be much easier to just use kwargs. And if you really want 
to make sure that both args are specified, check it and if not set raise an 
exception. But offering two variables where only one can be set is (in my 
opinion) a bit confusing.

##########
File path: airflow/providers/amazon/aws/hooks/sagemaker.py
##########
@@ -786,6 +825,48 @@ def list_training_jobs(
         )
         return results
 
+    def list_processing_jobs(
+        self, name_contains: Optional[str] = None, max_results: Optional[int] 
= None, **kwargs
+    ) -> List[Dict]:   # noqa: D402

Review comment:
       hmm `D402` is `First line should not be the function’s “signature”` 
because you mentioned `list_processing_jobs()` 
(https://github.com/PyCQA/pydocstyle/issues/284). -> I think it is fine.

##########
File path: docs/operators-and-hooks-ref.rst
##########
@@ -465,7 +465,8 @@ These integrations allow you to perform various operations 
within the Amazon Web
        :mod:`airflow.providers.amazon.aws.operators.sagemaker_model`,
        :mod:`airflow.providers.amazon.aws.operators.sagemaker_training`,
        :mod:`airflow.providers.amazon.aws.operators.sagemaker_transform`,
-       :mod:`airflow.providers.amazon.aws.operators.sagemaker_tuning`
+       :mod:`airflow.providers.amazon.aws.operators.sagemaker_tuning`,
+       :mod:`airflow.providers.amazon.aws.operators.sagemaker_processing`,

Review comment:
       Please sort this alphabetically. :)

##########
File path: tests/providers/amazon/aws/operators/test_sagemaker_processing.py
##########
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+import mock
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.operators.sagemaker_processing import 
SageMakerProcessingOperator
+
+job_name = 'test-job-name'
+
+create_processing_params = \
+    {
+        "AppSpecification": {
+            "ContainerArguments": ["container_arg"],
+            "ContainerEntrypoint": ["container_entrypoint"],
+            "ImageUri": "{{ image_uri }}",
+        },
+        "Environment": {"{{ key }}": "{{ value }}"},
+        "ExperimentConfig": {
+            "ExperimentName": "ExperimentName",
+            "TrialComponentDisplayName": "TrialComponentDisplayName",
+            "TrialName": "TrialName",
+        },
+        "ProcessingInputs": [
+            {
+                "InputName": "AnalyticsInputName",
+                "S3Input": {
+                    "LocalPath": "{{ Local Path }}",
+                    "S3CompressionType": "None",
+                    "S3DataDistributionType": "FullyReplicated",
+                    "S3DataType": "S3Prefix",
+                    "S3InputMode": "File",
+                    "S3Uri": "{{ S3Uri }}",
+                },
+            }
+        ],
+        "ProcessingJobName": job_name,
+        "ProcessingOutputConfig": {
+            "KmsKeyId": "KmsKeyID",
+            "Outputs": [
+                {
+                    "OutputName": "AnalyticsOutputName",
+                    "S3Output": {
+                        "LocalPath": "{{ Local Path }}",
+                        "S3UploadMode": "EndOfJob",
+                        "S3Uri": "{{ S3Uri }}",
+                    },
+                }
+            ],
+        },
+        "ProcessingResources": {
+            "ClusterConfig": {
+                "InstanceCount": 2,
+                "InstanceType": "ml.p2.xlarge",
+                "VolumeSizeInGB": 30,
+                "VolumeKmsKeyId": "{{ kms_key }}",
+            }
+        },
+        "RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
+        "StoppingCondition": {"MaxRuntimeInSeconds": 3600},
+        "Tags": [{"{{ key }}": "{{ value }}"}],
+    }
+
+
+# noinspection PyUnusedLocal
+# pylint: disable=unused-argument
+class TestSageMakerProcessingOperator(unittest.TestCase):
+
+    def setUp(self):
+        self.sagemaker = SageMakerProcessingOperator(
+            task_id='test_sagemaker_operator',
+            aws_conn_id='sagemaker_test_id',
+            config=create_processing_params,
+            wait_for_completion=False,
+            check_interval=5
+        )
+
+    def test_parse_config_integers(self):
+        self.sagemaker.parse_config_integers()
+        
self.assertEqual(self.sagemaker.config['ProcessingResources']['ClusterConfig']['InstanceCount'],
+                         
int(self.sagemaker.config['ProcessingResources']['ClusterConfig']['InstanceCount']))
+        
self.assertEqual(self.sagemaker.config['ProcessingResources']['ClusterConfig']['VolumeSizeInGB'],
+                         
int(self.sagemaker.config['ProcessingResources']['ClusterConfig']['VolumeSizeInGB']))
+        
self.assertEqual(self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds'],
+                         
int(self.sagemaker.config['StoppingCondition']['MaxRuntimeInSeconds']))
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_processing_job')
+    def test_execute(self, mock_processing, mock_client):
+        mock_processing.return_value = {'ProcessingJobArn': 'testarn',
+                                        'ResponseMetadata': {'HTTPStatusCode': 
200}}
+        self.sagemaker.execute(None)
+        mock_processing.assert_called_once_with(create_processing_params,
+                                                wait_for_completion=False,
+                                                check_interval=5,
+                                                max_ingestion_time=None
+                                                )
+
+    @mock.patch.object(SageMakerHook, 'get_conn')
+    @mock.patch.object(SageMakerHook, 'create_processing_job')
+    def test_execute_with_failure(self, mock_processing, mock_client):
+        mock_processing.return_value = {'ProcessingJobArn': 'testarn',
+                                        'ResponseMetadata': {'HTTPStatusCode': 
404}}
+        self.assertRaises(AirflowException, self.sagemaker.execute, None)
+# pylint: enable=unused-argument
+
+    @mock.patch.object(SageMakerHook, "get_conn")
+    @mock.patch.object(SageMakerHook, "list_processing_jobs")
+    @mock.patch.object(SageMakerHook, "create_processing_job")
+    def test_execute_with_existing_job_increment(
+        self, mock_create_processing_job, mock_list_processing_jobs, 
mock_client
+    ):
+        self.sagemaker.action_if_job_exists = "increment"
+        mock_create_processing_job.return_value = {"ResponseMetadata": 
{"HTTPStatusCode": 200}}
+        mock_list_processing_jobs.return_value = [{"ProcessingJobName": 
job_name}]
+        self.sagemaker.execute(None)
+
+        expected_config = create_processing_params.copy()
+        # Expect to see ProcessingJobName suffixed with "-2" because we return 
one existing job
+        expected_config["ProcessingJobName"] = f"{job_name}-2"
+        mock_create_processing_job.assert_called_once_with(
+            expected_config,
+            wait_for_completion=False,
+            check_interval=5,
+            max_ingestion_time=None,
+        )
+
+    @mock.patch.object(SageMakerHook, "get_conn")
+    @mock.patch.object(SageMakerHook, "list_processing_jobs")
+    @mock.patch.object(SageMakerHook, "create_processing_job")
+    def test_execute_with_existing_job_fail(
+        self, mock_create_processing_job, mock_list_processing_jobs, 
mock_client
+    ):
+        self.sagemaker.action_if_job_exists = "fail"

Review comment:
       Maybe add another simple test to check if the validation for 
`action_if_job_exists` works as expected.

##########
File path: airflow/providers/amazon/aws/hooks/sagemaker.py
##########
@@ -123,7 +123,7 @@ def secondary_training_status_message(job_description, 
prev_description):
     return '\n'.join(status_strs)
 
 
-class SageMakerHook(AwsBaseHook):
+class SageMakerHook(AwsBaseHook):   # pylint: disable=R0904

Review comment:
       ```suggestion
   class SageMakerHook(AwsBaseHook):   # pylint: disable=too-many-public-methods
   ```

##########
File path: airflow/providers/amazon/aws/operators/sagemaker_processing.py
##########
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.operators.sagemaker_base import 
SageMakerBaseOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class SageMakerProcessingOperator(SageMakerBaseOperator):
+    """
+    Initiate a SageMaker processing job.
+
+    This operator returns The ARN of the processing job created in Amazon 
SageMaker.
+
+    :param config: The configuration necessary to start a processing job 
(templated).
+
+        For details of the configuration parameter see 
:py:meth:`SageMaker.Client.create_processing_job`
+    :type config: dict
+    :param aws_conn_id: The AWS connection ID to use.
+    :type aws_conn_id: str

Review comment:
       Can you add `aws_conn_id` as a kwarg to `__init__` like you did with 
`config`. It is more explicit.

##########
File path: tests/providers/amazon/aws/operators/test_sagemaker_processing.py
##########
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+import mock
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.operators.sagemaker_processing import 
SageMakerProcessingOperator
+
+job_name = 'test-job-name'
+
+create_processing_params = \
+    {

Review comment:
       ```suggestion
   create_processing_params = {
   ```

##########
File path: airflow/providers/amazon/aws/operators/sagemaker_processing.py
##########
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.operators.sagemaker_base import 
SageMakerBaseOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class SageMakerProcessingOperator(SageMakerBaseOperator):
+    """
+    Initiate a SageMaker processing job.
+
+    This operator returns The ARN of the processing job created in Amazon 
SageMaker.
+
+    :param config: The configuration necessary to start a processing job 
(templated).
+
+        For details of the configuration parameter see 
:py:meth:`SageMaker.Client.create_processing_job`
+    :type config: dict
+    :param aws_conn_id: The AWS connection ID to use.
+    :type aws_conn_id: str
+    :param wait_for_completion: If wait is set to True, the time interval, in 
seconds,
+        that the operation waits to check the status of the processing job.
+    :type wait_for_completion: bool
+    :param print_log: if the operator should print the cloudwatch log during 
processing
+    :type print_log: bool
+    :param check_interval: if wait is set to be true, this is the time interval
+        in seconds which the operator will check the status of the processing 
job
+    :type check_interval: int
+    :param max_ingestion_time: If wait is set to True, the operation fails if 
the processing job
+        doesn't finish within max_ingestion_time seconds. If you set this 
parameter to None,
+        the operation does not timeout.
+    :type max_ingestion_time: int
+    :param action_if_job_exists: Behaviour if the job name already exists. 
Possible options are "increment"
+        (default) and "fail".
+    :type action_if_job_exists: str
+    """
+
+    integer_fields = [
+        ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
+        ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
+        ['StoppingCondition', 'MaxRuntimeInSeconds']
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 config,
+                 wait_for_completion=True,
+                 print_log=True,
+                 check_interval=30,
+                 max_ingestion_time=None,
+                 action_if_job_exists: str = "increment",  # TODO use 
typing.Literal for this in Python 3.8
+                 *args, **kwargs):
+        super().__init__(config=config, *args, **kwargs)
+
+        self.wait_for_completion = wait_for_completion
+        self.print_log = print_log
+        self.check_interval = check_interval
+        self.max_ingestion_time = max_ingestion_time
+
+        if action_if_job_exists in ("increment", "fail"):
+            self.action_if_job_exists = action_if_job_exists
+        else:
+            raise AirflowException(
+                "Argument action_if_job_exists accepts only 'increment' and 
'fail'. "
+                f"Provided value: '{action_if_job_exists}'."
+            )

Review comment:
       ```suggestion
           if action_if_job_exists not in ("increment", "fail"):
               raise AirflowException(
                   "Argument action_if_job_exists accepts only 'increment' and 
'fail'. "
                   f"Provided value: '{action_if_job_exists}'."
               )
           self.action_if_job_exists = action_if_job_exists
   ```
   And you want to put the validation at the very top, because it is essential. 
Without a proper value this method don't need to go on.

##########
File path: tests/providers/amazon/aws/operators/test_sagemaker_processing.py
##########
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+import mock
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.operators.sagemaker_processing import 
SageMakerProcessingOperator
+
+job_name = 'test-job-name'
+
+create_processing_params = \
+    {
+        "AppSpecification": {
+            "ContainerArguments": ["container_arg"],
+            "ContainerEntrypoint": ["container_entrypoint"],
+            "ImageUri": "{{ image_uri }}",
+        },
+        "Environment": {"{{ key }}": "{{ value }}"},
+        "ExperimentConfig": {
+            "ExperimentName": "ExperimentName",
+            "TrialComponentDisplayName": "TrialComponentDisplayName",
+            "TrialName": "TrialName",
+        },
+        "ProcessingInputs": [
+            {
+                "InputName": "AnalyticsInputName",
+                "S3Input": {
+                    "LocalPath": "{{ Local Path }}",
+                    "S3CompressionType": "None",
+                    "S3DataDistributionType": "FullyReplicated",
+                    "S3DataType": "S3Prefix",
+                    "S3InputMode": "File",
+                    "S3Uri": "{{ S3Uri }}",
+                },
+            }
+        ],
+        "ProcessingJobName": job_name,
+        "ProcessingOutputConfig": {
+            "KmsKeyId": "KmsKeyID",
+            "Outputs": [
+                {
+                    "OutputName": "AnalyticsOutputName",
+                    "S3Output": {
+                        "LocalPath": "{{ Local Path }}",
+                        "S3UploadMode": "EndOfJob",
+                        "S3Uri": "{{ S3Uri }}",
+                    },
+                }
+            ],
+        },
+        "ProcessingResources": {
+            "ClusterConfig": {
+                "InstanceCount": 2,
+                "InstanceType": "ml.p2.xlarge",
+                "VolumeSizeInGB": 30,
+                "VolumeKmsKeyId": "{{ kms_key }}",
+            }
+        },
+        "RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
+        "StoppingCondition": {"MaxRuntimeInSeconds": 3600},
+        "Tags": [{"{{ key }}": "{{ value }}"}],
+    }
+
+
+# noinspection PyUnusedLocal
+# pylint: disable=unused-argument

Review comment:
       I think you did this because of `mock_client` ? Then I would recommend 
to just check that the mock has been called (once or how often you expect it to 
be called). Even better would be with the correct args if there are any.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to