This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 12944002aa Add fallback `region_name` value to AWS Executors (#38704)
12944002aa is described below

commit 12944002aa66c8eda5c2e6e99c7924ede5831bd1
Author: Andrey Anshin <[email protected]>
AuthorDate: Wed Apr 3 18:29:03 2024 +0400

    Add fallback `region_name` value to AWS Executors (#38704)
    
    * Add fallback `region_name` value to AWS Executors
    
    * Get rid of os.environ in test_ecs_executor.py
---
 .../amazon/aws/executors/batch/batch_executor.py   |   2 +-
 .../amazon/aws/executors/ecs/ecs_executor.py       |   2 +-
 .../aws/executors/batch/test_batch_executor.py     |  17 +-
 .../amazon/aws/executors/ecs/test_ecs_executor.py  | 207 +++++++++------------
 4 files changed, 101 insertions(+), 127 deletions(-)

diff --git a/airflow/providers/amazon/aws/executors/batch/batch_executor.py 
b/airflow/providers/amazon/aws/executors/batch/batch_executor.py
index 1bd5135dea..5fb0076966 100644
--- a/airflow/providers/amazon/aws/executors/batch/batch_executor.py
+++ b/airflow/providers/amazon/aws/executors/batch/batch_executor.py
@@ -153,7 +153,7 @@ class AwsBatchExecutor(BaseExecutor):
             AllBatchConfigKeys.AWS_CONN_ID,
             fallback=CONFIG_DEFAULTS[AllBatchConfigKeys.AWS_CONN_ID],
         )
-        region_name = conf.get(CONFIG_GROUP_NAME, 
AllBatchConfigKeys.REGION_NAME)
+        region_name = conf.get(CONFIG_GROUP_NAME, 
AllBatchConfigKeys.REGION_NAME, fallback=None)
         self.batch = BatchClientHook(aws_conn_id=aws_conn_id, 
region_name=region_name).conn
         self.attempts_since_last_successful_connection += 1
         self.last_connection_reload = timezone.utcnow()
diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py 
b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
index 87b9abe435..aec7762d87 100644
--- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
+++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
@@ -183,7 +183,7 @@ class AwsEcsExecutor(BaseExecutor):
             AllEcsConfigKeys.AWS_CONN_ID,
             fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
         )
-        region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME)
+        region_name = conf.get(CONFIG_GROUP_NAME, 
AllEcsConfigKeys.REGION_NAME, fallback=None)
         self.ecs = EcsHook(aws_conn_id=aws_conn_id, 
region_name=region_name).conn
         self.attempts_since_last_successful_connection += 1
         self.last_connection_reload = timezone.utcnow()
diff --git a/tests/providers/amazon/aws/executors/batch/test_batch_executor.py 
b/tests/providers/amazon/aws/executors/batch/test_batch_executor.py
index c34b696210..e8ad0e4592 100644
--- a/tests/providers/amazon/aws/executors/batch/test_batch_executor.py
+++ b/tests/providers/amazon/aws/executors/batch/test_batch_executor.py
@@ -41,6 +41,7 @@ from airflow.providers.amazon.aws.executors.batch.utils 
import (
 )
 from airflow.utils.helpers import convert_camel_to_snake
 from airflow.utils.state import State
+from tests.test_utils.config import conf_vars
 
 ARN1 = "arn1"
 
@@ -49,11 +50,15 @@ MOCK_JOB_ID = "batch-job-id"
 
 @pytest.fixture
 def set_env_vars():
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.REGION_NAME}".upper()]
 = "us-west-1"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.JOB_NAME}".upper()]
 = "some-job-name"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.JOB_QUEUE}".upper()]
 = "some-job-queue"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.JOB_DEFINITION}".upper()]
 = "some-job-def"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllBatchConfigKeys.MAX_SUBMIT_JOB_ATTEMPTS}".upper()]
 = "3"
+    overrides: dict[tuple[str, str], str] = {
+        (CONFIG_GROUP_NAME, AllBatchConfigKeys.REGION_NAME): "us-east-1",
+        (CONFIG_GROUP_NAME, AllBatchConfigKeys.JOB_NAME): "some-job-name",
+        (CONFIG_GROUP_NAME, AllBatchConfigKeys.JOB_QUEUE): "some-job-queue",
+        (CONFIG_GROUP_NAME, AllBatchConfigKeys.JOB_DEFINITION): "some-job-def",
+        (CONFIG_GROUP_NAME, AllBatchConfigKeys.MAX_SUBMIT_JOB_ATTEMPTS): "3",
+    }
+    with conf_vars(overrides):
+        yield
 
 
 @pytest.fixture
