This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e8be1bf8b22 Allow to pass container_name parameter to
EcsRunTaskOperator (#46152)
e8be1bf8b22 is described below
commit e8be1bf8b2260f6dbe64b1f879a8d017a9b77e56
Author: Evgeny Liskovets <[email protected]>
AuthorDate: Tue Feb 4 12:03:17 2025 -0500
Allow to pass container_name parameter to EcsRunTaskOperator (#46152)
---
providers/src/airflow/providers/amazon/aws/operators/ecs.py | 8 ++++++--
providers/tests/amazon/aws/operators/test_ecs.py | 12 ++++++++++++
2 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/providers/src/airflow/providers/amazon/aws/operators/ecs.py
b/providers/src/airflow/providers/amazon/aws/operators/ecs.py
index 0ca8e2be8f9..bf59d1b1b7c 100644
--- a/providers/src/airflow/providers/amazon/aws/operators/ecs.py
+++ b/providers/src/airflow/providers/amazon/aws/operators/ecs.py
@@ -375,6 +375,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
:param awslogs_fetch_interval: the interval that the ECS task log fetcher
should wait
in between each Cloudwatch logs fetches.
If deferrable is set to True, that parameter is ignored and
waiter_delay is used instead.
+ :param container_name: The name of the container to fetch logs from. If
not set, the first container is used.
:param quota_retry: Config if and how to retry the launch of a new ECS
task, to handle
transient errors.
:param reattach: If set to True, will check if the task previously
launched by the task_instance
@@ -414,6 +415,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
"awslogs_region",
"awslogs_stream_prefix",
"awslogs_fetch_interval",
+ "container_name",
"propagate_tags",
"reattach",
"number_logs_exception",
@@ -445,6 +447,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
awslogs_region: str | None = None,
awslogs_stream_prefix: str | None = None,
awslogs_fetch_interval: timedelta = timedelta(seconds=30),
+ container_name: str | None = None,
propagate_tags: str | None = None,
quota_retry: dict | None = None,
reattach: bool = False,
@@ -484,7 +487,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
self.awslogs_region = self.region_name
self.arn: str | None = None
- self.container_name: str | None = None
+ self.container_name: str | None = container_name
self._started_by: str | None = None
self.retry_args = quota_retry
@@ -628,7 +631,8 @@ class EcsRunTaskOperator(EcsBaseOperator):
self.log.info("ECS Task started: %s", response)
self.arn = response["tasks"][0]["taskArn"]
- self.container_name = response["tasks"][0]["containers"][0]["name"]
+ if not self.container_name:
+ self.container_name = response["tasks"][0]["containers"][0]["name"]
self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))
def _try_reattach_task(self, started_by: str):
diff --git a/providers/tests/amazon/aws/operators/test_ecs.py
b/providers/tests/amazon/aws/operators/test_ecs.py
index 2290e49d2df..824ba291f0b 100644
--- a/providers/tests/amazon/aws/operators/test_ecs.py
+++ b/providers/tests/amazon/aws/operators/test_ecs.py
@@ -184,6 +184,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
"awslogs_region",
"awslogs_stream_prefix",
"awslogs_fetch_interval",
+ "container_name",
"propagate_tags",
"reattach",
"number_logs_exception",
@@ -752,6 +753,17 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
# task gets described to assert its success
client_mock().describe_tasks.assert_called_once_with(cluster="test_cluster",
tasks=["my_arn"])
+ @mock.patch.object(EcsBaseOperator, "client")
+
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
+ def test_container_name_in_log_stream(self, client_mock, log_fetcher_mock):
+ container_name = "container-name"
+ prefix = "prefix"
+ self.set_up_operator(
+ awslogs_group="awslogs-group", awslogs_stream_prefix=prefix,
container_name=container_name
+ )
+
+ assert
self.ecs._get_logs_stream_name().startswith(f"{prefix}/{container_name}/")
+
class TestEcsCreateClusterOperator(EcsBaseTestCase):
@pytest.mark.parametrize("waiter_delay, waiter_max_attempts",
WAITERS_TEST_CASES)