This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e01ff4749c Add realtime container execution logs for BatchOperator 
(#31837)
e01ff4749c is described below

commit e01ff4749cb2469b21f467a1b0089d0115f39368
Author: Anirudh Krishnan <[email protected]>
AuthorDate: Mon Jun 19 12:00:17 2023 -0500

    Add realtime container execution logs for BatchOperator (#31837)
    
    * Update param description for get_batch_log_fetcher in batch_waiters
    
    * Update the log fetcher with latest changes with continuation token
---
 airflow/providers/amazon/aws/hooks/batch_client.py |  23 +++-
 .../providers/amazon/aws/hooks/batch_waiters.py    |  30 ++++-
 airflow/providers/amazon/aws/hooks/ecs.py          |  89 -------------
 airflow/providers/amazon/aws/operators/batch.py    |  39 +++++-
 airflow/providers/amazon/aws/operators/ecs.py      |   8 +-
 .../providers/amazon/aws/utils/task_log_fetcher.py | 109 ++++++++++++++++
 .../amazon/aws/hooks/test_batch_client.py          |  25 ++++
 .../amazon/aws/hooks/test_batch_waiters.py         |  39 ++++++
 tests/providers/amazon/aws/hooks/test_ecs.py       | 138 +--------------------
 tests/providers/amazon/aws/operators/test_batch.py |  28 +++++
 tests/providers/amazon/aws/operators/test_ecs.py   |  12 +-
 .../test_ecs.py => utils/test_task_log_fetcher.py} |  64 +---------
 12 files changed, 298 insertions(+), 306 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py 
b/airflow/providers/amazon/aws/hooks/batch_client.py
index 4f6e217341..be24aadb67 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -28,6 +28,7 @@ from __future__ import annotations
 
 from random import uniform
 from time import sleep
+from typing import Callable
 
 import botocore.client
 import botocore.exceptions
@@ -35,6 +36,7 @@ import botocore.waiter
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 from airflow.typing_compat import Protocol, runtime_checkable
 
 
@@ -253,7 +255,12 @@ class BatchClientHook(AwsBaseHook):
 
         raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: 
{job}")
 
-    def wait_for_job(self, job_id: str, delay: int | float | None = None) -> 
None:
+    def wait_for_job(
+        self,
+        job_id: str,
+        delay: int | float | None = None,
+        get_batch_log_fetcher: Callable[[str], AwsTaskLogFetcher | None] | 
None = None,
+    ) -> None:
         """
         Wait for Batch job to complete.
 
@@ -261,11 +268,23 @@ class BatchClientHook(AwsBaseHook):
 
         :param delay: a delay before polling for job status
 
+        :param get_batch_log_fetcher : a method that returns batch_log_fetcher
+
         :raises: AirflowException
         """
         self.delay(delay)
         self.poll_for_job_running(job_id, delay)
-        self.poll_for_job_complete(job_id, delay)
+        batch_log_fetcher = None
+        try:
+            if get_batch_log_fetcher:
+                batch_log_fetcher = get_batch_log_fetcher(job_id)
+                if batch_log_fetcher:
+                    batch_log_fetcher.start()
+            self.poll_for_job_complete(job_id, delay)
+        finally:
+            if batch_log_fetcher:
+                batch_log_fetcher.stop()
+                batch_log_fetcher.join()
         self.log.info("AWS Batch job (%s) has completed", job_id)
 
     def poll_for_job_running(self, job_id: str, delay: int | float | None = 
None) -> None:
diff --git a/airflow/providers/amazon/aws/hooks/batch_waiters.py 
b/airflow/providers/amazon/aws/hooks/batch_waiters.py
index c746798dff..ec01fbadd9 100644
--- a/airflow/providers/amazon/aws/hooks/batch_waiters.py
+++ b/airflow/providers/amazon/aws/hooks/batch_waiters.py
@@ -29,6 +29,7 @@ import json
 import sys
 from copy import deepcopy
 from pathlib import Path
+from typing import Callable
 
 import botocore.client
 import botocore.exceptions
@@ -36,6 +37,7 @@ import botocore.waiter
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 
 
 class BatchWaitersHook(BatchClientHook):
@@ -184,7 +186,12 @@ class BatchWaitersHook(BatchClientHook):
         """
         return self.waiter_model.waiter_names
 
-    def wait_for_job(self, job_id: str, delay: int | float | None = None) -> 
None:
+    def wait_for_job(
+        self,
+        job_id: str,
+        delay: int | float | None = None,
+        get_batch_log_fetcher: Callable[[str], AwsTaskLogFetcher | None] | 
None = None,
+    ) -> None:
         """
         Wait for Batch job to complete.  This assumes that the 
``.waiter_model`` is configured
         using some variation of the ``.default_config`` so that it can 
generate waiters with the
@@ -194,6 +201,9 @@ class BatchWaitersHook(BatchClientHook):
 
         :param delay:  A delay before polling for job status
 
+        :param get_batch_log_fetcher: A method that returns batch_log_fetcher 
of
+            type AwsTaskLogFetcher or None when the CloudWatch log stream 
hasn't been created yet.
+
         :raises: AirflowException
 
         .. note::
@@ -216,10 +226,20 @@ class BatchWaitersHook(BatchClientHook):
             waiter.config.max_attempts = sys.maxsize  # timeout is managed by 
Airflow
             waiter.wait(jobs=[job_id])
 
-            waiter = self.get_waiter("JobComplete")
-            waiter.config.delay = self.add_jitter(waiter.config.delay, 
width=2, minima=1)
-            waiter.config.max_attempts = sys.maxsize  # timeout is managed by 
Airflow
-            waiter.wait(jobs=[job_id])
+            batch_log_fetcher = None
+            try:
+                if get_batch_log_fetcher:
+                    batch_log_fetcher = get_batch_log_fetcher(job_id)
+                    if batch_log_fetcher:
+                        batch_log_fetcher.start()
+                waiter = self.get_waiter("JobComplete")
+                waiter.config.delay = self.add_jitter(waiter.config.delay, 
width=2, minima=1)
+                waiter.config.max_attempts = sys.maxsize  # timeout is managed 
by Airflow
+                waiter.wait(jobs=[job_id])
+            finally:
+                if batch_log_fetcher:
+                    batch_log_fetcher.stop()
+                    batch_log_fetcher.join()
 
         except (botocore.exceptions.ClientError, 
botocore.exceptions.WaiterError) as err:
             raise AirflowException(err)
diff --git a/airflow/providers/amazon/aws/hooks/ecs.py 
b/airflow/providers/amazon/aws/hooks/ecs.py
index 94baeb0a9a..ad45da5d0a 100644
--- a/airflow/providers/amazon/aws/hooks/ecs.py
+++ b/airflow/providers/amazon/aws/hooks/ecs.py
@@ -17,19 +17,10 @@
 # under the License.
 from __future__ import annotations
 
-import time
-from collections import deque
-from datetime import datetime, timedelta
-from logging import Logger
-from threading import Event, Thread
-from typing import Generator
-
-from botocore.exceptions import ClientError, ConnectionClosedError
 from botocore.waiter import Waiter
 
 from airflow.providers.amazon.aws.exceptions import EcsOperatorError, 
EcsTaskFailToStart
 from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
-from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
 from airflow.providers.amazon.aws.utils import _StringCompareEnum
 from airflow.typing_compat import Protocol, runtime_checkable
 
@@ -143,86 +134,6 @@ class EcsHook(AwsGenericHook):
         return self.conn.describe_tasks(cluster=cluster, 
tasks=[task])["tasks"][0]["lastStatus"]
 
 
-class EcsTaskLogFetcher(Thread):
-    """
-    Fetches Cloudwatch log events with specific interval as a thread
-    and sends the log events to the info channel of the provided logger.
-    """
-
-    def __init__(
-        self,
-        *,
-        log_group: str,
-        log_stream_name: str,
-        fetch_interval: timedelta,
-        logger: Logger,
-        aws_conn_id: str | None = "aws_default",
-        region_name: str | None = None,
-    ):
-        super().__init__()
-        self._event = Event()
-
-        self.fetch_interval = fetch_interval
-
-        self.logger = logger
-        self.log_group = log_group
-        self.log_stream_name = log_stream_name
-
-        self.hook = AwsLogsHook(aws_conn_id=aws_conn_id, 
region_name=region_name)
-
-    def run(self) -> None:
-        continuation_token = AwsLogsHook.ContinuationToken()
-        while not self.is_stopped():
-            time.sleep(self.fetch_interval.total_seconds())
-            log_events = self._get_log_events(continuation_token)
-            for log_event in log_events:
-                self.logger.info(self._event_to_str(log_event))
-
-    def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None 
= None) -> Generator:
-        if skip_token is None:
-            skip_token = AwsLogsHook.ContinuationToken()
-        try:
-            yield from self.hook.get_log_events(
-                self.log_group, self.log_stream_name, 
continuation_token=skip_token
-            )
-        except ClientError as error:
-            if error.response["Error"]["Code"] != "ResourceNotFoundException":
-                self.logger.warning("Error on retrieving Cloudwatch log 
events", error)
-            else:
-                self.logger.info(
-                    "Cannot find log stream yet, it can take a couple of 
seconds to show up. "
-                    "If this error persists, check that the log group and 
stream are correct: "
-                    "group: %s\tstream: %s",
-                    self.log_group,
-                    self.log_stream_name,
-                )
-            yield from ()
-        except ConnectionClosedError as error:
-            self.logger.warning("ConnectionClosedError on retrieving 
Cloudwatch log events", error)
-            yield from ()
-
-    def _event_to_str(self, event: dict) -> str:
-        event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
-        formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
-        message = event["message"]
-        return f"[{formatted_event_dt}] {message}"
-
-    def get_last_log_messages(self, number_messages) -> list:
-        return [log["message"] for log in deque(self._get_log_events(), 
maxlen=number_messages)]
-
-    def get_last_log_message(self) -> str | None:
-        try:
-            return self.get_last_log_messages(1)[0]
-        except IndexError:
-            return None
-
-    def is_stopped(self) -> bool:
-        return self._event.is_set()
-
-    def stop(self):
-        self._event.set()
-
-
 @runtime_checkable
 class EcsProtocol(Protocol):
     """
diff --git a/airflow/providers/amazon/aws/operators/batch.py 
b/airflow/providers/amazon/aws/operators/batch.py
index 2825ed5a01..88feb01311 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -25,6 +25,7 @@
 from __future__ import annotations
 
 import warnings
+from datetime import timedelta
 from functools import cached_property
 from typing import TYPE_CHECKING, Any, Sequence
 
@@ -39,6 +40,7 @@ from airflow.providers.amazon.aws.links.batch import (
 from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
 from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
 from airflow.providers.amazon.aws.utils import trim_none_values
+from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -79,6 +81,10 @@ class BatchOperator(BaseOperator):
     :param tags: collection of tags to apply to the AWS Batch job submission
         if None, no tags are submitted
     :param deferrable: Run operator in the deferrable mode.
+    :param awslogs_enabled: Specifies whether logs from CloudWatch
+        should be printed or not, False.
+        If it is an array job, only the logs of the first task will be printed.
+    :param awslogs_fetch_interval: The interval with which cloudwatch logs are 
to be fetched, 30 sec.
     :param poll_interval: (Deferrable mode only) Time in seconds to wait 
between polling.
 
     .. note::
@@ -104,6 +110,8 @@ class BatchOperator(BaseOperator):
         "waiters",
         "tags",
         "wait_for_completion",
+        "awslogs_enabled",
+        "awslogs_fetch_interval",
     )
     template_fields_renderers = {
         "container_overrides": "json",
@@ -145,6 +153,8 @@ class BatchOperator(BaseOperator):
         wait_for_completion: bool = True,
         deferrable: bool = False,
         poll_interval: int = 30,
+        awslogs_enabled: bool = False,
+        awslogs_fetch_interval: timedelta = timedelta(seconds=30),
         **kwargs,
     ) -> None:
         BaseOperator.__init__(self, **kwargs)
@@ -179,6 +189,8 @@ class BatchOperator(BaseOperator):
         self.wait_for_completion = wait_for_completion
         self.deferrable = deferrable
         self.poll_interval = poll_interval
+        self.awslogs_enabled = awslogs_enabled
+        self.awslogs_fetch_interval = awslogs_fetch_interval
 
         # params for hook
         self.max_retries = max_retries
@@ -319,10 +331,16 @@ class BatchOperator(BaseOperator):
                 job_queue_arn=job_queue_arn,
             )
 
-        if self.waiters:
-            self.waiters.wait_for_job(self.job_id)
+        if self.awslogs_enabled:
+            if self.waiters:
+                self.waiters.wait_for_job(self.job_id, 
get_batch_log_fetcher=self._get_batch_log_fetcher)
+            else:
+                self.hook.wait_for_job(self.job_id, 
get_batch_log_fetcher=self._get_batch_log_fetcher)
         else:
-            self.hook.wait_for_job(self.job_id)
+            if self.waiters:
+                self.waiters.wait_for_job(self.job_id)
+            else:
+                self.hook.wait_for_job(self.job_id)
 
         awslogs = self.hook.get_job_all_awslogs_info(self.job_id)
         if awslogs:
@@ -347,6 +365,21 @@ class BatchOperator(BaseOperator):
         self.hook.check_job_success(self.job_id)
         self.log.info("AWS Batch job (%s) succeeded", self.job_id)
 
+    def _get_batch_log_fetcher(self, job_id: str) -> AwsTaskLogFetcher | None:
+        awslog_info = self.hook.get_job_awslogs_info(job_id)
+
+        if not awslog_info:
+            return None
+
+        return AwsTaskLogFetcher(
+            aws_conn_id=self.aws_conn_id,
+            region_name=awslog_info["awslogs_region"],
+            log_group=awslog_info["awslogs_group"],
+            log_stream_name=awslog_info["awslogs_stream_name"],
+            fetch_interval=self.awslogs_fetch_interval,
+            logger=self.log,
+        )
+
 
 class BatchCreateComputeEnvironmentOperator(BaseOperator):
     """Create an AWS Batch compute environment.
diff --git a/airflow/providers/amazon/aws/operators/ecs.py 
b/airflow/providers/amazon/aws/operators/ecs.py
index a72b15c18a..bc8c4b70d7 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -33,9 +33,9 @@ from airflow.providers.amazon.aws.hooks.base_aws import 
AwsBaseHook
 from airflow.providers.amazon.aws.hooks.ecs import (
     EcsClusterStates,
     EcsHook,
-    EcsTaskLogFetcher,
     should_retry_eni,
 )
+from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 from airflow.utils.helpers import prune_dict
 from airflow.utils.session import provide_session
 
@@ -447,7 +447,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
 
         self.arn: str | None = None
         self.retry_args = quota_retry
-        self.task_log_fetcher: EcsTaskLogFetcher | None = None
+        self.task_log_fetcher: AwsTaskLogFetcher | None = None
         self.wait_for_completion = wait_for_completion
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
@@ -597,12 +597,12 @@ class EcsRunTaskOperator(EcsBaseOperator):
     def _aws_logs_enabled(self):
         return self.awslogs_group and self.awslogs_stream_prefix
 
-    def _get_task_log_fetcher(self) -> EcsTaskLogFetcher:
+    def _get_task_log_fetcher(self) -> AwsTaskLogFetcher:
         if not self.awslogs_group:
             raise ValueError("must specify awslogs_group to fetch task logs")
         log_stream_name = 
f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"
 
-        return EcsTaskLogFetcher(
+        return AwsTaskLogFetcher(
             aws_conn_id=self.aws_conn_id,
             region_name=self.awslogs_region,
             log_group=self.awslogs_group,
diff --git a/airflow/providers/amazon/aws/utils/task_log_fetcher.py 
b/airflow/providers/amazon/aws/utils/task_log_fetcher.py
new file mode 100644
index 0000000000..97b43a67b2
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/task_log_fetcher.py
@@ -0,0 +1,109 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import time
+from collections import deque
+from datetime import datetime, timedelta
+from logging import Logger
+from threading import Event, Thread
+from typing import Generator
+
+from botocore.exceptions import ClientError, ConnectionClosedError
+
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+
+
+class AwsTaskLogFetcher(Thread):
+    """
+    Fetches Cloudwatch log events with specific interval as a thread
+    and sends the log events to the info channel of the provided logger.
+    """
+
+    def __init__(
+        self,
+        *,
+        log_group: str,
+        log_stream_name: str,
+        fetch_interval: timedelta,
+        logger: Logger,
+        aws_conn_id: str | None = "aws_default",
+        region_name: str | None = None,
+    ):
+        super().__init__()
+        self._event = Event()
+
+        self.fetch_interval = fetch_interval
+
+        self.logger = logger
+        self.log_group = log_group
+        self.log_stream_name = log_stream_name
+
+        self.hook = AwsLogsHook(aws_conn_id=aws_conn_id, 
region_name=region_name)
+
+    def run(self) -> None:
+        continuation_token = AwsLogsHook.ContinuationToken()
+        while not self.is_stopped():
+            time.sleep(self.fetch_interval.total_seconds())
+            log_events = self._get_log_events(continuation_token)
+            for log_event in log_events:
+                self.logger.info(self._event_to_str(log_event))
+
+    def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None 
= None) -> Generator:
+        if skip_token is None:
+            skip_token = AwsLogsHook.ContinuationToken()
+        try:
+            yield from self.hook.get_log_events(
+                self.log_group, self.log_stream_name, 
continuation_token=skip_token
+            )
+        except ClientError as error:
+            if error.response["Error"]["Code"] != "ResourceNotFoundException":
+                self.logger.warning("Error on retrieving Cloudwatch log 
events", error)
+            else:
+                self.logger.info(
+                    "Cannot find log stream yet, it can take a couple of 
seconds to show up. "
+                    "If this error persists, check that the log group and 
stream are correct: "
+                    "group: %s\tstream: %s",
+                    self.log_group,
+                    self.log_stream_name,
+                )
+            yield from ()
+        except ConnectionClosedError as error:
+            self.logger.warning("ConnectionClosedError on retrieving 
Cloudwatch log events", error)
+            yield from ()
+
+    def _event_to_str(self, event: dict) -> str:
+        event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
+        formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
+        message = event["message"]
+        return f"[{formatted_event_dt}] {message}"
+
+    def get_last_log_messages(self, number_messages) -> list:
+        return [log["message"] for log in deque(self._get_log_events(), 
maxlen=number_messages)]
+
+    def get_last_log_message(self) -> str | None:
+        try:
+            return self.get_last_log_messages(1)[0]
+        except IndexError:
+            return None
+
+    def is_stopped(self) -> bool:
+        return self._event.is_set()
+
+    def stop(self):
+        self._event.set()
diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py 
b/tests/providers/amazon/aws/hooks/test_batch_client.py
index aef8be1d26..ea04cb2080 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_client.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_client.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import logging
+import time
 from unittest import mock
 
 import botocore.exceptions
@@ -25,6 +26,7 @@ import pytest
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 
 # Use dummy AWS credentials
 AWS_REGION = "eu-west-1"
@@ -115,6 +117,29 @@ class TestBatchClient:
 
         assert self.client_mock.describe_jobs.call_count == 4
 
+    def test_wait_for_job_with_logs(self):
+        self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": 
JOB_ID, "status": "SUCCEEDED"}]}
+
+        batch_log_fetcher = mock.Mock(spec=AwsTaskLogFetcher)
+        mock_get_batch_log_fetcher = mock.Mock(return_value=batch_log_fetcher)
+
+        thread_start = mock.Mock(side_effect=lambda: time.sleep(2))
+        thread_stop = mock.Mock(side_effect=lambda: time.sleep(2))
+        thread_join = mock.Mock(side_effect=lambda: time.sleep(2))
+
+        with mock.patch.object(
+            batch_log_fetcher, "start", thread_start
+        ) as mock_fetcher_start, mock.patch.object(
+            batch_log_fetcher, "stop", thread_stop
+        ) as mock_fetcher_stop, mock.patch.object(
+            batch_log_fetcher, "join", thread_join
+        ) as mock_fetcher_join:
+            self.batch_client.wait_for_job(JOB_ID, 
get_batch_log_fetcher=mock_get_batch_log_fetcher)
+            mock_get_batch_log_fetcher.assert_called_with(JOB_ID)
+            mock_fetcher_start.assert_called_once()
+            mock_fetcher_stop.assert_called_once()
+            mock_fetcher_join.assert_called_once()
+
     def test_poll_job_running_for_status_running(self):
         self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": 
JOB_ID, "status": "RUNNING"}]}
         self.batch_client.poll_for_job_running(JOB_ID)
diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py 
b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
index 285784d184..cdf581417c 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import inspect
 import itertools
+import time
 from unittest import mock
 
 import boto3
@@ -29,6 +30,7 @@ from moto import mock_batch
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.batch_waiters import BatchWaitersHook
+from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 
 INTERMEDIATE_STATES = ("SUBMITTED", "PENDING", "RUNNABLE", "STARTING")
 RUNNING_STATE = "RUNNING"
@@ -145,6 +147,43 @@ class TestBatchWaiters:
             assert mock_config.delay == 0
             assert mock_config.max_attempts == sys.maxsize
 
+    def test_wait_for_job_with_cloudwatch_logs(self):
+
+        # mock delay for speedy test
+        mock_jitter = mock.Mock(return_value=0)
+        self.batch_waiters.add_jitter = mock_jitter
+
+        batch_log_fetcher = mock.Mock(spec=AwsTaskLogFetcher)
+        mock_get_batch_log_fetcher = mock.Mock(return_value=batch_log_fetcher)
+
+        thread_start = mock.Mock(side_effect=lambda: time.sleep(2))
+        thread_stop = mock.Mock(side_effect=lambda: time.sleep(2))
+        thread_join = mock.Mock(side_effect=lambda: time.sleep(2))
+
+        with mock.patch.object(self.batch_waiters, "get_waiter") as 
mock_get_waiter, mock.patch.object(
+            batch_log_fetcher, "start", thread_start
+        ) as mock_fetcher_start, mock.patch.object(
+            batch_log_fetcher, "stop", thread_stop
+        ) as mock_fetcher_stop, mock.patch.object(
+            batch_log_fetcher, "join", thread_join
+        ) as mock_fetcher_join:
+
+            # Run the wait_for_job method
+            self.batch_waiters.wait_for_job(self.job_id, 
get_batch_log_fetcher=mock_get_batch_log_fetcher)
+
+            # Assertions
+            assert mock_get_waiter.call_args_list == [
+                mock.call("JobExists"),
+                mock.call("JobRunning"),
+                mock.call("JobComplete"),
+            ]
+
+            
mock_get_waiter.return_value.wait.assert_called_with(jobs=[self.job_id])
+            mock_get_batch_log_fetcher.assert_called_with(self.job_id)
+            mock_fetcher_start.assert_called_once()
+            mock_fetcher_stop.assert_called_once()
+            mock_fetcher_join.assert_called_once()
+
     def test_wait_for_job_raises_for_client_error(self):
         # mock delay for speedy test
         mock_jitter = mock.Mock(return_value=0)
