o-nikolas commented on code in PR #29447:
URL: https://github.com/apache/airflow/pull/29447#discussion_r1143988109


##########
airflow/providers/amazon/aws/operators/ecs.py:
##########
@@ -471,39 +470,40 @@ def __init__(
             self.awslogs_region = self.region
 
         self.arn: str | None = None
+        self.started_by: str | None = None
+
         self.retry_args = quota_retry
         self.task_log_fetcher: EcsTaskLogFetcher | None = None
         self.wait_for_completion = wait_for_completion
 
-    @provide_session
-    def execute(self, context, session=None):
+    def execute(self, context):
         self.log.info(
             "Running ECS Task - Task definition: %s - on cluster %s", 
self.task_definition, self.cluster
         )
         self.log.info("EcsOperator overrides: %s", self.overrides)
 
         if self.reattach:
-            self._try_reattach_task(context)
+            # Generate deterministic UUID which refers to unique 
TaskInstanceKey
+            ti: TaskInstance = context["ti"]
+            self.started_by = generate_uuid(*map(str, ti.key.primary))
+            if self.do_xcom_push:
+                ti.xcom_push("started_by", self.started_by)

Review Comment:
   I'm confused why you're pushing this value to xcom. I don't see where it's 
read back? In `_try_reattach_task` it is gotten from `self.started_by` not xcom.



##########
airflow/providers/amazon/aws/operators/ecs.py:
##########
@@ -564,29 +561,17 @@ def _start_task(self, context):
         self.ecs_task_id = self.arn.split("/")[-1]
         self.log.info("ECS task ID is: %s", self.ecs_task_id)
 
-        if self.reattach:
-            # Save the task ARN in XCom to be able to reattach it if needed
-            self.xcom_push(context, key=self.REATTACH_XCOM_KEY, value=self.arn)
-
-    def _try_reattach_task(self, context):
-        task_def_resp = 
self.client.describe_task_definition(taskDefinition=self.task_definition)
-        ecs_task_family = task_def_resp["taskDefinition"]["family"]
-
+    def _try_reattach_task(self):
         list_tasks_resp = self.client.list_tasks(
-            cluster=self.cluster, desiredStatus="RUNNING", 
family=ecs_task_family
+            cluster=self.cluster, desiredStatus="RUNNING", 
startedBy=self.started_by

Review Comment:
   Don't you want to see if there is a value for started_by in xcom here?



##########
airflow/providers/amazon/aws/operators/ecs.py:
##########
@@ -471,39 +470,40 @@ def __init__(
             self.awslogs_region = self.region
 
         self.arn: str | None = None
+        self.started_by: str | None = None
+
         self.retry_args = quota_retry
         self.task_log_fetcher: EcsTaskLogFetcher | None = None
         self.wait_for_completion = wait_for_completion
 
-    @provide_session
-    def execute(self, context, session=None):
+    def execute(self, context):
         self.log.info(
             "Running ECS Task - Task definition: %s - on cluster %s", 
self.task_definition, self.cluster
         )
         self.log.info("EcsOperator overrides: %s", self.overrides)
 
         if self.reattach:
-            self._try_reattach_task(context)
+            # Generate deterministic UUID which refers to unique 
TaskInstanceKey
+            ti: TaskInstance = context["ti"]
+            self.started_by = generate_uuid(*map(str, ti.key.primary))
+            if self.do_xcom_push:
+                ti.xcom_push("started_by", self.started_by)
 
-        self._start_wait_check_task(context)
+            self._try_reattach_task()
 
-        self.log.info("ECS Task has been successfully executed")
+        self._start_wait_check_task()
 
-        if self.reattach:
-            # Clear the XCom value storing the ECS task ARN if the task has 
completed
-            # as we can't reattach it anymore
-            self._xcom_del(session, 
self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))

Review Comment:
   Do you not want the same behaviour for your implementation instead of 
leaving the keys behind? They're small but will accumulate over time.



##########
airflow/providers/amazon/aws/operators/ecs.py:
##########
@@ -564,29 +561,17 @@ def _start_task(self, context):
         self.ecs_task_id = self.arn.split("/")[-1]
         self.log.info("ECS task ID is: %s", self.ecs_task_id)
 
-        if self.reattach:
-            # Save the task ARN in XCom to be able to reattach it if needed
-            self.xcom_push(context, key=self.REATTACH_XCOM_KEY, value=self.arn)
-
-    def _try_reattach_task(self, context):
-        task_def_resp = 
self.client.describe_task_definition(taskDefinition=self.task_definition)
-        ecs_task_family = task_def_resp["taskDefinition"]["family"]
-
+    def _try_reattach_task(self):
         list_tasks_resp = self.client.list_tasks(
-            cluster=self.cluster, desiredStatus="RUNNING", 
family=ecs_task_family
+            cluster=self.cluster, desiredStatus="RUNNING", 
startedBy=self.started_by
         )
         running_tasks = list_tasks_resp["taskArns"]
-
-        # Check if the ECS task previously launched is already running
-        previous_task_arn = self.xcom_pull(
-            context,
-            
task_ids=self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id),
-            key=self.REATTACH_XCOM_KEY,
-        )
-        if previous_task_arn in running_tasks:
-            self.arn = previous_task_arn
-            self.ecs_task_id = self.arn.split("/")[-1]
+        if running_tasks:
+            if len(running_tasks) > 1:
+                self.log.warning("Found more then one previously launched 
tasks: %s", running_tasks)

Review Comment:
   How is this possible if the started_by id is unique, task retries?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to