This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e2c73fd35ca AIP-67 - Multi-team: Per team executor config (env var
only) (#55003)
e2c73fd35ca is described below
commit e2c73fd35cab3274307b0f4a08615fdbc905f53c
Author: Niko Oliveira <[email protected]>
AuthorDate: Sun Sep 7 17:08:51 2025 -0700
AIP-67 - Multi-team: Per team executor config (env var only) (#55003)
* Multi-team: per team executor config (env var only)
Configuration for teams can now be specified via environment variable
using the triple underscore syntax outlined in AIP-67. This applies to
any configuration, but specifically is required for executor based
configuration.
A small shim has been added to BaseExecutor to allow easier access to
team based config.
ECS executor is converted to this new shim as a proof of concept for the
mechanism.
* PR Feedback: comment fixup
---
airflow-core/src/airflow/configuration.py | 21 ++--
.../src/airflow/executors/base_executor.py | 23 ++++
airflow-core/tests/unit/core/test_configuration.py | 11 ++
.../amazon/aws/executors/ecs/ecs_executor.py | 45 ++++----
.../aws/executors/ecs/ecs_executor_config.py | 17 +--
.../amazon/aws/executors/ecs/test_ecs_executor.py | 117 ++++++++++++++++-----
6 files changed, 177 insertions(+), 57 deletions(-)
diff --git a/airflow-core/src/airflow/configuration.py
b/airflow-core/src/airflow/configuration.py
index 8c384e64266..24a439e848e 100644
--- a/airflow-core/src/airflow/configuration.py
+++ b/airflow-core/src/airflow/configuration.py
@@ -877,12 +877,14 @@ class AirflowConfigParser(ConfigParser):
mask_secret_core(value)
mask_secret_sdk(value)
- def _env_var_name(self, section: str, key: str) -> str:
- return f"{ENV_VAR_PREFIX}{section.replace('.',
'_').upper()}__{key.upper()}"
-
- def _get_env_var_option(self, section: str, key: str):
- # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore)
- env_var = self._env_var_name(section, key)
+ def _env_var_name(self, section: str, key: str, team_name: str | None =
None) -> str:
+ team_component: str = f"{team_name.upper()}___" if team_name else ""
+ return f"{ENV_VAR_PREFIX}{team_component}{section.replace('.',
'_').upper()}__{key.upper()}"
+
+ def _get_env_var_option(self, section: str, key: str, team_name: str |
None = None):
+ # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore)
OR for team based
+ # configuration must have the format
AIRFLOW__{TEAM_NAME}___{SECTION}__{KEY}
+ env_var: str = self._env_var_name(section, key, team_name=team_name)
if env_var in os.environ:
return expand_env_var(os.environ[env_var])
# alternatively AIRFLOW__{SECTION}__{KEY}_CMD (for a command)
@@ -982,6 +984,7 @@ class AirflowConfigParser(ConfigParser):
suppress_warnings: bool = False,
lookup_from_deprecated: bool = True,
_extra_stacklevel: int = 0,
+ team_name: str | None = None,
**kwargs,
) -> str | None:
section = section.lower()
@@ -1044,6 +1047,7 @@ class AirflowConfigParser(ConfigParser):
section,
issue_warning=not warning_emitted,
extra_stacklevel=_extra_stacklevel,
+ team_name=team_name,
)
if option is not None:
return option
@@ -1170,13 +1174,14 @@ class AirflowConfigParser(ConfigParser):
section: str,
issue_warning: bool = True,
extra_stacklevel: int = 0,
+ team_name: str | None = None,
) -> str | None:
- option = self._get_env_var_option(section, key)
+ option = self._get_env_var_option(section, key, team_name=team_name)
if option is not None:
return option
if deprecated_section and deprecated_key:
with self.suppress_future_warnings():
- option = self._get_env_var_option(deprecated_section,
deprecated_key)
+ option = self._get_env_var_option(deprecated_section,
deprecated_key, team_name=team_name)
if option is not None:
if issue_warning:
self._warn_deprecate(section, key, deprecated_section,
deprecated_key, extra_stacklevel)
diff --git a/airflow-core/src/airflow/executors/base_executor.py
b/airflow-core/src/airflow/executors/base_executor.py
index 1723d94708b..70b9ba9ef08 100644
--- a/airflow-core/src/airflow/executors/base_executor.py
+++ b/airflow-core/src/airflow/executors/base_executor.py
@@ -97,6 +97,28 @@ class RunningRetryAttemptType:
return True
+class ExecutorConf:
+ """
+ This class is used to fetch configuration for an executor for a particular
team_name.
+
+ It wraps the implementation of the configuration.get() to look for the
particular section and key
+ prefixed with the team_name. This makes it easy for child classes (i.e.
concrete executors) to fetch
+ configuration values for a particular team_name without having to worry
about passing through the
+ team_name for every call to get configuration.
+
+ Currently config only supports environment variables for team specific
configuration.
+ """
+
+ def __init__(self, team_name: str | None = None) -> None:
+ self.team_name: str | None = team_name
+
+ def get(self, *args, **kwargs) -> str | None:
+ return conf.get(*args, **kwargs, team_name=self.team_name)
+
+ def getboolean(self, *args, **kwargs) -> bool:
+ return conf.getboolean(*args, **kwargs, team_name=self.team_name)
+
+
class BaseExecutor(LoggingMixin):
"""
Base class to inherit for concrete executors such as Celery, Kubernetes,
Local, etc.
@@ -150,6 +172,7 @@ class BaseExecutor(LoggingMixin):
self.running: set[TaskInstanceKey] = set()
self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
self._task_event_logs: deque[Log] = deque()
+ self.conf = ExecutorConf(team_name)
if self.parallelism <= 0:
raise ValueError("parallelism is set to 0 or lower")
diff --git a/airflow-core/tests/unit/core/test_configuration.py
b/airflow-core/tests/unit/core/test_configuration.py
index 3e518294a43..757804428a4 100644
--- a/airflow-core/tests/unit/core/test_configuration.py
+++ b/airflow-core/tests/unit/core/test_configuration.py
@@ -158,6 +158,17 @@ class TestConf:
assert conf.has_option("testsection", "testkey")
+ def test_env_team(self):
+ with patch(
+ "os.environ",
+ {
+ "AIRFLOW__CELERY__RESULT_BACKEND": "FOO",
+ "AIRFLOW__UNIT_TEST_TEAM___CELERY__RESULT_BACKEND": "BAR",
+ },
+ ):
+ assert conf.get("celery", "result_backend") == "FOO"
+ assert conf.get("celery", "result_backend",
team_name="unit_test_team") == "BAR"
+
@conf_vars({("core", "percent"): "with%%inside"})
def test_conf_as_dict(self):
cfg_dict = conf.as_dict()
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
index a7fe158f14c..396e998af61 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
@@ -32,7 +32,6 @@ from typing import TYPE_CHECKING
from botocore.exceptions import ClientError, NoCredentialsError
-from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.amazon.aws.executors.ecs.boto_schema import
BotoDescribeTasksSchema, BotoRunTaskSchema
@@ -98,13 +97,6 @@ class AwsEcsExecutor(BaseExecutor):
Airflow TaskInstance's executor_config.
"""
- # Maximum number of retries to run an ECS task.
- MAX_RUN_TASK_ATTEMPTS = conf.get(
- CONFIG_GROUP_NAME,
- AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS,
- fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS],
- )
-
# AWS limits the maximum number of ARNs in the describe_tasks function.
DESCRIBE_TASKS_BATCH_SIZE = 99
@@ -118,8 +110,18 @@ class AwsEcsExecutor(BaseExecutor):
self.active_workers: EcsTaskCollection = EcsTaskCollection()
self.pending_tasks: deque = deque()
- self.cluster = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER)
- self.container_name = conf.get(CONFIG_GROUP_NAME,
AllEcsConfigKeys.CONTAINER_NAME)
+ # Check if self has the ExecutorConf set on the self.conf attribute,
and if not, set it to the global
+ # configuration object. This allows the changes to be backwards
compatible with older versions of
+ # Airflow.
+ # Can be removed when minimum supported provider version is equal to
the version of core airflow
+ # which introduces multi-team configuration.
+ if not hasattr(self, "conf"):
+ from airflow.configuration import conf
+
+ self.conf = conf
+
+ self.cluster = self.conf.get(CONFIG_GROUP_NAME,
AllEcsConfigKeys.CLUSTER)
+ self.container_name = self.conf.get(CONFIG_GROUP_NAME,
AllEcsConfigKeys.CONTAINER_NAME)
self.attempts_since_last_successful_connection = 0
self.load_ecs_connection(check_connection=False)
@@ -127,6 +129,13 @@ class AwsEcsExecutor(BaseExecutor):
self.run_task_kwargs = self._load_run_kwargs()
+ # Maximum number of retries to run an ECS task.
+ self.max_run_task_attempts = self.conf.get(
+ CONFIG_GROUP_NAME,
+ AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS,
+ fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS],
+ )
+
def queue_workload(self, workload: workloads.All, session: Session | None)
-> None:
from airflow.executors import workloads
@@ -154,7 +163,7 @@ class AwsEcsExecutor(BaseExecutor):
def start(self):
"""Call this when the Executor is run for the first time by the
scheduler."""
- check_health = conf.getboolean(
+ check_health = self.conf.getboolean(
CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP,
fallback=False
)
@@ -218,12 +227,12 @@ class AwsEcsExecutor(BaseExecutor):
def load_ecs_connection(self, check_connection: bool = True):
self.log.info("Loading Connection information")
- aws_conn_id = conf.get(
+ aws_conn_id = self.conf.get(
CONFIG_GROUP_NAME,
AllEcsConfigKeys.AWS_CONN_ID,
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
)
- region_name = conf.get(CONFIG_GROUP_NAME,
AllEcsConfigKeys.REGION_NAME, fallback=None)
+ region_name = self.conf.get(CONFIG_GROUP_NAME,
AllEcsConfigKeys.REGION_NAME, fallback=None)
self.ecs = EcsHook(aws_conn_id=aws_conn_id,
region_name=region_name).conn
self.attempts_since_last_successful_connection += 1
self.last_connection_reload = timezone.utcnow()
@@ -340,13 +349,13 @@ class AwsEcsExecutor(BaseExecutor):
queue = task_info.queue
exec_info = task_info.config
failure_count = self.active_workers.failure_count_by_key(task_key)
- if int(failure_count) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
+ if int(failure_count) < int(self.max_run_task_attempts):
self.log.warning(
"Airflow task %s failed due to %s. Failure %s out of %s
occurred on %s. Rescheduling.",
task_key,
reason,
failure_count,
- self.__class__.MAX_RUN_TASK_ATTEMPTS,
+ self.max_run_task_attempts,
task_arn,
)
self.pending_tasks.append(
@@ -416,8 +425,8 @@ class AwsEcsExecutor(BaseExecutor):
failure_reasons.extend([f["reason"] for f in
run_task_response["failures"]])
if failure_reasons:
- # Make sure the number of attempts does not exceed
MAX_RUN_TASK_ATTEMPTS
- if int(attempt_number) <
int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
+ # Make sure the number of attempts does not exceed
max_run_task_attempts
+ if int(attempt_number) < int(self.max_run_task_attempts):
ecs_task.attempt_number += 1
ecs_task.next_attempt_time = timezone.utcnow() +
calculate_next_attempt_delay(
attempt_number
@@ -545,7 +554,7 @@ class AwsEcsExecutor(BaseExecutor):
def _load_run_kwargs(self) -> dict:
from airflow.providers.amazon.aws.executors.ecs.ecs_executor_config
import build_task_kwargs
- ecs_executor_run_task_kwargs = build_task_kwargs()
+ ecs_executor_run_task_kwargs = build_task_kwargs(self.conf)
try:
self.get_container(ecs_executor_run_task_kwargs["overrides"]["containerOverrides"])["command"]
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py
index bcbbbfc9e8c..f7753903a2f 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py
@@ -32,7 +32,6 @@ from __future__ import annotations
import json
from json import JSONDecodeError
-from airflow.configuration import conf
from airflow.providers.amazon.aws.executors.ecs.utils import (
CONFIG_GROUP_NAME,
ECS_LAUNCH_TYPE_EC2,
@@ -46,23 +45,27 @@ from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.utils.helpers import prune_dict
-def _fetch_templated_kwargs() -> dict[str, str]:
- run_task_kwargs_value = conf.get(CONFIG_GROUP_NAME,
AllEcsConfigKeys.RUN_TASK_KWARGS, fallback=dict())
+def _fetch_templated_kwargs(conf) -> dict[str, str]:
+ run_task_kwargs_value = conf.get(
+ CONFIG_GROUP_NAME,
+ AllEcsConfigKeys.RUN_TASK_KWARGS,
+ fallback=dict(),
+ )
return json.loads(str(run_task_kwargs_value))
-def _fetch_config_values() -> dict[str, str]:
+def _fetch_config_values(conf) -> dict[str, str]:
return prune_dict(
{key: conf.get(CONFIG_GROUP_NAME, key, fallback=None) for key in
RunTaskKwargsConfigKeys()}
)
-def build_task_kwargs() -> dict:
+def build_task_kwargs(conf) -> dict:
all_config_keys = AllEcsConfigKeys()
# This will put some kwargs at the root of the dictionary that do NOT
belong there. However,
# the code below expects them to be there and will rearrange them as
necessary.
- task_kwargs = _fetch_config_values()
- task_kwargs.update(_fetch_templated_kwargs())
+ task_kwargs = _fetch_config_values(conf)
+ task_kwargs.update(_fetch_templated_kwargs(conf))
has_launch_type: bool = all_config_keys.LAUNCH_TYPE in task_kwargs
has_capacity_provider: bool = all_config_keys.CAPACITY_PROVIDER_STRATEGY
in task_kwargs
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
index 6d0b0055906..8f5fe85bf10 100644
--- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -25,7 +25,7 @@ import time
from collections.abc import Callable
from functools import partial
from unittest import mock
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, patch
import pytest
import yaml
@@ -33,7 +33,9 @@ from botocore.exceptions import ClientError
from inflection import camelize
from semver import VersionInfo
+from airflow.configuration import conf
from airflow.exceptions import AirflowException
+from airflow.executors import base_executor
from airflow.executors.base_executor import BaseExecutor
from airflow.models import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
@@ -538,7 +540,7 @@ class TestAwsEcsExecutor:
mock_executor.execute_async(mock_airflow_key, mock_cmd)
# No matter what, don't schedule until run_task becomes successful.
- for _ in range(int(mock_executor.MAX_RUN_TASK_ATTEMPTS) * 2):
+ for _ in range(int(mock_executor.max_run_task_attempts) * 2):
mock_executor.attempt_task_runs()
# Task is not stored in active workers.
assert len(mock_executor.active_workers) == 0
@@ -555,7 +557,7 @@ class TestAwsEcsExecutor:
mock_executor.execute_async(mock_airflow_key, mock_cmd)
# No matter what, don't schedule until run_task becomes successful.
- for _ in range(int(mock_executor.MAX_RUN_TASK_ATTEMPTS) * 2):
+ for _ in range(int(mock_executor.max_run_task_attempts) * 2):
mock_executor.attempt_task_runs()
# Task is not stored in active workers.
assert len(mock_executor.active_workers) == 0
@@ -567,7 +569,7 @@ class TestAwsEcsExecutor:
The executor should attempt each task exactly once per sync()
iteration.
It should preserve the order of tasks, and attempt each task up to
- `MAX_RUN_TASK_ATTEMPTS` times before dropping the task.
+ `max_run_task_attempts` times before dropping the task.
"""
airflow_keys = [
TaskInstanceKey("a", "task_a", "c", 1, -1),
@@ -627,7 +629,7 @@ class TestAwsEcsExecutor:
The executor should attempt each task exactly once per sync()
iteration.
It should preserve the order of tasks, and attempt each task up to
- `MAX_RUN_TASK_ATTEMPTS` times before dropping the task. If a task
succeeds, the task
+ `max_run_task_attempts` times before dropping the task. If a task
succeeds, the task
should be removed from pending_jobs and into active_workers.
"""
airflow_keys = [
@@ -705,7 +707,7 @@ class TestAwsEcsExecutor:
"""
Test API failure retries.
"""
- AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "2"
+ mock_executor.max_run_task_attempts = "2"
airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"]
airflow_commands = [_generate_mock_cmd(), _generate_mock_cmd()]
@@ -834,7 +836,7 @@ class TestAwsEcsExecutor:
@mock.patch.object(BaseExecutor, "success")
def test_failed_sync(self, success_mock, fail_mock, mock_executor):
"""Test success and failure states."""
- AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "1"
+ mock_executor.max_run_task_attempts = "1"
self._mock_sync(mock_executor, State.FAILED)
mock_executor.sync()
@@ -850,7 +852,7 @@ class TestAwsEcsExecutor:
@mock.patch.object(BaseExecutor, "fail")
def test_removed_sync(self, fail_mock, success_mock, mock_executor):
"""A removed task will be treated as a failed task."""
- AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "1"
+ mock_executor.max_run_task_attempts = "1"
self._mock_sync(mock_executor, expected_state=State.REMOVED,
set_task_state=State.REMOVED)
mock_executor.sync_running_tasks()
@@ -868,7 +870,7 @@ class TestAwsEcsExecutor:
self, _, success_mock, fail_mock, mock_airflow_key, mock_executor,
mock_cmd
):
"""Test that failure_count/attempt_number is cumulative for pending
tasks and active workers."""
- AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "5"
+ mock_executor.max_run_task_attempts = "5"
mock_executor.ecs.run_task.return_value = {
"tasks": [],
"failures": [
@@ -980,8 +982,8 @@ class TestAwsEcsExecutor:
assert len(mock_executor.active_workers.get_all_arns()) == 1
task_key = mock_executor.active_workers.arn_to_key[ARN1]
- # Call Sync 2 times with failures. The task can only fail
MAX_RUN_TASK_ATTEMPTS times.
- for check_count in range(1, int(AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS)):
+ # Call Sync 2 times with failures. The task can only fail
max_run_task_attempts times.
+ for check_count in range(1, int(mock_executor.max_run_task_attempts)):
mock_executor.sync_running_tasks()
assert mock_executor.ecs.describe_tasks.call_count == check_count
@@ -1215,7 +1217,7 @@ class TestAwsEcsExecutor:
mock_success_function.assert_called_once()
def test_update_running_tasks_failed(self, mock_executor, caplog):
- AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "1"
+ mock_executor.max_run_task_attempts = "1"
caplog.set_level(logging.WARNING)
self._add_mock_task(mock_executor, ARN1)
test_response_task_json = {
@@ -1343,14 +1345,81 @@ class TestEcsExecutorConfig:
}
with conf_vars(conf_overrides):
with pytest.raises(ValueError) as raised:
- ecs_executor_config.build_task_kwargs()
+ ecs_executor_config.build_task_kwargs(conf)
assert raised.match("At least one subnet is required to run a task.")
+ # TODO: When merged this needs updating to the actually supported version
+ @pytest.mark.skipif(
+ not hasattr(base_executor, "ExecutorConf"),
+ reason="Test requires a version of airflow which includes updates to
support multi team",
+ )
+ def test_team_config(self):
+ # Team name to be used throughout
+ team_name = "team_a"
+ # Patch environment to include two sets of configs for the ECS
executor. One that is related to a
+ # team and one that is not. The we will create two executors (one with
a team and one without) and
+ # ensure the correct configs are used.
+ config_overrides = [
+ (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CLUSTER}",
"some_cluster"),
+
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}",
"container_name"),
+
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.TASK_DEFINITION}",
"some_task_def"),
+ (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}",
"FARGATE"),
+
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.PLATFORM_VERSION}",
"LATEST"),
+
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.ASSIGN_PUBLIC_IP}", "False"),
+
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SECURITY_GROUPS}",
"sg1,sg2"),
+ (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}",
"sub1,sub2"),
+ (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.REGION_NAME}",
"us-west-1"),
+ # team Config
+
(f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CLUSTER}",
"team_a_cluster"),
+ (
+
f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}",
+ "team_a_container",
+ ),
+ (
+
f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.TASK_DEFINITION}",
+ "team_a_task_def",
+ ),
+
(f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}",
"EC2"),
+ (
+
f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SECURITY_GROUPS}",
+ "team_a_sg1,team_a_sg2",
+ ),
+ (
+
f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}",
+ "team_a_sub1,team_a_sub2",
+ ),
+
(f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.REGION_NAME}",
"us-west-2"),
+ ]
+ with patch("os.environ", {key.upper(): value for key, value in
config_overrides}):
+ team_executor = AwsEcsExecutor(team_name=team_name)
+ task_kwargs =
ecs_executor_config.build_task_kwargs(team_executor.conf)
+
+ assert task_kwargs["cluster"] == "team_a_cluster"
+ assert task_kwargs["overrides"]["containerOverrides"][0]["name"]
== "team_a_container"
+ assert task_kwargs["networkConfiguration"]["awsvpcConfiguration"]
== {
+ "subnets": ["team_a_sub1", "team_a_sub2"],
+ "securityGroups": ["team_a_sg1", "team_a_sg2"],
+ }
+ assert task_kwargs["launchType"] == "EC2"
+ assert task_kwargs["taskDefinition"] == "team_a_task_def"
+ # Now create an executor without a team and ensure the non-team
configs are used.
+ non_team_executor = AwsEcsExecutor()
+ task_kwargs =
ecs_executor_config.build_task_kwargs(non_team_executor.conf)
+ assert task_kwargs["cluster"] == "some_cluster"
+ assert task_kwargs["overrides"]["containerOverrides"][0]["name"]
== "container_name"
+ assert task_kwargs["networkConfiguration"]["awsvpcConfiguration"]
== {
+ "subnets": ["sub1", "sub2"],
+ "securityGroups": ["sg1", "sg2"],
+ "assignPublicIp": "DISABLED",
+ }
+ assert task_kwargs["launchType"] == "FARGATE"
+ assert task_kwargs["taskDefinition"] == "some_task_def"
+
@conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME):
"container-name"})
def test_config_defaults_are_applied(self, assign_subnets):
from airflow.providers.amazon.aws.executors.ecs import
ecs_executor_config
- task_kwargs =
_recursive_flatten_dict(ecs_executor_config.build_task_kwargs())
+ task_kwargs =
_recursive_flatten_dict(ecs_executor_config.build_task_kwargs(conf))
found_keys = {convert_camel_to_snake(key): key for key in
task_kwargs.keys()}
for expected_key, expected_value in CONFIG_DEFAULTS.items():
@@ -1388,12 +1457,12 @@ class TestEcsExecutorConfig:
monkeypatch.delenv(run_task_kwargs_env_key, raising=False)
from airflow.providers.amazon.aws.executors.ecs import
ecs_executor_config
- task_kwargs = ecs_executor_config.build_task_kwargs()
+ task_kwargs = ecs_executor_config.build_task_kwargs(conf)
assert task_kwargs["platformVersion"] == default_version
# Provide a new value explicitly and assert that it is applied over
the default.
monkeypatch.setenv(platform_version_env_key, first_explicit_version)
- task_kwargs = ecs_executor_config.build_task_kwargs()
+ task_kwargs = ecs_executor_config.build_task_kwargs(conf)
assert task_kwargs["platformVersion"] == first_explicit_version
# Provide a value via template and assert that it is applied over the
explicit value.
@@ -1401,12 +1470,12 @@ class TestEcsExecutorConfig:
run_task_kwargs_env_key,
json.dumps({AllEcsConfigKeys.PLATFORM_VERSION: templated_version}),
)
- task_kwargs = ecs_executor_config.build_task_kwargs()
+ task_kwargs = ecs_executor_config.build_task_kwargs(conf)
assert task_kwargs["platformVersion"] == templated_version
# Provide a new value explicitly and assert it is not applied over the
templated values.
monkeypatch.setenv(platform_version_env_key, second_explicit_version)
- task_kwargs = ecs_executor_config.build_task_kwargs()
+ task_kwargs = ecs_executor_config.build_task_kwargs(conf)
assert task_kwargs["platformVersion"] == templated_version
@mock.patch.object(EcsHook, "conn")
@@ -1428,7 +1497,7 @@ class TestEcsExecutorConfig:
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper(),
json.dumps(provided_run_task_kwargs),
)
- task_kwargs = ecs_executor_config.build_task_kwargs()
+ task_kwargs = ecs_executor_config.build_task_kwargs(conf)
assert task_kwargs["platformVersion"] == templated_version
assert task_kwargs["cluster"] == templated_cluster
@@ -1445,7 +1514,7 @@ class TestEcsExecutorConfig:
run_task_kwargs_env_key =
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper()
monkeypatch.setenv(run_task_kwargs_env_key,
json.dumps(provided_run_task_kwargs))
- task_kwargs = ecs_executor_config.build_task_kwargs()
+ task_kwargs = ecs_executor_config.build_task_kwargs(conf)
# Verify that tag names are exempt from the camel-case conversion.
assert task_kwargs["tags"] == templated_tags
@@ -1465,7 +1534,7 @@ class TestEcsExecutorConfig:
for key, value in kwargs_to_test.items():
monkeypatch.setenv(f"AIRFLOW__{CONFIG_GROUP_NAME}__{key}".upper(),
value)
- run_task_kwargs = ecs_executor_config.build_task_kwargs()
+ run_task_kwargs = ecs_executor_config.build_task_kwargs(conf)
run_task_kwargs_network_config =
run_task_kwargs["networkConfiguration"]["awsvpcConfiguration"]
for key, value in kwargs_to_test.items():
# Assert that the values are not at the root of the kwargs
@@ -1569,7 +1638,7 @@ class TestEcsExecutorConfig:
with conf_vars(conf_overrides):
from airflow.providers.amazon.aws.executors.ecs import
ecs_executor_config
- task_kwargs = ecs_executor_config.build_task_kwargs()
+ task_kwargs = ecs_executor_config.build_task_kwargs(conf)
assert "launchType" not in task_kwargs
assert task_kwargs["capacityProviderStrategy"] ==
valid_capacity_provider
@@ -1583,7 +1652,7 @@ class TestEcsExecutorConfig:
with conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE):
None}):
from airflow.providers.amazon.aws.executors.ecs import
ecs_executor_config
- task_kwargs = ecs_executor_config.build_task_kwargs()
+ task_kwargs = ecs_executor_config.build_task_kwargs(conf)
assert "launchType" not in task_kwargs
assert "capacityProviderStrategy" not in task_kwargs
mock_conn.describe_clusters.assert_called_once()
@@ -1596,7 +1665,7 @@ class TestEcsExecutorConfig:
with conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE):
None}):
from airflow.providers.amazon.aws.executors.ecs import
ecs_executor_config
- task_kwargs = ecs_executor_config.build_task_kwargs()
+ task_kwargs = ecs_executor_config.build_task_kwargs(conf)
assert task_kwargs["launchType"] == "FARGATE"
@pytest.mark.parametrize(