diff --git a/tests/providers/amazon/aws/hooks/test_ecs.py 
b/tests/providers/amazon/aws/hooks/test_ecs.py
index d90ec65e52..db1e494f94 100644
--- a/tests/providers/amazon/aws/hooks/test_ecs.py
+++ b/tests/providers/amazon/aws/hooks/test_ecs.py
@@ -16,16 +16,12 @@
 # under the License.
 from __future__ import annotations
 
-from datetime import timedelta
 from unittest import mock
-from unittest.mock import PropertyMock
 
 import pytest
-from botocore.exceptions import ClientError
 
 from airflow.providers.amazon.aws.exceptions import EcsOperatorError, 
EcsTaskFailToStart
-from airflow.providers.amazon.aws.hooks.ecs import EcsHook, EcsTaskLogFetcher, 
should_retry, should_retry_eni
-from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.hooks.ecs import EcsHook, should_retry, 
should_retry_eni
 
 DEFAULT_CONN_ID: str = "aws_default"
 REGION: str = "us-east-1"
@@ -82,135 +78,3 @@ class TestShouldRetryEni:
                 "ref pull has been retried 5 time(s): failed to resolve 
reference"
             )
         )
-
-
-class TestEcsTaskLogFetcher:
-    @mock.patch("logging.Logger")
-    def set_up_log_fetcher(self, logger_mock):
-        self.logger_mock = logger_mock
-
-        self.log_fetcher = EcsTaskLogFetcher(
-            log_group="test_log_group",
-            log_stream_name="test_log_stream_name",
-            fetch_interval=timedelta(milliseconds=1),
-            logger=logger_mock,
-        )
-
-    def setup_method(self):
-        self.set_up_log_fetcher()
-
-    @mock.patch(
-        "threading.Event.is_set",
-        side_effect=(False, False, False, True),
-    )
-    @mock.patch(
-        "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
-        side_effect=(
-            iter(
-                [
-                    {"timestamp": 1617400267123, "message": "First"},
-                    {"timestamp": 1617400367456, "message": "Second"},
-                ]
-            ),
-            iter(
-                [
-                    {"timestamp": 1617400467789, "message": "Third"},
-                ]
-            ),
-            iter([]),
-        ),
-    )
-    def test_run(self, get_log_events_mock, event_is_set_mock):
-
-        self.log_fetcher.run()
-
-        self.logger_mock.info.assert_has_calls(
-            [
-                mock.call("[2021-04-02 21:51:07,123] First"),
-                mock.call("[2021-04-02 21:52:47,456] Second"),
-                mock.call("[2021-04-02 21:54:27,789] Third"),
-            ]
-        )
-
-    @mock.patch(
-        "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
-        side_effect=ClientError({"Error": {"Code": 
"ResourceNotFoundException"}}, None),
-    )
-    def test_get_log_events_with_expected_error(self, get_log_events_mock):
-        with pytest.raises(StopIteration):
-            next(self.log_fetcher._get_log_events())
-
-    @mock.patch(
-        "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
-        side_effect=Exception(),
-    )
-    def test_get_log_events_with_unexpected_error(self, get_log_events_mock):
-        with pytest.raises(Exception):
-            next(self.log_fetcher._get_log_events())
-
-    @mock.patch.object(AwsLogsHook, "conn", new_callable=PropertyMock)
-    def test_get_log_events_updates_token(self, logs_conn_mock):
-        logs_conn_mock().get_log_events.return_value = {
-            "events": ["my_event"],
-            "nextForwardToken": "my_next_token",
-        }
-
-        token = AwsLogsHook.ContinuationToken()
-        list(self.log_fetcher._get_log_events(token))
-
-        assert token.value == "my_next_token"
-        # 2 calls expected, it's only on the second one that the stop 
condition old_token == next_token is met
-        assert logs_conn_mock().get_log_events.call_count == 2
-
-    def test_event_to_str(self):
-        events = [
-            {"timestamp": 1617400267123, "message": "First"},
-            {"timestamp": 1617400367456, "message": "Second"},
-            {"timestamp": 1617400467789, "message": "Third"},
-        ]
-        assert [self.log_fetcher._event_to_str(event) for event in events] == (
-            [
-                "[2021-04-02 21:51:07,123] First",
-                "[2021-04-02 21:52:47,456] Second",
-                "[2021-04-02 21:54:27,789] Third",
-            ]
-        )
-
-    @mock.patch(
-        "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
-        return_value=(),
-    )
-    def test_get_last_log_message_with_no_log_events(self, mock_log_events):
-        assert self.log_fetcher.get_last_log_message() is None
-
-    @mock.patch(
-        "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
-        return_value=iter(
-            [
-                {"timestamp": 1617400267123, "message": "First"},
-                {"timestamp": 1617400367456, "message": "Second"},
-            ]
-        ),
-    )
-    def test_get_last_log_message_with_log_events(self, mock_log_events):
-        assert self.log_fetcher.get_last_log_message() == "Second"
-
-    @mock.patch(
-        "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
-        return_value=iter(
-            [
-                {"timestamp": 1617400267123, "message": "First"},
-                {"timestamp": 1617400367456, "message": "Second"},
-                {"timestamp": 1617400367458, "message": "Third"},
-            ]
-        ),
-    )
-    def test_get_last_log_messages_with_log_events(self, mock_log_events):
-        assert self.log_fetcher.get_last_log_messages(2) == ["Second", "Third"]
-
-    @mock.patch(
-        "airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events",
-        return_value=(),
-    )
-    def test_get_last_log_messages_with_no_log_events(self, mock_log_events):
-        assert self.log_fetcher.get_last_log_messages(2) == []
diff --git a/tests/providers/amazon/aws/operators/test_batch.py 
b/tests/providers/amazon/aws/operators/test_batch.py
index f559424dff..a65e00d8db 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -120,6 +120,8 @@ class TestBatchOperator:
             "waiters",
             "tags",
             "wait_for_completion",
