This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e0f21f43c6 Various fixes on ECS run task operator (#31838)
e0f21f43c6 is described below
commit e0f21f43c63b13fd48f55aa660746edc37df1458
Author: Raphaƫl Vandon <[email protected]>
AuthorDate: Fri Jun 16 12:22:24 2023 -0700
Various fixes on ECS run task operator (#31838)
* ECS Run Task op should not try to get logs or check the status if not
waiting for completion
---
airflow/providers/amazon/aws/hooks/ecs.py | 9 +++-
airflow/providers/amazon/aws/operators/ecs.py | 28 ++++++++-----
.../operators/ecs.rst | 2 +-
tests/providers/amazon/aws/operators/test_ecs.py | 7 ++--
tests/system/providers/amazon/aws/example_ecs.py | 48 ++++++++++------------
.../providers/amazon/aws/example_ecs_fargate.py | 22 ++++++++++
6 files changed, 75 insertions(+), 41 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/ecs.py
b/airflow/providers/amazon/aws/hooks/ecs.py
index 17119d5968..94baeb0a9a 100644
--- a/airflow/providers/amazon/aws/hooks/ecs.py
+++ b/airflow/providers/amazon/aws/hooks/ecs.py
@@ -188,7 +188,14 @@ class EcsTaskLogFetcher(Thread):
except ClientError as error:
if error.response["Error"]["Code"] != "ResourceNotFoundException":
self.logger.warning("Error on retrieving Cloudwatch log
events", error)
-
+ else:
+ self.logger.info(
+ "Cannot find log stream yet, it can take a couple of
seconds to show up. "
+ "If this error persists, check that the log group and
stream are correct: "
+ "group: %s\tstream: %s",
+ self.log_group,
+ self.log_stream_name,
+ )
yield from ()
except ConnectionClosedError as error:
self.logger.warning("ConnectionClosedError on retrieving
Cloudwatch log events", error)
diff --git a/airflow/providers/amazon/aws/operators/ecs.py
b/airflow/providers/amazon/aws/operators/ecs.py
index 149ae2c11c..3a62f389da 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -480,6 +480,17 @@ class EcsRunTaskOperator(EcsBaseOperator):
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
+ if self._aws_logs_enabled() and not self.wait_for_completion:
+ self.log.warning(
+ "Trying to get logs without waiting for the task to complete
is undefined behavior."
+ )
+
+ @staticmethod
+ def _get_ecs_task_id(task_arn: str | None) -> str | None:
+ if task_arn is None:
+ return None
+ return task_arn.split("/")[-1]
+
@provide_session
def execute(self, context, session=None):
self.log.info(
@@ -506,25 +517,24 @@ class EcsRunTaskOperator(EcsBaseOperator):
@AwsBaseHook.retry(should_retry_eni)
def _start_wait_check_task(self, context):
-
if not self.arn:
self._start_task(context)
+ if not self.wait_for_completion:
+ return
+
if self._aws_logs_enabled():
self.log.info("Starting ECS Task Log Fetcher")
self.task_log_fetcher = self._get_task_log_fetcher()
self.task_log_fetcher.start()
try:
- if self.wait_for_completion:
- self._wait_for_task_ended()
+ self._wait_for_task_ended()
finally:
self.task_log_fetcher.stop()
-
self.task_log_fetcher.join()
else:
- if self.wait_for_completion:
- self._wait_for_task_ended()
+ self._wait_for_task_ended()
self._check_success_task()
@@ -566,8 +576,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
self.log.info("ECS Task started: %s", response)
self.arn = response["tasks"][0]["taskArn"]
- self.ecs_task_id = self.arn.split("/")[-1]
- self.log.info("ECS task ID is: %s", self.ecs_task_id)
+ self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))
if self.reattach:
# Save the task ARN in XCom to be able to reattach it if needed
@@ -590,7 +599,6 @@ class EcsRunTaskOperator(EcsBaseOperator):
)
if previous_task_arn in running_tasks:
self.arn = previous_task_arn
- self.ecs_task_id = self.arn.split("/")[-1]
self.log.info("Reattaching previously launched task: %s", self.arn)
else:
self.log.info("No active previously launched task found to
reattach")
@@ -620,7 +628,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
def _get_task_log_fetcher(self) -> EcsTaskLogFetcher:
if not self.awslogs_group:
raise ValueError("must specify awslogs_group to fetch task logs")
- log_stream_name = f"{self.awslogs_stream_prefix}/{self.ecs_task_id}"
+ log_stream_name =
f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"
return EcsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
diff --git a/docs/apache-airflow-providers-amazon/operators/ecs.rst
b/docs/apache-airflow-providers-amazon/operators/ecs.rst
index d513485a9a..e6b4385d36 100644
--- a/docs/apache-airflow-providers-amazon/operators/ecs.rst
+++ b/docs/apache-airflow-providers-amazon/operators/ecs.rst
@@ -250,7 +250,7 @@ both can be overridden with provided values. Raises an
AirflowException with
the failure reason if a failed state is provided and that state is reached
before the target state.
-.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ecs.py
+.. exampleinclude::
/../../tests/system/providers/amazon/aws/example_ecs_fargate.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_ecs_task_state]
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py
b/tests/providers/amazon/aws/operators/test_ecs.py
index cadaa6e329..ca23931e90 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -304,7 +304,10 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
assert self.ecs.arn ==
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
- assert self.ecs.ecs_task_id == TASK_ID
+
+ def test_task_id_parsing(self):
+ id =
EcsRunTaskOperator._get_ecs_task_id(f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}")
+ assert id == TASK_ID
@mock.patch.object(EcsBaseOperator, "client")
def test_execute_with_failures(self, client_mock):
@@ -571,7 +574,6 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
check_mock.assert_called_once_with()
xcom_del_mock.assert_called_once()
assert self.ecs.arn ==
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
- assert self.ecs.ecs_task_id == TASK_ID
@pytest.mark.parametrize(
"launch_type, tags",
@@ -620,7 +622,6 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
check_mock.assert_called_once_with()
xcom_del_mock.assert_called_once()
assert self.ecs.arn ==
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
- assert self.ecs.ecs_task_id == TASK_ID
@mock.patch.object(EcsBaseOperator, "client")
@mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
diff --git a/tests/system/providers/amazon/aws/example_ecs.py
b/tests/system/providers/amazon/aws/example_ecs.py
index 194b070b51..be90f8c96f 100644
--- a/tests/system/providers/amazon/aws/example_ecs.py
+++ b/tests/system/providers/amazon/aws/example_ecs.py
@@ -23,7 +23,7 @@ import boto3
from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates,
EcsTaskStates
+from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates
from airflow.providers.amazon.aws.operators.ecs import (
EcsCreateClusterOperator,
EcsDeleteClusterOperator,
@@ -34,7 +34,6 @@ from airflow.providers.amazon.aws.operators.ecs import (
from airflow.providers.amazon.aws.sensors.ecs import (
EcsClusterStateSensor,
EcsTaskDefinitionStateSensor,
- EcsTaskStateSensor,
)
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY,
SystemTestContextBuilder
@@ -67,6 +66,12 @@ def get_region():
return boto3.session.Session().region_name
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def clean_logs(group_name: str):
+ client = boto3.client("logs")
+ client.delete_log_group(logGroupName=group_name)
+
+
with DAG(
dag_id=DAG_ID,
schedule="@once",
@@ -85,6 +90,7 @@ with DAG(
asg_name = f"{env_id}-asg"
aws_region = get_region()
+ log_group_name = f"/ecs_test/{env_id}"
# [START howto_operator_ecs_create_cluster]
create_cluster = EcsCreateClusterOperator(
@@ -114,7 +120,16 @@ with DAG(
"workingDirectory": "/usr/bin",
"entryPoint": ["sh", "-c"],
"command": ["ls"],
- }
+ "logConfiguration": {
+ "logDriver": "awslogs",
+ "options": {
+ "awslogs-group": log_group_name,
+ "awslogs-region": aws_region,
+ "awslogs-create-group": "true",
+ "awslogs-stream-prefix": "ecs",
+ },
+ },
+ },
],
register_task_kwargs={
"cpu": "256",
@@ -140,38 +155,19 @@ with DAG(
"containerOverrides": [
{
"name": container_name,
- "command": ["echo", "hello", "world"],
+ "command": ["echo hello world"],
},
],
},
network_configuration={"awsvpcConfiguration": {"subnets":
existing_cluster_subnets}},
# [START howto_awslogs_ecs]
- awslogs_group="/ecs/hello-world",
+ awslogs_group=log_group_name,
awslogs_region=aws_region,
- awslogs_stream_prefix="ecs/hello-world-container",
+ awslogs_stream_prefix=f"ecs/{container_name}",
# [END howto_awslogs_ecs]
- # You must set `reattach=True` in order to get ecs_task_arn if you
plan to use a Sensor.
- reattach=True,
)
# [END howto_operator_ecs_run_task]
- # EcsRunTaskOperator waits by default, setting as False to test the Sensor
below.
- run_task.wait_for_completion = False
-
- # [START howto_sensor_ecs_task_state]
- # By default, EcsTaskStateSensor waits until the task has started, but the
- # demo task runs so fast that the sensor misses it. This sensor instead
- # demonstrates how to wait until the ECS Task has completed by providing
- # the target_state and failure_states parameters.
- await_task_finish = EcsTaskStateSensor(
- task_id="await_task_finish",
- cluster=existing_cluster_name,
- task=run_task.output["ecs_task_arn"],
- target_state=EcsTaskStates.STOPPED,
- failure_states={EcsTaskStates.NONE},
- )
- # [END howto_sensor_ecs_task_state]
-
# [START howto_operator_ecs_deregister_task_definition]
deregister_task = EcsDeregisterTaskDefinitionOperator(
task_id="deregister_task",
@@ -209,10 +205,10 @@ with DAG(
register_task,
await_task_definition,
run_task,
- await_task_finish,
deregister_task,
delete_cluster,
await_delete_cluster,
+ clean_logs(log_group_name),
)
from tests.system.utils.watcher import watcher
diff --git a/tests/system/providers/amazon/aws/example_ecs_fargate.py
b/tests/system/providers/amazon/aws/example_ecs_fargate.py
index 704bd91cdf..b23f85a956 100644
--- a/tests/system/providers/amazon/aws/example_ecs_fargate.py
+++ b/tests/system/providers/amazon/aws/example_ecs_fargate.py
@@ -23,7 +23,9 @@ import boto3
from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.hooks.ecs import EcsTaskStates
from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator
+from airflow.providers.amazon.aws.sensors.ecs import EcsTaskStateSensor
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import ENV_ID_KEY,
SystemTestContextBuilder
@@ -120,9 +122,28 @@ with DAG(
"assignPublicIp": "ENABLED",
},
},
+ # You must set `reattach=True` in order to get ecs_task_arn if you
plan to use a Sensor.
+ reattach=True,
)
# [END howto_operator_ecs]
+ # EcsRunTaskOperator waits by default, setting as False to test the Sensor
below.
+ hello_world.wait_for_completion = False
+
+ # [START howto_sensor_ecs_task_state]
+ # By default, EcsTaskStateSensor waits until the task has started, but the
+ # demo task runs so fast that the sensor misses it. This sensor instead
+ # demonstrates how to wait until the ECS Task has completed by providing
+ # the target_state and failure_states parameters.
+ await_task_finish = EcsTaskStateSensor(
+ task_id="await_task_finish",
+ cluster=cluster_name,
+ task=hello_world.output["ecs_task_arn"],
+ target_state=EcsTaskStates.STOPPED,
+ failure_states={EcsTaskStates.NONE},
+ )
+ # [END howto_sensor_ecs_task_state]
+
chain(
# TEST SETUP
test_context,
@@ -130,6 +151,7 @@ with DAG(
create_task_definition,
# TEST BODY
hello_world,
+ await_task_finish,
# TEST TEARDOWN
delete_task_definition(create_task_definition),
delete_cluster(cluster_name),