This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 56e083f1455 AIP-72: Swap KubernetesExecutor to use taskSDK for 
execution (#46860)
56e083f1455 is described below

commit 56e083f1455a83dcd8204e9746cfb2fad18f06da
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Feb 20 12:29:26 2025 +0530

    AIP-72: Swap KubernetesExecutor to use taskSDK for execution (#46860)
---
 kubernetes_tests/test_kubernetes_executor.py       |  6 ++
 .../kubernetes/executors/kubernetes_executor.py    | 33 +++++++-
 .../executors/kubernetes_executor_utils.py         | 21 ++++-
 .../providers/cncf/kubernetes/pod_generator.py     | 60 ++++++++++++---
 .../unit/cncf/kubernetes/test_pod_generator.py     | 62 +++++++++++++++
 .../airflow/sdk/execution_time/execute_workload.py | 89 ++++++++++++++++++++++
 .../src/airflow/sdk/execution_time/task_runner.py  |  1 -
 7 files changed, 256 insertions(+), 16 deletions(-)

diff --git a/kubernetes_tests/test_kubernetes_executor.py 
b/kubernetes_tests/test_kubernetes_executor.py
index 622a4daaa0d..92e58d98118 100644
--- a/kubernetes_tests/test_kubernetes_executor.py
+++ b/kubernetes_tests/test_kubernetes_executor.py
@@ -26,6 +26,9 @@ from kubernetes_tests.test_base import (
 
 @pytest.mark.skipif(EXECUTOR != "KubernetesExecutor", reason="Only runs on 
KubernetesExecutor")
 class TestKubernetesExecutor(BaseK8STest):
+    @pytest.mark.skip(
+        reason="TODO: AIP-72 Porting over executor_config not yet done. Remove 
once #46892 is handled"
+    )
     @pytest.mark.execution_timeout(300)
     def test_integration_run_dag(self):
         dag_id = "example_kubernetes_executor"
@@ -51,6 +54,9 @@ class TestKubernetesExecutor(BaseK8STest):
         )
 
     @pytest.mark.execution_timeout(300)
+    @pytest.mark.skip(
+        reason="TODO: AIP-72 Porting over executor_config not yet done. Remove 
once #46892 is handled"
+    )
     def test_integration_run_dag_with_scheduler_failure(self):
         dag_id = "example_kubernetes_executor"
 
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
index 845c91e0db3..3f6f49070ae 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
@@ -39,6 +39,8 @@ from typing import TYPE_CHECKING, Any
 from deprecated import deprecated
 from sqlalchemy import select
 
+from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
+from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS
 from kubernetes.dynamic import DynamicClient
 
 try:
@@ -78,6 +80,7 @@ if TYPE_CHECKING:
 
     from sqlalchemy.orm import Session
 
+    from airflow.executors import workloads
     from airflow.executors.base_executor import CommandType
     from airflow.models.taskinstance import TaskInstance
     from airflow.models.taskinstancekey import TaskInstanceKey
@@ -136,6 +139,11 @@ class KubernetesExecutor(BaseExecutor):
     RUNNING_POD_LOG_LINES = 100
     supports_ad_hoc_ti_run: bool = True
 
+    if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
+        # In the v3 path, we store workloads, not commands as strings.
+        # TODO: TaskSDK: move this type change into BaseExecutor
+        queued_tasks: dict[TaskInstanceKey, workloads.All]  # type: 
ignore[assignment]
+
     def __init__(self):
         self.kube_config = KubeConfig()
         self._manager = multiprocessing.Manager()
@@ -250,8 +258,6 @@ class KubernetesExecutor(BaseExecutor):
         else:
             self.log.info("Add task %s with command %s", key, command)
 
-        from airflow.providers.cncf.kubernetes.pod_generator import 
PodGenerator
-
         try:
             kube_executor_config = PodGenerator.from_obj(executor_config)
         except Exception:
@@ -269,6 +275,29 @@ class KubernetesExecutor(BaseExecutor):
         # try and remove it from the QUEUED state while we process it
         self.last_handled[key] = time.time()
 
+    def queue_workload(self, workload: workloads.All, session: Session | None) 
-> None:
+        from airflow.executors import workloads
+
+        if not isinstance(workload, workloads.ExecuteTask):
+            raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(workload)}")
+        ti = workload.ti
+        self.queued_tasks[ti.key] = workload
+
+    def _process_workloads(self, workloads: list[workloads.All]) -> None:
+        # Airflow V3 version
+        for w in workloads:
+            # TODO: AIP-72 handle populating tokens once 
https://github.com/apache/airflow/issues/45107 is handled.
+            command = [w]
+            key = w.ti.key  # type: ignore[union-attr]
+            queue = w.ti.queue  # type: ignore[union-attr]
+
+            # TODO: will be handled by 
https://github.com/apache/airflow/issues/46892
+            executor_config = {}  # type: ignore[var-annotated]
+
+            del self.queued_tasks[key]
+            self.execute_async(key=key, command=command, queue=queue, 
executor_config=executor_config)  # type: ignore[arg-type]
+            self.running.add(key)
+
     def sync(self) -> None:
         """Synchronize task state."""
         if TYPE_CHECKING:
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
index 1b7917502f0..dda6b6e6b10 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
@@ -388,8 +388,24 @@ class AirflowKubernetesScheduler(LoggingMixin):
         key, command, kube_executor_config, pod_template_file = next_job
 
         dag_id, task_id, run_id, try_number, map_index = key
-
-        if command[0:3] != ["airflow", "tasks", "run"]:
+        ser_input = ""
+        if len(command) == 1:
+            from airflow.executors.workloads import ExecuteTask
+
+            if isinstance(command[0], ExecuteTask):
+                workload = command[0]
+                ser_input = workload.model_dump_json()
+                command = [
+                    "python",
+                    "-m",
+                    "airflow.sdk.execution_time.execute_workload",
+                    "/tmp/execute/input.json",
+                ]
+            else:
+                raise ValueError(
+                    f"KubernetesExecutor doesn't know how to handle workload 
of type: {type(command[0])}"
+                )
+        elif command[0:3] != ["airflow", "tasks", "run"]:
             raise ValueError('The command must start with ["airflow", "tasks", 
"run"].')
 
         base_worker_pod = get_base_pod_from_template(pod_template_file, 
self.kube_config)
@@ -411,6 +427,7 @@ class AirflowKubernetesScheduler(LoggingMixin):
             date=None,
             run_id=run_id,
             args=list(command),
+            content_json_for_volume=ser_input,
             pod_override_object=kube_executor_config,
             base_worker_pod=base_worker_pod,
             with_mutation_hook=True,
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py
index 5fab194963e..e1f599a40a8 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py
@@ -47,7 +47,7 @@ from 
airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
 from airflow.utils import yaml
 from airflow.utils.hashlib_wrapper import md5
 from airflow.version import version as airflow_version
-from kubernetes.client import models as k8s
+from kubernetes.client import V1EmptyDirVolumeSource, V1Volume, V1VolumeMount, 
models as k8s
 from kubernetes.client.api_client import ApiClient
 
 if TYPE_CHECKING:
@@ -288,6 +288,7 @@ class PodGenerator:
         scheduler_job_id: str,
         run_id: str | None = None,
         map_index: int = -1,
+        content_json_for_volume: str = "",
         *,
         with_mutation_hook: bool = False,
     ) -> k8s.V1Pod:
@@ -326,6 +327,14 @@ class PodGenerator:
         if run_id:
             annotations["run_id"] = run_id
 
+        main_container = k8s.V1Container(
+            name="base",
+            args=args,
+            image=image,
+            env=[
+                k8s.V1EnvVar(name="AIRFLOW_IS_K8S_EXECUTOR_POD", value="True"),
+            ],
+        )
         dynamic_pod = k8s.V1Pod(
             metadata=k8s.V1ObjectMeta(
                 namespace=namespace,
@@ -341,18 +350,47 @@ class PodGenerator:
                     run_id=run_id,
                 ),
             ),
-            spec=k8s.V1PodSpec(
-                containers=[
-                    k8s.V1Container(
-                        name="base",
-                        args=args,
-                        image=image,
-                        env=[k8s.V1EnvVar(name="AIRFLOW_IS_K8S_EXECUTOR_POD", 
value="True")],
-                    )
-                ]
-            ),
         )
 