+            "awslogs_enabled",
+            "awslogs_fetch_interval",
         )
 
     @mock.patch.object(BatchClientHook, "get_job_description")
@@ -273,6 +275,32 @@ class TestBatchOperator:
             batch.execute(context=None)
         assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger 
is not a BatchOperatorTrigger"
 
+    @mock.patch.object(BatchClientHook, "get_job_description")
+    @mock.patch.object(BatchClientHook, "wait_for_job")
+    @mock.patch.object(BatchClientHook, "check_job_success")
+    
@mock.patch("airflow.providers.amazon.aws.links.batch.BatchJobQueueLink.persist")
+    
@mock.patch("airflow.providers.amazon.aws.links.batch.BatchJobDefinitionLink.persist")
+    def test_monitor_job_with_logs(
+        self, job_definition_persist_mock, job_queue_persist_mock, check_mock, 
wait_mock, job_description_mock
+    ):
+        batch = BatchOperator(
+            task_id="task",
+            job_name=JOB_NAME,
+            job_queue="queue",
+            job_definition="hello-world",
+            awslogs_enabled=True,
+        )
+
+        batch.job_id = JOB_ID
+
+        batch.monitor_job(context=None)
+
+        job_description_mock.assert_called_with(job_id=JOB_ID)
+        job_definition_persist_mock.assert_called_once()
+        job_queue_persist_mock.assert_called_once()
+        wait_mock.assert_called_once()
+        assert len(wait_mock.call_args) == 2
+
 
 class TestBatchCreateComputeEnvironmentOperator:
     @mock.patch.object(BatchClientHook, "client")
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py 
b/tests/providers/amazon/aws/operators/test_ecs.py
index 6366a879c2..8c99a02ce8 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -35,8 +35,8 @@ from airflow.providers.amazon.aws.operators.ecs import (
     EcsDeregisterTaskDefinitionOperator,
     EcsRegisterTaskDefinitionOperator,
     EcsRunTaskOperator,
-    EcsTaskLogFetcher,
 )
