This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2d8625253f Remove non-public interface usage in EcsRunTaskOperator
(#29447)
2d8625253f is described below
commit 2d8625253f7101a9da7161a7856f4a4084457548
Author: Andrey Anshin <[email protected]>
AuthorDate: Thu Aug 24 15:52:58 2023 +0400
Remove non-public interface usage in EcsRunTaskOperator (#29447)
* Remove non-public interface usage in EcsOperator
Co-authored-by: D. Ferruzzi <[email protected]>
* Turn started_by into private operator attribute
---------
Co-authored-by: D. Ferruzzi <[email protected]>
---
airflow/providers/amazon/aws/operators/ecs.py | 68 ++++++---------
airflow/providers/amazon/aws/utils/identifiers.py | 51 ++++++++++++
tests/providers/amazon/aws/operators/test_ecs.py | 97 +++++++++++-----------
.../providers/amazon/aws/utils/test_identifiers.py | 74 +++++++++++++++++
4 files changed, 196 insertions(+), 94 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/ecs.py
b/airflow/providers/amazon/aws/operators/ecs.py
index 9025e85f52..27d63d042d 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -28,7 +28,7 @@ import boto3
from airflow.configuration import conf
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
-from airflow.models import BaseOperator, XCom
+from airflow.models import BaseOperator
from airflow.providers.amazon.aws.exceptions import EcsOperatorError,
EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook,
should_retry_eni
@@ -38,11 +38,12 @@ from airflow.providers.amazon.aws.triggers.ecs import (
ClusterInactiveTrigger,
TaskDoneTrigger,
)
+from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
-from airflow.utils.session import provide_session
if TYPE_CHECKING:
+ from airflow.models import TaskInstance
from airflow.utils.context import Context
DEFAULT_CONN_ID = "aws_default"
@@ -450,8 +451,6 @@ class EcsRunTaskOperator(EcsBaseOperator):
"network_configuration": "json",
"tags": "json",
}
- REATTACH_XCOM_KEY = "ecs_task_arn"
- REATTACH_XCOM_TASK_ID_TEMPLATE = "{task_id}_task_arn"
def __init__(
self,
@@ -507,6 +506,8 @@ class EcsRunTaskOperator(EcsBaseOperator):
self.awslogs_region = self.region
self.arn: str | None = None
+ self._started_by: str | None = None
+
self.retry_args = quota_retry
self.task_log_fetcher: AwsTaskLogFetcher | None = None
self.wait_for_completion = wait_for_completion
@@ -525,19 +526,22 @@ class EcsRunTaskOperator(EcsBaseOperator):
return None
return task_arn.split("/")[-1]
- @provide_session
- def execute(self, context, session=None):
+ def execute(self, context):
self.log.info(
"Running ECS Task - Task definition: %s - on cluster %s",
self.task_definition, self.cluster
)
self.log.info("EcsOperator overrides: %s", self.overrides)
if self.reattach:
- self._try_reattach_task(context)
+ # Generate deterministic UUID which refers to unique
TaskInstanceKey
+ ti: TaskInstance = context["ti"]
+ self._started_by = generate_uuid(*map(str, ti.key.primary))
+ self.log.info("Try to find run with startedBy=%r",
self._started_by)
+ self._try_reattach_task(started_by=self._started_by)
if not self.arn:
# start the task except if we reattached to an existing one just
before.
- self._start_task(context)
+ self._start_task()
if self.deferrable:
self.defer(
@@ -574,7 +578,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
else:
self._wait_for_task_ended()
- self._after_execution(session)
+ self._after_execution()
if self.do_xcom_push and self.task_log_fetcher:
return self.task_log_fetcher.get_last_log_message()
@@ -598,27 +602,15 @@ class EcsRunTaskOperator(EcsBaseOperator):
if len(one_log["events"]) > 0:
return one_log["events"][0]["message"]
- @provide_session
- def _after_execution(self, session=None):
+ def _after_execution(self):
self._check_success_task()
- self.log.info("ECS Task has been successfully executed")
-
- if self.reattach:
- # Clear the XCom value storing the ECS task ARN if the task has
completed
- # as we can't reattach it anymore
- self._xcom_del(session,
self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))
-
- def _xcom_del(self, session, task_id):
- session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id ==
task_id).delete()
-
- @AwsBaseHook.retry(should_retry_eni)
- def _start_task(self, context):
+ def _start_task(self):
run_opts = {
"cluster": self.cluster,
"taskDefinition": self.task_definition,
"overrides": self.overrides,
- "startedBy": self.owner,
+ "startedBy": self._started_by or self.owner,
}
if self.capacity_provider_strategy:
@@ -650,27 +642,17 @@ class EcsRunTaskOperator(EcsBaseOperator):
self.arn = response["tasks"][0]["taskArn"]
self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))
- if self.reattach:
- # Save the task ARN in XCom to be able to reattach it if needed
- self.xcom_push(context, key=self.REATTACH_XCOM_KEY, value=self.arn)
-
- def _try_reattach_task(self, context):
- task_def_resp =
self.client.describe_task_definition(taskDefinition=self.task_definition)
- ecs_task_family = task_def_resp["taskDefinition"]["family"]
-
+ def _try_reattach_task(self, started_by: str):
+ if not started_by:
+ raise AirflowException("`started_by` should not be empty or None")
list_tasks_resp = self.client.list_tasks(
- cluster=self.cluster, desiredStatus="RUNNING",
family=ecs_task_family
+ cluster=self.cluster, desiredStatus="RUNNING", startedBy=started_by
)
running_tasks = list_tasks_resp["taskArns"]
-
- # Check if the ECS task previously launched is already running
- previous_task_arn = self.xcom_pull(
- context,
-
task_ids=self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id),
- key=self.REATTACH_XCOM_KEY,
- )
- if previous_task_arn in running_tasks:
- self.arn = previous_task_arn
+ if running_tasks:
+ if len(running_tasks) > 1:
+ self.log.warning("Found more then one previously launched
tasks: %s", running_tasks)
+ self.arn = running_tasks[0]
self.log.info("Reattaching previously launched task: %s", self.arn)
else:
self.log.info("No active previously launched task found to
reattach")
@@ -690,8 +672,6 @@ class EcsRunTaskOperator(EcsBaseOperator):
},
)
- return
-
def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix
diff --git a/airflow/providers/amazon/aws/utils/identifiers.py
b/airflow/providers/amazon/aws/utils/identifiers.py
new file mode 100644
index 0000000000..cac653e3c1
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/identifiers.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from uuid import NAMESPACE_OID, UUID, uuid5
+
+NIL_UUID = UUID(int=0)
+
+
+def generate_uuid(*values: str | None, namespace: UUID = NAMESPACE_OID) -> str:
+ """
+ Convert input values to deterministic UUID string representation.
+
+ This function is only intended to generate a hash which used as an
identifier, not for any security use.
+
+ Generates a UUID v5 (SHA-1 + Namespace) for each value provided,
+ and this UUID is used as the Namespace for the next element.
+
+ If only one non-None value is provided to the function, then the result of
the function
+ would be the same as result of ``uuid.uuid5``.
+
+ All ``None`` values are replaced by NIL UUID. If it only one value is
provided then return NIL UUID.
+
+ :param namespace: Initial namespace value to pass into the ``uuid.uuid5``
function.
+ """
+ if not values:
+ raise ValueError("Expected at least 1 argument")
+
+ if len(values) == 1 and values[0] is None:
+ return str(NIL_UUID)
+
+ result = namespace
+ for item in values:
+ result = uuid5(result, item if item is not None else str(NIL_UUID))
+
+ return str(result)
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py
b/tests/providers/amazon/aws/operators/test_ecs.py
index f324918171..c0f258365a 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -521,39 +521,47 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
["", {"testTagKey": "testTagValue"}],
],
)
- @mock.patch.object(EcsRunTaskOperator, "_xcom_del")
- @mock.patch.object(
- EcsRunTaskOperator,
- "xcom_pull",
- return_value=f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
+ @pytest.mark.parametrize(
+ "arns, expected_arn",
+ [
+ pytest.param(
+ [
+ f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
+
"arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54",
+ ],
+ f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
+ id="multiple-arns",
+ ),
+ pytest.param(
+ [
+ f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
+ ],
+ f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
+ id="simgle-arn",
+ ),
+ ],
)
+ @mock.patch("airflow.providers.amazon.aws.operators.ecs.generate_uuid")
@mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended")
@mock.patch.object(EcsRunTaskOperator, "_check_success_task")
@mock.patch.object(EcsRunTaskOperator, "_start_task")
@mock.patch.object(EcsBaseOperator, "client")
def test_reattach_successful(
- self,
- client_mock,
- start_mock,
- check_mock,
- wait_mock,
- xcom_pull_mock,
- xcom_del_mock,
- launch_type,
- tags,
+ self, client_mock, start_mock, check_mock, wait_mock, uuid_mock,
launch_type, tags, arns, expected_arn
):
+ """Test reattach on first running Task ARN."""
+ mock_ti = mock.MagicMock(name="MockedTaskInstance")
+ mock_ti.key.primary = ("mock_dag", "mock_ti", "mock_runid", 42)
+ fake_uuid = "01-02-03-04"
+ uuid_mock.return_value = fake_uuid
self.set_up_operator(launch_type=launch_type, tags=tags)
- client_mock.describe_task_definition.return_value = {"taskDefinition":
{"family": "f"}}
- client_mock.list_tasks.return_value = {
- "taskArns": [
-
"arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b54",
- f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}",
- ]
- }
+ client_mock.list_tasks.return_value = {"taskArns": arns}
self.ecs.reattach = True
- self.ecs.execute(self.mock_context)
+ self.ecs.execute({"ti": mock_ti})
+
+ uuid_mock.assert_called_once_with("mock_dag", "mock_ti", "mock_runid",
"42")
extend_args = {}
if launch_type:
@@ -563,20 +571,14 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
if tags:
extend_args["tags"] = [{"key": k, "value": v} for (k, v) in
tags.items()]
-
client_mock.describe_task_definition.assert_called_once_with(taskDefinition="t")
-
- client_mock.list_tasks.assert_called_once_with(cluster="c",
desiredStatus="RUNNING", family="f")
+ client_mock.list_tasks.assert_called_once_with(
+ cluster="c", desiredStatus="RUNNING", startedBy=fake_uuid
+ )
start_mock.assert_not_called()
- xcom_pull_mock.assert_called_once_with(
- self.mock_context,
- key=self.ecs.REATTACH_XCOM_KEY,
-
task_ids=self.ecs.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.ecs.task_id),
- )
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
- xcom_del_mock.assert_called_once()
- assert self.ecs.arn ==
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
+ assert self.ecs.arn == expected_arn
@pytest.mark.parametrize(
"launch_type, tags",
@@ -587,29 +589,25 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
["", {"testTagKey": "testTagValue"}],
],
)
- @mock.patch.object(EcsRunTaskOperator, "_xcom_del")
- @mock.patch.object(EcsRunTaskOperator, "_try_reattach_task")
+ @mock.patch("airflow.providers.amazon.aws.operators.ecs.generate_uuid")
@mock.patch.object(EcsRunTaskOperator, "_wait_for_task_ended")
@mock.patch.object(EcsRunTaskOperator, "_check_success_task")
@mock.patch.object(EcsBaseOperator, "client")
def test_reattach_save_task_arn_xcom(
- self,
- client_mock,
- check_mock,
- wait_mock,
- reattach_mock,
- xcom_del_mock,
- launch_type,
- tags,
+ self, client_mock, check_mock, wait_mock, uuid_mock, launch_type,
tags, caplog
):
+ """Test no reattach in no running Task started by this Task ID."""
+ mock_ti = mock.MagicMock(name="MockedTaskInstance")
+ mock_ti.key.primary = ("mock_dag", "mock_ti", "mock_runid", 42)
+ fake_uuid = "01-02-03-04"
+ uuid_mock.return_value = fake_uuid
self.set_up_operator(launch_type=launch_type, tags=tags)
- client_mock.describe_task_definition.return_value = {"taskDefinition":
{"family": "f"}}
client_mock.list_tasks.return_value = {"taskArns": []}
client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
self.ecs.reattach = True
- self.ecs.execute(self.mock_context)
+ self.ecs.execute({"ti": mock_ti})
extend_args = {}
if launch_type:
@@ -619,12 +617,14 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
if tags:
extend_args["tags"] = [{"key": k, "value": v} for (k, v) in
tags.items()]
- reattach_mock.assert_called_once()
+ client_mock.list_tasks.assert_called_once_with(
+ cluster="c", desiredStatus="RUNNING", startedBy=fake_uuid
+ )
client_mock.run_task.assert_called_once()
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
- xcom_del_mock.assert_called_once()
assert self.ecs.arn ==
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
+ assert "No active previously launched task found to reattach" in
caplog.messages
@mock.patch.object(EcsBaseOperator, "client")
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
@@ -670,8 +670,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
assert deferred.value.trigger.task_arn ==
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
@mock.patch.object(EcsRunTaskOperator, "client", new_callable=PropertyMock)
- @mock.patch.object(EcsRunTaskOperator, "_xcom_del")
- def test_execute_complete(self, xcom_del_mock: MagicMock, client_mock):
+ def test_execute_complete(self, client_mock):
event = {"status": "success", "task_arn": "my_arn"}
self.ecs.reattach = True
@@ -679,8 +678,6 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
# task gets described to assert its success
client_mock().describe_tasks.assert_called_once_with(cluster="c",
tasks=["my_arn"])
- # if reattach mode, xcom value is deleted on success
- xcom_del_mock.assert_called_once()
class TestEcsCreateClusterOperator(EcsBaseTestCase):
diff --git a/tests/providers/amazon/aws/utils/test_identifiers.py
b/tests/providers/amazon/aws/utils/test_identifiers.py
new file mode 100644
index 0000000000..e0334a3472
--- /dev/null
+++ b/tests/providers/amazon/aws/utils/test_identifiers.py
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import random
+import string
+import uuid
+
+import pytest
+
+from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
+from airflow.utils.types import NOTSET
+
+
+class TestGenerateUuid:
+ @pytest.fixture(
+ autouse=True,
+ params=[
+ pytest.param(NOTSET, id="default-namespace"),
+ pytest.param(uuid.UUID(int=42), id="custom-namespace"),
+ ],
+ )
+ def setup_namespace(self, request):
+ self.default_namespace = request.param is NOTSET
+ self.namespace = uuid.NAMESPACE_OID if self.default_namespace else
request.param
+ self.kwargs = {"namespace": self.namespace} if not
self.default_namespace else {}
+
+ def test_deterministic(self):
+ """Test that result is deterministic and a valid UUID object"""
+ args = [
+ "".join(random.choice(string.ascii_letters) for _ in
range(random.randint(3, 13)))
+ for _ in range(100)
+ ]
+ result = generate_uuid(*args, **self.kwargs)
+ assert result == generate_uuid(*args, **self.kwargs)
+ assert uuid.UUID(result).version == 5, "Should generate UUID v5"
+
+ def test_nil_uuid(self):
+ """Test that result of single None are NIL UUID, regardless
namespace."""
+ assert generate_uuid(None, **self.kwargs) ==
"00000000-0000-0000-0000-000000000000"
+
+ def test_single_uuid_value(self):
+ """Test that result of single not None value are the same as uuid5."""
+ assert generate_uuid("", **self.kwargs) ==
str(uuid.uuid5(self.namespace, ""))
+ assert generate_uuid("Airflow", **self.kwargs) ==
str(uuid.uuid5(self.namespace, "Airflow"))
+
+ def test_multiple_none_value(self):
+ """Test that result of single None are NIL UUID, regardless of
namespace."""
+ multi_none = generate_uuid(None, None, **self.kwargs)
+ assert multi_none != "00000000-0000-0000-0000-000000000000"
+ assert uuid.UUID(multi_none).version == 5
+
+ # Test that None values not skipped
+ assert generate_uuid(None, "1", None, **self.kwargs) !=
generate_uuid("1", **self.kwargs)
+ assert generate_uuid(None, "1", **self.kwargs) != generate_uuid("1",
**self.kwargs)
+ assert generate_uuid("1", None, **self.kwargs) != generate_uuid("1",
**self.kwargs)
+
+ def test_no_args_value(self):
+ with pytest.raises(ValueError, match="Expected at least 1 argument"):
+ generate_uuid(**self.kwargs)