+        podspec = k8s.V1PodSpec(
+            containers=[main_container],
+        )
+
+        if content_json_for_volume:
+            import shlex
+
+            input_file_path = "/tmp/execute/input.json"
+            execute_volume = V1Volume(
+                name="execute-volume",
+                empty_dir=V1EmptyDirVolumeSource(),
+            )
+
+            execute_volume_mount = V1VolumeMount(
+                name="execute-volume",
+                mount_path="/tmp/execute",
+                read_only=False,
+            )
+
+            escaped_json = shlex.quote(content_json_for_volume)
+            init_container = k8s.V1Container(
+                name="init-container",
+                image="busybox",
+                command=["/bin/sh", "-c", f"echo {escaped_json} > 
{input_file_path}"],
+                volume_mounts=[execute_volume_mount],
+            )
+
+            main_container.volume_mounts = [execute_volume_mount]
+            main_container.command = args[:-1]
+            main_container.args = args[-1:]
+
+            podspec = k8s.V1PodSpec(
+                containers=[main_container],
+                volumes=[execute_volume],
+                init_containers=[init_container],
+            )
+
+        dynamic_pod.spec = podspec
+
         # Reconcile the pods starting with the first chronologically,
         # Pod from the pod_template_File -> Pod from the K8s executor -> Pod 
from executor_config arg
         pod_list = [base_worker_pod, dynamic_pod, pod_override_object]
diff --git 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py
index 77166452410..e4c9db06688 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py
@@ -167,6 +167,68 @@ class TestPodGenerator:
             ),
         )
 