+from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 from airflow.utils.types import NOTSET
 
 CLUSTER_NAME = "test_cluster"
@@ -368,7 +368,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         client_mock.describe_tasks.assert_called_once_with(cluster="c", 
tasks=["arn"])
 
     @mock.patch.object(EcsBaseOperator, "client")
-    @mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
+    
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
     def test_check_success_tasks_raises_cloudwatch_logs(self, 
log_fetcher_mock, client_mock):
         self.ecs.arn = "arn"
         self.ecs.task_log_fetcher = log_fetcher_mock
@@ -387,7 +387,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         client_mock.describe_tasks.assert_called_once_with(cluster="c", 
tasks=["arn"])
 
     @mock.patch.object(EcsBaseOperator, "client")
-    @mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
+    
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
     def test_check_success_tasks_raises_cloudwatch_logs_empty(self, 
log_fetcher_mock, client_mock):
         self.ecs.arn = "arn"
         self.ecs.task_log_fetcher = log_fetcher_mock
@@ -624,7 +624,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         assert self.ecs.arn == 
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
 
     @mock.patch.object(EcsBaseOperator, "client")
-    @mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
+    
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
     def test_execute_xcom_with_log(self, log_fetcher_mock, client_mock):
         self.ecs.do_xcom_push = True
         self.ecs.task_log_fetcher = log_fetcher_mock
