This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2e95a2a4ca check sagemaker training job status before deferring 
`SageMakerTrainingOperator` (#36685)
2e95a2a4ca is described below

commit 2e95a2a4ca236df112679ec4b445dd99525c3f91
Author: Wei Lee <[email protected]>
AuthorDate: Mon Feb 5 10:21:13 2024 +0800

    check sagemaker training job status before deferring 
`SageMakerTrainingOperator` (#36685)
---
 airflow/providers/amazon/aws/hooks/logs.py         |  86 +++++++++++-
 airflow/providers/amazon/aws/hooks/sagemaker.py    | 145 +++++++++++++++++++--
 .../providers/amazon/aws/operators/sagemaker.py    |  89 ++++++++++---
 airflow/providers/amazon/aws/triggers/sagemaker.py |  83 +++++++++++-
 .../aws/operators/test_sagemaker_training.py       |  84 +++++++++++-
 tests/www/views/test_views_rendered.py             |   2 +-
 6 files changed, 459 insertions(+), 30 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/logs.py 
b/airflow/providers/amazon/aws/hooks/logs.py
index 5e38ad32e3..1e7473734f 100644
--- a/airflow/providers/amazon/aws/hooks/logs.py
+++ b/airflow/providers/amazon/aws/hooks/logs.py
@@ -17,8 +17,11 @@
 # under the License.
 from __future__ import annotations
 
+import asyncio
 import warnings
-from typing import Generator
+from typing import Any, AsyncGenerator, Generator
+
+from botocore.exceptions import ClientError
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -151,3 +154,84 @@ class AwsLogsHook(AwsBaseHook):
                 num_consecutive_empty_response = 0
 
             continuation_token.value = response["nextForwardToken"]
+
+    async def describe_log_streams_async(
+        self, log_group: str, stream_prefix: str, order_by: str, count: int
+    ) -> dict[str, Any] | None:
+        """Async function to get the list of log streams for the specified log 
group.
+
+        You can list all the log streams or filter the results by prefix. You 
can also control
+        how the results are ordered.
+
+        :param log_group: The name of the log group.
+        :param stream_prefix: The prefix to match.
+        :param order_by: If the value is LogStreamName , the results are 
ordered by log stream name.
+         If the value is LastEventTime , the results are ordered by the event 
time. The default value is LogStreamName.
+        :param count: The maximum number of items returned
+        """
+        async with self.async_conn as client:
+            try:
+                response: dict[str, Any] = await client.describe_log_streams(
+                    logGroupName=log_group,
+                    logStreamNamePrefix=stream_prefix,
+                    orderBy=order_by,
+                    limit=count,
+                )
+                return response
+            except ClientError as error:
+                # On the very first training job run on an account, there's no 
log group until
+                # the container starts logging, so ignore any errors thrown 
about that
+                if error.response["Error"]["Code"] == 
"ResourceNotFoundException":
+                    return None
+                raise error
+
+    async def get_log_events_async(
+        self,
+        log_group: str,
+        log_stream_name: str,
+        start_time: int = 0,
+        skip: int = 0,
+        start_from_head: bool = True,
+    ) -> AsyncGenerator[Any, dict[str, Any]]:
+        """A generator for log items in a single stream. This will yield all 
the items that are available.
+
+        :param log_group: The name of the log group.
+        :param log_stream_name: The name of the specific stream.
+        :param start_time: The time stamp value to start reading the logs from 
(default: 0).
+        :param skip: The number of log entries to skip at the start (default: 
0).
+            This is for when there are multiple entries at the same timestamp.
+        :param start_from_head: whether to start from the beginning (True) of 
the log or
+            at the end of the log (False).
+        """
+        next_token = None
+        while True:
+            if next_token is not None:
+                token_arg: dict[str, str] = {"nextToken": next_token}
+            else:
+                token_arg = {}
+
+            async with self.async_conn as client:
+                response = await client.get_log_events(
+                    logGroupName=log_group,
+                    logStreamName=log_stream_name,
+                    startTime=start_time,
+                    startFromHead=start_from_head,
+                    **token_arg,
+                )
+
+                events = response["events"]
+                event_count = len(events)
+
+                if event_count > skip:
+                    events = events[skip:]
+                    skip = 0
+                else:
+                    skip -= event_count
+                    events = []
+
+                for event in events:
+                    await asyncio.sleep(1)
+                    yield event
+
+                if next_token != response["nextForwardToken"]:
+                    next_token = response["nextForwardToken"]
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py 
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 101d17f92b..74f5b0ae8b 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -26,8 +26,9 @@ import warnings
 from collections import Counter, namedtuple
 from datetime import datetime
 from functools import partial
-from typing import Any, Callable, Generator, cast
+from typing import Any, AsyncGenerator, Callable, Generator, cast
 
+from asgiref.sync import sync_to_async
 from botocore.exceptions import ClientError
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
@@ -310,10 +311,12 @@ class SageMakerHook(AwsBaseHook):
                 max_ingestion_time,
             )
 
-            billable_time = (
-                describe_response["TrainingEndTime"] - 
describe_response["TrainingStartTime"]
-            ) * describe_response["ResourceConfig"]["InstanceCount"]
-            self.log.info("Billable seconds: %d", 
int(billable_time.total_seconds()) + 1)
+            billable_seconds = SageMakerHook.count_billable_seconds(
+                training_start_time=describe_response["TrainingStartTime"],
+                training_end_time=describe_response["TrainingEndTime"],
+                
instance_count=describe_response["ResourceConfig"]["InstanceCount"],
+            )
+            self.log.info("Billable seconds: %d", billable_seconds)
 
         return response
 
@@ -811,10 +814,12 @@ class SageMakerHook(AwsBaseHook):
             if status in failed_states:
                 reason = last_description.get("FailureReason", "(No reason 
provided)")
                 raise AirflowException(f"Error training {job_name}: {status} 
Reason: {reason}")
-            billable_time = (
-                last_description["TrainingEndTime"] - 
last_description["TrainingStartTime"]
-            ) * instance_count
-            self.log.info("Billable seconds: %d", 
int(billable_time.total_seconds()) + 1)
+            billable_seconds = SageMakerHook.count_billable_seconds(
+                training_start_time=last_description["TrainingStartTime"],
+                training_end_time=last_description["TrainingEndTime"],
+                instance_count=instance_count,
+            )
+            self.log.info("Billable seconds: %d", billable_seconds)
 
     def list_training_jobs(
         self, name_contains: str | None = None, max_results: int | None = 
None, **kwargs
@@ -1300,3 +1305,125 @@ class SageMakerHook(AwsBaseHook):
             if "BestCandidate" in res:
                 return res["BestCandidate"]
         return None
+
+    @staticmethod
+    def count_billable_seconds(
+        training_start_time: datetime, training_end_time: datetime, 
instance_count: int
+    ) -> int:
+        billable_time = (training_end_time - training_start_time) * 
instance_count
+        return int(billable_time.total_seconds()) + 1
+
+    async def describe_training_job_async(self, job_name: str) -> dict[str, 
Any]:
+        """
+        Return the training job info associated with the name.
+
+        :param job_name: the name of the training job
+        """
+        async with self.async_conn as client:
+            response: dict[str, Any] = await 
client.describe_training_job(TrainingJobName=job_name)
+            return response
+
+    async def describe_training_job_with_log_async(
+        self,
+        job_name: str,
+        positions: dict[str, Any],
+        stream_names: list[str],
+        instance_count: int,
+        state: int,
+        last_description: dict[str, Any],
+        last_describe_job_call: float,
+    ) -> tuple[int, dict[str, Any], float]:
+        """
+        Return the training job info associated with job_name and print 
CloudWatch logs.
+
+        :param job_name: name of the job to check status
+        :param positions: A list of pairs of (timestamp, skip) which 
represents the last record
+            read from each stream.
+        :param stream_names: A list of the log stream names. The position of 
the stream in this list is
+            the stream number.
+        :param instance_count: Count of the instance created for the job 
initially
+        :param state: log state
+        :param last_description: Latest description of the training job
+        :param last_describe_job_call: previous job called time
+        """
+        log_group = "/aws/sagemaker/TrainingJobs"
+
+        if len(stream_names) < instance_count:
+            logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+            streams = await logs_hook.describe_log_streams_async(
+                log_group=log_group,
+                stream_prefix=job_name + "/",
+                order_by="LogStreamName",
+                count=instance_count,
+            )
+
+            stream_names = [s["logStreamName"] for s in streams["logStreams"]] 
if streams else []
+            positions.update([(s, Position(timestamp=0, skip=0)) for s in 
stream_names if s not in positions])
+
+        if len(stream_names) > 0:
+            async for idx, event in self.get_multi_stream(log_group, 
stream_names, positions):
+                self.log.info(event["message"])
+                ts, count = positions[stream_names[idx]]
+                if event["timestamp"] == ts:
+                    positions[stream_names[idx]] = Position(timestamp=ts, 
skip=count + 1)
+                else:
+                    positions[stream_names[idx]] = 
Position(timestamp=event["timestamp"], skip=1)
+
+        if state == LogState.COMPLETE:
+            return state, last_description, last_describe_job_call
+
+        if state == LogState.JOB_COMPLETE:
+            state = LogState.COMPLETE
+        elif time.time() - last_describe_job_call >= 30:
+            description = await self.describe_training_job_async(job_name)
+            last_describe_job_call = time.time()
+
+            if await 
sync_to_async(secondary_training_status_changed)(description, last_description):
+                self.log.info(
+                    await 
sync_to_async(secondary_training_status_message)(description, last_description)
+                )
+                last_description = description
+
+            status = description["TrainingJobStatus"]
+
+            if status not in self.non_terminal_states:
+                state = LogState.JOB_COMPLETE
+        return state, last_description, last_describe_job_call
+
+    async def get_multi_stream(
+        self, log_group: str, streams: list[str], positions: dict[str, Any]
+    ) -> AsyncGenerator[Any, tuple[int, Any | None]]:
+        """Iterate over the available events coming and interleaving the 
events from each stream so they're yielded in timestamp order.
+
+        :param log_group: The name of the log group.
+        :param streams: A list of the log stream names. The position of the 
stream in this list is
+            the stream number.
+        :param positions: A list of pairs of (timestamp, skip) which 
represents the last record
+            read from each stream.
+        """
+        positions = positions or {s: Position(timestamp=0, skip=0) for s in 
streams}
+        events: list[Any | None] = []
+
+        logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+        event_iters = [
+            logs_hook.get_log_events_async(log_group, s, 
positions[s].timestamp, positions[s].skip)
+            for s in streams
+        ]
+        for event_stream in event_iters:
+            if not event_stream:
+                events.append(None)
+                continue
+
+            try:
+                events.append(await event_stream.__anext__())
+            except StopAsyncIteration:
+                events.append(None)
+
+            while any(events):
+                i = argmin(events, lambda x: x["timestamp"] if x else 
9999999999) or 0
+                yield i, events[i]
+
+                try:
+                    events[i] = await event_iters[i].__anext__()
+                except StopAsyncIteration:
+                    events[i] = None
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py 
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 66a616811e..e8cfa26b29 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -29,9 +29,14 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.hooks.sagemaker import (
+    LogState,
+    SageMakerHook,
+    secondary_training_status_message,
+)
 from airflow.providers.amazon.aws.triggers.sagemaker import (
     SageMakerPipelineTrigger,
+    SageMakerTrainingPrintLogTrigger,
     SageMakerTrigger,
 )
 from airflow.providers.amazon.aws.utils import trim_none_values
@@ -899,9 +904,11 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
                     aws_conn_id=self.aws_conn_id,
                 ),
                 method_name="execute_complete",
-                timeout=datetime.timedelta(seconds=self.max_ingestion_time)
-                if self.max_ingestion_time is not None
-                else None,
+                timeout=(
+                    datetime.timedelta(seconds=self.max_ingestion_time)
+                    if self.max_ingestion_time is not None
+                    else None
+                ),
             )
             description = {}  # never executed but makes static checkers happy
         elif self.wait_for_completion:
@@ -1085,28 +1092,80 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
             raise AirflowException(f"Sagemaker Training Job creation failed: 
{response}")
 
         if self.deferrable and self.wait_for_completion:
-            self.defer(
-                timeout=self.execution_timeout,
-                trigger=SageMakerTrigger(
+            description = 
self.hook.describe_training_job(self.config["TrainingJobName"])
+            status = description["TrainingJobStatus"]
+
+            if self.print_log:
+                instance_count = description["ResourceConfig"]["InstanceCount"]
+                last_describe_job_call = time.monotonic()
+                job_already_completed = status not in 
self.hook.non_terminal_states
+                _, description, last_describe_job_call = 
self.hook.describe_training_job_with_log(
+                    self.config["TrainingJobName"],
+                    {},
+                    [],
+                    instance_count,
+                    LogState.COMPLETE if job_already_completed else 
LogState.TAILING,
+                    description,
+                    last_describe_job_call,
+                )
+                self.log.info(secondary_training_status_message(description, 
None))
+
+            if status in self.hook.failed_states:
+                reason = description.get("FailureReason", "(No reason 
provided)")
+                raise AirflowException(f"SageMaker job failed because 
{reason}")
+            elif status == "Completed":
+                log_message = f"{self.task_id} completed successfully."
+                if self.print_log:
+                    billable_seconds = SageMakerHook.count_billable_seconds(
+                        training_start_time=description["TrainingStartTime"],
+                        training_end_time=description["TrainingEndTime"],
+                        instance_count=instance_count,
+                    )
+                    log_message = f"Billable seconds: 
{billable_seconds}\n{log_message}"
+                self.log.info(log_message)
+                return {"Training": serialize(description)}
+
+            timeout = self.execution_timeout
+            if self.max_ingestion_time:
+                timeout = datetime.timedelta(seconds=self.max_ingestion_time)
+
+            trigger: SageMakerTrainingPrintLogTrigger | SageMakerTrigger
+            if self.print_log:
+                trigger = SageMakerTrainingPrintLogTrigger(
+                    job_name=self.config["TrainingJobName"],
+                    poke_interval=self.check_interval,
+                    aws_conn_id=self.aws_conn_id,
+                )
+            else:
+                trigger = SageMakerTrigger(
                     job_name=self.config["TrainingJobName"],
                     job_type="Training",
                     poke_interval=self.check_interval,
                     max_attempts=self.max_attempts,
                     aws_conn_id=self.aws_conn_id,
-                ),
+                )
+
+            self.defer(
+                timeout=timeout,
+                trigger=trigger,
                 method_name="execute_complete",
             )
 
-        self.serialized_training_data = serialize(
-            self.hook.describe_training_job(self.config["TrainingJobName"])
-        )
-        return {"Training": self.serialized_training_data}
+        return self.serialize_result()
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> dict[str, dict]:
+        if event is None:
+            err_msg = "Trigger error: event is None"
+            self.log.error(err_msg)
+            raise AirflowException(err_msg)
 
-    def execute_complete(self, context, event=None):
         if event["status"] != "success":
             raise AirflowException(f"Error while running job: {event}")
-        else:
-            self.log.info(event["message"])
+
+        self.log.info(event["message"])
+        return self.serialize_result()
+
+    def serialize_result(self) -> dict[str, dict]:
         self.serialized_training_data = serialize(
             self.hook.describe_training_job(self.config["TrainingJobName"])
         )
diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py 
b/airflow/providers/amazon/aws/triggers/sagemaker.py
index b6ae8e8842..c6425d28e0 100644
--- a/airflow/providers/amazon/aws/triggers/sagemaker.py
+++ b/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import asyncio
+import time
 from collections import Counter
 from enum import IntEnum
 from functools import cached_property
@@ -26,7 +27,7 @@ from typing import Any, AsyncIterator
 from botocore.exceptions import WaiterError
 
 from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.hooks.sagemaker import LogState, 
SageMakerHook
 from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
@@ -196,3 +197,83 @@ class SageMakerPipelineTrigger(BaseTrigger):
                     await asyncio.sleep(int(self.waiter_delay))
 
             raise AirflowException("Waiter error: max attempts reached")
+
+
+class SageMakerTrainingPrintLogTrigger(BaseTrigger):
+    """
+    SageMakerTrainingPrintLogTrigger is fired as deferred class with params to 
run the task in triggerer.
+
+    :param job_name: name of the job to check status
+    :param poke_interval:  polling period in seconds to check for the status
+    :param aws_conn_id: AWS connection ID for sagemaker
+    """
+
+    def __init__(
+        self,
+        job_name: str,
+        poke_interval: float,
+        aws_conn_id: str = "aws_default",
+    ):
+        super().__init__()
+        self.job_name = job_name
+        self.poke_interval = poke_interval
+        self.aws_conn_id = aws_conn_id
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes SageMakerTrainingPrintLogTrigger arguments and 
classpath."""
+        return (
+            
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger",
+            {
+                "poke_interval": self.poke_interval,
+                "aws_conn_id": self.aws_conn_id,
+                "job_name": self.job_name,
+            },
+        )
+
+    @cached_property
+    def hook(self) -> SageMakerHook:
+        return SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Makes async connection to sagemaker async hook and gets job status 
for a job submitted by the operator."""
+        stream_names: list[str] = []  # The list of log streams
+        positions: dict[str, Any] = {}  # The current position in each stream, 
map of stream name -> position
+
+        last_description = await 
self.hook.describe_training_job_async(self.job_name)
+        instance_count = last_description["ResourceConfig"]["InstanceCount"]
+        status = last_description["TrainingJobStatus"]
+        job_already_completed = status not in self.hook.non_terminal_states
+        state = LogState.COMPLETE if job_already_completed else 
LogState.TAILING
+        last_describe_job_call = time.time()
+        while True:
+            try:
+                (
+                    state,
+                    last_description,
+                    last_describe_job_call,
+                ) = await self.hook.describe_training_job_with_log_async(
+                    self.job_name,
+                    positions,
+                    stream_names,
+                    instance_count,
+                    state,
+                    last_description,
+                    last_describe_job_call,
+                )
+                status = last_description["TrainingJobStatus"]
+                if status in self.hook.non_terminal_states:
+                    await asyncio.sleep(self.poke_interval)
+                elif status in self.hook.failed_states:
+                    reason = last_description.get("FailureReason", "(No reason 
provided)")
+                    error_message = f"SageMaker job failed because {reason}"
+                    yield TriggerEvent({"status": "error", "message": 
error_message})
+                else:
+                    billable_seconds = SageMakerHook.count_billable_seconds(
+                        
training_start_time=last_description["TrainingStartTime"],
+                        training_end_time=last_description["TrainingEndTime"],
+                        instance_count=instance_count,
+                    )
+                    self.log.info("Billable seconds: %d", billable_seconds)
+                    yield TriggerEvent({"status": "success", "message": 
last_description})
+            except Exception as e:
+                yield TriggerEvent({"status": "error", "message": str(e)})
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py 
b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index dab87f1ff2..9d3ad5aee2 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import datetime
 from unittest import mock
 
 import pytest
@@ -23,10 +24,12 @@ from botocore.exceptions import ClientError
 from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException, TaskDeferred
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.hooks.sagemaker import LogState, 
SageMakerHook
 from airflow.providers.amazon.aws.operators import sagemaker
 from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerBaseOperator, SageMakerTrainingOperator
-from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.amazon.aws.triggers.sagemaker import (
+    SageMakerTrigger,
+)
 from airflow.providers.openlineage.extractors import OperatorLineage
 
 EXPECTED_INTEGER_FIELDS: list[list[str]] = [
@@ -117,8 +120,39 @@ class TestSageMakerTrainingOperator:
         with pytest.raises(AirflowException):
             self.sagemaker.execute(None)
 
+    
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator.defer")
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_training_job_with_log",
+        return_value=(
+            LogState.JOB_COMPLETE,
+            {
+                "TrainingJobStatus": "Completed",
+                "ResourceConfig": {"InstanceCount": 1},
+                "TrainingEndTime": datetime(2023, 5, 15),
+                "TrainingStartTime": datetime(2023, 5, 16),
+            },
+            50,
+        ),
+    )
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_training_job",
+        return_value={
+            "TrainingJobStatus": "Completed",
+            "ResourceConfig": {"InstanceCount": 1},
+            "TrainingEndTime": datetime(2023, 5, 15),
+            "TrainingStartTime": datetime(2023, 5, 16),
+        },
+    )
     @mock.patch.object(SageMakerHook, "create_training_job")
-    def test_operator_defer(self, mock_training):
+    def test_operator_complete_before_defer(
+        self,
+        mock_training,
+        mock_describe_training_job,
+        mock_describe_training_job_with_log,
+        mock_defer,
+    ):
         mock_training.return_value = {
             "TrainingJobArn": "test_arn",
             "ResponseMetadata": {"HTTPStatusCode": 200},
@@ -126,6 +160,50 @@ class TestSageMakerTrainingOperator:
         self.sagemaker.deferrable = True
         self.sagemaker.wait_for_completion = True
         self.sagemaker.check_if_job_exists = False
+
+        self.sagemaker.execute(context=None)
+        assert not mock_defer.called
+
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_training_job_with_log",
+        return_value=(
+            LogState.WAIT_IN_PROGRESS,
+            {
+                "TrainingJobStatus": "Training",
+                "ResourceConfig": {"InstanceCount": 1},
+                "TrainingEndTime": datetime(2023, 5, 15),
+                "TrainingStartTime": datetime(2023, 5, 16),
+            },
+            50,
+        ),
+    )
+    @mock.patch.object(
+        SageMakerHook,
+        "describe_training_job",
+        return_value={
+            "TrainingJobStatus": "Training",
+            "ResourceConfig": {"InstanceCount": 1},
+            "TrainingEndTime": datetime(2023, 5, 15),
+            "TrainingStartTime": datetime(2023, 5, 16),
+        },
+    )
+    @mock.patch.object(SageMakerHook, "create_training_job")
+    def test_operator_defer(
+        self,
+        mock_training,
+        mock_describe_training_job,
+        mock_describe_training_job_with_log,
+    ):
+        mock_training.return_value = {
+            "TrainingJobArn": "test_arn",
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        self.sagemaker.deferrable = True
+        self.sagemaker.wait_for_completion = True
+        self.sagemaker.check_if_job_exists = False
+        self.sagemaker.print_log = False
+
         with pytest.raises(TaskDeferred) as exc:
             self.sagemaker.execute(context=None)
         assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is 
not a SagemakerTrigger"
diff --git a/tests/www/views/test_views_rendered.py 
b/tests/www/views/test_views_rendered.py
index 5aa8e6ba2f..7aabec7ef6 100644
--- a/tests/www/views/test_views_rendered.py
+++ b/tests/www/views/test_views_rendered.py
@@ -256,7 +256,7 @@ def test_rendered_template_secret(admin_client, 
create_dag_run, task_secret):
 
 if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
     # Handle collection of the test by non-db case
-    Variable = mock.MagicMock()  # type: ignore[misc] # noqa: F811
+    Variable = mock.MagicMock()  # type: ignore[misc]  # noqa: F811
 else:
     initial_db_init()
 

Reply via email to