@@ -514,7 +519,7 @@ class TestAwsBatchExecutor:
     @mock.patch(
         
"airflow.providers.amazon.aws.executors.batch.boto_schema.BatchDescribeJobsResponseSchema.load"
     )
-    def test_health_check_failure(self, mock_executor):
+    def test_health_check_failure(self, mock_executor, set_env_vars):
         mock_executor.batch.describe_jobs.side_effect = 
Exception("Test_failure")
         executor = AwsBatchExecutor()
         batch_mock = mock.Mock(spec=executor.batch)
diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py 
b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
index 8762480821..4110483162 100644
--- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -54,6 +54,7 @@ from airflow.providers.amazon.aws.hooks.ecs import EcsHook
 from airflow.utils.helpers import convert_camel_to_snake
 from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.timezone import utcnow
+from tests.test_utils.config import conf_vars
 
 pytestmark = pytest.mark.db_test
 
@@ -130,16 +131,20 @@ def mock_config():
 
 @pytest.fixture
 def set_env_vars():
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.REGION_NAME}".upper()]
 = "us-west-1"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CLUSTER}".upper()] 
= "some-cluster"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}".upper()]
 = "container-name"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.TASK_DEFINITION}".upper()]
 = "some-task-def"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}".upper()]
 = "FARGATE"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.PLATFORM_VERSION}".upper()]
 = "LATEST"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.ASSIGN_PUBLIC_IP}".upper()]
 = "False"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SECURITY_GROUPS}".upper()]
 = "sg1,sg2"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}".upper()] 
= "sub1,sub2"
-    
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS}".upper()]
 = "3"
+    overrides: dict[tuple[str, str], str] = {
+        (CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME): "us-west-1",
+        (CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER): "some-cluster",
+        (CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME): "container-name",
+        (CONFIG_GROUP_NAME, AllEcsConfigKeys.TASK_DEFINITION): "some-task-def",
+        (CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE): "FARGATE",
+        (CONFIG_GROUP_NAME, AllEcsConfigKeys.PLATFORM_VERSION): "LATEST",
+        (CONFIG_GROUP_NAME, AllEcsConfigKeys.ASSIGN_PUBLIC_IP): "False",
+        (CONFIG_GROUP_NAME, AllEcsConfigKeys.SECURITY_GROUPS): "sg1,sg2",
+        (CONFIG_GROUP_NAME, AllEcsConfigKeys.SUBNETS): "sub1,sub2",
+        (CONFIG_GROUP_NAME, AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS): "3",
+    }
+    with conf_vars(overrides):
+        yield
 
 
 @pytest.fixture
@@ -362,9 +367,6 @@ class TestEcsExecutorTask:
 class TestAwsEcsExecutor:
     """Tests the AWS ECS Executor."""
 
-    def teardown_method(self) -> None:
-        self._unset_conf()
-
     def test_execute(self, mock_airflow_key, mock_executor):
         """Test execution from end-to-end."""
         airflow_key = mock_airflow_key()
@@ -1011,12 +1013,6 @@ class TestAwsEcsExecutor:
         )
         assert 0 == len(mock_executor.pending_tasks)
 
-    @staticmethod
-    def _unset_conf():
-        for env in os.environ:
-            if env.startswith(f"AIRFLOW__{CONFIG_GROUP_NAME.upper()}__"):
-                os.environ.pop(env)
-
     def _mock_sync(
         self,
         executor: AwsEcsExecutor,
@@ -1193,18 +1189,16 @@ class TestAwsEcsExecutor:
 class TestEcsExecutorConfig:
     @pytest.fixture
     def assign_subnets(self):
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}".upper()] 
= "sub1,sub2"
+        with conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.SUBNETS): 
"sub1,sub2"}):
+            yield
 
-    @staticmethod
-    def teardown_method() -> None:
-        for env in os.environ:
-            if env.startswith(f"AIRFLOW__{CONFIG_GROUP_NAME}__".upper()):
-                os.environ.pop(env)
+    @pytest.fixture
+    def assign_container_name(self):
+        with conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME): 
"foobar"}):
+            yield
 
     def test_flatten_dict(self):
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}".upper()] 
= "sub1,sub2"
         nested_dict = {"a": "a", "b": "b", "c": {"d": "d"}}
-
         assert _recursive_flatten_dict(nested_dict) == {"a": "a", "b": "b", 
"d": "d"}
 
     def test_validate_config_defaults(self):