@@ -634,7 +634,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         assert self.ecs.execute(None) == "Log output"
 
     @mock.patch.object(EcsBaseOperator, "client")
-    @mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
+    
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
     def test_execute_xcom_with_no_log(self, log_fetcher_mock, client_mock):
         self.ecs.do_xcom_push = True
         self.ecs.task_log_fetcher = log_fetcher_mock
@@ -649,7 +649,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         assert self.ecs.execute(None) is None
 
     @mock.patch.object(EcsBaseOperator, "client")
-    @mock.patch.object(EcsTaskLogFetcher, "get_last_log_message", 
return_value="Log output")
+    @mock.patch.object(AwsTaskLogFetcher, "get_last_log_message", 
return_value="Log output")
     def test_execute_xcom_disabled(self, log_fetcher_mock, client_mock):
         self.ecs.do_xcom_push = False
         assert self.ecs.execute(None) is None
diff --git a/tests/providers/amazon/aws/hooks/test_ecs.py 
b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
similarity index 71%
copy from tests/providers/amazon/aws/hooks/test_ecs.py
copy to tests/providers/amazon/aws/utils/test_task_log_fetcher.py
index d90ec65e52..dbda751cfb 100644
--- a/tests/providers/amazon/aws/hooks/test_ecs.py
+++ b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 from __future__ import annotations
 
 from datetime import timedelta
