This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5f2ebb312b ECS Overrides for AWS Batch submit_job (#39903)
5f2ebb312b is described below
commit 5f2ebb312b08769b454a777280ddf5c43c38bb87
Author: Josh Dimarsky <[email protected]>
AuthorDate: Wed May 29 05:18:10 2024 -0400
ECS Overrides for AWS Batch submit_job (#39903)
---
airflow/providers/amazon/aws/hooks/batch_client.py | 3 +
airflow/providers/amazon/aws/operators/batch.py | 8 +++
tests/providers/amazon/aws/operators/test_batch.py | 73 ++++++++++++++++++++--
3 files changed, 79 insertions(+), 5 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py
b/airflow/providers/amazon/aws/hooks/batch_client.py
index b419239a16..f024134560 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -102,6 +102,7 @@ class BatchProtocol(Protocol):
arrayProperties: dict,
parameters: dict,
containerOverrides: dict,
+ ecsPropertiesOverride: dict,
tags: dict,
) -> dict:
"""
@@ -119,6 +120,8 @@ class BatchProtocol(Protocol):
:param containerOverrides: the same parameter that boto3 will receive
+ :param ecsPropertiesOverride: the same parameter that boto3 will
receive
+
:param tags: the same parameter that boto3 will receive
:return: an API response
diff --git a/airflow/providers/amazon/aws/operators/batch.py
b/airflow/providers/amazon/aws/operators/batch.py
index 00b6287145..849fc19346 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -65,6 +65,7 @@ class BatchOperator(BaseOperator):
:param job_queue: the queue name on AWS Batch
:param overrides: DEPRECATED, use container_overrides instead with the
same value.
:param container_overrides: the `containerOverrides` parameter for boto3
(templated)
+ :param ecs_properties_override: the `ecsPropertiesOverride` parameter for
boto3 (templated)
:param node_overrides: the `nodeOverrides` parameter for boto3 (templated)
:param share_identifier: The share identifier for the job. Don't specify
this parameter if the job queue
doesn't have a scheduling policy.
@@ -112,6 +113,7 @@ class BatchOperator(BaseOperator):
"job_queue",
"container_overrides",
"array_properties",
+ "ecs_properties_override",
"node_overrides",
"parameters",
"retry_strategy",
@@ -124,6 +126,7 @@ class BatchOperator(BaseOperator):
template_fields_renderers = {
"container_overrides": "json",
"parameters": "json",
+ "ecs_properties_override": "json",
"node_overrides": "json",
"retry_strategy": "json",
}
@@ -160,6 +163,7 @@ class BatchOperator(BaseOperator):
overrides: dict | None = None, # deprecated
container_overrides: dict | None = None,
array_properties: dict | None = None,
+ ecs_properties_override: dict | None = None,
node_overrides: dict | None = None,
share_identifier: str | None = None,
scheduling_priority_override: int | None = None,
@@ -201,6 +205,7 @@ class BatchOperator(BaseOperator):
stacklevel=2,
)
+ self.ecs_properties_override = ecs_properties_override
self.node_overrides = node_overrides
self.share_identifier = share_identifier
self.scheduling_priority_override = scheduling_priority_override
@@ -296,6 +301,8 @@ class BatchOperator(BaseOperator):
self.log.info("AWS Batch job - container overrides: %s",
self.container_overrides)
if self.array_properties:
self.log.info("AWS Batch job - array properties: %s",
self.array_properties)
+ if self.ecs_properties_override:
+ self.log.info("AWS Batch job - ECS properties: %s",
self.ecs_properties_override)
if self.node_overrides:
self.log.info("AWS Batch job - node properties: %s",
self.node_overrides)
@@ -307,6 +314,7 @@ class BatchOperator(BaseOperator):
"parameters": self.parameters,
"tags": self.tags,
"containerOverrides": self.container_overrides,
+ "ecsPropertiesOverride": self.ecs_properties_override,
"nodeOverrides": self.node_overrides,
"retryStrategy": self.retry_strategy,
"shareIdentifier": self.share_identifier,
diff --git a/tests/providers/amazon/aws/operators/test_batch.py
b/tests/providers/amazon/aws/operators/test_batch.py
index f769c1baa8..27f86e279c 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -132,6 +132,7 @@ class TestBatchOperator:
assert batch_job.retry_strategy is None
assert batch_job.container_overrides is None
assert batch_job.array_properties is None
+ assert batch_job.ecs_properties_override is None
assert batch_job.node_overrides is None
assert batch_job.share_identifier is None
assert batch_job.scheduling_priority_override is None
@@ -149,6 +150,7 @@ class TestBatchOperator:
"job_queue",
"container_overrides",
"array_properties",
+ "ecs_properties_override",
"node_overrides",
"parameters",
"retry_strategy",
@@ -204,6 +206,62 @@ class TestBatchOperator:
tags={},
)
+ @mock.patch.object(BatchClientHook, "get_job_description")
+ @mock.patch.object(BatchClientHook, "wait_for_job")
+ @mock.patch.object(BatchClientHook, "check_job_success")
+ def test_execute_with_ecs_overrides(self, check_mock, wait_mock,
job_description_mock):
+ self.batch.container_overrides = None
+ self.batch.ecs_properties_override = {
+ "taskProperties": [
+ {
+ "containers": [
+ {
+ "command": [
+ "string",
+ ],
+ "environment": [
+ {"name": "string", "value": "string"},
+ ],
+ "name": "string",
+ "resourceRequirements": [
+ {"value": "string", "type":
"'GPU'|'VCPU'|'MEMORY'"},
+ ],
+ },
+ ]
+ },
+ ]
+ }
+ self.batch.execute(self.mock_context)
+
+ self.client_mock.submit_job.assert_called_once_with(
+ jobQueue="queue",
+ jobName=JOB_NAME,
+ jobDefinition="hello-world",
+ ecsPropertiesOverride={
+ "taskProperties": [
+ {
+ "containers": [
+ {
+ "command": [
+ "string",
+ ],
+ "environment": [
+ {"name": "string", "value": "string"},
+ ],
+ "name": "string",
+ "resourceRequirements": [
+ {"value": "string", "type":
"'GPU'|'VCPU'|'MEMORY'"},
+ ],
+ },
+ ]
+ },
+ ]
+ },
+ parameters={},
+ retryStrategy={"attempts": 1},
+ tags={},
+ )
+
@mock.patch.object(BatchClientHook, "check_job_success")
def test_wait_job_complete_using_waiters(self, check_mock):
mock_waiters = mock.Mock()
@@ -238,7 +296,7 @@ class TestBatchOperator:
self.batch.on_kill()
self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID,
reason="Task killed by the user")
- @pytest.mark.parametrize("override", ["overrides", "node_overrides"])
+ @pytest.mark.parametrize("override", ["overrides", "node_overrides",
"ecs_properties_override"])
@patch(
"airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.client",
new_callable=mock.PropertyMock,
@@ -269,10 +327,15 @@ class TestBatchOperator:
"parameters": {},
"tags": {},
}
- if override == "overrides":
- expected_args["containerOverrides"] = {"a": "a"}
- else:
- expected_args["nodeOverrides"] = {"a": "a"}
+
+ py2api = {
+ "overrides": "containerOverrides",
+ "node_overrides": "nodeOverrides",
+ "ecs_properties_override": "ecsPropertiesOverride",
+ }
+
+ expected_args[py2api[override]] = {"a": "a"}
+
client_mock().submit_job.assert_called_once_with(**expected_args)
def test_deprecated_override_param(self):