+    @pytest.mark.parametrize(
+        "content_json, expected",
+        [
+            pytest.param(
+                
'{"token":"mock","ti":{"id":"4d828a62-a417-4936-a7a6-2b3fabacecab","task_id":"mock","dag_id":"mock","run_id":"mock","try_number":1,"map_index":-1,"pool_slots":1,"queue":"default","priority_weight":1},"dag_rel_path":"mock.py","bundle_info":{"name":"n/a","version":"no
 matter"},"log_path":"mock.log","kind":"ExecuteTask"}',
+                
'{"token":"mock","ti":{"id":"4d828a62-a417-4936-a7a6-2b3fabacecab","task_id":"mock","dag_id":"mock","run_id":"mock","try_number":1,"map_index":-1,"pool_slots":1,"queue":"default","priority_weight":1},"dag_rel_path":"mock.py","bundle_info":{"name":"n/a","version":"no
 matter"},"log_path":"mock.log","kind":"ExecuteTask"}',
+                id="regular-input",
+            ),
+            pytest.param(
+                
'{"token":"mock","ti":{"id":"4d828a62-a417-4936-a7a6-2b3fabacecab","task_id":"moc\'k","dag_id":"mock","run_id":"mock","try_number":1,"map_index":-1,"pool_slots":1,"queue":"default","priority_weight":1},"dag_rel_path":"mock.py","bundle_info":{"name":"n/a","version":"no
 matter"},"log_path":"mock.log","kind":"ExecuteTask"}',
+                
'{"token":"mock","ti":{"id":"4d828a62-a417-4936-a7a6-2b3fabacecab","task_id":"moc\'"\'"\'k","dag_id":"mock","run_id":"mock","try_number":1,"map_index":-1,"pool_slots":1,"queue":"default","priority_weight":1},"dag_rel_path":"mock.py","bundle_info":{"name":"n/a","version":"no
 matter"},"log_path":"mock.log","kind":"ExecuteTask"}',
+                id="input-with-single-quote-in-task-id",
+            ),
+        ],
+    )
+    def test_pod_spec_for_task_sdk_runs(self, content_json, expected, 
data_file):
+        template_file = 
data_file("pods/generator_base_with_secrets.yaml").as_posix()
+        worker_config = PodGenerator.deserialize_model_file(template_file)
+        result = PodGenerator.construct_pod(
+            dag_id="dag_id",
+            task_id="task_id",
+            pod_id="pod_id",
+            kube_image="test-image",
+            try_number=3,
+            date=self.logical_date,
+            args=[
+                "python",
+                "-m",
+                "airflow.sdk.execution_time.execute_workload",
+                "/tmp/execute/input.json",
+            ],
+            pod_override_object=None,
+            base_worker_pod=worker_config,
+            namespace="namespace",
+            scheduler_job_id="uuid",
+            content_json_for_volume=content_json,
+        )
+        sanitized_result = self.k8s_client.sanitize_for_serialization(result)
+
+        init_containers = sanitized_result["spec"]["initContainers"]
+        assert len(init_containers) == 1
+        init_container = init_containers[0]
+        assert init_container == {
+            "command": [
+                "/bin/sh",
+                "-c",
+                f"echo '{expected}' > /tmp/execute/input.json",
+            ],
+            "image": "busybox",
+            "name": "init-container",
+            "volumeMounts": [{"mountPath": "/tmp/execute", "name": 
"execute-volume", "readOnly": False}],
+        }
+
+        volumes = sanitized_result["spec"]["volumes"]
+        assert len(volumes) == 1
+        volume = volumes[0]
+        assert volume == {"emptyDir": {}, "name": "execute-volume"}
+
+        main_container = sanitized_result["spec"]["containers"][0]
+        assert main_container["command"] == ["python", "-m", 
"airflow.sdk.execution_time.execute_workload"]
+        assert main_container["args"] == ["/tmp/execute/input.json"]
+
     def test_from_obj_pod_override_object(self):
         obj = {
             "pod_override": k8s.V1Pod(
diff --git a/task_sdk/src/airflow/sdk/execution_time/execute_workload.py 
b/task_sdk/src/airflow/sdk/execution_time/execute_workload.py
new file mode 100644
index 00000000000..489c26fe607
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/execution_time/execute_workload.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Module for executing an Airflow task using the workload json provided by a 
input file.
+
+Usage:
+    python execute_workload.py <input_file>
+
+Arguments:
+    input_file (str): Path to the JSON file containing the workload definition.
+"""
+
+from __future__ import annotations
+
+import argparse
+import sys
+
+import structlog
+
+log = structlog.get_logger(logger_name=__name__)
+
+
+def execute_workload(input: str) -> None:
+    from pydantic import TypeAdapter
+
+    from airflow.configuration import conf
+    from airflow.executors import workloads
+    from airflow.sdk.execution_time.supervisor import supervise
+    from airflow.sdk.log import configure_logging
+
+    configure_logging(output=sys.stdout.buffer)
+
+    decoder = TypeAdapter[workloads.All](workloads.All)
+    workload = decoder.validate_json(input)
+
+    if not isinstance(workload, workloads.ExecuteTask):
+        raise ValueError(f"KubernetesExecutor does not know how to handle 
{type(workload)}")
+
+    log.info("Executing workload in Kubernetes", workload=workload)
+
+    supervise(
+        # This is the "wrong" ti type, but it duck types the same. TODO: 
Create a protocol for this.
+        ti=workload.ti,  # type: ignore[arg-type]
+        dag_rel_path=workload.dag_rel_path,
+        bundle_info=workload.bundle_info,
+        token=workload.token,
+        # fallback to internal cluster service for api server
+        server=conf.get(
+            "workers",
+            "execution_api_server_url",
+            
fallback="http://airflow-api-server.airflow.svc.cluster.local:9091/execution/";,
+        ),
+        log_path=workload.log_path,
+    )
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="Execute a workload in a Containerised executor using the 
task SDK."
+    )
+    parser.add_argument(
+        "input_file", help="Path to the input JSON file containing the 
execution workload payload."
+    )
+
+    args = parser.parse_args()
+
+    with open(args.input_file) as file:
+        input_data = file.read()
+
+    execute_workload(input_data)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index b2f9c37c776..99967579bb4 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -783,7 +783,6 @@ def finalize(ti: RuntimeTaskInstance, state: 
TerminalTIState, log: Logger):
 
 def main():
     # TODO: add an exception here, it causes an oof of a stack trace!
-
     global SUPERVISOR_COMMS
     SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin)
     try:

Reply via email to