This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d8625253f Remove non-public interface usage in EcsRunTaskOperator 
(#29447)
2d8625253f is described below

commit 2d8625253f7101a9da7161a7856f4a4084457548
Author: Andrey Anshin <[email protected]>
AuthorDate: Thu Aug 24 15:52:58 2023 +0400

    Remove non-public interface usage in EcsRunTaskOperator (#29447)
    
    * Remove non-public interface usage in EcsOperator
    
    Co-authored-by: D. Ferruzzi <[email protected]>
    
    * Turn started_by into private operator attribute
    
    ---------
    
    Co-authored-by: D. Ferruzzi <[email protected]>
---
 airflow/providers/amazon/aws/operators/ecs.py      | 68 ++++++---------
 airflow/providers/amazon/aws/utils/identifiers.py  | 51 ++++++++++++
 tests/providers/amazon/aws/operators/test_ecs.py   | 97 +++++++++++-----------
 .../providers/amazon/aws/utils/test_identifiers.py | 74 +++++++++++++++++
 4 files changed, 196 insertions(+), 94 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/ecs.py 
b/airflow/providers/amazon/aws/operators/ecs.py
index 9025e85f52..27d63d042d 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -28,7 +28,7 @@ import boto3
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
-from airflow.models import BaseOperator, XCom
+from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.exceptions import EcsOperatorError, 
EcsTaskFailToStart
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, 
should_retry_eni
@@ -38,11 +38,12 @@ from airflow.providers.amazon.aws.triggers.ecs import (
     ClusterInactiveTrigger,
     TaskDoneTrigger,
 )
+from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
 from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 from airflow.utils.helpers import prune_dict
-from airflow.utils.session import provide_session
 
 if TYPE_CHECKING:
+    from airflow.models import TaskInstance
     from airflow.utils.context import Context
 
 DEFAULT_CONN_ID = "aws_default"
@@ -450,8 +451,6 @@ class EcsRunTaskOperator(EcsBaseOperator):
         "network_configuration": "json",
         "tags": "json",
     }
-    REATTACH_XCOM_KEY = "ecs_task_arn"
-    REATTACH_XCOM_TASK_ID_TEMPLATE = "{task_id}_task_arn"
 
     def __init__(
         self,
@@ -507,6 +506,8 @@ class EcsRunTaskOperator(EcsBaseOperator):
             self.awslogs_region = self.region
 
         self.arn: str | None = None
+        self._started_by: str | None = None
+
         self.retry_args = quota_retry
         self.task_log_fetcher: AwsTaskLogFetcher | None = None
         self.wait_for_completion = wait_for_completion
@@ -525,19 +526,22 @@ class EcsRunTaskOperator(EcsBaseOperator):
             return None
         return task_arn.split("/")[-1]
 
-    @provide_session
-    def execute(self, context, session=None):
+    def execute(self, context):
         self.log.info(
             "Running ECS Task - Task definition: %s - on cluster %s", 
self.task_definition, self.cluster
         )
         self.log.info("EcsOperator overrides: %s", self.overrides)
 
         if self.reattach:
-            self._try_reattach_task(context)
+            # Generate deterministic UUID which refers to unique 
TaskInstanceKey
+            ti: TaskInstance = context["ti"]
+            self._started_by = generate_uuid(*map(str, ti.key.primary))
+            self.log.info("Try to find run with startedBy=%r", 
self._started_by)
+            self._try_reattach_task(started_by=self._started_by)
 
         if not self.arn:
             # start the task except if we reattached to an existing one just 
before.
-            self._start_task(context)
+            self._start_task()
 
         if self.deferrable:
             self.defer(
@@ -574,7 +578,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
         else:
             self._wait_for_task_ended()
 
-        self._after_execution(session)
+        self._after_execution()
 
         if self.do_xcom_push and self.task_log_fetcher:
             return self.task_log_fetcher.get_last_log_message()
@@ -598,27 +602,15 @@ class EcsRunTaskOperator(EcsBaseOperator):
             if len(one_log["events"]) > 0:
                 return one_log["events"][0]["message"]
 
-    @provide_session
-    def _after_execution(self, session=None):
+    def _after_execution(self):
         self._check_success_task()
 
-        self.log.info("ECS Task has been successfully executed")
-
-        if self.reattach:
-            # Clear the XCom value storing the ECS task ARN if the task has 
completed
-            # as we can't reattach it anymore
-            self._xcom_del(session, 
self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))
-
-    def _xcom_del(self, session, task_id):
-        session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == 
task_id).delete()
-
-    @AwsBaseHook.retry(should_retry_eni)
-    def _start_task(self, context):
+    def _start_task(self):
         run_opts = {
             "cluster": self.cluster,
             "taskDefinition": self.task_definition,
             "overrides": self.overrides,
-            "startedBy": self.owner,
+            "startedBy": self._started_by or self.owner,
         }
 
         if self.capacity_provider_strategy:
@@ -650,27 +642,17 @@ class EcsRunTaskOperator(EcsBaseOperator):
         self.arn = response["tasks"][0]["taskArn"]
         self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))
 
