This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8c0e5354660 Add `EksPodTrigger` (#64187)
8c0e5354660 is described below

commit 8c0e5354660bd5a169fda8543c81cc662b42ba97
Author: Niko Oliveira <[email protected]>
AuthorDate: Tue Apr 7 00:27:21 2026 -0700

    Add `EksPodTrigger` (#64187)
    
    * Add a new EKS specific Triggerer
    
    This replicates the paradigm used for the EksPodOperator which wraps
    the KubernetesPodOperator. In a similar fashion to the operator, the
    EksPodTrigger will generate credentials and then call the
    KubernetesPodTrigger which it wraps.
    
    * Don't use AirflowException
    
    * More AirflowException --> RuntimeError
    
    * Revert one exception change
    
    * remove unnecessary mocking
---
 .../airflow/providers/amazon/aws/operators/eks.py  |  77 ++++++++++++
 .../airflow/providers/amazon/aws/triggers/eks.py   | 130 +++++++++++++++++++++
 .../tests/unit/amazon/aws/operators/test_eks.py    |  58 +++++++++
 .../tests/unit/amazon/aws/triggers/test_eks.py     | 128 +++++++++++++++++++-
 4 files changed, 392 insertions(+), 1 deletion(-)

diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py
index f024ab056d2..cf6d147bdcf 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py
@@ -39,6 +39,7 @@ from airflow.providers.amazon.aws.triggers.eks import (
     EksDeleteClusterTrigger,
     EksDeleteFargateProfileTrigger,
     EksDeleteNodegroupTrigger,
+    EksPodTrigger,
 )
 from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
@@ -1119,6 +1120,82 @@ class EksPodOperator(KubernetesPodOperator):
         if self.config_file:
             raise AirflowException("The config_file is not an allowed 
parameter for the EksPodOperator.")
 
+    def invoke_defer_method(self, last_log_time=None, context=None) -> None:
+        """Override to use EksPodTrigger which regenerates kubeconfig with 
fresh credentials."""
+        import datetime
+
+        from airflow.providers.cncf.kubernetes.triggers.pod import 
ContainerState
+        from airflow.providers.common.compat.sdk import 
AirflowNotFoundException
+
+        self.convert_config_file_to_dict()
+
+        connection_extras = None
+        if self.kubernetes_conn_id:
+            try:
+                try:
+                    from airflow.sdk import BaseHook
+                except ImportError:
+                    from airflow.hooks.base import BaseHook  # type: 
ignore[attr-defined, no-redef]
+
+                conn = BaseHook.get_connection(self.kubernetes_conn_id)
+            except AirflowNotFoundException:
+                self.log.warning(
+                    "Could not resolve connection extras for deferral: 
connection `%s` not found. "
+                    "Triggerer will try to resolve it from its own 
environment.",
+                    self.kubernetes_conn_id,
+                )
+            else:
+                connection_extras = conn.extra_dejson
+                self.log.info("Successfully resolved connection extras for 
deferral.")
+
+        trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc)
+
+        if self.pod is None or self.pod.metadata is None:
+            raise RuntimeError("Pod must be created with metadata before 
deferring")
+
+        trigger = EksPodTrigger(
+            eks_cluster_name=self.cluster_name,
+            aws_conn_id=self.aws_conn_id,
+            region=self.region,
+            pod_name=self.pod.metadata.name,
+            pod_namespace=self.pod.metadata.namespace,
+            trigger_start_time=trigger_start_time,
+            kubernetes_conn_id=self.kubernetes_conn_id,
+            connection_extras=connection_extras,
+            cluster_context=self.cluster_context,
+            config_dict=self._config_dict,
+            in_cluster=self.in_cluster,
+            poll_interval=self.poll_interval,
+            get_logs=self.get_logs,
+            startup_timeout=self.startup_timeout_seconds,
+            startup_check_interval=self.startup_check_interval_seconds,
+            schedule_timeout=self.schedule_timeout_seconds,
+            base_container_name=self.base_container_name,
+            on_finish_action=self.on_finish_action.value,
+            on_kill_action=self.on_kill_action.value,
+            termination_grace_period=self.termination_grace_period,
+            last_log_time=last_log_time,
+            logging_interval=self.logging_interval,
+            trigger_kwargs=self.trigger_kwargs,
+        )
+        container_state = trigger.define_container_state(self.pod) if self.pod 
else None
+        if context and (
+            container_state == ContainerState.TERMINATED or container_state == 
ContainerState.FAILED
+        ):
+            self.log.info("Skipping deferral as pod is already in a terminal 
state")
+            self.trigger_reentry(
+                context=context,
+                event={
+                    "status": "success" if container_state == 
ContainerState.TERMINATED else "failed",
+                    "namespace": self.pod.metadata.namespace,
+                    "name": self.pod.metadata.name,
+                    "last_log_time": last_log_time,
+                    **(self.trigger_kwargs or {}),
+                },
+            )
+        else:
+            self.defer(trigger=trigger, method_name="trigger_reentry")
+
     def execute(self, context: Context):
         eks_hook = EksHook(
             aws_conn_id=self.aws_conn_id,
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py
index 3ce4c7f5fb0..a386f96da1b 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/eks.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import datetime
 from typing import TYPE_CHECKING, Any
 
 from botocore.exceptions import ClientError
@@ -23,10 +24,13 @@ from botocore.exceptions import ClientError
 from airflow.providers.amazon.aws.hooks.eks import EksHook
 from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
 from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
+from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
 from airflow.providers.common.compat.sdk import AirflowException
 from airflow.triggers.base import TriggerEvent
 
 if TYPE_CHECKING:
+    from pendulum import DateTime
+
     from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
 
 
@@ -89,6 +93,132 @@ class EksCreateClusterTrigger(AwsBaseWaiterTrigger):
                 yield TriggerEvent({"status": "success"})
 
 
+class EksPodTrigger(KubernetesPodTrigger):
+    """
+    KubernetesPodTrigger for EKS that generates fresh kubeconfig with new 
credentials.
+
+    When ``EksPodOperator`` defers, the kubeconfig stored in ``config_dict`` 
contains
+    an exec command that references a temporary credentials file. That file is 
cleaned
+    up when the operator's context managers exit (on deferral). By the time 
the trigger
+    runs — whether in a real triggerer process or inline via ``dag.test()`` — 
the file
+    is gone.
+
+    This trigger solves the problem by regenerating the kubeconfig with fresh 
AWS
+    credentials before executing. The temporary files are kept alive for the 
entire
+    duration of the trigger's ``run()`` method.
+
+    :param eks_cluster_name: The name of the Amazon EKS Cluster.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region: Which AWS region the connection should use.
+    """
+
+    def __init__(
+        self,
+        *,
+        eks_cluster_name: str,
+        aws_conn_id: str | None = None,
+        region: str | None = None,
+        pod_name: str,
+        pod_namespace: str,
+        trigger_start_time: datetime.datetime,
+        base_container_name: str,
+        kubernetes_conn_id: str | None = None,
+        connection_extras: dict | None = None,
+        poll_interval: float = 2,
+        cluster_context: str | None = None,
+        config_dict: dict | None = None,
+        in_cluster: bool | None = None,
+        get_logs: bool = True,
+        startup_timeout: int = 120,
+        startup_check_interval: float = 5,
+        schedule_timeout: int = 120,
+        on_finish_action: str = "delete_pod",
+        on_kill_action: str = "delete_pod",
+        termination_grace_period: int | None = None,
+        last_log_time: DateTime | None = None,
+        logging_interval: int | None = None,
+        trigger_kwargs: dict | None = None,
+    ):
+        super().__init__(
+            pod_name=pod_name,
+            pod_namespace=pod_namespace,
+            trigger_start_time=trigger_start_time,
+            base_container_name=base_container_name,
+            kubernetes_conn_id=kubernetes_conn_id,
+            connection_extras=connection_extras,
+            poll_interval=poll_interval,
+            cluster_context=cluster_context,
+            config_dict=config_dict,
+            in_cluster=in_cluster,
+            get_logs=get_logs,
+            startup_timeout=startup_timeout,
+            startup_check_interval=startup_check_interval,
+            schedule_timeout=schedule_timeout,
+            on_finish_action=on_finish_action,
+            on_kill_action=on_kill_action,
+            termination_grace_period=termination_grace_period,
+            last_log_time=last_log_time,
+            logging_interval=logging_interval,
+            trigger_kwargs=trigger_kwargs,
+        )
+        self.eks_cluster_name = eks_cluster_name
+        self._aws_conn_id = aws_conn_id
+        self.region = region
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize EksPodTrigger arguments and classpath."""
+        _, kwargs = super().serialize()
+        kwargs["eks_cluster_name"] = self.eks_cluster_name
+        kwargs["aws_conn_id"] = self._aws_conn_id
+        kwargs["region"] = self.region
+        return (
+            "airflow.providers.amazon.aws.triggers.eks.EksPodTrigger",
+            kwargs,
+        )
+
+    async def run(self):
+        """Generate fresh kubeconfig, then delegate to the parent trigger."""
+        from airflow.utils import yaml
+
+        eks_hook = EksHook(
+            aws_conn_id=self._aws_conn_id,
+            region_name=self.region,
+        )
+        session = eks_hook.get_session()
+        credentials_obj = session.get_credentials()
+        if credentials_obj is None:
+            raise RuntimeError(
+                "Unable to retrieve AWS credentials for EKS trigger. "
+                "Credentials may have expired or not been configured."
+            )
+        credentials = credentials_obj.get_frozen_credentials()
+
+        # Create fresh credential and kubeconfig files.  The context managers
+        # keep the temp files alive for the entire duration of the trigger.
+        with eks_hook._secure_credential_context(
+            credentials.access_key, credentials.secret_key, credentials.token
+        ) as credentials_file:
+            with eks_hook.generate_config_file(
+                eks_cluster_name=self.eks_cluster_name,
+                pod_namespace=self.pod_namespace,
+                credentials_file=credentials_file,
+            ) as config_file_path:
+                # Reading a small local temp file created by the context 
manager above.
+                # Blocking I/O is acceptable here as the file is tiny and 
local.
+                from pathlib import Path
+
+                self.config_dict = yaml.safe_load(
+                    Path(config_file_path).read_text()  # noqa: ASYNC240
+                )
+
+                # Invalidate any previously cached hook so the new config_dict
+                # is picked up when the parent creates the AsyncKubernetesHook.
+                self.__dict__.pop("hook", None)
+
+                async for event in super().run():
+                    yield event
+
+
 class EksDeleteClusterTrigger(AwsBaseWaiterTrigger):
     """
     Trigger for EksDeleteClusterOperator.
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
index 2e1c850ba94..5003e56d76a 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py
@@ -39,6 +39,7 @@ from airflow.providers.amazon.aws.triggers.eks import (
     EksCreateFargateProfileTrigger,
     EksCreateNodegroupTrigger,
     EksDeleteFargateProfileTrigger,
+    EksPodTrigger,
 )
 from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
 from airflow.providers.common.compat.sdk import TaskDeferred
@@ -1116,3 +1117,60 @@ class TestEksPodOperator:
 
         # Verify super()._refresh_cached_properties() was NOT called since we 
raised
         mock_super_refresh.assert_not_called()
+
+    @mock.patch(
+        
"airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.convert_config_file_to_dict"
+    )
+    def test_invoke_defer_method_uses_eks_trigger(self, mock_convert_config):
+        """invoke_defer_method should create an EksPodTrigger and call 
defer."""
+        op = EksPodOperator(
+            task_id="run_pod",
+            pod_name="run_pod",
+            cluster_name=CLUSTER_NAME,
+            image="amazon/aws-cli:latest",
+            cmds=["sh", "-c", "ls"],
+            labels={"demo": "hello_world"},
+            get_logs=True,
+            on_finish_action="delete_pod",
+        )
+
+        # Set up pod metadata as it would be after execute creates the pod
+        mock_pod = mock.MagicMock()
+        mock_pod.metadata.name = "test-pod-abc123"
+        mock_pod.metadata.namespace = "default"
+        # Set status to None so define_container_state returns UNDEFINED (not 
terminal)
+        mock_pod.status = None
+        op.pod = mock_pod
+
+        with pytest.raises(TaskDeferred) as exc:
+            op.invoke_defer_method()
+
+        # Verify the trigger is an EksPodTrigger (not the base 
KubernetesPodTrigger)
+        trigger = exc.value.trigger
+        assert isinstance(trigger, EksPodTrigger)
+        assert trigger.eks_cluster_name == CLUSTER_NAME
+        assert trigger._aws_conn_id == "aws_default"
+        assert trigger.pod_name == "test-pod-abc123"
+        assert trigger.pod_namespace == "default"
+
+    @mock.patch(
+        
"airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.convert_config_file_to_dict"
+    )
+    def test_invoke_defer_method_raises_when_pod_is_none(self, 
mock_convert_config):
+        """invoke_defer_method should raise RuntimeError when pod is None."""
+        op = EksPodOperator(
+            task_id="run_pod",
+            pod_name="run_pod",
+            cluster_name=CLUSTER_NAME,
+            image="amazon/aws-cli:latest",
+            cmds=["sh", "-c", "ls"],
+            labels={"demo": "hello_world"},
+            get_logs=True,
+            on_finish_action="delete_pod",
+        )
+
+        # pod is None by default
+        op.pod = None
+
+        with pytest.raises(RuntimeError, match="Pod must be created with 
metadata before deferring"):
+            op.invoke_defer_method()
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py
index e8fe934be8f..fcdd712cc56 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_eks.py
@@ -16,7 +16,8 @@
 # under the License.
 from __future__ import annotations
 
-from unittest.mock import AsyncMock, Mock, call, patch
+import datetime
+from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
 
 import pytest
 from botocore.exceptions import ClientError
@@ -24,6 +25,7 @@ from botocore.exceptions import ClientError
 from airflow.providers.amazon.aws.triggers.eks import (
     EksCreateClusterTrigger,
     EksDeleteClusterTrigger,
+    EksPodTrigger,
 )
 from airflow.providers.common.compat.sdk import AirflowException
 from airflow.triggers.base import TriggerEvent
@@ -318,3 +320,127 @@ class 
TestEksDeleteClusterTriggerDeleteNodegroupsAndFargateProfiles(TestEksTrigg
         self.trigger.log.info.assert_called_once_with(
             "No Fargate profiles associated with cluster %s", CLUSTER_NAME
         )
+
+
+class TestEksPodTrigger:
+    """Tests for EksPodTrigger."""
+
+    TRIGGER_START_TIME = datetime.datetime(2026, 1, 1, 
tzinfo=datetime.timezone.utc)
+
+    def _create_trigger(self, **overrides):
+        """Create an EksPodTrigger with sensible defaults."""
+        defaults = {
+            "eks_cluster_name": CLUSTER_NAME,
+            "aws_conn_id": AWS_CONN_ID,
+            "region": REGION_NAME,
+            "pod_name": "test-pod",
+            "pod_namespace": "default",
+            "trigger_start_time": self.TRIGGER_START_TIME,
+            "base_container_name": "base",
+            "config_dict": {"old": "stale-config"},
+        }
+        defaults.update(overrides)
+        return EksPodTrigger(**defaults)
+
+    def test_serialize_includes_eks_fields(self):
+        """serialize() should include eks_cluster_name, aws_conn_id, and 
region."""
+        trigger = self._create_trigger()
+        classpath, kwargs = trigger.serialize()
+
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.eks.EksPodTrigger"
+        assert kwargs["eks_cluster_name"] == CLUSTER_NAME
+        assert kwargs["aws_conn_id"] == AWS_CONN_ID
+        assert kwargs["region"] == REGION_NAME
+        # Also verify parent fields are present
+        assert kwargs["pod_name"] == "test-pod"
+        assert kwargs["pod_namespace"] == "default"
+
+    def test_serialize_roundtrip(self):
+        """A trigger created from serialized kwargs should serialize 
identically."""
+        trigger = self._create_trigger()
+        classpath, kwargs = trigger.serialize()
+
+        trigger2 = EksPodTrigger(**kwargs)
+        classpath2, kwargs2 = trigger2.serialize()
+
+        assert classpath == classpath2
+        assert kwargs == kwargs2
+
+    @pytest.mark.asyncio
+    
@patch("airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger.run")
+    
@patch("airflow.providers.amazon.aws.hooks.eks.EksHook.generate_config_file")
+    
@patch("airflow.providers.amazon.aws.hooks.eks.EksHook._secure_credential_context")
+    @patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_session")
+    @patch("airflow.providers.amazon.aws.hooks.eks.EksHook.__init__", 
return_value=None)
+    async def test_run_generates_fresh_kubeconfig(
+        self,
+        mock_eks_hook_init,
+        mock_get_session,
+        mock_secure_credential_context,
+        mock_generate_config_file,
+        mock_parent_run,
+    ):
+        """run() should get fresh credentials, generate kubeconfig, and 
delegate to parent."""
+        # Set up credential mocks
+        mock_session = MagicMock()
+        mock_credentials = MagicMock()
+        mock_frozen = MagicMock()
+        mock_frozen.access_key = "AKIATEST"
+        mock_frozen.secret_key = "secret123"
+        mock_frozen.token = "token456"
+        mock_get_session.return_value = mock_session
+        mock_session.get_credentials.return_value = mock_credentials
+        mock_credentials.get_frozen_credentials.return_value = mock_frozen
+
+        # Set up context manager mocks
+        mock_secure_credential_context.return_value.__enter__.return_value = 
"/tmp/test.aws_creds"
+        mock_generate_config_file.return_value.__enter__.return_value = 
"/tmp/test_kubeconfig"
+
+        # Mock reading the kubeconfig file
+        with patch("pathlib.Path.read_text", return_value="apiVersion: 
v1\nkind: Config\nclusters: []"):
+
+            async def mock_gen():
+                yield TriggerEvent({"status": "success"})
+
+            mock_parent_run.return_value = mock_gen()
+
+            trigger = self._create_trigger()
+            events = []
+            async for event in trigger.run():
+                events.append(event)
+
+        assert len(events) == 1
+        assert events[0] == TriggerEvent({"status": "success"})
+
+        # Verify credentials were fetched
+        mock_eks_hook_init.assert_called_once_with(aws_conn_id=AWS_CONN_ID, 
region_name=REGION_NAME)
+        mock_get_session.assert_called_once()
+        mock_session.get_credentials.assert_called_once()
+        mock_credentials.get_frozen_credentials.assert_called_once()
+
+        # Verify credential context and config generation
+        mock_secure_credential_context.assert_called_once_with("AKIATEST", 
"secret123", "token456")
+        mock_generate_config_file.assert_called_once_with(
+            eks_cluster_name=CLUSTER_NAME,
+            pod_namespace="default",
+            credentials_file="/tmp/test.aws_creds",
+        )
+
+    @pytest.mark.asyncio
+    @patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_session")
+    @patch("airflow.providers.amazon.aws.hooks.eks.EksHook.__init__", 
return_value=None)
+    async def test_run_raises_when_credentials_unavailable(
+        self,
+        mock_eks_hook_init,
+        mock_get_session,
+    ):
+        """run() should raise RuntimeError when credentials cannot be 
retrieved."""
+        mock_session = MagicMock()
+        mock_get_session.return_value = mock_session
+        mock_session.get_credentials.return_value = None
+
+        trigger = self._create_trigger()
+
+        with pytest.raises(RuntimeError, match="Unable to retrieve AWS 
credentials"):
+            async for _ in trigger.run():
+                pass

Reply via email to