This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5f2ebb312b ECS Overrides for AWS Batch submit_job (#39903)
5f2ebb312b is described below

commit 5f2ebb312b08769b454a777280ddf5c43c38bb87
Author: Josh Dimarsky <[email protected]>
AuthorDate: Wed May 29 05:18:10 2024 -0400

    ECS Overrides for AWS Batch submit_job (#39903)
---
 airflow/providers/amazon/aws/hooks/batch_client.py |  3 +
 airflow/providers/amazon/aws/operators/batch.py    |  8 +++
 tests/providers/amazon/aws/operators/test_batch.py | 73 ++++++++++++++++++++--
 3 files changed, 79 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py 
b/airflow/providers/amazon/aws/hooks/batch_client.py
index b419239a16..f024134560 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -102,6 +102,7 @@ class BatchProtocol(Protocol):
         arrayProperties: dict,
         parameters: dict,
         containerOverrides: dict,
+        ecsPropertiesOverride: dict,
         tags: dict,
     ) -> dict:
         """
@@ -119,6 +120,8 @@ class BatchProtocol(Protocol):
 
         :param containerOverrides: the same parameter that boto3 will receive
 
+        :param ecsPropertiesOverride: the same parameter that boto3 will 
receive
+
         :param tags: the same parameter that boto3 will receive
 
         :return: an API response
diff --git a/airflow/providers/amazon/aws/operators/batch.py 
b/airflow/providers/amazon/aws/operators/batch.py
index 00b6287145..849fc19346 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -65,6 +65,7 @@ class BatchOperator(BaseOperator):
     :param job_queue: the queue name on AWS Batch
     :param overrides: DEPRECATED, use container_overrides instead with the 
same value.
     :param container_overrides: the `containerOverrides` parameter for boto3 
(templated)
+    :param ecs_properties_override: the `ecsPropertiesOverride` parameter for 
boto3 (templated)
     :param node_overrides: the `nodeOverrides` parameter for boto3 (templated)
     :param share_identifier: The share identifier for the job. Don't specify 
this parameter if the job queue
         doesn't have a scheduling policy.
@@ -112,6 +113,7 @@ class BatchOperator(BaseOperator):
         "job_queue",
         "container_overrides",
         "array_properties",
+        "ecs_properties_override",
         "node_overrides",
         "parameters",
         "retry_strategy",
@@ -124,6 +126,7 @@ class BatchOperator(BaseOperator):
     template_fields_renderers = {
         "container_overrides": "json",
         "parameters": "json",
+        "ecs_properties_override": "json",
         "node_overrides": "json",
         "retry_strategy": "json",
     }
@@ -160,6 +163,7 @@ class BatchOperator(BaseOperator):
         overrides: dict | None = None,  # deprecated
         container_overrides: dict | None = None,
         array_properties: dict | None = None,
+        ecs_properties_override: dict | None = None,
         node_overrides: dict | None = None,
         share_identifier: str | None = None,
         scheduling_priority_override: int | None = None,
@@ -201,6 +205,7 @@ class BatchOperator(BaseOperator):
                 stacklevel=2,
             )
 
+        self.ecs_properties_override = ecs_properties_override
         self.node_overrides = node_overrides
         self.share_identifier = share_identifier
         self.scheduling_priority_override = scheduling_priority_override
@@ -296,6 +301,8 @@ class BatchOperator(BaseOperator):
             self.log.info("AWS Batch job - container overrides: %s", 
self.container_overrides)
         if self.array_properties:
             self.log.info("AWS Batch job - array properties: %s", 
self.array_properties)
+        if self.ecs_properties_override:
+            self.log.info("AWS Batch job - ECS properties: %s", 
self.ecs_properties_override)
         if self.node_overrides:
             self.log.info("AWS Batch job - node properties: %s", 
self.node_overrides)
 
@@ -307,6 +314,7 @@ class BatchOperator(BaseOperator):
             "parameters": self.parameters,
             "tags": self.tags,
             "containerOverrides": self.container_overrides,
+            "ecsPropertiesOverride": self.ecs_properties_override,
             "nodeOverrides": self.node_overrides,
             "retryStrategy": self.retry_strategy,
             "shareIdentifier": self.share_identifier,
diff --git a/tests/providers/amazon/aws/operators/test_batch.py 
b/tests/providers/amazon/aws/operators/test_batch.py
index f769c1baa8..27f86e279c 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -132,6 +132,7 @@ class TestBatchOperator:
         assert batch_job.retry_strategy is None
         assert batch_job.container_overrides is None
         assert batch_job.array_properties is None
+        assert batch_job.ecs_properties_override is None
         assert batch_job.node_overrides is None
         assert batch_job.share_identifier is None
         assert batch_job.scheduling_priority_override is None
@@ -149,6 +150,7 @@ class TestBatchOperator:
             "job_queue",
             "container_overrides",
             "array_properties",
+            "ecs_properties_override",
             "node_overrides",
             "parameters",
             "retry_strategy",
@@ -204,6 +206,62 @@ class TestBatchOperator:
             tags={},
         )
 
+    @mock.patch.object(BatchClientHook, "get_job_description")
+    @mock.patch.object(BatchClientHook, "wait_for_job")
+    @mock.patch.object(BatchClientHook, "check_job_success")
+    def test_execute_with_ecs_overrides(self, check_mock, wait_mock, 
job_description_mock):
+        self.batch.container_overrides = None
+        self.batch.ecs_properties_override = {
+            "taskProperties": [
+                {
+                    "containers": [
+                        {
+                            "command": [
+                                "string",
+                            ],
+                            "environment": [
+                                {"name": "string", "value": "string"},
+                            ],
+                            "name": "string",
+                            "resourceRequirements": [
+                                {"value": "string", "type": 
"'GPU'|'VCPU'|'MEMORY'"},
+                            ],
+                        },
+                    ]
+                },
+            ]
+        }
+        self.batch.execute(self.mock_context)
+
+        self.client_mock.submit_job.assert_called_once_with(
+            jobQueue="queue",
+            jobName=JOB_NAME,
+            jobDefinition="hello-world",
+            ecsPropertiesOverride={
+                "taskProperties": [
+                    {
+                        "containers": [
+                            {
+                                "command": [
+                                    "string",
+                                ],
+                                "environment": [
+                                    {"name": "string", "value": "string"},
+                                ],
+                                "name": "string",
+                                "resourceRequirements": [
+                                    {"value": "string", "type": 
"'GPU'|'VCPU'|'MEMORY'"},
+                                ],
+                            },
+                        ]
+                    },
+                ]
+            },
+            parameters={},
+            retryStrategy={"attempts": 1},
+            tags={},
+        )
+
     @mock.patch.object(BatchClientHook, "check_job_success")
     def test_wait_job_complete_using_waiters(self, check_mock):
         mock_waiters = mock.Mock()
@@ -238,7 +296,7 @@ class TestBatchOperator:
         self.batch.on_kill()
         self.client_mock.terminate_job.assert_called_once_with(jobId=JOB_ID, 
reason="Task killed by the user")
 
-    @pytest.mark.parametrize("override", ["overrides", "node_overrides"])
+    @pytest.mark.parametrize("override", ["overrides", "node_overrides", 
"ecs_properties_override"])
     @patch(
         
"airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.client",
         new_callable=mock.PropertyMock,
@@ -269,10 +327,15 @@ class TestBatchOperator:
             "parameters": {},
             "tags": {},
         }
-        if override == "overrides":
-            expected_args["containerOverrides"] = {"a": "a"}
-        else:
-            expected_args["nodeOverrides"] = {"a": "a"}
+
+        py2api = {
+            "overrides": "containerOverrides",
+            "node_overrides": "nodeOverrides",
+            "ecs_properties_override": "ecsPropertiesOverride",
+        }
+
+        expected_args[py2api[override]] = {"a": "a"}
+
         client_mock().submit_job.assert_called_once_with(**expected_args)
 
     def test_deprecated_override_param(self):

Reply via email to