This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 56e083f1455 AIP-72: Swap KubernetesExecutor to use taskSDK for
execution (#46860)
56e083f1455 is described below
commit 56e083f1455a83dcd8204e9746cfb2fad18f06da
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Feb 20 12:29:26 2025 +0530
AIP-72: Swap KubernetesExecutor to use taskSDK for execution (#46860)
---
kubernetes_tests/test_kubernetes_executor.py | 6 ++
.../kubernetes/executors/kubernetes_executor.py | 33 +++++++-
.../executors/kubernetes_executor_utils.py | 21 ++++-
.../providers/cncf/kubernetes/pod_generator.py | 60 ++++++++++++---
.../unit/cncf/kubernetes/test_pod_generator.py | 62 +++++++++++++++
.../airflow/sdk/execution_time/execute_workload.py | 89 ++++++++++++++++++++++
.../src/airflow/sdk/execution_time/task_runner.py | 1 -
7 files changed, 256 insertions(+), 16 deletions(-)
diff --git a/kubernetes_tests/test_kubernetes_executor.py
b/kubernetes_tests/test_kubernetes_executor.py
index 622a4daaa0d..92e58d98118 100644
--- a/kubernetes_tests/test_kubernetes_executor.py
+++ b/kubernetes_tests/test_kubernetes_executor.py
@@ -26,6 +26,9 @@ from kubernetes_tests.test_base import (
@pytest.mark.skipif(EXECUTOR != "KubernetesExecutor", reason="Only runs on
KubernetesExecutor")
class TestKubernetesExecutor(BaseK8STest):
+ @pytest.mark.skip(
+ reason="TODO: AIP-72 Porting over executor_config not yet done. Remove
once #46892 is handled"
+ )
@pytest.mark.execution_timeout(300)
def test_integration_run_dag(self):
dag_id = "example_kubernetes_executor"
@@ -51,6 +54,9 @@ class TestKubernetesExecutor(BaseK8STest):
)
@pytest.mark.execution_timeout(300)
+ @pytest.mark.skip(
+ reason="TODO: AIP-72 Porting over executor_config not yet done. Remove
once #46892 is handled"
+ )
def test_integration_run_dag_with_scheduler_failure(self):
dag_id = "example_kubernetes_executor"
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
index 845c91e0db3..3f6f49070ae 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
@@ -39,6 +39,8 @@ from typing import TYPE_CHECKING, Any
from deprecated import deprecated
from sqlalchemy import select
+from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
+from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS
from kubernetes.dynamic import DynamicClient
try:
@@ -78,6 +80,7 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
+ from airflow.executors import workloads
from airflow.executors.base_executor import CommandType
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
@@ -136,6 +139,11 @@ class KubernetesExecutor(BaseExecutor):
RUNNING_POD_LOG_LINES = 100
supports_ad_hoc_ti_run: bool = True
+ if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
+ # In the v3 path, we store workloads, not commands as strings.
+ # TODO: TaskSDK: move this type change into BaseExecutor
+ queued_tasks: dict[TaskInstanceKey, workloads.All] # type:
ignore[assignment]
+
def __init__(self):
self.kube_config = KubeConfig()
self._manager = multiprocessing.Manager()
@@ -250,8 +258,6 @@ class KubernetesExecutor(BaseExecutor):
else:
self.log.info("Add task %s with command %s", key, command)
- from airflow.providers.cncf.kubernetes.pod_generator import
PodGenerator
-
try:
kube_executor_config = PodGenerator.from_obj(executor_config)
except Exception:
@@ -269,6 +275,29 @@ class KubernetesExecutor(BaseExecutor):
# try and remove it from the QUEUED state while we process it
self.last_handled[key] = time.time()
+ def queue_workload(self, workload: workloads.All, session: Session | None)
-> None:
+ from airflow.executors import workloads
+
+ if not isinstance(workload, workloads.ExecuteTask):
+ raise RuntimeError(f"{type(self)} cannot handle workloads of type
{type(workload)}")
+ ti = workload.ti
+ self.queued_tasks[ti.key] = workload
+
+ def _process_workloads(self, workloads: list[workloads.All]) -> None:
+ # Airflow V3 version
+ for w in workloads:
+ # TODO: AIP-72 handle populating tokens once
https://github.com/apache/airflow/issues/45107 is handled.
+ command = [w]
+ key = w.ti.key # type: ignore[union-attr]
+ queue = w.ti.queue # type: ignore[union-attr]
+
+ # TODO: will be handled by
https://github.com/apache/airflow/issues/46892
+ executor_config = {} # type: ignore[var-annotated]
+
+ del self.queued_tasks[key]
+ self.execute_async(key=key, command=command, queue=queue,
executor_config=executor_config) # type: ignore[arg-type]
+ self.running.add(key)
+
def sync(self) -> None:
"""Synchronize task state."""
if TYPE_CHECKING:
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
index 1b7917502f0..dda6b6e6b10 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py
@@ -388,8 +388,24 @@ class AirflowKubernetesScheduler(LoggingMixin):
key, command, kube_executor_config, pod_template_file = next_job
dag_id, task_id, run_id, try_number, map_index = key
-
- if command[0:3] != ["airflow", "tasks", "run"]:
+ ser_input = ""
+ if len(command) == 1:
+ from airflow.executors.workloads import ExecuteTask
+
+ if isinstance(command[0], ExecuteTask):
+ workload = command[0]
+ ser_input = workload.model_dump_json()
+ command = [
+ "python",
+ "-m",
+ "airflow.sdk.execution_time.execute_workload",
+ "/tmp/execute/input.json",
+ ]
+ else:
+ raise ValueError(
+ f"KubernetesExecutor doesn't know how to handle workload
of type: {type(command[0])}"
+ )
+ elif command[0:3] != ["airflow", "tasks", "run"]:
raise ValueError('The command must start with ["airflow", "tasks",
"run"].')
base_worker_pod = get_base_pod_from_template(pod_template_file,
self.kube_config)
@@ -411,6 +427,7 @@ class AirflowKubernetesScheduler(LoggingMixin):
date=None,
run_id=run_id,
args=list(command),
+ content_json_for_volume=ser_input,
pod_override_object=kube_executor_config,
base_worker_pod=base_worker_pod,
with_mutation_hook=True,
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py
index 5fab194963e..e1f599a40a8 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py
@@ -47,7 +47,7 @@ from
airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
from airflow.utils import yaml
from airflow.utils.hashlib_wrapper import md5
from airflow.version import version as airflow_version
-from kubernetes.client import models as k8s
+from kubernetes.client import V1EmptyDirVolumeSource, V1Volume, V1VolumeMount,
models as k8s
from kubernetes.client.api_client import ApiClient
if TYPE_CHECKING:
@@ -288,6 +288,7 @@ class PodGenerator:
scheduler_job_id: str,
run_id: str | None = None,
map_index: int = -1,
+ content_json_for_volume: str = "",
*,
with_mutation_hook: bool = False,
) -> k8s.V1Pod:
@@ -326,6 +327,14 @@ class PodGenerator:
if run_id:
annotations["run_id"] = run_id
+ main_container = k8s.V1Container(
+ name="base",
+ args=args,
+ image=image,
+ env=[
+ k8s.V1EnvVar(name="AIRFLOW_IS_K8S_EXECUTOR_POD", value="True"),
+ ],
+ )
dynamic_pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
namespace=namespace,
@@ -341,18 +350,47 @@ class PodGenerator:
run_id=run_id,
),
),
- spec=k8s.V1PodSpec(
- containers=[
- k8s.V1Container(
- name="base",
- args=args,
- image=image,
- env=[k8s.V1EnvVar(name="AIRFLOW_IS_K8S_EXECUTOR_POD",
value="True")],
- )
- ]
- ),
)
+ podspec = k8s.V1PodSpec(
+ containers=[main_container],
+ )
+
+ if content_json_for_volume:
+ import shlex
+
+ input_file_path = "/tmp/execute/input.json"
+ execute_volume = V1Volume(
+ name="execute-volume",
+ empty_dir=V1EmptyDirVolumeSource(),
+ )
+
+ execute_volume_mount = V1VolumeMount(
+ name="execute-volume",
+ mount_path="/tmp/execute",
+ read_only=False,
+ )
+
+ escaped_json = shlex.quote(content_json_for_volume)
+ init_container = k8s.V1Container(
+ name="init-container",
+ image="busybox",
+ command=["/bin/sh", "-c", f"echo {escaped_json} >
{input_file_path}"],
+ volume_mounts=[execute_volume_mount],
+ )
+
+ main_container.volume_mounts = [execute_volume_mount]
+ main_container.command = args[:-1]
+ main_container.args = args[-1:]
+
+ podspec = k8s.V1PodSpec(
+ containers=[main_container],
+ volumes=[execute_volume],
+ init_containers=[init_container],
+ )
+
+ dynamic_pod.spec = podspec
+
# Reconcile the pods starting with the first chronologically,
# Pod from the pod_template_File -> Pod from the K8s executor -> Pod
from executor_config arg
pod_list = [base_worker_pod, dynamic_pod, pod_override_object]
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py
index 77166452410..e4c9db06688 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_pod_generator.py
@@ -167,6 +167,68 @@ class TestPodGenerator:
),
)
+ @pytest.mark.parametrize(
+ "content_json, expected",
+ [
+ pytest.param(
+
'{"token":"mock","ti":{"id":"4d828a62-a417-4936-a7a6-2b3fabacecab","task_id":"mock","dag_id":"mock","run_id":"mock","try_number":1,"map_index":-1,"pool_slots":1,"queue":"default","priority_weight":1},"dag_rel_path":"mock.py","bundle_info":{"name":"n/a","version":"no
matter"},"log_path":"mock.log","kind":"ExecuteTask"}',
+
'{"token":"mock","ti":{"id":"4d828a62-a417-4936-a7a6-2b3fabacecab","task_id":"mock","dag_id":"mock","run_id":"mock","try_number":1,"map_index":-1,"pool_slots":1,"queue":"default","priority_weight":1},"dag_rel_path":"mock.py","bundle_info":{"name":"n/a","version":"no
matter"},"log_path":"mock.log","kind":"ExecuteTask"}',
+ id="regular-input",
+ ),
+ pytest.param(
+
'{"token":"mock","ti":{"id":"4d828a62-a417-4936-a7a6-2b3fabacecab","task_id":"moc\'k","dag_id":"mock","run_id":"mock","try_number":1,"map_index":-1,"pool_slots":1,"queue":"default","priority_weight":1},"dag_rel_path":"mock.py","bundle_info":{"name":"n/a","version":"no
matter"},"log_path":"mock.log","kind":"ExecuteTask"}',
+
'{"token":"mock","ti":{"id":"4d828a62-a417-4936-a7a6-2b3fabacecab","task_id":"moc\'"\'"\'k","dag_id":"mock","run_id":"mock","try_number":1,"map_index":-1,"pool_slots":1,"queue":"default","priority_weight":1},"dag_rel_path":"mock.py","bundle_info":{"name":"n/a","version":"no
matter"},"log_path":"mock.log","kind":"ExecuteTask"}',
+ id="input-with-single-quote-in-task-id",
+ ),
+ ],
+ )
+ def test_pod_spec_for_task_sdk_runs(self, content_json, expected,
data_file):
+ template_file =
data_file("pods/generator_base_with_secrets.yaml").as_posix()
+ worker_config = PodGenerator.deserialize_model_file(template_file)
+ result = PodGenerator.construct_pod(
+ dag_id="dag_id",
+ task_id="task_id",
+ pod_id="pod_id",
+ kube_image="test-image",
+ try_number=3,
+ date=self.logical_date,
+ args=[
+ "python",
+ "-m",
+ "airflow.sdk.execution_time.execute_workload",
+ "/tmp/execute/input.json",
+ ],
+ pod_override_object=None,
+ base_worker_pod=worker_config,
+ namespace="namespace",
+ scheduler_job_id="uuid",
+ content_json_for_volume=content_json,
+ )
+ sanitized_result = self.k8s_client.sanitize_for_serialization(result)
+
+ init_containers = sanitized_result["spec"]["initContainers"]
+ assert len(init_containers) == 1
+ init_container = init_containers[0]
+ assert init_container == {
+ "command": [
+ "/bin/sh",
+ "-c",
+ f"echo '{expected}' > /tmp/execute/input.json",
+ ],
+ "image": "busybox",
+ "name": "init-container",
+ "volumeMounts": [{"mountPath": "/tmp/execute", "name":
"execute-volume", "readOnly": False}],
+ }
+
+ volumes = sanitized_result["spec"]["volumes"]
+ assert len(volumes) == 1
+ volume = volumes[0]
+ assert volume == {"emptyDir": {}, "name": "execute-volume"}
+
+ main_container = sanitized_result["spec"]["containers"][0]
+ assert main_container["command"] == ["python", "-m",
"airflow.sdk.execution_time.execute_workload"]
+ assert main_container["args"] == ["/tmp/execute/input.json"]
+
def test_from_obj_pod_override_object(self):
obj = {
"pod_override": k8s.V1Pod(
diff --git a/task_sdk/src/airflow/sdk/execution_time/execute_workload.py
b/task_sdk/src/airflow/sdk/execution_time/execute_workload.py
new file mode 100644
index 00000000000..489c26fe607
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/execution_time/execute_workload.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Module for executing an Airflow task using the workload json provided by a
input file.
+
+Usage:
+ python execute_workload.py <input_file>
+
+Arguments:
+ input_file (str): Path to the JSON file containing the workload definition.
+"""
+
+from __future__ import annotations
+
+import argparse
+import sys
+
+import structlog
+
+log = structlog.get_logger(logger_name=__name__)
+
+
+def execute_workload(input: str) -> None:
+ from pydantic import TypeAdapter
+
+ from airflow.configuration import conf
+ from airflow.executors import workloads
+ from airflow.sdk.execution_time.supervisor import supervise
+ from airflow.sdk.log import configure_logging
+
+ configure_logging(output=sys.stdout.buffer)
+
+ decoder = TypeAdapter[workloads.All](workloads.All)
+ workload = decoder.validate_json(input)
+
+ if not isinstance(workload, workloads.ExecuteTask):
+ raise ValueError(f"KubernetesExecutor does not know how to handle
{type(workload)}")
+
+ log.info("Executing workload in Kubernetes", workload=workload)
+
+ supervise(
+ # This is the "wrong" ti type, but it duck types the same. TODO:
Create a protocol for this.
+ ti=workload.ti, # type: ignore[arg-type]
+ dag_rel_path=workload.dag_rel_path,
+ bundle_info=workload.bundle_info,
+ token=workload.token,
+ # fallback to internal cluster service for api server
+ server=conf.get(
+ "workers",
+ "execution_api_server_url",
+
fallback="http://airflow-api-server.airflow.svc.cluster.local:9091/execution/",
+ ),
+ log_path=workload.log_path,
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Execute a workload in a Containerised executor using the
task SDK."
+ )
+ parser.add_argument(
+ "input_file", help="Path to the input JSON file containing the
execution workload payload."
+ )
+
+ args = parser.parse_args()
+
+ with open(args.input_file) as file:
+ input_data = file.read()
+
+ execute_workload(input_data)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index b2f9c37c776..99967579bb4 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -783,7 +783,6 @@ def finalize(ti: RuntimeTaskInstance, state:
TerminalTIState, log: Logger):
def main():
# TODO: add an exception here, it causes an oof of a stack trace!
-
global SUPERVISOR_COMMS
SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin)
try: