This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2e95a2a4ca check sagemaker training job status before deferring
`SageMakerTrainingOperator` (#36685)
2e95a2a4ca is described below
commit 2e95a2a4ca236df112679ec4b445dd99525c3f91
Author: Wei Lee <[email protected]>
AuthorDate: Mon Feb 5 10:21:13 2024 +0800
check sagemaker training job status before deferring
`SageMakerTrainingOperator` (#36685)
---
airflow/providers/amazon/aws/hooks/logs.py | 86 +++++++++++-
airflow/providers/amazon/aws/hooks/sagemaker.py | 145 +++++++++++++++++++--
.../providers/amazon/aws/operators/sagemaker.py | 89 ++++++++++---
airflow/providers/amazon/aws/triggers/sagemaker.py | 83 +++++++++++-
.../aws/operators/test_sagemaker_training.py | 84 +++++++++++-
tests/www/views/test_views_rendered.py | 2 +-
6 files changed, 459 insertions(+), 30 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/logs.py
b/airflow/providers/amazon/aws/hooks/logs.py
index 5e38ad32e3..1e7473734f 100644
--- a/airflow/providers/amazon/aws/hooks/logs.py
+++ b/airflow/providers/amazon/aws/hooks/logs.py
@@ -17,8 +17,11 @@
# under the License.
from __future__ import annotations
+import asyncio
import warnings
-from typing import Generator
+from typing import Any, AsyncGenerator, Generator
+
+from botocore.exceptions import ClientError
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -151,3 +154,84 @@ class AwsLogsHook(AwsBaseHook):
num_consecutive_empty_response = 0
continuation_token.value = response["nextForwardToken"]
+
+ async def describe_log_streams_async(
+ self, log_group: str, stream_prefix: str, order_by: str, count: int
+ ) -> dict[str, Any] | None:
+ """Async function to get the list of log streams for the specified log
group.
+
+ You can list all the log streams or filter the results by prefix. You
can also control
+ how the results are ordered.
+
+ :param log_group: The name of the log group.
+ :param stream_prefix: The prefix to match.
+ :param order_by: If the value is LogStreamName , the results are
ordered by log stream name.
+ If the value is LastEventTime , the results are ordered by the event
time. The default value is LogStreamName.
+ :param count: The maximum number of items returned
+ """
+ async with self.async_conn as client:
+ try:
+ response: dict[str, Any] = await client.describe_log_streams(
+ logGroupName=log_group,
+ logStreamNamePrefix=stream_prefix,
+ orderBy=order_by,
+ limit=count,
+ )
+ return response
+ except ClientError as error:
+ # On the very first training job run on an account, there's no
log group until
+ # the container starts logging, so ignore any errors thrown
about that
+ if error.response["Error"]["Code"] ==
"ResourceNotFoundException":
+ return None
+ raise error
+
+ async def get_log_events_async(
+ self,
+ log_group: str,
+ log_stream_name: str,
+ start_time: int = 0,
+ skip: int = 0,
+ start_from_head: bool = True,
+ ) -> AsyncGenerator[Any, dict[str, Any]]:
+ """A generator for log items in a single stream. This will yield all
the items that are available.
+
+ :param log_group: The name of the log group.
+ :param log_stream_name: The name of the specific stream.
+ :param start_time: The time stamp value to start reading the logs from
(default: 0).
+ :param skip: The number of log entries to skip at the start (default:
0).
+ This is for when there are multiple entries at the same timestamp.
+ :param start_from_head: whether to start from the beginning (True) of
the log or
+ at the end of the log (False).
+ """
+ next_token = None
+ while True:
+ if next_token is not None:
+ token_arg: dict[str, str] = {"nextToken": next_token}
+ else:
+ token_arg = {}
+
+ async with self.async_conn as client:
+ response = await client.get_log_events(
+ logGroupName=log_group,
+ logStreamName=log_stream_name,
+ startTime=start_time,
+ startFromHead=start_from_head,
+ **token_arg,
+ )
+
+ events = response["events"]
+ event_count = len(events)
+
+ if event_count > skip:
+ events = events[skip:]
+ skip = 0
+ else:
+ skip -= event_count
+ events = []
+
+ for event in events:
+ await asyncio.sleep(1)
+ yield event
+
+ if next_token != response["nextForwardToken"]:
+ next_token = response["nextForwardToken"]
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 101d17f92b..74f5b0ae8b 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -26,8 +26,9 @@ import warnings
from collections import Counter, namedtuple
from datetime import datetime
from functools import partial
-from typing import Any, Callable, Generator, cast
+from typing import Any, AsyncGenerator, Callable, Generator, cast
+from asgiref.sync import sync_to_async
from botocore.exceptions import ClientError
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
@@ -310,10 +311,12 @@ class SageMakerHook(AwsBaseHook):
max_ingestion_time,
)
- billable_time = (
- describe_response["TrainingEndTime"] -
describe_response["TrainingStartTime"]
- ) * describe_response["ResourceConfig"]["InstanceCount"]
- self.log.info("Billable seconds: %d",
int(billable_time.total_seconds()) + 1)
+ billable_seconds = SageMakerHook.count_billable_seconds(
+ training_start_time=describe_response["TrainingStartTime"],
+ training_end_time=describe_response["TrainingEndTime"],
+
instance_count=describe_response["ResourceConfig"]["InstanceCount"],
+ )
+ self.log.info("Billable seconds: %d", billable_seconds)
return response
@@ -811,10 +814,12 @@ class SageMakerHook(AwsBaseHook):
if status in failed_states:
reason = last_description.get("FailureReason", "(No reason
provided)")
raise AirflowException(f"Error training {job_name}: {status}
Reason: {reason}")
- billable_time = (
- last_description["TrainingEndTime"] -
last_description["TrainingStartTime"]
- ) * instance_count
- self.log.info("Billable seconds: %d",
int(billable_time.total_seconds()) + 1)
+ billable_seconds = SageMakerHook.count_billable_seconds(
+ training_start_time=last_description["TrainingStartTime"],
+ training_end_time=last_description["TrainingEndTime"],
+ instance_count=instance_count,
+ )
+ self.log.info("Billable seconds: %d", billable_seconds)
def list_training_jobs(
self, name_contains: str | None = None, max_results: int | None =
None, **kwargs
@@ -1300,3 +1305,125 @@ class SageMakerHook(AwsBaseHook):
if "BestCandidate" in res:
return res["BestCandidate"]
return None
+
+ @staticmethod
+ def count_billable_seconds(
+ training_start_time: datetime, training_end_time: datetime,
instance_count: int
+ ) -> int:
+ billable_time = (training_end_time - training_start_time) *
instance_count
+ return int(billable_time.total_seconds()) + 1
+
+ async def describe_training_job_async(self, job_name: str) -> dict[str,
Any]:
+ """
+ Return the training job info associated with the name.
+
+ :param job_name: the name of the training job
+ """
+ async with self.async_conn as client:
+ response: dict[str, Any] = await
client.describe_training_job(TrainingJobName=job_name)
+ return response
+
+ async def describe_training_job_with_log_async(
+ self,
+ job_name: str,
+ positions: dict[str, Any],
+ stream_names: list[str],
+ instance_count: int,
+ state: int,
+ last_description: dict[str, Any],
+ last_describe_job_call: float,
+ ) -> tuple[int, dict[str, Any], float]:
+ """
+ Return the training job info associated with job_name and print
CloudWatch logs.
+
+ :param job_name: name of the job to check status
+ :param positions: A list of pairs of (timestamp, skip) which
represents the last record
+ read from each stream.
+ :param stream_names: A list of the log stream names. The position of
the stream in this list is
+ the stream number.
+ :param instance_count: Count of the instance created for the job
initially
+ :param state: log state
+ :param last_description: Latest description of the training job
+ :param last_describe_job_call: previous job called time
+ """
+ log_group = "/aws/sagemaker/TrainingJobs"
+
+ if len(stream_names) < instance_count:
+ logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+ streams = await logs_hook.describe_log_streams_async(
+ log_group=log_group,
+ stream_prefix=job_name + "/",
+ order_by="LogStreamName",
+ count=instance_count,
+ )
+
+ stream_names = [s["logStreamName"] for s in streams["logStreams"]]
if streams else []
+ positions.update([(s, Position(timestamp=0, skip=0)) for s in
stream_names if s not in positions])
+
+ if len(stream_names) > 0:
+ async for idx, event in self.get_multi_stream(log_group,
stream_names, positions):
+ self.log.info(event["message"])
+ ts, count = positions[stream_names[idx]]
+ if event["timestamp"] == ts:
+ positions[stream_names[idx]] = Position(timestamp=ts,
skip=count + 1)
+ else:
+ positions[stream_names[idx]] =
Position(timestamp=event["timestamp"], skip=1)
+
+ if state == LogState.COMPLETE:
+ return state, last_description, last_describe_job_call
+
+ if state == LogState.JOB_COMPLETE:
+ state = LogState.COMPLETE
+ elif time.time() - last_describe_job_call >= 30:
+ description = await self.describe_training_job_async(job_name)
+ last_describe_job_call = time.time()
+
+ if await
sync_to_async(secondary_training_status_changed)(description, last_description):
+ self.log.info(
+ await
sync_to_async(secondary_training_status_message)(description, last_description)
+ )
+ last_description = description
+
+ status = description["TrainingJobStatus"]
+
+ if status not in self.non_terminal_states:
+ state = LogState.JOB_COMPLETE
+ return state, last_description, last_describe_job_call
+
+ async def get_multi_stream(
+ self, log_group: str, streams: list[str], positions: dict[str, Any]
+ ) -> AsyncGenerator[Any, tuple[int, Any | None]]:
+ """Iterate over the available events coming and interleaving the
events from each stream so they're yielded in timestamp order.
+
+ :param log_group: The name of the log group.
+ :param streams: A list of the log stream names. The position of the
stream in this list is
+ the stream number.
+ :param positions: A list of pairs of (timestamp, skip) which
represents the last record
+ read from each stream.
+ """
+ positions = positions or {s: Position(timestamp=0, skip=0) for s in
streams}
+ events: list[Any | None] = []
+
+ logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+ event_iters = [
+ logs_hook.get_log_events_async(log_group, s,
positions[s].timestamp, positions[s].skip)
+ for s in streams
+ ]
+ for event_stream in event_iters:
+ if not event_stream:
+ events.append(None)
+ continue
+
+ try:
+ events.append(await event_stream.__anext__())
+ except StopAsyncIteration:
+ events.append(None)
+
+ while any(events):
+ i = argmin(events, lambda x: x["timestamp"] if x else
9999999999) or 0
+ yield i, events[i]
+
+ try:
+ events[i] = await event_iters[i].__anext__()
+ except StopAsyncIteration:
+ events[i] = None
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py
b/airflow/providers/amazon/aws/operators/sagemaker.py
index 66a616811e..e8cfa26b29 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -29,9 +29,14 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.hooks.sagemaker import (
+ LogState,
+ SageMakerHook,
+ secondary_training_status_message,
+)
from airflow.providers.amazon.aws.triggers.sagemaker import (
SageMakerPipelineTrigger,
+ SageMakerTrainingPrintLogTrigger,
SageMakerTrigger,
)
from airflow.providers.amazon.aws.utils import trim_none_values
@@ -899,9 +904,11 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
- timeout=datetime.timedelta(seconds=self.max_ingestion_time)
- if self.max_ingestion_time is not None
- else None,
+ timeout=(
+ datetime.timedelta(seconds=self.max_ingestion_time)
+ if self.max_ingestion_time is not None
+ else None
+ ),
)
description = {} # never executed but makes static checkers happy
elif self.wait_for_completion:
@@ -1085,28 +1092,80 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
raise AirflowException(f"Sagemaker Training Job creation failed:
{response}")
if self.deferrable and self.wait_for_completion:
- self.defer(
- timeout=self.execution_timeout,
- trigger=SageMakerTrigger(
+ description =
self.hook.describe_training_job(self.config["TrainingJobName"])
+ status = description["TrainingJobStatus"]
+
+ if self.print_log:
+ instance_count = description["ResourceConfig"]["InstanceCount"]
+ last_describe_job_call = time.monotonic()
+ job_already_completed = status not in
self.hook.non_terminal_states
+ _, description, last_describe_job_call =
self.hook.describe_training_job_with_log(
+ self.config["TrainingJobName"],
+ {},
+ [],
+ instance_count,
+ LogState.COMPLETE if job_already_completed else
LogState.TAILING,
+ description,
+ last_describe_job_call,
+ )
+ self.log.info(secondary_training_status_message(description,
None))
+
+ if status in self.hook.failed_states:
+ reason = description.get("FailureReason", "(No reason
provided)")
+ raise AirflowException(f"SageMaker job failed because
{reason}")
+ elif status == "Completed":
+ log_message = f"{self.task_id} completed successfully."
+ if self.print_log:
+ billable_seconds = SageMakerHook.count_billable_seconds(
+ training_start_time=description["TrainingStartTime"],
+ training_end_time=description["TrainingEndTime"],
+ instance_count=instance_count,
+ )
+ log_message = f"Billable seconds:
{billable_seconds}\n{log_message}"
+ self.log.info(log_message)
+ return {"Training": serialize(description)}
+
+ timeout = self.execution_timeout
+ if self.max_ingestion_time:
+ timeout = datetime.timedelta(seconds=self.max_ingestion_time)
+
+ trigger: SageMakerTrainingPrintLogTrigger | SageMakerTrigger
+ if self.print_log:
+ trigger = SageMakerTrainingPrintLogTrigger(
+ job_name=self.config["TrainingJobName"],
+ poke_interval=self.check_interval,
+ aws_conn_id=self.aws_conn_id,
+ )
+ else:
+ trigger = SageMakerTrigger(
job_name=self.config["TrainingJobName"],
job_type="Training",
poke_interval=self.check_interval,
max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
- ),
+ )
+
+ self.defer(
+ timeout=timeout,
+ trigger=trigger,
method_name="execute_complete",
)
- self.serialized_training_data = serialize(
- self.hook.describe_training_job(self.config["TrainingJobName"])
- )
- return {"Training": self.serialized_training_data}
+ return self.serialize_result()
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> dict[str, dict]:
+ if event is None:
+ err_msg = "Trigger error: event is None"
+ self.log.error(err_msg)
+ raise AirflowException(err_msg)
- def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
- else:
- self.log.info(event["message"])
+
+ self.log.info(event["message"])
+ return self.serialize_result()
+
+ def serialize_result(self) -> dict[str, dict]:
self.serialized_training_data = serialize(
self.hook.describe_training_job(self.config["TrainingJobName"])
)
diff --git a/airflow/providers/amazon/aws/triggers/sagemaker.py
b/airflow/providers/amazon/aws/triggers/sagemaker.py
index b6ae8e8842..c6425d28e0 100644
--- a/airflow/providers/amazon/aws/triggers/sagemaker.py
+++ b/airflow/providers/amazon/aws/triggers/sagemaker.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import asyncio
+import time
from collections import Counter
from enum import IntEnum
from functools import cached_property
@@ -26,7 +27,7 @@ from typing import Any, AsyncIterator
from botocore.exceptions import WaiterError
from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.hooks.sagemaker import LogState,
SageMakerHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -196,3 +197,83 @@ class SageMakerPipelineTrigger(BaseTrigger):
await asyncio.sleep(int(self.waiter_delay))
raise AirflowException("Waiter error: max attempts reached")
+
+
+class SageMakerTrainingPrintLogTrigger(BaseTrigger):
+ """
+ SageMakerTrainingPrintLogTrigger is fired as deferred class with params to
run the task in triggerer.
+
+ :param job_name: name of the job to check status
+ :param poke_interval: polling period in seconds to check for the status
+ :param aws_conn_id: AWS connection ID for sagemaker
+ """
+
+ def __init__(
+ self,
+ job_name: str,
+ poke_interval: float,
+ aws_conn_id: str = "aws_default",
+ ):
+ super().__init__()
+ self.job_name = job_name
+ self.poke_interval = poke_interval
+ self.aws_conn_id = aws_conn_id
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes SageMakerTrainingPrintLogTrigger arguments and
classpath."""
+ return (
+
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger",
+ {
+ "poke_interval": self.poke_interval,
+ "aws_conn_id": self.aws_conn_id,
+ "job_name": self.job_name,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> SageMakerHook:
+ return SageMakerHook(aws_conn_id=self.aws_conn_id)
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """Makes async connection to sagemaker async hook and gets job status
for a job submitted by the operator."""
+ stream_names: list[str] = [] # The list of log streams
+ positions: dict[str, Any] = {} # The current position in each stream,
map of stream name -> position
+
+ last_description = await
self.hook.describe_training_job_async(self.job_name)
+ instance_count = last_description["ResourceConfig"]["InstanceCount"]
+ status = last_description["TrainingJobStatus"]
+ job_already_completed = status not in self.hook.non_terminal_states
+ state = LogState.COMPLETE if job_already_completed else
LogState.TAILING
+ last_describe_job_call = time.time()
+ while True:
+ try:
+ (
+ state,
+ last_description,
+ last_describe_job_call,
+ ) = await self.hook.describe_training_job_with_log_async(
+ self.job_name,
+ positions,
+ stream_names,
+ instance_count,
+ state,
+ last_description,
+ last_describe_job_call,
+ )
+ status = last_description["TrainingJobStatus"]
+ if status in self.hook.non_terminal_states:
+ await asyncio.sleep(self.poke_interval)
+ elif status in self.hook.failed_states:
+ reason = last_description.get("FailureReason", "(No reason
provided)")
+ error_message = f"SageMaker job failed because {reason}"
+ yield TriggerEvent({"status": "error", "message":
error_message})
+ else:
+ billable_seconds = SageMakerHook.count_billable_seconds(
+
training_start_time=last_description["TrainingStartTime"],
+ training_end_time=last_description["TrainingEndTime"],
+ instance_count=instance_count,
+ )
+ self.log.info("Billable seconds: %d", billable_seconds)
+ yield TriggerEvent({"status": "success", "message":
last_description})
+ except Exception as e:
+ yield TriggerEvent({"status": "error", "message": str(e)})
diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
index dab87f1ff2..9d3ad5aee2 100644
--- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py
+++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+from datetime import datetime
from unittest import mock
import pytest
@@ -23,10 +24,12 @@ from botocore.exceptions import ClientError
from openlineage.client.run import Dataset
from airflow.exceptions import AirflowException, TaskDeferred
-from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
+from airflow.providers.amazon.aws.hooks.sagemaker import LogState,
SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerBaseOperator, SageMakerTrainingOperator
-from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
+from airflow.providers.amazon.aws.triggers.sagemaker import (
+ SageMakerTrigger,
+)
from airflow.providers.openlineage.extractors import OperatorLineage
EXPECTED_INTEGER_FIELDS: list[list[str]] = [
@@ -117,8 +120,39 @@ class TestSageMakerTrainingOperator:
with pytest.raises(AirflowException):
self.sagemaker.execute(None)
+
@mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator.defer")
+ @mock.patch.object(
+ SageMakerHook,
+ "describe_training_job_with_log",
+ return_value=(
+ LogState.JOB_COMPLETE,
+ {
+ "TrainingJobStatus": "Completed",
+ "ResourceConfig": {"InstanceCount": 1},
+ "TrainingEndTime": datetime(2023, 5, 15),
+ "TrainingStartTime": datetime(2023, 5, 16),
+ },
+ 50,
+ ),
+ )
+ @mock.patch.object(
+ SageMakerHook,
+ "describe_training_job",
+ return_value={
+ "TrainingJobStatus": "Completed",
+ "ResourceConfig": {"InstanceCount": 1},
+ "TrainingEndTime": datetime(2023, 5, 15),
+ "TrainingStartTime": datetime(2023, 5, 16),
+ },
+ )
@mock.patch.object(SageMakerHook, "create_training_job")
- def test_operator_defer(self, mock_training):
+ def test_operator_complete_before_defer(
+ self,
+ mock_training,
+ mock_describe_training_job,
+ mock_describe_training_job_with_log,
+ mock_defer,
+ ):
mock_training.return_value = {
"TrainingJobArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -126,6 +160,50 @@ class TestSageMakerTrainingOperator:
self.sagemaker.deferrable = True
self.sagemaker.wait_for_completion = True
self.sagemaker.check_if_job_exists = False
+
+ self.sagemaker.execute(context=None)
+ assert not mock_defer.called
+
+ @mock.patch.object(
+ SageMakerHook,
+ "describe_training_job_with_log",
+ return_value=(
+ LogState.WAIT_IN_PROGRESS,
+ {
+ "TrainingJobStatus": "Training",
+ "ResourceConfig": {"InstanceCount": 1},
+ "TrainingEndTime": datetime(2023, 5, 15),
+ "TrainingStartTime": datetime(2023, 5, 16),
+ },
+ 50,
+ ),
+ )
+ @mock.patch.object(
+ SageMakerHook,
+ "describe_training_job",
+ return_value={
+ "TrainingJobStatus": "Training",
+ "ResourceConfig": {"InstanceCount": 1},
+ "TrainingEndTime": datetime(2023, 5, 15),
+ "TrainingStartTime": datetime(2023, 5, 16),
+ },
+ )
+ @mock.patch.object(SageMakerHook, "create_training_job")
+ def test_operator_defer(
+ self,
+ mock_training,
+ mock_describe_training_job,
+ mock_describe_training_job_with_log,
+ ):
+ mock_training.return_value = {
+ "TrainingJobArn": "test_arn",
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ self.sagemaker.deferrable = True
+ self.sagemaker.wait_for_completion = True
+ self.sagemaker.check_if_job_exists = False
+ self.sagemaker.print_log = False
+
with pytest.raises(TaskDeferred) as exc:
self.sagemaker.execute(context=None)
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is
not a SagemakerTrigger"
diff --git a/tests/www/views/test_views_rendered.py
b/tests/www/views/test_views_rendered.py
index 5aa8e6ba2f..7aabec7ef6 100644
--- a/tests/www/views/test_views_rendered.py
+++ b/tests/www/views/test_views_rendered.py
@@ -256,7 +256,7 @@ def test_rendered_template_secret(admin_client,
create_dag_run, task_secret):
if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Handle collection of the test by non-db case
- Variable = mock.MagicMock() # type: ignore[misc] # noqa: F811
+ Variable = mock.MagicMock() # type: ignore[misc] # noqa: F811
else:
initial_db_init()