This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e01ff4749c Add realtime container execution logs for BatchOperator
(#31837)
e01ff4749c is described below
commit e01ff4749cb2469b21f467a1b0089d0115f39368
Author: Anirudh Krishnan <[email protected]>
AuthorDate: Mon Jun 19 12:00:17 2023 -0500
Add realtime container execution logs for BatchOperator (#31837)
* Update param description for get_batch_log_fetcher in batch_waiters
* Update the log fetcher with latest changes with continuation token
---
airflow/providers/amazon/aws/hooks/batch_client.py | 23 +++-
.../providers/amazon/aws/hooks/batch_waiters.py | 30 ++++-
airflow/providers/amazon/aws/hooks/ecs.py | 89 -------------
airflow/providers/amazon/aws/operators/batch.py | 39 +++++-
airflow/providers/amazon/aws/operators/ecs.py | 8 +-
.../providers/amazon/aws/utils/task_log_fetcher.py | 109 ++++++++++++++++
.../amazon/aws/hooks/test_batch_client.py | 25 ++++
.../amazon/aws/hooks/test_batch_waiters.py | 39 ++++++
tests/providers/amazon/aws/hooks/test_ecs.py | 138 +--------------------
tests/providers/amazon/aws/operators/test_batch.py | 28 +++++
tests/providers/amazon/aws/operators/test_ecs.py | 12 +-
.../test_ecs.py => utils/test_task_log_fetcher.py} | 64 +---------
12 files changed, 298 insertions(+), 306 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py
b/airflow/providers/amazon/aws/hooks/batch_client.py
index 4f6e217341..be24aadb67 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -28,6 +28,7 @@ from __future__ import annotations
from random import uniform
from time import sleep
+from typing import Callable
import botocore.client
import botocore.exceptions
@@ -35,6 +36,7 @@ import botocore.waiter
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
from airflow.typing_compat import Protocol, runtime_checkable
@@ -253,7 +255,12 @@ class BatchClientHook(AwsBaseHook):
raise AirflowException(f"AWS Batch job ({job_id}) has unknown status:
{job}")
- def wait_for_job(self, job_id: str, delay: int | float | None = None) ->
None:
+ def wait_for_job(
+ self,
+ job_id: str,
+ delay: int | float | None = None,
+ get_batch_log_fetcher: Callable[[str], AwsTaskLogFetcher | None] |
None = None,
+ ) -> None:
"""
Wait for Batch job to complete.
@@ -261,11 +268,23 @@ class BatchClientHook(AwsBaseHook):
:param delay: a delay before polling for job status
+ :param get_batch_log_fetcher : a method that returns batch_log_fetcher
+
:raises: AirflowException
"""
self.delay(delay)
self.poll_for_job_running(job_id, delay)
- self.poll_for_job_complete(job_id, delay)
+ batch_log_fetcher = None
+ try:
+ if get_batch_log_fetcher:
+ batch_log_fetcher = get_batch_log_fetcher(job_id)
+ if batch_log_fetcher:
+ batch_log_fetcher.start()
+ self.poll_for_job_complete(job_id, delay)
+ finally:
+ if batch_log_fetcher:
+ batch_log_fetcher.stop()
+ batch_log_fetcher.join()
self.log.info("AWS Batch job (%s) has completed", job_id)
def poll_for_job_running(self, job_id: str, delay: int | float | None =
None) -> None:
diff --git a/airflow/providers/amazon/aws/hooks/batch_waiters.py
b/airflow/providers/amazon/aws/hooks/batch_waiters.py
index c746798dff..ec01fbadd9 100644
--- a/airflow/providers/amazon/aws/hooks/batch_waiters.py
+++ b/airflow/providers/amazon/aws/hooks/batch_waiters.py
@@ -29,6 +29,7 @@ import json
import sys
from copy import deepcopy
from pathlib import Path
+from typing import Callable
import botocore.client
import botocore.exceptions
@@ -36,6 +37,7 @@ import botocore.waiter
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
class BatchWaitersHook(BatchClientHook):
@@ -184,7 +186,12 @@ class BatchWaitersHook(BatchClientHook):
"""
return self.waiter_model.waiter_names
- def wait_for_job(self, job_id: str, delay: int | float | None = None) ->
None:
+ def wait_for_job(
+ self,
+ job_id: str,
+ delay: int | float | None = None,
+ get_batch_log_fetcher: Callable[[str], AwsTaskLogFetcher | None] |
None = None,
+ ) -> None:
"""
Wait for Batch job to complete. This assumes that the
``.waiter_model`` is configured
using some variation of the ``.default_config`` so that it can
generate waiters with the
@@ -194,6 +201,9 @@ class BatchWaitersHook(BatchClientHook):
:param delay: A delay before polling for job status
+ :param get_batch_log_fetcher: A method that returns batch_log_fetcher
of
+ type AwsTaskLogFetcher or None when the CloudWatch log stream
hasn't been created yet.
+
:raises: AirflowException
.. note::
@@ -216,10 +226,20 @@ class BatchWaitersHook(BatchClientHook):
waiter.config.max_attempts = sys.maxsize # timeout is managed by
Airflow
waiter.wait(jobs=[job_id])
- waiter = self.get_waiter("JobComplete")
- waiter.config.delay = self.add_jitter(waiter.config.delay,
width=2, minima=1)
- waiter.config.max_attempts = sys.maxsize # timeout is managed by
Airflow
- waiter.wait(jobs=[job_id])
+ batch_log_fetcher = None
+ try:
+ if get_batch_log_fetcher:
+ batch_log_fetcher = get_batch_log_fetcher(job_id)
+ if batch_log_fetcher:
+ batch_log_fetcher.start()
+ waiter = self.get_waiter("JobComplete")
+ waiter.config.delay = self.add_jitter(waiter.config.delay,
width=2, minima=1)
+ waiter.config.max_attempts = sys.maxsize # timeout is managed
by Airflow
+ waiter.wait(jobs=[job_id])
+ finally:
+ if batch_log_fetcher:
+ batch_log_fetcher.stop()
+ batch_log_fetcher.join()
except (botocore.exceptions.ClientError,
botocore.exceptions.WaiterError) as err:
raise AirflowException(err)
diff --git a/airflow/providers/amazon/aws/hooks/ecs.py
b/airflow/providers/amazon/aws/hooks/ecs.py
index 94baeb0a9a..ad45da5d0a 100644
--- a/airflow/providers/amazon/aws/hooks/ecs.py
+++ b/airflow/providers/amazon/aws/hooks/ecs.py
@@ -17,19 +17,10 @@
# under the License.
from __future__ import annotations
-import time
-from collections import deque
-from datetime import datetime, timedelta
-from logging import Logger
-from threading import Event, Thread
-from typing import Generator
-
-from botocore.exceptions import ClientError, ConnectionClosedError
from botocore.waiter import Waiter
from airflow.providers.amazon.aws.exceptions import EcsOperatorError,
EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
-from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.utils import _StringCompareEnum
from airflow.typing_compat import Protocol, runtime_checkable
@@ -143,86 +134,6 @@ class EcsHook(AwsGenericHook):
return self.conn.describe_tasks(cluster=cluster,
tasks=[task])["tasks"][0]["lastStatus"]
-class EcsTaskLogFetcher(Thread):
- """
- Fetches Cloudwatch log events with specific interval as a thread
- and sends the log events to the info channel of the provided logger.
- """
-
- def __init__(
- self,
- *,
- log_group: str,
- log_stream_name: str,
- fetch_interval: timedelta,
- logger: Logger,
- aws_conn_id: str | None = "aws_default",
- region_name: str | None = None,
- ):
- super().__init__()
- self._event = Event()
-
- self.fetch_interval = fetch_interval
-
- self.logger = logger
- self.log_group = log_group
- self.log_stream_name = log_stream_name
-
- self.hook = AwsLogsHook(aws_conn_id=aws_conn_id,
region_name=region_name)
-
- def run(self) -> None:
- continuation_token = AwsLogsHook.ContinuationToken()
- while not self.is_stopped():
- time.sleep(self.fetch_interval.total_seconds())
- log_events = self._get_log_events(continuation_token)
- for log_event in log_events:
- self.logger.info(self._event_to_str(log_event))
-
- def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None
= None) -> Generator:
- if skip_token is None:
- skip_token = AwsLogsHook.ContinuationToken()
- try:
- yield from self.hook.get_log_events(
- self.log_group, self.log_stream_name,
continuation_token=skip_token
- )
- except ClientError as error:
- if error.response["Error"]["Code"] != "ResourceNotFoundException":
- self.logger.warning("Error on retrieving Cloudwatch log
events", error)
- else:
- self.logger.info(
- "Cannot find log stream yet, it can take a couple of
seconds to show up. "
- "If this error persists, check that the log group and
stream are correct: "
- "group: %s\tstream: %s",
- self.log_group,
- self.log_stream_name,
- )
- yield from ()
- except ConnectionClosedError as error:
- self.logger.warning("ConnectionClosedError on retrieving
Cloudwatch log events", error)
- yield from ()
-
- def _event_to_str(self, event: dict) -> str:
- event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
- formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
- message = event["message"]
- return f"[{formatted_event_dt}] {message}"
-
- def get_last_log_messages(self, number_messages) -> list:
- return [log["message"] for log in deque(self._get_log_events(),
maxlen=number_messages)]
-
- def get_last_log_message(self) -> str | None:
- try:
- return self.get_last_log_messages(1)[0]
- except IndexError:
- return None
-
- def is_stopped(self) -> bool:
- return self._event.is_set()
-
- def stop(self):
- self._event.set()
-
-
@runtime_checkable
class EcsProtocol(Protocol):
"""
diff --git a/airflow/providers/amazon/aws/operators/batch.py
b/airflow/providers/amazon/aws/operators/batch.py
index 2825ed5a01..88feb01311 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -25,6 +25,7 @@
from __future__ import annotations
import warnings
+from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
@@ -39,6 +40,7 @@ from airflow.providers.amazon.aws.links.batch import (
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from airflow.providers.amazon.aws.utils import trim_none_values
+from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -79,6 +81,10 @@ class BatchOperator(BaseOperator):
:param tags: collection of tags to apply to the AWS Batch job submission
if None, no tags are submitted
:param deferrable: Run operator in the deferrable mode.
+ :param awslogs_enabled: Specifies whether logs from CloudWatch
+ should be printed or not, False.
+ If it is an array job, only the logs of the first task will be printed.
+ :param awslogs_fetch_interval: The interval with which cloudwatch logs are
to be fetched, 30 sec.
:param poll_interval: (Deferrable mode only) Time in seconds to wait
between polling.
.. note::
@@ -104,6 +110,8 @@ class BatchOperator(BaseOperator):
"waiters",
"tags",
"wait_for_completion",
+ "awslogs_enabled",
+ "awslogs_fetch_interval",
)
template_fields_renderers = {
"container_overrides": "json",
@@ -145,6 +153,8 @@ class BatchOperator(BaseOperator):
wait_for_completion: bool = True,
deferrable: bool = False,
poll_interval: int = 30,
+ awslogs_enabled: bool = False,
+ awslogs_fetch_interval: timedelta = timedelta(seconds=30),
**kwargs,
) -> None:
BaseOperator.__init__(self, **kwargs)
@@ -179,6 +189,8 @@ class BatchOperator(BaseOperator):
self.wait_for_completion = wait_for_completion
self.deferrable = deferrable
self.poll_interval = poll_interval
+ self.awslogs_enabled = awslogs_enabled
+ self.awslogs_fetch_interval = awslogs_fetch_interval
# params for hook
self.max_retries = max_retries
@@ -319,10 +331,16 @@ class BatchOperator(BaseOperator):
job_queue_arn=job_queue_arn,
)
- if self.waiters:
- self.waiters.wait_for_job(self.job_id)
+ if self.awslogs_enabled:
+ if self.waiters:
+ self.waiters.wait_for_job(self.job_id,
get_batch_log_fetcher=self._get_batch_log_fetcher)
+ else:
+ self.hook.wait_for_job(self.job_id,
get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
- self.hook.wait_for_job(self.job_id)
+ if self.waiters:
+ self.waiters.wait_for_job(self.job_id)
+ else:
+ self.hook.wait_for_job(self.job_id)
awslogs = self.hook.get_job_all_awslogs_info(self.job_id)
if awslogs:
@@ -347,6 +365,21 @@ class BatchOperator(BaseOperator):
self.hook.check_job_success(self.job_id)
self.log.info("AWS Batch job (%s) succeeded", self.job_id)
+ def _get_batch_log_fetcher(self, job_id: str) -> AwsTaskLogFetcher | None:
+ awslog_info = self.hook.get_job_awslogs_info(job_id)
+
+ if not awslog_info:
+ return None
+
+ return AwsTaskLogFetcher(
+ aws_conn_id=self.aws_conn_id,
+ region_name=awslog_info["awslogs_region"],
+ log_group=awslog_info["awslogs_group"],
+ log_stream_name=awslog_info["awslogs_stream_name"],
+ fetch_interval=self.awslogs_fetch_interval,
+ logger=self.log,
+ )
+
class BatchCreateComputeEnvironmentOperator(BaseOperator):
"""Create an AWS Batch compute environment.
diff --git a/airflow/providers/amazon/aws/operators/ecs.py
b/airflow/providers/amazon/aws/operators/ecs.py
index a72b15c18a..bc8c4b70d7 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -33,9 +33,9 @@ from airflow.providers.amazon.aws.hooks.base_aws import
AwsBaseHook
from airflow.providers.amazon.aws.hooks.ecs import (
EcsClusterStates,
EcsHook,
- EcsTaskLogFetcher,
should_retry_eni,
)
+from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
from airflow.utils.session import provide_session
@@ -447,7 +447,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
self.arn: str | None = None
self.retry_args = quota_retry
- self.task_log_fetcher: EcsTaskLogFetcher | None = None
+ self.task_log_fetcher: AwsTaskLogFetcher | None = None
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
@@ -597,12 +597,12 @@ class EcsRunTaskOperator(EcsBaseOperator):
def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix
- def _get_task_log_fetcher(self) -> EcsTaskLogFetcher:
+ def _get_task_log_fetcher(self) -> AwsTaskLogFetcher:
if not self.awslogs_group:
raise ValueError("must specify awslogs_group to fetch task logs")
log_stream_name =
f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"
- return EcsTaskLogFetcher(
+ return AwsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region,
log_group=self.awslogs_group,
diff --git a/airflow/providers/amazon/aws/utils/task_log_fetcher.py
b/airflow/providers/amazon/aws/utils/task_log_fetcher.py
new file mode 100644
index 0000000000..97b43a67b2
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/task_log_fetcher.py
@@ -0,0 +1,109 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import time
+from collections import deque
+from datetime import datetime, timedelta
+from logging import Logger
+from threading import Event, Thread
+from typing import Generator
+
+from botocore.exceptions import ClientError, ConnectionClosedError
+
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+
+
+class AwsTaskLogFetcher(Thread):
+ """
+ Fetches Cloudwatch log events with specific interval as a thread
+ and sends the log events to the info channel of the provided logger.
+ """
+
+ def __init__(
+ self,
+ *,
+ log_group: str,
+ log_stream_name: str,
+ fetch_interval: timedelta,
+ logger: Logger,
+ aws_conn_id: str | None = "aws_default",
+ region_name: str | None = None,
+ ):
+ super().__init__()
+ self._event = Event()
+
+ self.fetch_interval = fetch_interval
+
+ self.logger = logger
+ self.log_group = log_group
+ self.log_stream_name = log_stream_name
+
+ self.hook = AwsLogsHook(aws_conn_id=aws_conn_id,
region_name=region_name)
+
+ def run(self) -> None:
+ continuation_token = AwsLogsHook.ContinuationToken()
+ while not self.is_stopped():
+ time.sleep(self.fetch_interval.total_seconds())
+ log_events = self._get_log_events(continuation_token)
+ for log_event in log_events:
+ self.logger.info(self._event_to_str(log_event))
+
+ def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None
= None) -> Generator:
+ if skip_token is None:
+ skip_token = AwsLogsHook.ContinuationToken()
+ try:
+ yield from self.hook.get_log_events(
+ self.log_group, self.log_stream_name,
continuation_token=skip_token
+ )
+ except ClientError as error:
+ if error.response["Error"]["Code"] != "ResourceNotFoundException":
+ self.logger.warning("Error on retrieving Cloudwatch log
events", error)
+ else:
+ self.logger.info(
+ "Cannot find log stream yet, it can take a couple of
seconds to show up. "
+ "If this error persists, check that the log group and
stream are correct: "
+ "group: %s\tstream: %s",
+ self.log_group,
+ self.log_stream_name,
+ )
+ yield from ()
+ except ConnectionClosedError as error:
+ self.logger.warning("ConnectionClosedError on retrieving
Cloudwatch log events", error)
+ yield from ()
+
+ def _event_to_str(self, event: dict) -> str:
+ event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
+ formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
+ message = event["message"]
+ return f"[{formatted_event_dt}] {message}"
+
+ def get_last_log_messages(self, number_messages) -> list:
+ return [log["message"] for log in deque(self._get_log_events(),
maxlen=number_messages)]
+
+ def get_last_log_message(self) -> str | None:
+ try:
+ return self.get_last_log_messages(1)[0]
+ except IndexError:
+ return None
+
+ def is_stopped(self) -> bool:
+ return self._event.is_set()
+
+ def stop(self):
+ self._event.set()
diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py
b/tests/providers/amazon/aws/hooks/test_batch_client.py
index aef8be1d26..ea04cb2080 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_client.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_client.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import logging
+import time
from unittest import mock
import botocore.exceptions
@@ -25,6 +26,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
# Use dummy AWS credentials
AWS_REGION = "eu-west-1"
@@ -115,6 +117,29 @@ class TestBatchClient:
assert self.client_mock.describe_jobs.call_count == 4
+ def test_wait_for_job_with_logs(self):
+ self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId":
JOB_ID, "status": "SUCCEEDED"}]}
+
+ batch_log_fetcher = mock.Mock(spec=AwsTaskLogFetcher)
+ mock_get_batch_log_fetcher = mock.Mock(return_value=batch_log_fetcher)
+
+ thread_start = mock.Mock(side_effect=lambda: time.sleep(2))
+ thread_stop = mock.Mock(side_effect=lambda: time.sleep(2))
+ thread_join = mock.Mock(side_effect=lambda: time.sleep(2))
+
+ with mock.patch.object(
+ batch_log_fetcher, "start", thread_start
+ ) as mock_fetcher_start, mock.patch.object(
+ batch_log_fetcher, "stop", thread_stop
+ ) as mock_fetcher_stop, mock.patch.object(
+ batch_log_fetcher, "join", thread_join
+ ) as mock_fetcher_join:
+ self.batch_client.wait_for_job(JOB_ID,
get_batch_log_fetcher=mock_get_batch_log_fetcher)
+ mock_get_batch_log_fetcher.assert_called_with(JOB_ID)
+ mock_fetcher_start.assert_called_once()
+ mock_fetcher_stop.assert_called_once()
+ mock_fetcher_join.assert_called_once()
+
def test_poll_job_running_for_status_running(self):
self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId":
JOB_ID, "status": "RUNNING"}]}
self.batch_client.poll_for_job_running(JOB_ID)
diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
index 285784d184..cdf581417c 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import inspect
import itertools
+import time
from unittest import mock
import boto3
@@ -29,6 +30,7 @@ from moto import mock_batch
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_waiters import BatchWaitersHook
+from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
INTERMEDIATE_STATES = ("SUBMITTED", "PENDING", "RUNNABLE", "STARTING")
RUNNING_STATE = "RUNNING"
@@ -145,6 +147,43 @@ class TestBatchWaiters:
assert mock_config.delay == 0
assert mock_config.max_attempts == sys.maxsize
+ def test_wait_for_job_with_cloudwatch_logs(self):
+
+ # mock delay for speedy test
+ mock_jitter = mock.Mock(return_value=0)
+ self.batch_waiters.add_jitter = mock_jitter
+
+ batch_log_fetcher = mock.Mock(spec=AwsTaskLogFetcher)
+ mock_get_batch_log_fetcher = mock.Mock(return_value=batch_log_fetcher)
+
+ thread_start = mock.Mock(side_effect=lambda: time.sleep(2))
+ thread_stop = mock.Mock(side_effect=lambda: time.sleep(2))
+ thread_join = mock.Mock(side_effect=lambda: time.sleep(2))
+
+ with mock.patch.object(self.batch_waiters, "get_waiter") as
mock_get_waiter, mock.patch.object(
+ batch_log_fetcher, "start", thread_start
+ ) as mock_fetcher_start, mock.patch.object(
+ batch_log_fetcher, "stop", thread_stop
+ ) as mock_fetcher_stop, mock.patch.object(
+ batch_log_fetcher, "join", thread_join
+ ) as mock_fetcher_join:
+
+ # Run the wait_for_job method
+ self.batch_waiters.wait_for_job(self.job_id,
get_batch_log_fetcher=mock_get_batch_log_fetcher)
+
+ # Assertions
+ assert mock_get_waiter.call_args_list == [
+ mock.call("JobExists"),
+ mock.call("JobRunning"),
+ mock.call("JobComplete"),
+ ]
+
+
mock_get_waiter.return_value.wait.assert_called_with(jobs=[self.job_id])
+ mock_get_batch_log_fetcher.assert_called_with(self.job_id)
+ mock_fetcher_start.assert_called_once()
+ mock_fetcher_stop.assert_called_once()
+ mock_fetcher_join.assert_called_once()
+
def test_wait_for_job_raises_for_client_error(self):
# mock delay for speedy test
mock_jitter = mock.Mock(return_value=0)
diff --git a/tests/providers/amazon/aws/hooks/test_ecs.py
b/tests/providers/amazon/aws/hooks/test_ecs.py
index d90ec65e52..db1e494f94 100644
--- a/tests/providers/amazon/aws/hooks/test_ecs.py
+++ b/tests/providers/amazon/aws/hooks/test_ecs.py
@@ -16,16 +16,12 @@
# under the License.
from __future__ import annotations
-from datetime import timedelta
from unittest import mock
-from unittest.mock import PropertyMock
import pytest
-from botocore.exceptions import ClientError
from airflow.providers.amazon.aws.exceptions import EcsOperatorError,
EcsTaskFailToStart
-from airflow.providers.amazon.aws.hooks.ecs import EcsHook, EcsTaskLogFetcher,
should_retry, should_retry_eni
-from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.hooks.ecs import EcsHook, should_retry,
should_retry_eni
DEFAULT_CONN_ID: str = "aws_default"
REGION: str = "us-east-1"
@@ -82,135 +78,3 @@ class TestShouldRetryEni:
"ref pull has been retried 5 time(s): failed to resolve
reference"
)
)
-
-
-class TestEcsTaskLogFetcher:
- @mock.patch("logging.Logger")
- def set_up_log_fetcher(self, logger_mock):
- self.logger_mock = logger_mock
-
- self.log_fetcher = EcsTaskLogFetcher(
- log_group="test_log_group",
- log_stream_name="test_log_stream_name",
- fetch_interval=timedelta(milliseconds=1),
- logger=logger_mock,
- )
-
- def setup_method(self):
- self.set_up_log_fetcher()
-
- @mock.patch(
- "threading.Event.is_set",
- side_effect=(False, False, False, True),
- )
- @mock.patch(
- "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
- side_effect=(
- iter(
- [
- {"timestamp": 1617400267123, "message": "First"},
- {"timestamp": 1617400367456, "message": "Second"},
- ]
- ),
- iter(
- [
- {"timestamp": 1617400467789, "message": "Third"},
- ]
- ),
- iter([]),
- ),
- )
- def test_run(self, get_log_events_mock, event_is_set_mock):
-
- self.log_fetcher.run()
-
- self.logger_mock.info.assert_has_calls(
- [
- mock.call("[2021-04-02 21:51:07,123] First"),
- mock.call("[2021-04-02 21:52:47,456] Second"),
- mock.call("[2021-04-02 21:54:27,789] Third"),
- ]
- )
-
- @mock.patch(
- "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
- side_effect=ClientError({"Error": {"Code":
"ResourceNotFoundException"}}, None),
- )
- def test_get_log_events_with_expected_error(self, get_log_events_mock):
- with pytest.raises(StopIteration):
- next(self.log_fetcher._get_log_events())
-
- @mock.patch(
- "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
- side_effect=Exception(),
- )
- def test_get_log_events_with_unexpected_error(self, get_log_events_mock):
- with pytest.raises(Exception):
- next(self.log_fetcher._get_log_events())
-
- @mock.patch.object(AwsLogsHook, "conn", new_callable=PropertyMock)
- def test_get_log_events_updates_token(self, logs_conn_mock):
- logs_conn_mock().get_log_events.return_value = {
- "events": ["my_event"],
- "nextForwardToken": "my_next_token",
- }
-
- token = AwsLogsHook.ContinuationToken()
- list(self.log_fetcher._get_log_events(token))
-
- assert token.value == "my_next_token"
- # 2 calls expected, it's only on the second one that the stop
condition old_token == next_token is met
- assert logs_conn_mock().get_log_events.call_count == 2
-
- def test_event_to_str(self):
- events = [
- {"timestamp": 1617400267123, "message": "First"},
- {"timestamp": 1617400367456, "message": "Second"},
- {"timestamp": 1617400467789, "message": "Third"},
- ]
- assert [self.log_fetcher._event_to_str(event) for event in events] == (
- [
- "[2021-04-02 21:51:07,123] First",
- "[2021-04-02 21:52:47,456] Second",
- "[2021-04-02 21:54:27,789] Third",
- ]
- )
-
- @mock.patch(
- "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
- return_value=(),
- )
- def test_get_last_log_message_with_no_log_events(self, mock_log_events):
- assert self.log_fetcher.get_last_log_message() is None
-
- @mock.patch(
- "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
- return_value=iter(
- [
- {"timestamp": 1617400267123, "message": "First"},
- {"timestamp": 1617400367456, "message": "Second"},
- ]
- ),
- )
- def test_get_last_log_message_with_log_events(self, mock_log_events):
- assert self.log_fetcher.get_last_log_message() == "Second"
-
- @mock.patch(
- "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
- return_value=iter(
- [
- {"timestamp": 1617400267123, "message": "First"},
- {"timestamp": 1617400367456, "message": "Second"},
- {"timestamp": 1617400367458, "message": "Third"},
- ]
- ),
- )
- def test_get_last_log_messages_with_log_events(self, mock_log_events):
- assert self.log_fetcher.get_last_log_messages(2) == ["Second", "Third"]
-
- @mock.patch(
- "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
- return_value=(),
- )
- def test_get_last_log_messages_with_no_log_events(self, mock_log_events):
- assert self.log_fetcher.get_last_log_messages(2) == []
diff --git a/tests/providers/amazon/aws/operators/test_batch.py
b/tests/providers/amazon/aws/operators/test_batch.py
index f559424dff..a65e00d8db 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -120,6 +120,8 @@ class TestBatchOperator:
"waiters",
"tags",
"wait_for_completion",
+ "awslogs_enabled",
+ "awslogs_fetch_interval",
)
@mock.patch.object(BatchClientHook, "get_job_description")
@@ -273,6 +275,32 @@ class TestBatchOperator:
batch.execute(context=None)
assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger
is not a BatchOperatorTrigger"
+ @mock.patch.object(BatchClientHook, "get_job_description")
+ @mock.patch.object(BatchClientHook, "wait_for_job")
+ @mock.patch.object(BatchClientHook, "check_job_success")
+
@mock.patch("airflow.providers.amazon.aws.links.batch.BatchJobQueueLink.persist")
+
@mock.patch("airflow.providers.amazon.aws.links.batch.BatchJobDefinitionLink.persist")
+ def test_monitor_job_with_logs(
+ self, job_definition_persist_mock, job_queue_persist_mock, check_mock,
wait_mock, job_description_mock
+ ):
+ batch = BatchOperator(
+ task_id="task",
+ job_name=JOB_NAME,
+ job_queue="queue",
+ job_definition="hello-world",
+ awslogs_enabled=True,
+ )
+
+ batch.job_id = JOB_ID
+
+ batch.monitor_job(context=None)
+
+ job_description_mock.assert_called_with(job_id=JOB_ID)
+ job_definition_persist_mock.assert_called_once()
+ job_queue_persist_mock.assert_called_once()
+ wait_mock.assert_called_once()
+ assert len(wait_mock.call_args) == 2
+
class TestBatchCreateComputeEnvironmentOperator:
@mock.patch.object(BatchClientHook, "client")
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py
b/tests/providers/amazon/aws/operators/test_ecs.py
index 6366a879c2..8c99a02ce8 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -35,8 +35,8 @@ from airflow.providers.amazon.aws.operators.ecs import (
EcsDeregisterTaskDefinitionOperator,
EcsRegisterTaskDefinitionOperator,
EcsRunTaskOperator,
- EcsTaskLogFetcher,
)
+from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
from airflow.utils.types import NOTSET
CLUSTER_NAME = "test_cluster"
@@ -368,7 +368,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
client_mock.describe_tasks.assert_called_once_with(cluster="c",
tasks=["arn"])
@mock.patch.object(EcsBaseOperator, "client")
- @mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
+
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
def test_check_success_tasks_raises_cloudwatch_logs(self,
log_fetcher_mock, client_mock):
self.ecs.arn = "arn"
self.ecs.task_log_fetcher = log_fetcher_mock
@@ -387,7 +387,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
client_mock.describe_tasks.assert_called_once_with(cluster="c",
tasks=["arn"])
@mock.patch.object(EcsBaseOperator, "client")
- @mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
+
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
def test_check_success_tasks_raises_cloudwatch_logs_empty(self,
log_fetcher_mock, client_mock):
self.ecs.arn = "arn"
self.ecs.task_log_fetcher = log_fetcher_mock
@@ -624,7 +624,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
assert self.ecs.arn ==
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
@mock.patch.object(EcsBaseOperator, "client")
- @mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
+
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
def test_execute_xcom_with_log(self, log_fetcher_mock, client_mock):
self.ecs.do_xcom_push = True
self.ecs.task_log_fetcher = log_fetcher_mock
@@ -634,7 +634,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
assert self.ecs.execute(None) == "Log output"
@mock.patch.object(EcsBaseOperator, "client")
- @mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
+
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
def test_execute_xcom_with_no_log(self, log_fetcher_mock, client_mock):
self.ecs.do_xcom_push = True
self.ecs.task_log_fetcher = log_fetcher_mock
@@ -649,7 +649,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
assert self.ecs.execute(None) is None
@mock.patch.object(EcsBaseOperator, "client")
- @mock.patch.object(EcsTaskLogFetcher, "get_last_log_message",
return_value="Log output")
+ @mock.patch.object(AwsTaskLogFetcher, "get_last_log_message",
return_value="Log output")
def test_execute_xcom_disabled(self, log_fetcher_mock, client_mock):
self.ecs.do_xcom_push = False
assert self.ecs.execute(None) is None
diff --git a/tests/providers/amazon/aws/hooks/test_ecs.py
b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
similarity index 71%
copy from tests/providers/amazon/aws/hooks/test_ecs.py
copy to tests/providers/amazon/aws/utils/test_task_log_fetcher.py
index d90ec65e52..dbda751cfb 100644
--- a/tests/providers/amazon/aws/hooks/test_ecs.py
+++ b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from __future__ import annotations
from datetime import timedelta
@@ -23,73 +24,16 @@ from unittest.mock import PropertyMock
import pytest
from botocore.exceptions import ClientError
-from airflow.providers.amazon.aws.exceptions import EcsOperatorError,
EcsTaskFailToStart
-from airflow.providers.amazon.aws.hooks.ecs import EcsHook, EcsTaskLogFetcher,
should_retry, should_retry_eni
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
-
-DEFAULT_CONN_ID: str = "aws_default"
-REGION: str = "us-east-1"
-
-
[email protected]
-def mock_conn():
- with mock.patch.object(EcsHook, "conn") as _conn:
- yield _conn
-
-
-class TestEksHooks:
- def test_hook(self) -> None:
- hook = EcsHook(region_name=REGION)
- assert hook.conn is not None
- assert hook.aws_conn_id == DEFAULT_CONN_ID
- assert hook.region_name == REGION
-
- def test_get_cluster_state(self, mock_conn) -> None:
- mock_conn.describe_clusters.return_value = {"clusters": [{"status":
"ACTIVE"}]}
- assert EcsHook().get_cluster_state(cluster_name="cluster_name") ==
"ACTIVE"
-
- def test_get_task_definition_state(self, mock_conn) -> None:
- mock_conn.describe_task_definition.return_value = {"taskDefinition":
{"status": "ACTIVE"}}
- assert
EcsHook().get_task_definition_state(task_definition="task_name") == "ACTIVE"
-
- def test_get_task_state(self, mock_conn) -> None:
- mock_conn.describe_tasks.return_value = {"tasks": [{"lastStatus":
"ACTIVE"}]}
- assert EcsHook().get_task_state(cluster="cluster_name",
task="task_name") == "ACTIVE"
-
-
-class TestShouldRetry:
- def test_return_true_on_valid_reason(self):
- assert should_retry(EcsOperatorError([{"reason": "RESOURCE:MEMORY"}],
"Foo"))
-
- def test_return_false_on_invalid_reason(self):
- assert not should_retry(EcsOperatorError([{"reason":
"CLUSTER_NOT_FOUND"}], "Foo"))
-
-
-class TestShouldRetryEni:
- def test_return_true_on_valid_reason(self):
- assert should_retry_eni(
- EcsTaskFailToStart(
- "The task failed to start due to: "
- "Timeout waiting for network interface provisioning to
complete."
- )
- )
-
- def test_return_false_on_invalid_reason(self):
- assert not should_retry_eni(
- EcsTaskFailToStart(
- "The task failed to start due to: "
- "CannotPullContainerError: "
- "ref pull has been retried 5 time(s): failed to resolve
reference"
- )
- )
+from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
-class TestEcsTaskLogFetcher:
+class TestAwsTaskLogFetcher:
@mock.patch("logging.Logger")
def set_up_log_fetcher(self, logger_mock):
self.logger_mock = logger_mock
- self.log_fetcher = EcsTaskLogFetcher(
+ self.log_fetcher = AwsTaskLogFetcher(
log_group="test_log_group",
log_stream_name="test_log_stream_name",
fetch_interval=timedelta(milliseconds=1),