@@ -1224,29 +1218,24 @@ class TestEcsExecutorConfig:
             assert file_defaults[key] == CONFIG_DEFAULTS[key]
 
     def test_subnets_required(self):
-        assert 
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}".upper() not in 
os.environ
-
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.REGION_NAME}".upper()]
 = "us-west-1"
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CLUSTER}".upper()] 
= "some-cluster"
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}".upper()]
 = (
-            "container-name"
-        )
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.TASK_DEFINITION}".upper()]
 = (
-            "some-task-def"
-        )
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}".upper()]
 = "FARGATE"
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.PLATFORM_VERSION}".upper()]
 = "LATEST"
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.ASSIGN_PUBLIC_IP}".upper()]
 = "False"
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SECURITY_GROUPS}".upper()]
 = "sg1,sg2"
-
-        with pytest.raises(ValueError) as raised:
-            ecs_executor_config.build_task_kwargs()
+        conf_overrides = {
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.SUBNETS): None,
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME): "us-west-1",
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER): "some-cluster",
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME): 
"container-name",
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.TASK_DEFINITION): 
"some-task-def",
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE): "FARGATE",
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.PLATFORM_VERSION): "LATEST",
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.ASSIGN_PUBLIC_IP): "False",
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.SECURITY_GROUPS): "sg1,sg2",
+        }
+        with conf_vars(conf_overrides):
+            with pytest.raises(ValueError) as raised:
+                ecs_executor_config.build_task_kwargs()
         assert raised.match("At least one subnet is required to run a task.")
 
+    @conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME): 
"container-name"})
     def test_config_defaults_are_applied(self, assign_subnets):
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}".upper()]
 = (
-            "container-name"
-        )
         from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
 
         task_kwargs = 
_recursive_flatten_dict(ecs_executor_config.build_task_kwargs())
@@ -1268,54 +1257,50 @@ class TestEcsExecutorConfig:
                     expected_value = parse_assign_public_ip(expected_value)
                 assert expected_value == task_kwargs[found_keys[expected_key]]
 
-    def test_provided_values_override_defaults(self, assign_subnets):
+    def test_provided_values_override_defaults(self, assign_subnets, 
assign_container_name, monkeypatch):
         """
         Expected precedence is default values are overwritten by values 
provided explicitly,
         and those values are overwritten by those provided in run_task_kwargs.
         """
-        default_version = CONFIG_DEFAULTS[AllEcsConfigKeys.PLATFORM_VERSION]
-        templated_version = "1"
-        first_explicit_version = "2"
-        second_explicit_version = "3"
-
         run_task_kwargs_env_key = 
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper()
         platform_version_env_key = (
             
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.PLATFORM_VERSION}".upper()
         )
-        # Required param which doesn't have a default
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}".upper()]
 = "foobar"
+        default_version = CONFIG_DEFAULTS[AllEcsConfigKeys.PLATFORM_VERSION]
+        templated_version = "1"
+        first_explicit_version = "2"
+        second_explicit_version = "3"
 
         # Confirm the default value is applied when no value is provided.
-        assert run_task_kwargs_env_key not in os.environ
-        assert platform_version_env_key not in os.environ
+        monkeypatch.delenv(platform_version_env_key, raising=False)
+        monkeypatch.delenv(run_task_kwargs_env_key, raising=False)
         from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
 
         task_kwargs = ecs_executor_config.build_task_kwargs()
-
         assert task_kwargs["platformVersion"] == default_version
 
         # Provide a new value explicitly and assert that it is applied over 
the default.
-        os.environ[platform_version_env_key] = first_explicit_version
+        monkeypatch.setenv(platform_version_env_key, first_explicit_version)
         task_kwargs = ecs_executor_config.build_task_kwargs()
-
         assert task_kwargs["platformVersion"] == first_explicit_version
 
         # Provide a value via template and assert that it is applied over the 
explicit value.
-        os.environ[run_task_kwargs_env_key] = json.dumps(
-            {AllEcsConfigKeys.PLATFORM_VERSION: templated_version}
+        monkeypatch.setenv(
+            run_task_kwargs_env_key,
+            json.dumps({AllEcsConfigKeys.PLATFORM_VERSION: templated_version}),
         )
         task_kwargs = ecs_executor_config.build_task_kwargs()
-
         assert task_kwargs["platformVersion"] == templated_version
 
         # Provide a new value explicitly and assert it is not applied over the 
templated values.
-        os.environ[platform_version_env_key] = second_explicit_version
+        monkeypatch.setenv(platform_version_env_key, second_explicit_version)
         task_kwargs = ecs_executor_config.build_task_kwargs()
-
         assert task_kwargs["platformVersion"] == templated_version
 
     @mock.patch.object(EcsHook, "conn")
-    def test_count_can_not_be_modified_by_the_user(self, _, assign_subnets):
+    def test_count_can_not_be_modified_by_the_user(
+        self, _, assign_subnets, assign_container_name, monkeypatch
+    ):
         """The ``count`` parameter must always be 1; verify that the user can 
not override this value."""
         templated_version = "1"
         templated_cluster = "templated_cluster_name"
@@ -1325,13 +1310,12 @@ class TestEcsExecutorConfig:
             "count": 2,  # The user should not be allowed to overwrite count, 
it must be value of 1
         }
 
-        run_task_kwargs_env_key = 
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper()
-        # Required param which doesn't have a default
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}".upper()]
 = "foobar"
-
         # Provide values via task run kwargs template and assert that they are 
applied,
         # which verifies that the OTHER values were changed.
-        os.environ[run_task_kwargs_env_key] = 
json.dumps(provided_run_task_kwargs)
+        monkeypatch.setenv(
+            
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper(),
+            json.dumps(provided_run_task_kwargs),
+        )
         task_kwargs = ecs_executor_config.build_task_kwargs()
         assert task_kwargs["platformVersion"] == templated_version
         assert task_kwargs["cluster"] == templated_cluster
@@ -1339,7 +1323,7 @@ class TestEcsExecutorConfig:
         # Assert that count was NOT overridden when the others were applied.
         assert task_kwargs["count"] == 1
 
-    def test_verify_tags_are_used_as_provided(self, assign_subnets):
+    def test_verify_tags_are_used_as_provided(self, assign_subnets, 
assign_container_name, monkeypatch):
         """Confirm that the ``tags`` provided are not converted to 
camelCase."""
         templated_tags = {"Apache": "Airflow"}
 
@@ -1348,16 +1332,13 @@ class TestEcsExecutorConfig:
         }
 
         run_task_kwargs_env_key = 
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper()
-        # Required param which doesn't have a default
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}".upper()]
 = "foobar"
-
-        os.environ[run_task_kwargs_env_key] = 
json.dumps(provided_run_task_kwargs)
+        monkeypatch.setenv(run_task_kwargs_env_key, 
json.dumps(provided_run_task_kwargs))
         task_kwargs = ecs_executor_config.build_task_kwargs()
 
         # Verify that tag names are exempt from the camel-case conversion.
         assert task_kwargs["tags"] == templated_tags
 