@@ -23,73 +24,16 @@ from unittest.mock import PropertyMock
 import pytest
 from botocore.exceptions import ClientError
 
-from airflow.providers.amazon.aws.exceptions import EcsOperatorError, 
EcsTaskFailToStart
-from airflow.providers.amazon.aws.hooks.ecs import EcsHook, EcsTaskLogFetcher, 
should_retry, should_retry_eni
 from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
-
-DEFAULT_CONN_ID: str = "aws_default"
-REGION: str = "us-east-1"
-
-
[email protected]
-def mock_conn():
-    with mock.patch.object(EcsHook, "conn") as _conn:
-        yield _conn
-
-
-class TestEksHooks:
-    def test_hook(self) -> None:
-        hook = EcsHook(region_name=REGION)
-        assert hook.conn is not None
-        assert hook.aws_conn_id == DEFAULT_CONN_ID
-        assert hook.region_name == REGION
-
-    def test_get_cluster_state(self, mock_conn) -> None:
-        mock_conn.describe_clusters.return_value = {"clusters": [{"status": 
"ACTIVE"}]}
-        assert EcsHook().get_cluster_state(cluster_name="cluster_name") == 
"ACTIVE"
-
-    def test_get_task_definition_state(self, mock_conn) -> None:
-        mock_conn.describe_task_definition.return_value = {"taskDefinition": 
{"status": "ACTIVE"}}
-        assert 
EcsHook().get_task_definition_state(task_definition="task_name") == "ACTIVE"
-
-    def test_get_task_state(self, mock_conn) -> None:
-        mock_conn.describe_tasks.return_value = {"tasks": [{"lastStatus": 
"ACTIVE"}]}
-        assert EcsHook().get_task_state(cluster="cluster_name", 
task="task_name") == "ACTIVE"
-
-
-class TestShouldRetry:
-    def test_return_true_on_valid_reason(self):
-        assert should_retry(EcsOperatorError([{"reason": "RESOURCE:MEMORY"}], 
"Foo"))
-
-    def test_return_false_on_invalid_reason(self):
-        assert not should_retry(EcsOperatorError([{"reason": 
"CLUSTER_NOT_FOUND"}], "Foo"))
-
-
-class TestShouldRetryEni:
-    def test_return_true_on_valid_reason(self):
-        assert should_retry_eni(
-            EcsTaskFailToStart(
-                "The task failed to start due to: "
-                "Timeout waiting for network interface provisioning to 
complete."
-            )
-        )
-
-    def test_return_false_on_invalid_reason(self):
-        assert not should_retry_eni(
-            EcsTaskFailToStart(
-                "The task failed to start due to: "
-                "CannotPullContainerError: "
-                "ref pull has been retried 5 time(s): failed to resolve 
reference"
-            )
-        )
+from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 
 
-class TestEcsTaskLogFetcher:
+class TestAwsTaskLogFetcher:
     @mock.patch("logging.Logger")
     def set_up_log_fetcher(self, logger_mock):
         self.logger_mock = logger_mock
 
-        self.log_fetcher = EcsTaskLogFetcher(
+        self.log_fetcher = AwsTaskLogFetcher(
             log_group="test_log_group",
             log_stream_name="test_log_stream_name",
             fetch_interval=timedelta(milliseconds=1),


Reply via email to