-        if self.reattach:
-            # Save the task ARN in XCom to be able to reattach it if needed
-            self.xcom_push(context, key=self.REATTACH_XCOM_KEY, value=self.arn)
-
-    def _try_reattach_task(self, context):
-        task_def_resp = 
self.client.describe_task_definition(taskDefinition=self.task_definition)
-        ecs_task_family = task_def_resp["taskDefinition"]["family"]
-
+    def _try_reattach_task(self, started_by: str):
+        if not started_by:
+            raise AirflowException("`started_by` should not be empty or None")
         list_tasks_resp = self.client.list_tasks(
-            cluster=self.cluster, desiredStatus="RUNNING", 
family=ecs_task_family
+            cluster=self.cluster, desiredStatus="RUNNING", startedBy=started_by
         )
         running_tasks = list_tasks_resp["taskArns"]
-
-        # Check if the ECS task previously launched is already running
-        previous_task_arn = self.xcom_pull(
-            context,
-            
task_ids=self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id),
-            key=self.REATTACH_XCOM_KEY,
-        )
-        if previous_task_arn in running_tasks:
-            self.arn = previous_task_arn
+        if running_tasks:
+            if len(running_tasks) > 1:
+                self.log.warning("Found more then one previously launched 
tasks: %s", running_tasks)
+            self.arn = running_tasks[0]
             self.log.info("Reattaching previously launched task: %s", self.arn)
         else:
             self.log.info("No active previously launched task found to 
reattach")
@@ -690,8 +672,6 @@ class EcsRunTaskOperator(EcsBaseOperator):
             },
         )
 
-        return
-
     def _aws_logs_enabled(self):
         return self.awslogs_group and self.awslogs_stream_prefix
 
diff --git a/airflow/providers/amazon/aws/utils/identifiers.py 
b/airflow/providers/amazon/aws/utils/identifiers.py
new file mode 100644
index 0000000000..cac653e3c1
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/identifiers.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from uuid import NAMESPACE_OID, UUID, uuid5
+
+NIL_UUID = UUID(int=0)
+
+
+def generate_uuid(*values: str | None, namespace: UUID = NAMESPACE_OID) -> str:
+    """
+    Convert input values to deterministic UUID string representation.
+
+    This function is only intended to generate a hash which used as an 
identifier, not for any security use.
+
+    Generates a UUID v5 (SHA-1 + Namespace) for each value provided,
+    and this UUID is used as the Namespace for the next element.
+
+    If only one non-None value is provided to the function, then the result of 
the function
+    would be the same as result of ``uuid.uuid5``.
+
+    All ``None`` values are replaced by NIL UUID.  If it only one value is 
provided then return NIL UUID.
+
+    :param namespace: Initial namespace value to pass into the ``uuid.uuid5`` 
function.
+    """
+    if not values:
+        raise ValueError("Expected at least 1 argument")
+
+    if len(values) == 1 and values[0] is None:
+        return str(NIL_UUID)
+
+    result = namespace
+    for item in values:
+        result = uuid5(result, item if item is not None else str(NIL_UUID))
+
+    return str(result)
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py 
b/tests/providers/amazon/aws/operators/test_ecs.py
index f324918171..c0f258365a 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -521,39 +521,47 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
             ["", {"testTagKey": "testTagValue"}],
         ],
     )
