This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e8be1bf8b22 Allow to pass container_name parameter to 
EcsRunTaskOperator (#46152)
e8be1bf8b22 is described below

commit e8be1bf8b2260f6dbe64b1f879a8d017a9b77e56
Author: Evgeny Liskovets <[email protected]>
AuthorDate: Tue Feb 4 12:03:17 2025 -0500

    Allow to pass container_name parameter to EcsRunTaskOperator (#46152)
---
 providers/src/airflow/providers/amazon/aws/operators/ecs.py |  8 ++++++--
 providers/tests/amazon/aws/operators/test_ecs.py            | 12 ++++++++++++
 2 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/providers/src/airflow/providers/amazon/aws/operators/ecs.py 
b/providers/src/airflow/providers/amazon/aws/operators/ecs.py
index 0ca8e2be8f9..bf59d1b1b7c 100644
--- a/providers/src/airflow/providers/amazon/aws/operators/ecs.py
+++ b/providers/src/airflow/providers/amazon/aws/operators/ecs.py
@@ -375,6 +375,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
     :param awslogs_fetch_interval: the interval that the ECS task log fetcher 
should wait
         in between each Cloudwatch logs fetches.
         If deferrable is set to True, that parameter is ignored and 
waiter_delay is used instead.
+    :param container_name: The name of the container to fetch logs from. If 
not set, the first container is used.
     :param quota_retry: Config if and how to retry the launch of a new ECS 
task, to handle
         transient errors.
     :param reattach: If set to True, will check if the task previously 
launched by the task_instance
@@ -414,6 +415,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
         "awslogs_region",
         "awslogs_stream_prefix",
         "awslogs_fetch_interval",
+        "container_name",
         "propagate_tags",
         "reattach",
         "number_logs_exception",
@@ -445,6 +447,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
         awslogs_region: str | None = None,
         awslogs_stream_prefix: str | None = None,
         awslogs_fetch_interval: timedelta = timedelta(seconds=30),
+        container_name: str | None = None,
         propagate_tags: str | None = None,
         quota_retry: dict | None = None,
         reattach: bool = False,
@@ -484,7 +487,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
             self.awslogs_region = self.region_name
 
         self.arn: str | None = None
-        self.container_name: str | None = None
+        self.container_name: str | None = container_name
         self._started_by: str | None = None
 
         self.retry_args = quota_retry
@@ -628,7 +631,8 @@ class EcsRunTaskOperator(EcsBaseOperator):
         self.log.info("ECS Task started: %s", response)
 
         self.arn = response["tasks"][0]["taskArn"]
-        self.container_name = response["tasks"][0]["containers"][0]["name"]
+        if not self.container_name:
+            self.container_name = response["tasks"][0]["containers"][0]["name"]
         self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))
 
     def _try_reattach_task(self, started_by: str):
diff --git a/providers/tests/amazon/aws/operators/test_ecs.py 
b/providers/tests/amazon/aws/operators/test_ecs.py
index 2290e49d2df..824ba291f0b 100644
--- a/providers/tests/amazon/aws/operators/test_ecs.py
+++ b/providers/tests/amazon/aws/operators/test_ecs.py
@@ -184,6 +184,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
             "awslogs_region",
             "awslogs_stream_prefix",
             "awslogs_fetch_interval",
+            "container_name",
             "propagate_tags",
             "reattach",
             "number_logs_exception",
@@ -752,6 +753,17 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         # task gets described to assert its success
         
client_mock().describe_tasks.assert_called_once_with(cluster="test_cluster", 
tasks=["my_arn"])
 
+    @mock.patch.object(EcsBaseOperator, "client")
+    
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
+    def test_container_name_in_log_stream(self, client_mock, log_fetcher_mock):
+        container_name = "container-name"
+        prefix = "prefix"
+        self.set_up_operator(
+            awslogs_group="awslogs-group", awslogs_stream_prefix=prefix, 
container_name=container_name
+        )
+
+        assert 
self.ecs._get_logs_stream_name().startswith(f"{prefix}/{container_name}/")
+
 
 class TestEcsCreateClusterOperator(EcsBaseTestCase):
     @pytest.mark.parametrize("waiter_delay, waiter_max_attempts", 
WAITERS_TEST_CASES)

Reply via email to