This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a809c91528 Add deferrable param in SageMakerTrainingOperator (#31042)
a809c91528 is described below
commit a809c91528bebacdb5c3bac75ae3c7bf33a99308
Author: Pankaj Koti <[email protected]>
AuthorDate: Mon May 8 23:49:55 2023 +0530
Add deferrable param in SageMakerTrainingOperator (#31042)
* Add deferrable param in SageMakerTrainingOperator
This will allow running SageMakerTrainingOperator in an async fashion
meaning that we only submit a job from the worker to run a job and
then defer to the trigger for polling to wait for the job status reaching
a terminal state. This way, the worker slot won't be occupied for the
whole period of task execution.
---
.../providers/amazon/aws/operators/sagemaker.py | 45 +++++++++++-
.../aws/operators/test_sagemaker_training.py | 16 ++++-
.../amazon/aws/triggers/test_sagemaker.py | 81 ++++++++++++++++++++++
3 files changed, 139 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 4efb1a863b..587bf5b0a1 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -723,6 +723,8 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
:param print_log: if the operator should print the cloudwatch log during
training
:param check_interval: if wait is set to be true, this is the time interval
in seconds which the operator will check the status of the training job
+ :param max_attempts: Number of times to poll for query state before
returning the current state,
+ defaults to None.
:param max_ingestion_time: If wait is set to True, the operation fails if
the training job
doesn't finish within max_ingestion_time seconds. If you set this
parameter to None,
the operation does not timeout.
@@ -731,6 +733,8 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
:param action_if_job_exists: Behaviour if the job name already exists.
Possible options are "timestamp"
(default), "increment" (deprecated) and "fail".
This is only relevant if check_if_job_exists is True.
+ :param deferrable: Run operator in the deferrable mode. This is only
effective if wait_for_completion is
+ set to True.
:return Dict: Returns The ARN of the training job created in Amazon
SageMaker.
"""
@@ -742,15 +746,18 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
wait_for_completion: bool = True,
print_log: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
+ max_attempts: int | None = None,
max_ingestion_time: int | None = None,
check_if_job_exists: bool = True,
action_if_job_exists: str = "timestamp",
+ deferrable: bool = False,
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.print_log = print_log
self.check_interval = check_interval
+ self.max_attempts = max_attempts or 60
self.max_ingestion_time = max_ingestion_time
self.check_if_job_exists = check_if_job_exists
if action_if_job_exists in {"timestamp", "increment", "fail"}:
@@ -767,6 +774,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
f"Argument action_if_job_exists accepts only 'timestamp',
'increment' and 'fail'. \
Provided value: '{action_if_job_exists}'."
)
+ self.deferrable = deferrable
def expand_role(self) -> None:
"""Expands an IAM role name into an ARN."""
@@ -793,17 +801,50 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
)
self.log.info("Creating SageMaker training job %s.",
self.config["TrainingJobName"])
+
+ if self.deferrable and not self.wait_for_completion:
+ self.log.warning(
+ "Setting deferrable to True does not have effect when
wait_for_completion is set to False."
+ )
+
+ wait_for_completion = self.wait_for_completion
+ if self.deferrable and self.wait_for_completion:
+ # Set wait_for_completion to False so that it waits for the status
in the deferred task.
+ wait_for_completion = False
+
response = self.hook.create_training_job(
self.config,
- wait_for_completion=self.wait_for_completion,
+ wait_for_completion=wait_for_completion,
print_log=self.print_log,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker Training Job creation failed:
{response}")
+
+ if self.deferrable and self.wait_for_completion:
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=SageMakerTrigger(
+ job_name=self.config["TrainingJobName"],
+ job_type="Training",
+ poke_interval=self.check_interval,
+ max_attempts=self.max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ )
+
+ result = {"Training":
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
+ return result
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error while running job: {event}")
else:
- return {"Training":
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
+ self.log.info(event["message"])
+ result = {"Training":
serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
+ return result
class SageMakerDeleteModelOperator(SageMakerBaseOperator):
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index 82b76fdee0..e551317d33 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -21,10 +21,11 @@ from unittest import mock
import pytest
from botocore.exceptions import ClientError
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerTrainingOperator
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
EXPECTED_INTEGER_FIELDS: list[list[str]] = [
["ResourceConfig", "InstanceCount"],
@@ -113,3 +114,16 @@ class TestSageMakerTrainingOperator:
}
with pytest.raises(AirflowException):
self.sagemaker.execute(None)
+
+ @mock.patch.object(SageMakerHook, "create_training_job")
+ def test_operator_defer(self, mock_training):
+ mock_training.return_value = {
+ "TrainingJobArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ self.sagemaker.deferrable = True
+ self.sagemaker.wait_for_completion = True
+ self.sagemaker.check_if_job_exists = False
+ with pytest.raises(TaskDeferred) as exc:
+ self.sagemaker.execute(context=None)
+ assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"
diff --git a/tests/providers/amazon/aws/triggers/test_sagemaker.py
b/tests/providers/amazon/aws/triggers/test_sagemaker.py
new file mode 100644
index 0000000000..d8f0bd1cee
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_sagemaker.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.triggers.base import TriggerEvent
+from tests.providers.amazon.aws.utils.compat import AsyncMock, async_mock
+
+JOB_NAME = "job_name"
+JOB_TYPE = "job_type"
+AWS_CONN_ID = "aws_sagemaker_conn"
+POKE_INTERVAL = 30
+MAX_ATTEMPTS = 60
+
+
+class TestSagemakerTrigger:
+ def test_sagemaker_trigger_serialize(self):
+ sagemaker_trigger = SageMakerTrigger(
+ job_name=JOB_NAME,
+ job_type=JOB_TYPE,
+ poke_interval=POKE_INTERVAL,
+ max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ )
+ class_path, args = sagemaker_trigger.serialize()
+ assert class_path ==
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger"
+ assert args["job_name"] == JOB_NAME
+ assert args["job_type"] == JOB_TYPE
+ assert args["poke_interval"] == POKE_INTERVAL
+ assert args["max_attempts"] == MAX_ATTEMPTS
+ assert args["aws_conn_id"] == AWS_CONN_ID
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_waiter")
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.async_conn")
+
@async_mock.patch("airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter")
+ @async_mock.patch(
+
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter_job_name_arg"
+ )
+ async def test_sagemaker_trigger_run(
+ self,
+ mock_get_job_type_waiter_job_name_arg,
+ mock_get_job_type_waiter,
+ mock_async_conn,
+ mock_get_waiter,
+ ):
+ mock_get_job_type_waiter_job_name_arg.return_value = "job_name"
+ mock_get_job_type_waiter.return_value = "waiter"
+ mock = async_mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = mock
+
+ mock_get_waiter().wait = AsyncMock()
+
+ sagemaker_trigger = SageMakerTrigger(
+ job_name=JOB_NAME,
+ job_type=JOB_TYPE,
+ poke_interval=POKE_INTERVAL,
+ max_attempts=MAX_ATTEMPTS,
+ aws_conn_id=AWS_CONN_ID,
+ )
+
+ generator = sagemaker_trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "success", "message": "Job
completed."})