jscheffl commented on code in PR #43893:
URL: https://github.com/apache/airflow/pull/43893#discussion_r1837178283
##########
task_sdk/pyproject.toml:
##########
@@ -24,7 +24,11 @@ requires-python = ">=3.9, <3.13"
dependencies = [
"attrs>=24.2.0",
"google-re2>=1.1.20240702",
+ "httpx>=0.27.0",
"methodtools>=0.4.7",
+ "msgspec>=0.18.6",
Review Comment:
If this is 0.18.6... can we assume it is stable and not having breaking
changes by tomorrow?
##########
task_sdk/src/airflow/sdk/api/client.py:
##########
@@ -0,0 +1,215 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import uuid
+from typing import TYPE_CHECKING, Any
+
+import httpx
+import methodtools
+import msgspec
+import structlog
+from uuid6 import uuid7
+
+from airflow.sdk.api.datamodels._generated import (
+ ConnectionResponse,
+ State1 as TerminalState,
+ TaskInstanceState,
+ TIEnterRunningPayload,
+ TITerminalStatePayload,
+ ValidationError as RemoteValidationError,
+)
+from airflow.utils.net import get_hostname
+from airflow.utils.platform import getuser
+
+if TYPE_CHECKING:
+ from datetime import datetime
+
+
+log = structlog.get_logger(logger_name=__name__)
+
+__all__ = [
+ "Client",
+ "ConnectionOperations",
+ "ErrorBody",
+ "ServerResponseError",
+ "TaskInstanceOperations",
+]
+
+
+def get_json_error(response: httpx.Response):
+ """Raise a ServerResponseError if we can extract error info from the
error."""
+ err = ServerResponseError.from_response(response)
+ if err:
+ log.warning("Server error", detail=err.detail)
+ raise err
+
+
+def raise_on_4xx_5xx(response: httpx.Response):
+ return get_json_error(response) or response.raise_for_status()
+
+
+# Py 3.11+ version
+def raise_on_4xx_5xx_with_note(response: httpx.Response):
+ try:
+ return get_json_error(response) or response.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ if TYPE_CHECKING:
+ assert hasattr(e, "add_note")
+ e.add_note(
+ f"Correlation-id={response.headers.get('correlation-id', None) or
response.request.headers.get('correlation-id', 'no-correlction-id')}"
+ )
+ raise
+
+
+if hasattr(BaseException, "add_note"):
+ # Py 3.11+
+ raise_on_4xx_5xx = raise_on_4xx_5xx_with_note
+
+
+def add_correlation_id(request: httpx.Request):
+ request.headers["correlation-id"] = str(uuid7())
+
+
+class TaskInstanceOperations:
+ __slots__ = ("client",)
+
+ def __init__(self, client: Client):
+ self.client = client
+
+ def start(self, id: uuid.UUID, pid: int, when: datetime):
+ """Tell the API server that this TI has started running."""
+ body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(),
unixname=getuser(), start_date=when)
+
+ self.client.patch(f"task_instance/{id}/state",
content=self.client.encoder.encode(body))
+
+ def finish(self, id: uuid.UUID, state: TaskInstanceState, when: datetime):
+ """Tell the API server that this TI has reached a terminal state."""
+ body = TITerminalStatePayload(end_date=when,
state=TerminalState(state))
+
+ self.client.patch(f"task_instance/{id}/state",
content=self.client.encoder.encode(body))
+
+ def heartbeat(self, id: uuid.UUID):
+ self.client.put(f"task_instance/{id}/heartbeat")
+
+
+class ConnectionOperations:
+ __slots__ = ("client", "decoder")
+
+ def __init__(self, client: Client):
+ self.client = client
+ self.decoder: msgspec.json.Decoder[ConnectionResponse] =
msgspec.json.Decoder(type=ConnectionResponse)
+
+ def get(self, id: str) -> ConnectionResponse:
+ """Get a connection from the API server."""
+ resp = self.client.get(f"connection/{id}")
+ return self.decoder.decode(resp.read())
+
+
+class BearerAuth(httpx.Auth):
+ def __init__(self, token: str):
+ self.token: str = token
+
+ def auth_flow(self, request: httpx.Request):
+ if self.token:
+ request.headers["Authorization"] = "Bearer " + self.token
+ yield request
+
+
+def noop_handler(request: httpx.Request) -> httpx.Response:
+ log.debug("Dry-run request", method=request.method, path=request.url.path)
+ return httpx.Response(200, json={"text": "Hello, world!"})
Review Comment:
What is this for?
##########
task_sdk/src/airflow/sdk/api/datamodels/_generated.py:
##########
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# generated by datamodel-codegen:
+# filename: http://0.0.0.0:9091/execution/openapi.json
+# version: 0.26.3
+
+from __future__ import annotations
+
+from datetime import datetime
+from enum import Enum
+from typing import Literal
+
+from msgspec import Struct, field
+
+
+class ConnectionResponse(Struct):
+ """
+ Connection schema for responses with fields that are needed for Runtime.
+ """
+
+ conn_id: str
+ conn_type: str
+ host: str | None = None
+ schema_: str | None = field(name="schema", default=None)
+ login: str | None = None
+ password: str | None = None
+ port: int | None = None
+ extra: str | None = None
+
+
+class TIEnterRunningPayload(Struct):
+ """
+ Schema for updating TaskInstance to 'RUNNING' state with minimal required
fields.
+ """
+
+ hostname: str
+ unixname: str
+ pid: int
+ start_date: datetime
+ state: Literal["running"] | None = "running"
+
+
+class TIHeartbeatInfo(Struct):
+ """
+ Schema for TaskInstance heartbeat endpoint.
+ """
+
+ hostname: str
+ pid: int
+
+
+class State(Enum):
Review Comment:
What is the difference `TaskInstanceState` below?
##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
Review Comment:
What is the difference between the "Supervisor" and the "Worker" like today
CeleryWorker or EdgeWorker? Will the Supervisor replace these? Or will there be
exactly 1 Supervisor as wrapper per task being executed?
(So if Celery executes with concurrency of 16, will to spawn 16 Supervisor
or one that hosts 16 tasks?)
(And the question is mainly will each have it's own HTTP session and JWT or
will thse calls be multiplexed?)
##########
task_sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -0,0 +1,103 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+r"""
+Communication protocol between the Supervisor and the task process
+==================================================================
+
+* All communication is done over stdout/stdin in the form of "JSON lines" (each
+ message is a single JSON document terminated by `\n` character)
+* Messages from the subprocess are all log messages and are sent directly to
the log
+* No messages are sent to task process except in response to a request. (This
is because the task process will
+ be running user's code, so we can't read from stdin until we enter our code,
such as when requesting an XCom
+ value etc.)
Review Comment:
Would it be better to call the Suporvisor being a HTTP bridge or proxy to
the backend?
##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -0,0 +1,547 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Supervise and run Tasks in a subprocess."""
+
+from __future__ import annotations
+
+import atexit
+import io
+import logging
+import os
+import selectors
+import signal
+import sys
+import time
+import weakref
+from collections.abc import Generator
+from contextlib import suppress
+from datetime import datetime, timezone
+from socket import socket, socketpair
+from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal,
NoReturn, cast, overload
+from uuid import UUID
+
+import attrs
+import httpx
+import msgspec
+import psutil
+import structlog
+
+from airflow.sdk.api.client import Client
+from airflow.sdk.api.datamodels._generated import TaskInstanceState
+from airflow.sdk.execution_time.comms import ConnectionResponse,
GetConnection, StartupDetails, ToSupervisor
+
+if TYPE_CHECKING:
+ from structlog.typing import FilteringBoundLogger
+
+ from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
+ from airflow.sdk.api.datamodels.ti import TaskInstance
+
+
+__all__ = ["WatchedSubprocess", "supervise"]
+
+log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor")
+
+# TODO: Pull this from config
+SLOWEST_HEARTBEAT_INTERVAL: int = 30
+# Don't heartbeat more often than this
+FASTEST_HEARTBEAT_INTERVAL: int = 5
+
+
+@overload
+def mkpipe() -> tuple[socket, socket]: ...
+
+
+@overload
+def mkpipe(remote_read: Literal[True]) -> tuple[socket, BinaryIO]: ...
+
+
+def mkpipe(
+ remote_read: bool = False,
+) -> tuple[socket, socket | BinaryIO]:
+ """
+ Create a pair of connected sockets.
+
+ The inheritable flag will be set correctly so that the end destined for
the subprocess is kept open but
+ the end for this process is closed automatically by the OS.
+ """
+ rsock, wsock = socketpair()
+ local, remote = (wsock, rsock) if remote_read else (rsock, wsock)
+
+ remote.set_inheritable(True)
+ local.setblocking(False)
+
+ io: BinaryIO | socket
+ if remote_read:
+ # If _we_ are writing, we don't want to buffer
+ io = cast(BinaryIO, local.makefile("wb", buffering=0))
+ else:
+ io = local
+
+ return remote, io
+
+
+def _subprocess_main():
+ from airflow.sdk.execution_time.task_runner import main
+
+ main()
+
+
+def _fork_main(
+ child_stdin: socket,
+ child_stdout: socket,
+ child_stderr: socket,
+ log_fd: int,
+ target: Callable[[], None],
+) -> NoReturn:
+ # TODO: Make this process a session leader
+
+ # Uninstall the rich etc. exception handler
+ sys.excepthook = sys.__excepthook__
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
+ signal.signal(signal.SIGUSR2, signal.SIG_DFL)
+
+ if log_fd > 0:
+ # A channel that the task can send JSON-formated logs over.
+ #
+ # JSON logs sent this way will be handled nicely
+ from airflow.sdk.log import configure_logging
+
+ log_io = os.fdopen(log_fd, "wb", buffering=0)
+ configure_logging(enable_pretty_log=False, output=log_io)
+
+ last_chance_stderr = sys.__stderr__ or sys.stderr
+
+ # Ensure that sys.stdout et al (and the underlying filehandles for C
libraries etc) are connected to the
+ # pipes form the supervisor
+
+ for handle_name, sock, mode, close in (
+ ("stdin", child_stdin, "r", True),
+ ("stdout", child_stdout, "w", True),
+ ("stderr", child_stderr, "w", False),
+ ):
+ handle = getattr(sys, handle_name)
+ try:
+ fd = handle.fileno()
+ os.dup2(sock.fileno(), fd)
+ if close:
+ handle.close()
+ except io.UnsupportedOperation:
+ if "PYTEST_CURRENT_TEST" in os.environ:
+ # When we're running under pytest, the stdin is not a real
filehandle with an fd, so we need
+ # to handle that differently
+ fd = sock.fileno()
+ else:
+ raise
+
+ setattr(sys, handle_name, os.fdopen(fd, mode))
+
+ def exit(n: int) -> NoReturn:
+ with suppress(ValueError, OSError):
+ sys.stdout.flush()
+ with suppress(ValueError, OSError):
+ sys.stderr.flush()
+ with suppress(ValueError, OSError):
+ last_chance_stderr.flush()
+ os._exit(n)
+
+ if hasattr(atexit, "_clear"):
+ # Since we're in a fork we want to try and clear them
+ atexit._clear()
+ base_exit = exit
+
+ def exit(n: int) -> NoReturn:
+ atexit._run_exitfuncs()
+ base_exit(n)
+
+ try:
+ target()
+ exit(0)
+ except SystemExit as e:
+ code = 1
+ if isinstance(e.code, int):
+ code = e.code
+ elif e.code:
+ print(e.code, file=sys.stderr)
+ exit(code)
+ except Exception:
+ # Last ditch log attempt
+ exc, v, tb = sys.exc_info()
+
+ import traceback
+
+ try:
+ last_chance_stderr.write("--- Last chance exception handler ---\n")
+ traceback.print_exception(exc, value=v, tb=tb,
file=last_chance_stderr)
+ exit(99)
+ except Exception as e:
+ with suppress(Exception):
+ print(
+ f"--- Last chance exception handler failed ---
{repr(str(e))}\n", file=last_chance_stderr
+ )
+ exit(98)
+
+
[email protected]()
+class WatchedSubprocess:
+ ti_id: UUID
+ pid: int
+
+ stdin: BinaryIO
+ stdout: socket
+ stderr: socket
+
+ client: Client
+
+ _process: psutil.Process
+ _exit_code: int | None = None
+ _terminal_state: str | None = None
+
+ _last_heartbeat: float = 0
+
+ selector: selectors.BaseSelector =
attrs.field(factory=selectors.DefaultSelector)
+
+ procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] =
weakref.WeakValueDictionary()
+
+ def __attrs_post_init__(self):
+ self.procs[self.pid] = self
+
+ @classmethod
+ def start(
+ cls,
+ path: str | os.PathLike[str],
+ ti: TaskInstance,
+ client: Client,
+ target: Callable[[], None] = _subprocess_main,
+ ) -> WatchedSubprocess:
+ """Fork and start a new subprocess to execute the given task."""
+ # Create socketpairs/"pipes" to connect to the stdin and out from the
subprocess
+ child_stdin, feed_stdin = mkpipe(remote_read=True)
+ child_stdout, read_stdout = mkpipe()
+ child_stderr, read_stderr = mkpipe()
+
+ # Open these socketpair before forking off the child, so that it is
open when we fork.
+ child_comms, read_msgs = mkpipe()
+ child_logs, read_logs = mkpipe()
+
+ pid = os.fork()
+ if pid == 0:
+ # Parent ends of the sockets are closed by the OS as they are set
as non-inheritable
+
+ # Run the child entryoint
+ _fork_main(child_stdin, child_stdout, child_stderr,
child_logs.fileno(), target)
+
+ proc = cls(
+ ti_id=ti.id,
+ pid=pid,
+ stdin=feed_stdin,
+ stdout=read_stdout,
+ stderr=read_stderr,
+ process=psutil.Process(pid),
+ client=client,
+ )
+
+ # We've forked, but the task won't start until we send it the
StartupDetails message. But before we do
+ # that, we need to tell the server it's started (so it has the chance
to tell us "no, stop!" for any
+ # reason)
+ try:
+ client.task_instances.start(ti.id, pid,
datetime.now(tz=timezone.utc))
+ proc._last_heartbeat = time.monotonic()
+ except Exception:
+ # On any error kill that subprocess!
+ proc.kill(signal.SIGKILL)
+ raise
+
+ # TODO: Use logging providers to handle the chunked upload for us
+ task_logger: FilteringBoundLogger =
structlog.get_logger(logger_name="task").bind()
+
+ cb =
make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stdout"),
level=logging.INFO))
+
+ proc.selector.register(read_stdout, selectors.EVENT_READ, cb)
+ cb =
make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stderr"),
level=logging.ERROR))
+ proc.selector.register(read_stderr, selectors.EVENT_READ, cb)
+
+ proc.selector.register(
+ read_logs,
+ selectors.EVENT_READ,
+
make_buffered_socket_reader(process_log_messages_from_subprocess(task_logger)),
+ )
+ proc.selector.register(
+ read_msgs,
+ selectors.EVENT_READ,
+ make_buffered_socket_reader(proc.handle_requests(log=log)),
+ )
+
+ # Tell the task process what it needs to do!
+ msg = StartupDetails(
+ ti=ti,
+ file=str(path),
+ requests_fd=child_comms.fileno(),
+ )
+
+ # Close the remaining parent-end of the sockets we've passed to the
child via fork. We still have the
+ # other end of the pair open
+ child_stdout.close()
+ child_stdin.close()
+ child_comms.close()
+ child_logs.close()
+
+ # Send the message to tell the process what it needs to execute
+ log.debug("Sending", msg=msg)
+ feed_stdin.write(msgspec.json.encode(msg))
+ feed_stdin.write(b"\n")
+
+ return proc
+
+ def kill(self, signal: signal.Signals = signal.SIGINT):
+ if self._exit_code is not None:
+ return
+
+ with suppress(ProcessLookupError):
+ os.kill(self.pid, signal)
+
+ def wait(self) -> int:
+ if self._exit_code is not None:
+ return self._exit_code
+
+ # Until we have a selector for the process, don't poll for more than
10s, just in case it exists but
+ # doesn't produce any output
+ max_poll_interval = 10
+
+ try:
+ while self._exit_code is None or len(self.selector.get_map()):
+ last_heartbeat_ago = time.monotonic() - self._last_heartbeat
+ # Monitor the task to see if it's done. Wait in a syscall
(`select`) for as long as possible
+ # so we notice the subprocess finishing as quick as we can.
+ max_wait_time = max(
+ 0, # Make sure this value is never negative,
+ min(
+ # Ensure we heartbeat _at most_ 75% through the time
the zombie threshold time
+ SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75,
+ max_poll_interval,
+ ),
+ )
+ events = self.selector.select(timeout=max_wait_time)
+ for key, _ in events:
+ socket_handler = key.data
+ need_more = socket_handler(key.fileobj)
+
+ if not need_more:
+ self.selector.unregister(key.fileobj)
+ key.fileobj.close() # type: ignore[union-attr]
+
+ if self._exit_code is None:
+ try:
+ self._exit_code = self._process.wait(timeout=0)
+ log.debug("Task process exited",
exit_code=self._exit_code)
+ except psutil.TimeoutExpired:
+ pass
+
+ if last_heartbeat_ago < FASTEST_HEARTBEAT_INTERVAL:
+ # Avoid heartbeating too frequently
+ continue
+
+ try:
+ self.client.task_instances.heartbeat(self.ti_id)
+ self._last_heartbeat = time.monotonic()
+ except Exception:
+ log.warning("Couldn't heartbeat", exc_info=True)
+ # TODO: If we couldn't heartbeat for X times the interval,
kill ourselves
+ pass
+ finally:
+ self.selector.close()
+
+ self.client.task_instances.finish(
+ id=self.ti_id, state=self.final_state,
when=datetime.now(tz=timezone.utc)
+ )
+ return self._exit_code
+
+ @property
+ def final_state(self):
+ """
+ The final state of the TaskInstance.
+
+ By default this will be derived from the exit code of the task
+ (0=success, failed otherwise) but can be changed by the subprocess
+ sending a TaskState message, as long as the process exits with 0
+
+ Not valid before the process has finished.
+ """
+ if self._exit_code == 0:
+ return self._terminal_state or TaskInstanceState.SUCCESS
+ return TaskInstanceState.FAILED
+
+ def __rich_repr__(self):
+ yield "pid", self.pid
+ yield "exit_code", self._exit_code, None
+
+ __rich_repr__.angular = True # type: ignore[attr-defined]
+
+ def __repr__(self) -> str:
+ rep = f"<WatchedSubprocess pid={self.pid}"
+ if self._exit_code is not None:
+ rep += f" exit_code={self._exit_code}"
+ return rep + " >"
+
+ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None,
bytes, None]:
+ decoder: msgspec.json.Decoder[ToSupervisor] =
msgspec.json.Decoder(type=ToSupervisor)
+ encoder = msgspec.json.Encoder()
+ # Use a buffer to avoid small allocations
+ buffer = bytearray(64)
+ while True:
+ line = yield
+
+ try:
+ msg = decoder.decode(line)
+ except Exception:
+ log.exception("Unable to decode message", line=line)
+ continue
+
+ # if isinstnace(msg, TaskState):
+ # self._terminal_state = msg.state
+ # elif isinstance(msg, ReadXCom):
+ # resp = XComResponse(key="secret", value=True)
+ # encoder.encode_into(resp, buffer)
+ # self.stdin.write(buffer + b"\n")
+ if isinstance(msg, GetConnection):
Review Comment:
All that communication is very specifically implemented. Have you/we
considered using a generic communication protocol like gRPC instead of writing
our own?
##########
task_sdk/src/airflow/sdk/api/datamodels/activities.py:
##########
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import os
+
+import msgspec
+
+from airflow.sdk.api.datamodels.ti import TaskInstance
+
+
+class ExecuteTaskActivity(msgspec.Struct, tag="ExecuteTask", tag_field="kind"):
+ ti: TaskInstance
+ path: os.PathLike[str]
Review Comment:
What is this `path` - the working dir of the task where it is called?
##########
task_sdk/src/airflow/sdk/api/datamodels/_generated.py:
##########
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# generated by datamodel-codegen:
+# filename: http://0.0.0.0:9091/execution/openapi.json
+# version: 0.26.3
+
+from __future__ import annotations
+
+from datetime import datetime
+from enum import Enum
+from typing import Literal
+
+from msgspec import Struct, field
+
+
+class ConnectionResponse(Struct):
+ """
+ Connection schema for responses with fields that are needed for Runtime.
+ """
+
+ conn_id: str
+ conn_type: str
+ host: str | None = None
+ schema_: str | None = field(name="schema", default=None)
+ login: str | None = None
+ password: str | None = None
+ port: int | None = None
+ extra: str | None = None
+
+
+class TIEnterRunningPayload(Struct):
+ """
+ Schema for updating TaskInstance to 'RUNNING' state with minimal required
fields.
+ """
+
+ hostname: str
+ unixname: str
+ pid: int
+ start_date: datetime
+ state: Literal["running"] | None = "running"
+
+
+class TIHeartbeatInfo(Struct):
+ """
+ Schema for TaskInstance heartbeat endpoint.
+ """
+
+ hostname: str
+ pid: int
+
+
+class State(Enum):
+ REMOVED = "removed"
+ SCHEDULED = "scheduled"
+ QUEUED = "queued"
+ RUNNING = "running"
+ RESTARTING = "restarting"
+ UP_FOR_RETRY = "up_for_retry"
+ UP_FOR_RESCHEDULE = "up_for_reschedule"
+ UPSTREAM_FAILED = "upstream_failed"
+ DEFERRED = "deferred"
+
+
+class TITargetStatePayload(Struct):
+ """
+ Schema for updating TaskInstance to a target state, excluding terminal and
running states.
+ """
+
+ state: State
+
+
+class State1(Enum):
Review Comment:
What is the difference to `State`?
##########
task_sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -0,0 +1,103 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+r"""
+Communication protocol between the Supervisor and the task process
+==================================================================
+
+* All communication is done over stdout/stdin in the form of "JSON lines" (each
+ message is a single JSON document terminated by `\n` character)
+* Messages from the subprocess are all log messages and are sent directly to
the log
+* No messages are sent to task process except in response to a request. (This
is because the task process will
+ be running user's code, so we can't read from stdin until we enter our code,
such as when requesting an XCom
+ value etc.)
Review Comment:
Means that HTTP calls are synchronous and no async is possible? Else you
would potentially mix responses and need to multiplex traffic?
##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
Review Comment:
Okay, reading conde below I understand now that 1 Supervisor :: 1 Task
So Celery --> N Supervisor instances.
##########
task_sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -0,0 +1,103 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+r"""
+Communication protocol between the Supervisor and the task process
+==================================================================
+
+* All communication is done over stdout/stdin in the form of "JSON lines" (each
+ message is a single JSON document terminated by `\n` character)
+* Messages from the subprocess are all log messages and are sent directly to
the log
+* No messages are sent to task process except in response to a request. (This
is because the task process will
+ be running user's code, so we can't read from stdin until we enter our code,
such as when requesting an XCom
+ value etc.)
+
+The reason this communication protocol exists, rather than the task process
speaking directly to the Task
+Execution API server is because:
+
+1. To reduce the number of concurrent HTTP connections on the API server.
+
+ The supervisor already has to speak to that to heartbeat the running Task,
so having the task speak to its
+ parent process and having all API traffic go through that means that the
number of HTTP connections is
+ "halved". (Not every task will make API calls, so it's not always halved,
but it is reduced.)
+
+2. This means that the user Task code doesn't ever directly see the task
identity JWT token.
Review Comment:
3: Also al connection retry and error handling is being handled in the
supervisor, so a task executed does not need to care about tćonnection
flakiness?
##########
task_sdk/src/airflow/sdk/api/client.py:
##########
@@ -0,0 +1,215 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import uuid
+from typing import TYPE_CHECKING, Any
+
+import httpx
+import methodtools
+import msgspec
+import structlog
+from uuid6 import uuid7
+
+from airflow.sdk.api.datamodels._generated import (
+ ConnectionResponse,
+ State1 as TerminalState,
+ TaskInstanceState,
+ TIEnterRunningPayload,
+ TITerminalStatePayload,
+ ValidationError as RemoteValidationError,
+)
+from airflow.utils.net import get_hostname
+from airflow.utils.platform import getuser
+
+if TYPE_CHECKING:
+ from datetime import datetime
+
+
+log = structlog.get_logger(logger_name=__name__)
+
+__all__ = [
+ "Client",
+ "ConnectionOperations",
+ "ErrorBody",
+ "ServerResponseError",
+ "TaskInstanceOperations",
+]
+
+
+def get_json_error(response: httpx.Response):
+ """Raise a ServerResponseError if we can extract error info from the
error."""
+ err = ServerResponseError.from_response(response)
+ if err:
+ log.warning("Server error", detail=err.detail)
+ raise err
+
+
+def raise_on_4xx_5xx(response: httpx.Response):
+ return get_json_error(response) or response.raise_for_status()
Review Comment:
Does it mean this API client does not implement a retry within?
##########
task_sdk/src/airflow/sdk/api/datamodels/ti.py:
##########
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import uuid
+
+import msgspec
+
+
+class TaskInstance(msgspec.Struct, omit_defaults=True):
+ id: uuid.UUID
+
+ task_id: str
+ dag_id: str
Review Comment:
Is some DAG version needed here as well?
##########
task_sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -0,0 +1,195 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The entrypoint for the actual task execution process."""
+
+from __future__ import annotations
+
+import os
+import sys
+from io import FileIO
+from typing import TYPE_CHECKING, TextIO
+
+import attrs
+import msgspec
+import structlog
+
+from airflow.sdk import BaseOperator
+from airflow.sdk.execution_time.comms import StartupDetails, TaskInstance,
ToSupervisor, ToTask
+
+if TYPE_CHECKING:
+ from structlog.typing import FilteringBoundLogger as Logger
+
+
+class RuntimeTaskInstance(TaskInstance, kw_only=True):
+ task: BaseOperator
+
+
+def parse(what: StartupDetails) -> RuntimeTaskInstance:
+ # TODO: Task-SDK:
+ # Using DagBag here is aoubt 98% wrong, but it'll do for now
+
+ from airflow.models.dagbag import DagBag
+
+ bag = DagBag(
+ dag_folder=what.file,
+ include_examples=False,
+ safe_mode=False,
+ load_op_links=False,
+ )
+ if TYPE_CHECKING:
+ assert what.ti.dag_id
+
+ dag = bag.dags[what.ti.dag_id]
+
+ # install_loader()
+
+ # TODO: Handle task not found
+ task = dag.task_dict[what.ti.task_id]
+ if not isinstance(task, BaseOperator):
+ raise TypeError(f"task is of the wrong type, got {type(task)}, wanted
{BaseOperator}")
+ return RuntimeTaskInstance(**msgspec.structs.asdict(what.ti), task=task)
+
+
[email protected]()
+class CommsDecoder:
+ """Handle communication between the task in this process and the
supervisor parent process."""
+
+ input: TextIO = sys.stdin
+
+ decoder: msgspec.json.Decoder[ToTask] = attrs.field(factory=lambda:
msgspec.json.Decoder(type=ToTask))
+ encoder: msgspec.json.Encoder = attrs.field(factory=msgspec.json.Encoder)
+
+ request_socket: FileIO = attrs.field(init=False, default=None)
+
+ # Allocate a buffer for en/decoding. We use a small non-empty buffer to
avoid reallocating for small messages.
+ buffer: bytearray = attrs.field(factory=lambda: bytearray(64))
+
+ def get_message(self) -> ToTask:
+ """
+ Get a message from the parent.
+
+ This will block until the message has been received.
+ """
+ line = self.input.readline()
+ try:
+ msg = self.decoder.decode(line)
+ except Exception:
+ structlog.get_logger(logger_name="CommsDecoder").exception("Unable
to decode message", line=line)
+ raise
+
+ if isinstance(msg, StartupDetails):
+ # If we read a startup message, pull out the FDs we care about!
+ if msg.requests_fd > 0:
+ self.request_socket = os.fdopen(msg.requests_fd, "wb",
buffering=0)
+ return msg
+
+ def send_request(self, log: Logger, msg: ToSupervisor):
+ buffer = self.buffer
+ self.encoder.encode_into(msg, buffer)
+ buffer += b"\n"
+
+ log.debug("Sending request", json=buffer)
+ self.request_socket.write(buffer)
+
+
+# This global variable will be used by Connection/Variable classes etc to send
requests to
+SUPERVISOR_COMMS: CommsDecoder
Review Comment:
Ähm... and how will this work if multiple Supervisor instances are running
in parallel?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]