-    def test_that_provided_kwargs_are_moved_to_correct_nesting(self, 
assign_subnets):
+    def test_that_provided_kwargs_are_moved_to_correct_nesting(self, 
monkeypatch):
         """
         kwargs such as subnets, security groups,  public ip, and container 
name are valid run task kwargs,
         but they are not placed at the root of the kwargs dict, they should be 
nested in various sub dicts.
@@ -1370,7 +1351,7 @@ class TestEcsExecutorConfig:
             AllEcsConfigKeys.SUBNETS: "sub1,sub2",
         }
         for key, value in kwargs_to_test.items():
-            os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{key}".upper()] = value
+            monkeypatch.setenv(f"AIRFLOW__{CONFIG_GROUP_NAME}__{key}".upper(), 
value)
 
         run_task_kwargs = ecs_executor_config.build_task_kwargs()
         run_task_kwargs_network_config = 
run_task_kwargs["networkConfiguration"]["awsvpcConfiguration"]
@@ -1450,43 +1431,35 @@ class TestEcsExecutorConfig:
 
         executor.ecs = ecs_mock
 
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP}".upper()]
 = (
-            "False"
-        )
-
-        executor.start()
+        with conf_vars({(CONFIG_GROUP_NAME, 
AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP): "False"}):
+            executor.start()
 
         ecs_mock.stop_task.assert_not_called()
 
-    def test_providing_both_capacity_provider_and_launch_type_fails(self, 
set_env_vars):
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CAPACITY_PROVIDER_STRATEGY}".upper()]
 = (
-            "[{'capacityProvider': 'cp1', 'weight': 5}, {'capacityProvider': 
'cp2', 'weight': 1}]"
-        )
-        expected_error = (
+    def test_providing_both_capacity_provider_and_launch_type_fails(self, 
set_env_vars, monkeypatch):
+        cps = "[{'capacityProvider': 'cp1', 'weight': 5}, {'capacityProvider': 
'cp2', 'weight': 1}]"
+        expected_error = re.escape(
             "capacity_provider_strategy and launch_type are mutually 
exclusive, you can not provide both."
         )
-
-        with pytest.raises(ValueError, match=expected_error):
-            AwsEcsExecutor()
+        with conf_vars({(CONFIG_GROUP_NAME, 
AllEcsConfigKeys.CAPACITY_PROVIDER_STRATEGY): cps}):
+            with pytest.raises(ValueError, match=expected_error):
+                AwsEcsExecutor()
 
     def test_providing_capacity_provider(self, set_env_vars):
         # If a capacity provider strategy is supplied without a launch type, 
use the strategy.
-
         valid_capacity_provider = (
             "[{'capacityProvider': 'cp1', 'weight': 5}, {'capacityProvider': 
'cp2', 'weight': 1}]"
         )
+        conf_overrides = {
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.CAPACITY_PROVIDER_STRATEGY): 
valid_capacity_provider,
+            (CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE): None,
+        }
+        with conf_vars(conf_overrides):
+            from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
 
-        
os.environ[f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CAPACITY_PROVIDER_STRATEGY}".upper()]
 = (
-            valid_capacity_provider
-        )
-        
os.environ.pop(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}".upper())
-
-        from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
-
-        task_kwargs = ecs_executor_config.build_task_kwargs()
-
-        assert "launchType" not in task_kwargs
-        assert task_kwargs["capacityProviderStrategy"] == 
valid_capacity_provider
+            task_kwargs = ecs_executor_config.build_task_kwargs()
+            assert "launchType" not in task_kwargs
+            assert task_kwargs["capacityProviderStrategy"] == 
valid_capacity_provider
 
     @mock.patch.object(EcsHook, "conn")
     def 
test_providing_no_capacity_provider_no_lunch_type_with_cluster_default(self, 
mock_conn, set_env_vars):
@@ -1495,28 +1468,24 @@ class TestEcsExecutorConfig:
         mock_conn.describe_clusters.return_value = {
             "clusters": [{"defaultCapacityProviderStrategy": 
["some_strategy"]}]
         }
-        
os.environ.pop(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}".upper())
+        with conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE): 
None}):
+            from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
 
-        from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
-
-        task_kwargs = ecs_executor_config.build_task_kwargs()
-        assert "launchType" not in task_kwargs
-        assert "capacityProviderStrategy" not in task_kwargs
-        mock_conn.describe_clusters.assert_called_once()
+            task_kwargs = ecs_executor_config.build_task_kwargs()
+            assert "launchType" not in task_kwargs
+            assert "capacityProviderStrategy" not in task_kwargs
+            mock_conn.describe_clusters.assert_called_once()
 
     @mock.patch.object(EcsHook, "conn")
     def 
test_providing_no_capacity_provider_no_lunch_type_no_cluster_default(self, 
mock_conn, set_env_vars):
         # If no capacity provider strategy is supplied and no launch type, and 
the cluster
         # does not have a default capacity provider strategy, use the FARGATE 
launch type.
-
         mock_conn.describe_clusters.return_value = {"clusters": [{"status": 
"ACTIVE"}]}
+        with conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE): 
None}):
+            from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
 
-        
os.environ.pop(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}".upper())
-
-        from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
-
-        task_kwargs = ecs_executor_config.build_task_kwargs()
-        assert task_kwargs["launchType"] == "FARGATE"
+            task_kwargs = ecs_executor_config.build_task_kwargs()
+            assert task_kwargs["launchType"] == "FARGATE"
 
     @pytest.mark.parametrize(
         "run_task_kwargs, exec_config, expected_result",
@@ -1721,10 +1690,10 @@ class TestEcsExecutorConfig:
         ],
     )
     def test_run_task_kwargs_exec_config_overrides(
-        self, set_env_vars, run_task_kwargs, exec_config, expected_result
+        self, set_env_vars, run_task_kwargs, exec_config, expected_result, 
monkeypatch
     ):
         run_task_kwargs_env_key = 
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper()
-        os.environ[run_task_kwargs_env_key] = json.dumps(run_task_kwargs)
+        monkeypatch.setenv(run_task_kwargs_env_key, 
json.dumps(run_task_kwargs))
 
         mock_ti_key = mock.Mock(spec=TaskInstanceKey)
         command = ["command"]

Reply via email to