-    @mock.patch.object(EcsRunTaskOperator, "_xcom_del")
-    @mock.patch.object(
-        EcsRunTaskOperator,
-        "xcom_pull",
-        return_value=f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
+    @pytest.mark.parametrize(
+        "arns, expected_arn",
+        [
+            pytest.param(
+                [
+                    f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
+                    
"arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54",
+                ],
+                f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
+                id="multiple-arns",
+            ),
+            pytest.param(
+                [
+                    f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
+                ],
+                f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
+                id="simgle-arn",
+            ),
+        ],
     )
+    @mock.patch("airflow.providers.amazon.aws.operators.ecs.generate_uuid")
     @mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended")
     @mock.patch.object(EcsRunTaskOperator, "_check_success_task")
     @mock.patch.object(EcsRunTaskOperator, "_start_task")
     @mock.patch.object(EcsBaseOperator, "client")
     def test_reattach_successful(
-        self,
-        client_mock,
-        start_mock,
-        check_mock,
-        wait_mock,
-        xcom_pull_mock,
-        xcom_del_mock,
-        launch_type,
-        tags,
+        self, client_mock, start_mock, check_mock, wait_mock, uuid_mock, 
launch_type, tags, arns, expected_arn
     ):
+        """Test reattach on first running Task ARN."""
+        mock_ti = mock.MagicMock(name="MockedTaskInstance")
+        mock_ti.key.primary = ("mock_dag", "mock_ti", "mock_runid", 42)
+        fake_uuid = "01-02-03-04"
+        uuid_mock.return_value = fake_uuid
 
         self.set_up_operator(launch_type=launch_type, tags=tags)
-        client_mock.describe_task_definition.return_value = {"taskDefinition": 
{"family": "f"}}
-        client_mock.list_tasks.return_value = {
-            "taskArns": [
-                
"arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54",
-                f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
-            ]
-        }
+        client_mock.list_tasks.return_value = {"taskArns": arns}
 
         self.ecs.reattach = True
-        self.ecs.execute(self.mock_context)
+        self.ecs.execute({"ti": mock_ti})
+
+        uuid_mock.assert_called_once_with("mock_dag", "mock_ti", "mock_runid", 
"42")
 
         extend_args = {}
         if launch_type:
@@ -563,20 +571,14 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         if tags:
             extend_args["tags"] = [{"key": k, "value": v} for (k, v) in 
tags.items()]
 
-        
client_mock.describe_task_definition.assert_called_once_with(taskDefinition="t")
-
-        client_mock.list_tasks.assert_called_once_with(cluster="c", 
desiredStatus="RUNNING", family="f")
+        client_mock.list_tasks.assert_called_once_with(
+            cluster="c", desiredStatus="RUNNING", startedBy=fake_uuid
+        )
 
         start_mock.assert_not_called()
-        xcom_pull_mock.assert_called_once_with(
-            self.mock_context,
-            key=self.ecs.REATTACH_XCOM_KEY,
-            
task_ids=self.ecs.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.ecs.task_id),
-        )
         wait_mock.assert_called_once_with()
         check_mock.assert_called_once_with()
-        xcom_del_mock.assert_called_once()
-        assert self.ecs.arn == 
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
+        assert self.ecs.arn == expected_arn
 
     @pytest.mark.parametrize(
         "launch_type, tags",
@@ -587,29 +589,25 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
             ["", {"testTagKey": "testTagValue"}],
         ],
     )
-    @mock.patch.object(EcsRunTaskOperator, "_xcom_del")
-    @mock.patch.object(EcsRunTaskOperator, "_try_reattach_task")
+    @mock.patch("airflow.providers.amazon.aws.operators.ecs.generate_uuid")
     @mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended")
     @mock.patch.object(EcsRunTaskOperator, "_check_success_task")
     @mock.patch.object(EcsBaseOperator, "client")
     def test_reattach_save_task_arn_xcom(
-        self,
-        client_mock,
-        check_mock,
-        wait_mock,
-        reattach_mock,
-        xcom_del_mock,
-        launch_type,
-        tags,
+        self, client_mock, check_mock, wait_mock, uuid_mock, launch_type, 
tags, caplog
     ):
+        """Test no reattach in no running Task started by this Task ID."""
+        mock_ti = mock.MagicMock(name="MockedTaskInstance")
+        mock_ti.key.primary = ("mock_dag", "mock_ti", "mock_runid", 42)
+        fake_uuid = "01-02-03-04"
+        uuid_mock.return_value = fake_uuid
 
         self.set_up_operator(launch_type=launch_type, tags=tags)
-        client_mock.describe_task_definition.return_value = {"taskDefinition": 
{"family": "f"}}
         client_mock.list_tasks.return_value = {"taskArns": []}
         client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
 
         self.ecs.reattach = True
-        self.ecs.execute(self.mock_context)
+        self.ecs.execute({"ti": mock_ti})
 
         extend_args = {}
         if launch_type:
