This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8c0e5354660 Add `EksPodTrigger` (#64187)
8c0e5354660 is described below
commit 8c0e5354660bd5a169fda8543c81cc662b42ba97
Author: Niko Oliveira <[email protected]>
AuthorDate: Tue Apr 7 00:27:21 2026 -0700
Add `EksPodTrigger` (#64187)
* Add a new EKS specific Triggerer
This replicates the paradigm used for the EksPodOperator which wraps
the KubernetesPodOperator. In a similar fashion to the operator, the
EksPodTrigger will generate credentials and then call the
KubernetesPodTrigger which it wraps.
* Don't use AirflowException
* More AirflowException --> RuntimeError
* Revert one exception change
* remove unnecessary mocking
---
.../airflow/providers/amazon/aws/operators/eks.py | 77 ++++++++++++
.../airflow/providers/amazon/aws/triggers/eks.py | 130 +++++++++++++++++++++
.../tests/unit/amazon/aws/operators/test_eks.py | 58 +++++++++
.../tests/unit/amazon/aws/triggers/test_eks.py | 128 +++++++++++++++++++-
4 files changed, 392 insertions(+), 1 deletion(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py
index f024ab056d2..cf6d147bdcf 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py
@@ -39,6 +39,7 @@ from airflow.providers.amazon.aws.triggers.eks import (
EksDeleteClusterTrigger,
EksDeleteFargateProfileTrigger,
EksDeleteNodegroupTrigger,
+ EksPodTrigger,
)
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
@@ -1119,6 +1120,82 @@ class EksPodOperator(KubernetesPodOperator):
if self.config_file:
raise AirflowException("The config_file is not an allowed
parameter for the EksPodOperator.")
+ def invoke_defer_method(self, last_log_time=None, context=None) -> None:
+ """Override to use EksPodTrigger which regenerates kubeconfig with
fresh credentials."""
+ import datetime
+
+ from airflow.providers.cncf.kubernetes.triggers.pod import
ContainerState
+ from airflow.providers.common.compat.sdk import
AirflowNotFoundException
+
+ self.convert_config_file_to_dict()
+
+ connection_extras = None
+ if self.kubernetes_conn_id:
+ try:
+ try:
+ from airflow.sdk import BaseHook
+ except ImportError:
+ from airflow.hooks.base import BaseHook # type:
ignore[attr-defined, no-redef]
+
+ conn = BaseHook.get_connection(self.kubernetes_conn_id)
+ except AirflowNotFoundException:
+ self.log.warning(
+ "Could not resolve connection extras for deferral:
connection `%s` not found. "
+ "Triggerer will try to resolve it from its own
environment.",
+ self.kubernetes_conn_id,
+ )
+ else:
+ connection_extras = conn.extra_dejson
+ self.log.info("Successfully resolved connection extras for
deferral.")
+
+ trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc)
+
+ if self.pod is None or self.pod.metadata is None:
+ raise RuntimeError("Pod must be created with metadata before
deferring")
+
+ trigger = EksPodTrigger(
+ eks_cluster_name=self.cluster_name,
+ aws_conn_id=self.aws_conn_id,
+ region=self.region,
+ pod_name=self.pod.metadata.name,
+ pod_namespace=self.pod.metadata.namespace,
+ trigger_start_time=trigger_start_time,
+ kubernetes_conn_id=self.kubernetes_conn_id,
+ connection_extras=connection_extras,
+ cluster_context=self.cluster_context,
+ config_dict=self._config_dict,
+ in_cluster=self.in_cluster,
+ poll_interval=self.poll_interval,
+ get_logs=self.get_logs,
+ startup_timeout=self.startup_timeout_seconds,
+ startup_check_interval=self.startup_check_interval_seconds,
+ schedule_timeout=self.schedule_timeout_seconds,
+ base_container_name=self.base_container_name,
+ on_finish_action=self.on_finish_action.value,
+ on_kill_action=self.on_kill_action.value,
+ termination_grace_period=self.termination_grace_period,
+ last_log_time=last_log_time,
+ logging_interval=self.logging_interval,
+ trigger_kwargs=self.trigger_kwargs,
+ )
+ container_state = trigger.define_container_state(self.pod) if self.pod
else None
+ if context and (
+ container_state == ContainerState.TERMINATED or container_state ==
ContainerState.FAILED
+ ):
+ self.log.info("Skipping deferral as pod is already in a terminal
state")
+ self.trigger_reentry(
+ context=context,
+ event={
+ "status": "success" if container_state ==
ContainerState.TERMINATED else "failed",
+ "namespace": self.pod.metadata.namespace,
+ "name": self.pod.metadata.name,
+ "last_log_time": last_log_time,
+ **(self.trigger_kwargs or {}),
+ },
+ )
+ else:
+ self.defer(trigger=trigger, method_name="trigger_reentry")
+
def execute(self, context: Context):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py
index 3ce4c7f5fb0..a386f96da1b 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import datetime
from typing import TYPE_CHECKING, Any
from botocore.exceptions import ClientError
@@ -23,10 +24,13 @@ from botocore.exceptions import ClientError
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
+from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
from airflow.providers.common.compat.sdk import AirflowException
from airflow.triggers.base import TriggerEvent
if TYPE_CHECKING:
+ from pendulum import DateTime
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
@@ -89,6 +93,132 @@ class EksCreateClusterTrigger(AwsBaseWaiterTrigger):
yield TriggerEvent({"status": "success"})
+class EksPodTrigger(KubernetesPodTrigger):
+ """
+ KubernetesPodTrigger for EKS that generates fresh kubeconfig with new
credentials.
+
+ When ``EksPodOperator`` defers, the kubeconfig stored in ``config_dict``
contains
+ an exec command that references a temporary credentials file. That file is
cleaned
+ up when the operator's context managers exit (on deferral). By the time
the trigger
+ runs — whether in a real triggerer process or inline via ``dag.test()`` —
the file
+ is gone.
+
+ This trigger solves the problem by regenerating the kubeconfig with fresh
AWS
+ credentials before executing. The temporary files are kept alive for the
entire
+ duration of the trigger's ``run()`` method.
+
+ :param eks_cluster_name: The name of the Amazon EKS Cluster.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region: Which AWS region the connection should use.
+ """
+
+ def __init__(
+ self,
+ *,
+ eks_cluster_name: str,
+ aws_conn_id: str | None = None,
+ region: str | None = None,
+ pod_name: str,
+ pod_namespace: str,
+ trigger_start_time: datetime.datetime,
+ base_container_name: str,
+ kubernetes_conn_id: str | None = None,
+ connection_extras: dict | None = None,
+ poll_interval: float = 2,
+ cluster_context: str | None = None,
+ config_dict: dict | None = None,
+ in_cluster: bool | None = None,
+ get_logs: bool = True,
+ startup_timeout: int = 120,
+ startup_check_interval: float = 5,
+ schedule_timeout: int = 120,
+ on_finish_action: str = "delete_pod",
+ on_kill_action: str = "delete_pod",
+ termination_grace_period: int | None = None,
+ last_log_time: DateTime | None = None,
+ logging_interval: int | None = None,
+ trigger_kwargs: dict | None = None,
+ ):
+ super().__init__(
+ pod_name=pod_name,
+ pod_namespace=pod_namespace,
+ trigger_start_time=trigger_start_time,
+ base_container_name=base_container_name,
+ kubernetes_conn_id=kubernetes_conn_id,
+ connection_extras=connection_extras,
+ poll_interval=poll_interval,
+ cluster_context=cluster_context,
+ config_dict=config_dict,
+ in_cluster=in_cluster,
+ get_logs=get_logs,
+ startup_timeout=startup_timeout,
+ startup_check_interval=startup_check_interval,
+ schedule_timeout=schedule_timeout,
+ on_finish_action=on_finish_action,
+ on_kill_action=on_kill_action,
+ termination_grace_period=termination_grace_period,
+ last_log_time=last_log_time,
+ logging_interval=logging_interval,
+ trigger_kwargs=trigger_kwargs,
+ )
+ self.eks_cluster_name = eks_cluster_name
+ self._aws_conn_id = aws_conn_id
+ self.region = region
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize EksPodTrigger arguments and classpath."""
+ _, kwargs = super().serialize()
+ kwargs["eks_cluster_name"] = self.eks_cluster_name
+ kwargs["aws_conn_id"] = self._aws_conn_id
+ kwargs["region"] = self.region
+ return (
+ "airflow.providers.amazon.aws.triggers.eks.EksPodTrigger",
+ kwargs,
+ )
+
+ async def run(self):
+ """Generate fresh kubeconfig, then delegate to the parent trigger."""
+ from airflow.utils import yaml
+
+ eks_hook = EksHook(
+ aws_conn_id=self._aws_conn_id,
+ region_name=self.region,
+ )
+ session = eks_hook.get_session()
+ credentials_obj = session.get_credentials()
+ if credentials_obj is None:
+ raise RuntimeError(
+ "Unable to retrieve AWS credentials for EKS trigger. "
+ "Credentials may have expired or not been configured."
+ )
+ credentials = credentials_obj.get_frozen_credentials()
+
+ # Create fresh credential and kubeconfig files. The context managers
+ # keep the temp files alive for the entire duration of the trigger.
+ with eks_hook._secure_credential_context(
+ credentials.access_key, credentials.secret_key, credentials.token
+ ) as credentials_file:
+ with eks_hook.generate_config_file(
+ eks_cluster_name=self.eks_cluster_name,
+ pod_namespace=self.pod_namespace,
+ credentials_file=credentials_file,
+ ) as config_file_path:
+ # Reading a small local temp file created by the context
manager above.
+ # Blocking I/O is acceptable here as the file is tiny and
local.
+ from pathlib import Path
+
+ self.config_dict = yaml.safe_load(
+ Path(config_file_path).read_text() # noqa: ASYNC240
+ )
+
+ # Invalidate any previously cached hook so the new config_dict
+ # is picked up when the parent creates the AsyncKubernetesHook.
+ self.__dict__.pop("hook", None)
+
+ async for event in super().run():
+ yield event
+
+
class EksDeleteClusterTrigger(AwsBaseWaiterTrigger):
"""
Trigger for EksDeleteClusterOperator.
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
index 2e1c850ba94..5003e56d76a 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
@@ -39,6 +39,7 @@ from airflow.providers.amazon.aws.triggers.eks import (
EksCreateFargateProfileTrigger,
EksCreateNodegroupTrigger,
EksDeleteFargateProfileTrigger,
+ EksPodTrigger,
)
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
from airflow.providers.common.compat.sdk import TaskDeferred
@@ -1116,3 +1117,60 @@ class TestEksPodOperator:
# Verify super()._refresh_cached_properties() was NOT called since we
raised
mock_super_refresh.assert_not_called()
+
+ @mock.patch(
+
"airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.convert_config_file_to_dict"
+ )
+ def test_invoke_defer_method_uses_eks_trigger(self, mock_convert_config):
+ """invoke_defer_method should create an EksPodTrigger and call
defer."""
+ op = EksPodOperator(
+ task_id="run_pod",
+ pod_name="run_pod",
+ cluster_name=CLUSTER_NAME,
+ image="amazon/aws-cli:latest",
+ cmds=["sh", "-c", "ls"],
+ labels={"demo": "hello_world"},
+ get_logs=True,
+ on_finish_action="delete_pod",
+ )
+
+ # Set up pod metadata as it would be after execute creates the pod
+ mock_pod = mock.MagicMock()
+ mock_pod.metadata.name = "test-pod-abc123"
+ mock_pod.metadata.namespace = "default"
+ # Set status to None so define_container_state returns UNDEFINED (not
terminal)
+ mock_pod.status = None
+ op.pod = mock_pod
+
+ with pytest.raises(TaskDeferred) as exc:
+ op.invoke_defer_method()
+
+ # Verify the trigger is an EksPodTrigger (not the base
KubernetesPodTrigger)
+ trigger = exc.value.trigger
+ assert isinstance(trigger, EksPodTrigger)
+ assert trigger.eks_cluster_name == CLUSTER_NAME
+ assert trigger._aws_conn_id == "aws_default"
+ assert trigger.pod_name == "test-pod-abc123"
+ assert trigger.pod_namespace == "default"
+
+ @mock.patch(
+
"airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.convert_config_file_to_dict"
+ )
+ def test_invoke_defer_method_raises_when_pod_is_none(self,
mock_convert_config):
+ """invoke_defer_method should raise RuntimeError when pod is None."""
+ op = EksPodOperator(
+ task_id="run_pod",
+ pod_name="run_pod",
+ cluster_name=CLUSTER_NAME,
+ image="amazon/aws-cli:latest",
+ cmds=["sh", "-c", "ls"],
+ labels={"demo": "hello_world"},
+ get_logs=True,
+ on_finish_action="delete_pod",
+ )
+
+ # pod is None by default
+ op.pod = None
+
+ with pytest.raises(RuntimeError, match="Pod must be created with
metadata before deferring"):
+ op.invoke_defer_method()
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py
b/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py
index e8fe934be8f..fcdd712cc56 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py
@@ -16,7 +16,8 @@
# under the License.
from __future__ import annotations
-from unittest.mock import AsyncMock, Mock, call, patch
+import datetime
+from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
import pytest
from botocore.exceptions import ClientError
@@ -24,6 +25,7 @@ from botocore.exceptions import ClientError
from airflow.providers.amazon.aws.triggers.eks import (
EksCreateClusterTrigger,
EksDeleteClusterTrigger,
+ EksPodTrigger,
)
from airflow.providers.common.compat.sdk import AirflowException
from airflow.triggers.base import TriggerEvent
@@ -318,3 +320,127 @@ class
TestEksDeleteClusterTriggerDeleteNodegroupsAndFargateProfiles(TestEksTrigg
self.trigger.log.info.assert_called_once_with(
"No Fargate profiles associated with cluster %s", CLUSTER_NAME
)
+
+
+class TestEksPodTrigger:
+ """Tests for EksPodTrigger."""
+
+ TRIGGER_START_TIME = datetime.datetime(2026, 1, 1,
tzinfo=datetime.timezone.utc)
+
+ def _create_trigger(self, **overrides):
+ """Create an EksPodTrigger with sensible defaults."""
+ defaults = {
+ "eks_cluster_name": CLUSTER_NAME,
+ "aws_conn_id": AWS_CONN_ID,
+ "region": REGION_NAME,
+ "pod_name": "test-pod",
+ "pod_namespace": "default",
+ "trigger_start_time": self.TRIGGER_START_TIME,
+ "base_container_name": "base",
+ "config_dict": {"old": "stale-config"},
+ }
+ defaults.update(overrides)
+ return EksPodTrigger(**defaults)
+
+ def test_serialize_includes_eks_fields(self):
+ """serialize() should include eks_cluster_name, aws_conn_id, and
region."""
+ trigger = self._create_trigger()
+ classpath, kwargs = trigger.serialize()
+
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.eks.EksPodTrigger"
+ assert kwargs["eks_cluster_name"] == CLUSTER_NAME
+ assert kwargs["aws_conn_id"] == AWS_CONN_ID
+ assert kwargs["region"] == REGION_NAME
+ # Also verify parent fields are present
+ assert kwargs["pod_name"] == "test-pod"
+ assert kwargs["pod_namespace"] == "default"
+
+ def test_serialize_roundtrip(self):
+ """A trigger created from serialized kwargs should serialize
identically."""
+ trigger = self._create_trigger()
+ classpath, kwargs = trigger.serialize()
+
+ trigger2 = EksPodTrigger(**kwargs)
+ classpath2, kwargs2 = trigger2.serialize()
+
+ assert classpath == classpath2
+ assert kwargs == kwargs2
+
+ @pytest.mark.asyncio
+
@patch("airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger.run")
+
@patch("airflow.providers.amazon.aws.hooks.eks.EksHook.generate_config_file")
+
@patch("airflow.providers.amazon.aws.hooks.eks.EksHook._secure_credential_context")
+ @patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_session")
+ @patch("airflow.providers.amazon.aws.hooks.eks.EksHook.__init__",
return_value=None)
+ async def test_run_generates_fresh_kubeconfig(
+ self,
+ mock_eks_hook_init,
+ mock_get_session,
+ mock_secure_credential_context,
+ mock_generate_config_file,
+ mock_parent_run,
+ ):
+ """run() should get fresh credentials, generate kubeconfig, and
delegate to parent."""
+ # Set up credential mocks
+ mock_session = MagicMock()
+ mock_credentials = MagicMock()
+ mock_frozen = MagicMock()
+ mock_frozen.access_key = "AKIATEST"
+ mock_frozen.secret_key = "secret123"
+ mock_frozen.token = "token456"
+ mock_get_session.return_value = mock_session
+ mock_session.get_credentials.return_value = mock_credentials
+ mock_credentials.get_frozen_credentials.return_value = mock_frozen
+
+ # Set up context manager mocks
+ mock_secure_credential_context.return_value.__enter__.return_value =
"/tmp/test.aws_creds"
+ mock_generate_config_file.return_value.__enter__.return_value =
"/tmp/test_kubeconfig"
+
+ # Mock reading the kubeconfig file
+ with patch("pathlib.Path.read_text", return_value="apiVersion:
v1\nkind: Config\nclusters: []"):
+
+ async def mock_gen():
+ yield TriggerEvent({"status": "success"})
+
+ mock_parent_run.return_value = mock_gen()
+
+ trigger = self._create_trigger()
+ events = []
+ async for event in trigger.run():
+ events.append(event)
+
+ assert len(events) == 1
+ assert events[0] == TriggerEvent({"status": "success"})
+
+ # Verify credentials were fetched
+ mock_eks_hook_init.assert_called_once_with(aws_conn_id=AWS_CONN_ID,
region_name=REGION_NAME)
+ mock_get_session.assert_called_once()
+ mock_session.get_credentials.assert_called_once()
+ mock_credentials.get_frozen_credentials.assert_called_once()
+
+ # Verify credential context and config generation
+ mock_secure_credential_context.assert_called_once_with("AKIATEST",
"secret123", "token456")
+ mock_generate_config_file.assert_called_once_with(
+ eks_cluster_name=CLUSTER_NAME,
+ pod_namespace="default",
+ credentials_file="/tmp/test.aws_creds",
+ )
+
+ @pytest.mark.asyncio
+ @patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_session")
+ @patch("airflow.providers.amazon.aws.hooks.eks.EksHook.__init__",
return_value=None)
+ async def test_run_raises_when_credentials_unavailable(
+ self,
+ mock_eks_hook_init,
+ mock_get_session,
+ ):
+ """run() should raise RuntimeError when credentials cannot be
retrieved."""
+ mock_session = MagicMock()
+ mock_get_session.return_value = mock_session
+ mock_session.get_credentials.return_value = None
+
+ trigger = self._create_trigger()
+
+ with pytest.raises(RuntimeError, match="Unable to retrieve AWS
credentials"):
+ async for _ in trigger.run():
+ pass