@@ -619,12 +617,14 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         if tags:
             extend_args["tags"] = [{"key": k, "value": v} for (k, v) in 
tags.items()]
 
-        reattach_mock.assert_called_once()
+        client_mock.list_tasks.assert_called_once_with(
+            cluster="c", desiredStatus="RUNNING", startedBy=fake_uuid
+        )
         client_mock.run_task.assert_called_once()
         wait_mock.assert_called_once_with()
         check_mock.assert_called_once_with()
-        xcom_del_mock.assert_called_once()
         assert self.ecs.arn == 
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
+        assert "No active previously launched task found to reattach" in 
caplog.messages
 
     @mock.patch.object(EcsBaseOperator, "client")
     
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
@@ -670,8 +670,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         assert deferred.value.trigger.task_arn == 
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
 
     @mock.patch.object(EcsRunTaskOperator, "client", new_callable=PropertyMock)
-    @mock.patch.object(EcsRunTaskOperator, "_xcom_del")
-    def test_execute_complete(self, xcom_del_mock: MagicMock, client_mock):
+    def test_execute_complete(self, client_mock):
         event = {"status": "success", "task_arn": "my_arn"}
         self.ecs.reattach = True
 
@@ -679,8 +678,6 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
 
         # task gets described to assert its success
         client_mock().describe_tasks.assert_called_once_with(cluster="c", 
tasks=["my_arn"])
-        # if reattach mode, xcom value is deleted on success
-        xcom_del_mock.assert_called_once()
 
 
 class TestEcsCreateClusterOperator(EcsBaseTestCase):
diff --git a/tests/providers/amazon/aws/utils/test_identifiers.py 
b/tests/providers/amazon/aws/utils/test_identifiers.py
new file mode 100644
index 0000000000..e0334a3472
--- /dev/null
+++ b/tests/providers/amazon/aws/utils/test_identifiers.py
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import random
+import string
+import uuid
+
+import pytest
+
+from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
+from airflow.utils.types import NOTSET
+
+
+class TestGenerateUuid:
+    @pytest.fixture(
+        autouse=True,
+        params=[
+            pytest.param(NOTSET, id="default-namespace"),
+            pytest.param(uuid.UUID(int=42), id="custom-namespace"),
+        ],
+    )
+    def setup_namespace(self, request):
+        self.default_namespace = request.param is NOTSET
+        self.namespace = uuid.NAMESPACE_OID if self.default_namespace else 
request.param
+        self.kwargs = {"namespace": self.namespace} if not 
self.default_namespace else {}
+
+    def test_deterministic(self):
+        """Test that result is deterministic and a valid UUID object"""
+        args = [
+            "".join(random.choice(string.ascii_letters) for _ in 
range(random.randint(3, 13)))
+            for _ in range(100)
+        ]
+        result = generate_uuid(*args, **self.kwargs)
+        assert result == generate_uuid(*args, **self.kwargs)
+        assert uuid.UUID(result).version == 5, "Should generate UUID v5"
+
+    def test_nil_uuid(self):
+        """Test that result of single None are NIL UUID, regardless 
namespace."""
+        assert generate_uuid(None, **self.kwargs) == 
"00000000-0000-0000-0000-000000000000"
+
+    def test_single_uuid_value(self):
+        """Test that result of single not None value are the same as uuid5."""
+        assert generate_uuid("", **self.kwargs) == 
str(uuid.uuid5(self.namespace, ""))
+        assert generate_uuid("Airflow", **self.kwargs) == 
str(uuid.uuid5(self.namespace, "Airflow"))
+
+    def test_multiple_none_value(self):
+        """Test that result of single None are NIL UUID, regardless of 
namespace."""
+        multi_none = generate_uuid(None, None, **self.kwargs)
+        assert multi_none != "00000000-0000-0000-0000-000000000000"
+        assert uuid.UUID(multi_none).version == 5
+
+        # Test that None values not skipped
+        assert generate_uuid(None, "1", None, **self.kwargs) != 
generate_uuid("1", **self.kwargs)
+        assert generate_uuid(None, "1", **self.kwargs) != generate_uuid("1", 
**self.kwargs)
+        assert generate_uuid("1", None, **self.kwargs) != generate_uuid("1", 
**self.kwargs)
+
+    def test_no_args_value(self):
+        with pytest.raises(ValueError, match="Expected at least 1 argument"):
+            generate_uuid(**self.kwargs)

Reply via email to