This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 75de1d87710 Start building the replacement task runner for Task 
Execution SDK (#43893)
75de1d87710 is described below

commit 75de1d877108ee9859fd4c57054d6775daa27256
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Nov 14 14:31:50 2024 +0000

    Start building the replacement task runner for Task Execution SDK (#43893)
    
    The eventual goal if this "airflow.sdk.exeuction_time" package is to replace
    LocalTaskJob and StandardTaskRunner, but at this stage it co-exists with 
it's
    replacement.
    
    As this PR is not a complete re-implementation of all the features that 
exist
    currently (no handling of task level callbacks yet, no AirflowSkipException
    etc.) the current tests are skeleton at best. Once we get closer to feature
    parity (in future PRs) the tests will grow to match.
    
    This supervisor and task runner operates slightly differently to the current
    classes in these ways
    
    **Logs from the subprocess are send over a different channel to 
stdout/stderr**
    
    This makes the task supervisor a little bit more complex as it now has to
    read stdout, stderr and a logs channel. The advantage of this approach is
    that it makes the logs setup in the task process itself markedly simpler --
    all it has to do is write logs output to the custom file handle as JSON and
    it will show up "natively" as logs.
    
    structlog has been chosen as the logging engine over stdlib's own logging as
    the ability to have structured fields in the logs is nice, and stdlib is
    configured to send logs to a stuctlog processor.
    
    **Direct database access is replaced with an HTTP API client**
    
    This is the crux of this feature and of AIP-72 in general -- tasks run via
    this runner can no longer access DB models or DB session directly. This PR
    doesn't yet implement the code/shims to make 
`Connection.get_connection_from_secrets`
    use this client yet - that will be future work.
    
    The reason tasks don't speak directly to the API server is primarily for two
    reasons:
    
    1. The supervisor process already needs to maintain an http session in order
       to report the task as started, to heart beat it, and to mark it as
       finished; and so because of that
    2. Reduce the number of active HTTP connections for tasks to 1 per task
       (instead of 2 per task).
    
    The other reason we have this interface is that DAG parsing code will very
    soon need to be updated to not have direct DB access either, and having this
    "in process" interface ability already means that we can support commands 
like
    `airflow dags reserialize` without having a running API server.
    
    The API client itself is not auto-generated: I tried a number of different
    client generates based on the OpenAPI spec and found them all lacking or 
buggy
    in different ways, and the http client side itself is very simple, the only
    interesting/difficult bit is the generation of the datamodels from the 
OpenAPI
    spec which I found one that
    
    ---------
    
    Co-authored-by: Kaxil Naik <[email protected]>
---
 Dockerfile                                         |   2 +-
 Dockerfile.ci                                      |   2 +-
 airflow/utils/net.py                               |   4 +-
 scripts/docker/install_airflow.sh                  |   2 +-
 task_sdk/pyproject.toml                            |  59 +-
 task_sdk/src/airflow/sdk/__init__.py               |   3 +
 .../airflow/sdk/api/__init__.py}                   |  15 -
 task_sdk/src/airflow/sdk/api/client.py             | 216 ++++++++
 .../airflow/sdk/api/datamodels/__init__.py}        |  15 -
 .../src/airflow/sdk/api/datamodels/_generated.py   | 148 +++++
 .../airflow/sdk/api/datamodels/activities.py}      |  16 +-
 .../airflow/sdk/api/datamodels/ti.py}              |  19 +-
 .../airflow/sdk/execution_time/__init__.py}        |  16 +-
 task_sdk/src/airflow/sdk/execution_time/comms.py   | 120 +++++
 .../src/airflow/sdk/execution_time/supervisor.py   | 599 +++++++++++++++++++++
 .../src/airflow/sdk/execution_time/task_runner.py  | 191 +++++++
 task_sdk/src/airflow/sdk/log.py                    | 372 +++++++++++++
 task_sdk/src/airflow/sdk/types.py                  |   2 +-
 task_sdk/tests/{conftest.py => api/__init__.py}    |  15 -
 task_sdk/tests/api/test_client.py                  |  62 +++
 task_sdk/tests/conftest.py                         |  58 ++
 .../tests/{conftest.py => defintions/__init__.py}  |  15 -
 task_sdk/tests/defintions/test_baseoperator.py     |  19 +
 .../{conftest.py => execution_time/__init__.py}    |  15 -
 task_sdk/tests/{ => execution_time}/conftest.py    |  18 +-
 task_sdk/tests/execution_time/test_supervisor.py   | 150 ++++++
 task_sdk/tests/execution_time/test_task_runner.py  |  56 ++
 tests/cli/commands/test_celery_command.py          |   1 +
 28 files changed, 2087 insertions(+), 123 deletions(-)

diff --git a/Dockerfile b/Dockerfile
index 5ca9949b021..d9fb1878f11 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -890,7 +890,7 @@ function install_airflow() {
 
         # Similarly we need _a_ file for task_sdk too
         mkdir -p ./task_sdk/src/airflow/sdk/
-        touch ./task_sdk/src/airflow/sdk/__init__.py
+        echo '__version__ = "0.0.0dev0"' > 
./task_sdk/src/airflow/sdk/__init__.py
 
         trap 'rm -f ./providers/src/airflow/providers/__init__.py 
./task_sdk/src/airflow/__init__.py 2>/dev/null' EXIT
 
diff --git a/Dockerfile.ci b/Dockerfile.ci
index 943270aec69..952993984e5 100644
--- a/Dockerfile.ci
+++ b/Dockerfile.ci
@@ -660,7 +660,7 @@ function install_airflow() {
 
         # Similarly we need _a_ file for task_sdk too
         mkdir -p ./task_sdk/src/airflow/sdk/
-        touch ./task_sdk/src/airflow/sdk/__init__.py
+        echo '__version__ = "0.0.0dev0"' > 
./task_sdk/src/airflow/sdk/__init__.py
 
         trap 'rm -f ./providers/src/airflow/providers/__init__.py 
./task_sdk/src/airflow/__init__.py 2>/dev/null' EXIT
 
diff --git a/airflow/utils/net.py b/airflow/utils/net.py
index 992aee67e80..9fc79b3842c 100644
--- a/airflow/utils/net.py
+++ b/airflow/utils/net.py
@@ -20,8 +20,6 @@ from __future__ import annotations
 import socket
 from functools import lru_cache
 
-from airflow.configuration import conf
-
 
 # patched version of socket.getfqdn() - see 
https://github.com/python/cpython/issues/49254
 @lru_cache(maxsize=None)
@@ -53,4 +51,6 @@ def get_host_ip_address():
 
 def get_hostname():
     """Fetch the hostname using the callable from config or use 
`airflow.utils.net.getfqdn` as a fallback."""
+    from airflow.configuration import conf
+
     return conf.getimport("core", "hostname_callable", 
fallback="airflow.utils.net.getfqdn")()
diff --git a/scripts/docker/install_airflow.sh 
b/scripts/docker/install_airflow.sh
index 2975c50c2d6..27dd25ba260 100644
--- a/scripts/docker/install_airflow.sh
+++ b/scripts/docker/install_airflow.sh
@@ -54,7 +54,7 @@ function install_airflow() {
 
         # Similarly we need _a_ file for task_sdk too
         mkdir -p ./task_sdk/src/airflow/sdk/
-        touch ./task_sdk/src/airflow/sdk/__init__.py
+        echo '__version__ = "0.0.0dev0"' > 
./task_sdk/src/airflow/sdk/__init__.py
 
         trap 'rm -f ./providers/src/airflow/providers/__init__.py 
./task_sdk/src/airflow/__init__.py 2>/dev/null' EXIT
 
diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml
index f290dfb17fd..5da673a79bf 100644
--- a/task_sdk/pyproject.toml
+++ b/task_sdk/pyproject.toml
@@ -17,20 +17,30 @@
 
 [project]
 name = "apache-airflow-task-sdk"
-version = "0.1.0.dev0"
+dynamic = ["version"]
 description = "Python Task SDK for Apache Airflow DAG Authors"
 readme = { file = "README.md", content-type = "text/markdown" }
 requires-python = ">=3.9, <3.13"
 dependencies = [
     "attrs>=24.2.0",
     "google-re2>=1.1.20240702",
+    "httpx>=0.27.0",
     "methodtools>=0.4.7",
+    "msgspec>=0.18.6",
+    "psutil>=6.1.0",
+    "structlog>=24.4.0",
+]
+classifiers = [
+  "Framework :: Apache Airflow",
 ]
 
 [build-system]
 requires = ["hatchling"]
 build-backend = "hatchling.build"
 
+[tool.hatch.version]
+path = "src/airflow/sdk/__init__.py"
+
 [tool.hatch.build.targets.wheel]
 packages = ["src/airflow"]
 # This file only exists to make pyright/VSCode happy, don't ship it
@@ -46,11 +56,24 @@ namespace-packages = ["src/airflow"]
 # Ignore Doc rules et al for anything outside of tests
 "!src/*" = ["D", "TID253", "S101", "TRY002"]
 
-"src/airflow/sdk/__init__.py" = ["TCH004"]
+# Ignore the pytest rules outside the tests folder - 
https://github.com/astral-sh/ruff/issues/14205
+"!tests/*" = ["PT"]
 
 # Pycharm barfs if this "stub" file has future imports
 "src/airflow/__init__.py" = ["I002"]
 
+"src/airflow/sdk/__init__.py" = ["TCH004"]
+
+# msgspec needs types for annotations to be defined, even with future
+# annotations, so disable the "type check only import" for these files
+"src/airflow/sdk/api/datamodels/*.py" = ["TCH001"]
+
+# Only the public API should _require_ docstrings on classes
+"!src/airflow/sdk/definitions/*" = ["D101"]
+
+# Generated file, be less strict
+"src/airflow/sdk/*/_generated.py" = ["D"]
+
 [tool.uv]
 dev-dependencies = [
     "kgb>=7.1.1",
@@ -59,6 +82,7 @@ dev-dependencies = [
     "pytest>=8.3.3",
 ]
 
+
 [tool.coverage.run]
 branch = true
 relative_files = true
@@ -76,3 +100,34 @@ exclude_also = [
     "@(typing(_extensions)?\\.)?overload",
     "if (typing(_extensions)?\\.)?TYPE_CHECKING:",
 ]
+
+[dependency-groups]
+codegen = [
+    "datamodel-code-generator[http]>=0.26.3",
+]
+
+[tool.black]
+# This is needed for datamodel-codegen to treat this as the "project" file
+
+# To use:
+#
+#   uv run --group codegen --project apache-airflow-task-sdk --directory 
task_sdk datamodel-codegen
+[tool.datamodel-codegen]
+capitalise-enum-members=true # `State.RUNNING` not `State.running`
+disable-timestamp=true
+enable-version-header=true
+enum-field-as-literal='one' # When a single enum member, make it output a 
`Literal["..."]`
+input-file-type='openapi'
+output-model-type='pydantic_v2.BaseModel'
+output-datetime-class='datetime'
+target-python-version='3.9'
+use-annotated=true
+use-default=true
+use-double-quotes=true
+use-schema-description=true  # Desc becomes class doc comment
+use-standard-collections=true # list[] not List[]
+use-subclass-enum=true # enum, not union of Literals
+use-union-operator=true # 3.9+annotations, not `Union[]`
+
+url = 'http://0.0.0.0:9091/execution/openapi.json'
+output = 'src/airflow/sdk/api/datamodels/_generated.py'
diff --git a/task_sdk/src/airflow/sdk/__init__.py 
b/task_sdk/src/airflow/sdk/__init__.py
index f538baedff0..bd882f43dd0 100644
--- a/task_sdk/src/airflow/sdk/__init__.py
+++ b/task_sdk/src/airflow/sdk/__init__.py
@@ -25,8 +25,11 @@ __all__ = [
     "Label",
     "TaskGroup",
     "dag",
+    "__version__",
 ]
 
+__version__ = "1.0.0.dev1"
+
 if TYPE_CHECKING:
     from airflow.sdk.definitions.baseoperator import BaseOperator
     from airflow.sdk.definitions.dag import DAG, dag
diff --git a/task_sdk/tests/conftest.py 
b/task_sdk/src/airflow/sdk/api/__init__.py
similarity index 69%
copy from task_sdk/tests/conftest.py
copy to task_sdk/src/airflow/sdk/api/__init__.py
index ddc7c61656a..13a83393a91 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/src/airflow/sdk/api/__init__.py
@@ -14,18 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from __future__ import annotations
-
-import os
-
-import pytest
-
-pytest_plugins = "tests_common.pytest_plugin"
-
-# Task SDK does not need access to the Airflow database
-os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
-
-
[email protected](tryfirst=True)
-def pytest_configure(config: pytest.Config) -> None:
-    config.inicfg["airflow_deprecations_ignore"] = []
diff --git a/task_sdk/src/airflow/sdk/api/client.py 
b/task_sdk/src/airflow/sdk/api/client.py
new file mode 100644
index 00000000000..ece3bc96009
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -0,0 +1,216 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import sys
+import uuid
+from typing import TYPE_CHECKING, Any
+
+import httpx
+import methodtools
+import structlog
+from pydantic import BaseModel
+from uuid6 import uuid7
+
+from airflow.sdk import __version__
+from airflow.sdk.api.datamodels._generated import (
+    ConnectionResponse,
+    State1 as TerminalState,
+    TaskInstanceState,
+    TIEnterRunningPayload,
+    TITerminalStatePayload,
+    ValidationError as RemoteValidationError,
+)
+from airflow.utils.net import get_hostname
+from airflow.utils.platform import getuser
+
+if TYPE_CHECKING:
+    from datetime import datetime
+
+
+log = structlog.get_logger(logger_name=__name__)
+
+__all__ = [
+    "Client",
+    "ConnectionOperations",
+    "ErrorBody",
+    "ServerResponseError",
+    "TaskInstanceOperations",
+]
+
+
+def get_json_error(response: httpx.Response):
+    """Raise a ServerResponseError if we can extract error info from the 
error."""
+    err = ServerResponseError.from_response(response)
+    if err:
+        log.warning("Server error", detail=err.detail)
+        raise err
+
+
+def raise_on_4xx_5xx(response: httpx.Response):
+    return get_json_error(response) or response.raise_for_status()
+
+
+# Py 3.11+ version
+def raise_on_4xx_5xx_with_note(response: httpx.Response):
+    try:
+        return get_json_error(response) or response.raise_for_status()
+    except httpx.HTTPStatusError as e:
+        if TYPE_CHECKING:
+            assert hasattr(e, "add_note")
+        e.add_note(
+            f"Correlation-id={response.headers.get('correlation-id', None) or 
response.request.headers.get('correlation-id', 'no-correlction-id')}"
+        )
+        raise
+
+
+if hasattr(BaseException, "add_note"):
+    # Py 3.11+
+    raise_on_4xx_5xx = raise_on_4xx_5xx_with_note
+
+
+def add_correlation_id(request: httpx.Request):
+    request.headers["correlation-id"] = str(uuid7())
+
+
+class TaskInstanceOperations:
+    __slots__ = ("client",)
+
+    def __init__(self, client: Client):
+        self.client = client
+
+    def start(self, id: uuid.UUID, pid: int, when: datetime):
+        """Tell the API server that this TI has started running."""
+        body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), 
unixname=getuser(), start_date=when)
+
+        self.client.patch(f"task-instance/{id}/state", 
content=body.model_dump_json())
+
+    def finish(self, id: uuid.UUID, state: TaskInstanceState, when: datetime):
+        """Tell the API server that this TI has reached a terminal state."""
+        body = TITerminalStatePayload(end_date=when, 
state=TerminalState(state))
+
+        self.client.patch(f"task-instance/{id}/state", 
content=body.model_dump_json())
+
+    def heartbeat(self, id: uuid.UUID):
+        self.client.put(f"task-instance/{id}/heartbeat")
+
+
+class ConnectionOperations:
+    __slots__ = ("client", "decoder")
+
+    def __init__(self, client: Client):
+        self.client = client
+
+    def get(self, id: str) -> ConnectionResponse:
+        """Get a connection from the API server."""
+        resp = self.client.get(f"connection/{id}")
+        return ConnectionResponse.model_validate_json(resp.read())
+
+
+class BearerAuth(httpx.Auth):
+    def __init__(self, token: str):
+        self.token: str = token
+
+    def auth_flow(self, request: httpx.Request):
+        if self.token:
+            request.headers["Authorization"] = "Bearer " + self.token
+        yield request
+
+
+# This exists as a aid for debugging or local running via the `dry_run` 
argument to Client. It doesn't make
+# sense for returning connections etc.
+def noop_handler(request: httpx.Request) -> httpx.Response:
+    log.debug("Dry-run request", method=request.method, path=request.url.path)
+    return httpx.Response(200, json={"text": "Hello, world!"})
+
+
+class Client(httpx.Client):
+    def __init__(self, *, base_url: str | None, dry_run: bool = False, token: 
str, **kwargs: Any):
+        if (not base_url) ^ dry_run:
+            raise ValueError(f"Can only specify one of {base_url=} or 
{dry_run=}")
+        auth = BearerAuth(token)
+
+        if dry_run:
+            # If dry run is requested, install a no op handler so that simple 
tasks can "heartbeat" using a
+            # real client, but just don't make any HTTP requests
+            kwargs["transport"] = httpx.MockTransport(noop_handler)
+            kwargs["base_url"] = "dry-run://server"
+        else:
+            kwargs["base_url"] = base_url
+        pyver = f"{'.'.join(map(str, sys.version_info[:3]))}"
+        super().__init__(
+            auth=auth,
+            headers={"user-agent": f"apache-airflow-task-sdk/{__version__} 
(Python/{pyver})"},
+            event_hooks={"response": [raise_on_4xx_5xx], "request": 
[add_correlation_id]},
+            **kwargs,
+        )
+
+    # We "group" or "namespace" operations by what they operate on, rather 
than a flat namespace with all
+    # methods on one object prefixed with the object type 
(`.task_instances.update` rather than
+    # `task_instance_update` etc.)
+
+    @methodtools.lru_cache()  # type: ignore[misc]
+    @property
+    def task_instances(self) -> TaskInstanceOperations:
+        """Operations related to TaskInstances."""
+        return TaskInstanceOperations(self)
+
+    @methodtools.lru_cache()  # type: ignore[misc]
+    @property
+    def connections(self) -> ConnectionOperations:
+        """Operations related to TaskInstances."""
+        return ConnectionOperations(self)
+
+
+class ErrorBody(BaseModel):
+    detail: list[RemoteValidationError] | dict[str, Any]
+
+    def __repr__(self):
+        return repr(self.detail)
+
+
+class ServerResponseError(httpx.HTTPStatusError):
+    def __init__(self, message: str, *, request: httpx.Request, response: 
httpx.Response):
+        super().__init__(message, request=request, response=response)
+
+    detail: ErrorBody
+
+    @classmethod
+    def from_response(cls, response: httpx.Response) -> ServerResponseError | 
None:
+        if response.is_success:
+            return None
+        # 4xx or 5xx error?
+        if 400 < (response.status_code // 100) >= 600:
+            return None
+
+        if response.headers.get("content-type") != "application/json":
+            return None
+
+        try:
+            err = ErrorBody.model_validate_json(response.read())
+            if isinstance(err.detail, list):
+                msg = "Remote server returned validation error"
+            else:
+                msg = err.detail.get("message", "") or "Un-parseable error"
+        except Exception:
+            err = ErrorBody.model_validate_json(response.content)
+            msg = "Server returned error"
+
+        self = cls(msg, request=response.request, response=response)
+        self.detail = err
+        return self
diff --git a/task_sdk/tests/conftest.py 
b/task_sdk/src/airflow/sdk/api/datamodels/__init__.py
similarity index 69%
copy from task_sdk/tests/conftest.py
copy to task_sdk/src/airflow/sdk/api/datamodels/__init__.py
index ddc7c61656a..13a83393a91 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/__init__.py
@@ -14,18 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from __future__ import annotations
-
-import os
-
-import pytest
-
-pytest_plugins = "tests_common.pytest_plugin"
-
-# Task SDK does not need access to the Airflow database
-os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
-
-
[email protected](tryfirst=True)
-def pytest_configure(config: pytest.Config) -> None:
-    config.inicfg["airflow_deprecations_ignore"] = []
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
new file mode 100644
index 00000000000..f41508cae2a
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -0,0 +1,148 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# generated by datamodel-codegen:
+#   filename:  http://0.0.0.0:9091/execution/openapi.json
+#   version:   0.26.3
+
+from __future__ import annotations
+
+from datetime import datetime
+from enum import Enum
+from typing import Annotated, Any, Literal
+
+from pydantic import BaseModel, Field
+
+
+class ConnectionResponse(BaseModel):
+    """
+    Connection schema for responses with fields that are needed for Runtime.
+    """
+
+    conn_id: Annotated[str, Field(title="Conn Id")]
+    conn_type: Annotated[str, Field(title="Conn Type")]
+    host: Annotated[str | None, Field(title="Host")] = None
+    schema_: Annotated[str | None, Field(alias="schema", title="Schema")] = 
None
+    login: Annotated[str | None, Field(title="Login")] = None
+    password: Annotated[str | None, Field(title="Password")] = None
+    port: Annotated[int | None, Field(title="Port")] = None
+    extra: Annotated[str | None, Field(title="Extra")] = None
+
+
+class TIEnterRunningPayload(BaseModel):
+    """
+    Schema for updating TaskInstance to 'RUNNING' state with minimal required 
fields.
+    """
+
+    state: Annotated[Literal["running"] | None, Field(title="State")] = 
"running"
+    hostname: Annotated[str, Field(title="Hostname")]
+    unixname: Annotated[str, Field(title="Unixname")]
+    pid: Annotated[int, Field(title="Pid")]
+    start_date: Annotated[datetime, Field(title="Start Date")]
+
+
+class TIHeartbeatInfo(BaseModel):
+    """
+    Schema for TaskInstance heartbeat endpoint.
+    """
+
+    hostname: Annotated[str, Field(title="Hostname")]
+    pid: Annotated[int, Field(title="Pid")]
+
+
+class State(Enum):
+    REMOVED = "removed"
+    SCHEDULED = "scheduled"
+    QUEUED = "queued"
+    RUNNING = "running"
+    RESTARTING = "restarting"
+    UP_FOR_RETRY = "up_for_retry"
+    UP_FOR_RESCHEDULE = "up_for_reschedule"
+    UPSTREAM_FAILED = "upstream_failed"
+    DEFERRED = "deferred"
+
+
+class TITargetStatePayload(BaseModel):
+    """
+    Schema for updating TaskInstance to a target state, excluding terminal and 
running states.
+    """
+
+    state: State
+
+
+class State1(Enum):
+    FAILED = "failed"
+    SUCCESS = "success"
+    SKIPPED = "skipped"
+
+
+class TITerminalStatePayload(BaseModel):
+    """
+    Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or 
FAILED).
+    """
+
+    state: Annotated[State1, Field(title="TerminalState")]
+    end_date: Annotated[datetime, Field(title="End Date")]
+
+
+class TaskInstanceState(str, Enum):
+    """
+    All possible states that a Task Instance can be in.
+
+    Note that None is also allowed, so always use this in a type hint with 
Optional.
+    """
+
+    REMOVED = "removed"
+    SCHEDULED = "scheduled"
+    QUEUED = "queued"
+    RUNNING = "running"
+    SUCCESS = "success"
+    RESTARTING = "restarting"
+    FAILED = "failed"
+    UP_FOR_RETRY = "up_for_retry"
+    UP_FOR_RESCHEDULE = "up_for_reschedule"
+    UPSTREAM_FAILED = "upstream_failed"
+    SKIPPED = "skipped"
+    DEFERRED = "deferred"
+
+
+class ValidationError(BaseModel):
+    loc: Annotated[list[str | int], Field(title="Location")]
+    msg: Annotated[str, Field(title="Message")]
+    type: Annotated[str, Field(title="Error Type")]
+
+
+class VariableResponse(BaseModel):
+    """
+    Variable schema for responses with fields that are needed for Runtime.
+    """
+
+    key: Annotated[str, Field(title="Key")]
+    value: Annotated[str | None, Field(title="Value")] = None
+
+
+class XComResponse(BaseModel):
+    """
+    XCom schema for responses with fields that are needed for Runtime.
+    """
+
+    key: Annotated[str, Field(title="Key")]
+    value: Annotated[Any, Field(title="Value")]
+
+
+class HTTPValidationError(BaseModel):
+    detail: Annotated[list[ValidationError] | None, Field(title="Detail")] = 
None
diff --git a/task_sdk/tests/conftest.py 
b/task_sdk/src/airflow/sdk/api/datamodels/activities.py
similarity index 73%
copy from task_sdk/tests/conftest.py
copy to task_sdk/src/airflow/sdk/api/datamodels/activities.py
index ddc7c61656a..04f2b389d5d 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/activities.py
@@ -14,18 +14,18 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 from __future__ import annotations
 
 import os
 
-import pytest
-
-pytest_plugins = "tests_common.pytest_plugin"
+from pydantic import BaseModel
 
-# Task SDK does not need access to the Airflow database
-os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
+from airflow.sdk.api.datamodels.ti import TaskInstance
 
 
[email protected](tryfirst=True)
-def pytest_configure(config: pytest.Config) -> None:
-    config.inicfg["airflow_deprecations_ignore"] = []
+class ExecuteTaskActivity(BaseModel):
+    ti: TaskInstance
+    path: os.PathLike[str]
+    token: str
+    """The identity token for this workload"""
diff --git a/task_sdk/tests/conftest.py 
b/task_sdk/src/airflow/sdk/api/datamodels/ti.py
similarity index 72%
copy from task_sdk/tests/conftest.py
copy to task_sdk/src/airflow/sdk/api/datamodels/ti.py
index ddc7c61656a..ce9e1e870ae 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/ti.py
@@ -14,18 +14,19 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from __future__ import annotations
 
-import os
+from __future__ import annotations
 
-import pytest
+import uuid
 
-pytest_plugins = "tests_common.pytest_plugin"
+from pydantic import BaseModel
 
-# Task SDK does not need access to the Airflow database
-os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
 
+class TaskInstance(BaseModel):
+    id: uuid.UUID
 
[email protected](tryfirst=True)
-def pytest_configure(config: pytest.Config) -> None:
-    config.inicfg["airflow_deprecations_ignore"] = []
+    task_id: str
+    dag_id: str
+    run_id: str
+    try_number: int
+    map_index: int | None = None
diff --git a/task_sdk/tests/conftest.py 
b/task_sdk/src/airflow/sdk/execution_time/__init__.py
similarity index 69%
copy from task_sdk/tests/conftest.py
copy to task_sdk/src/airflow/sdk/execution_time/__init__.py
index ddc7c61656a..217e5db9607 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/src/airflow/sdk/execution_time/__init__.py
@@ -1,3 +1,4 @@
+#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -14,18 +15,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from __future__ import annotations
-
-import os
-
-import pytest
-
-pytest_plugins = "tests_common.pytest_plugin"
-
-# Task SDK does not need access to the Airflow database
-os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
-
-
[email protected](tryfirst=True)
-def pytest_configure(config: pytest.Config) -> None:
-    config.inicfg["airflow_deprecations_ignore"] = []
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py 
b/task_sdk/src/airflow/sdk/execution_time/comms.py
new file mode 100644
index 00000000000..3128e98bf43
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+r"""
+Communication protocol between the Supervisor and the task process
+==================================================================
+
+* All communication is done over stdout/stdin in the form of "JSON lines" (each
+  message is a single JSON document terminated by `\n` character)
+* Messages from the subprocess are all log messages and are sent directly to 
the log
+* No messages are sent to task process except in response to a request. (This 
is because the task process will
+  be running user's code, so we can't read from stdin until we enter our code, 
such as when requesting an XCom
+  value etc.)
+
+The reason this communication protocol exists, rather than the task process 
speaking directly to the Task
+Execution API server is because:
+
+1. To reduce the number of concurrent HTTP connections on the API server.
+
+   The supervisor already has to speak to that to heartbeat the running Task, 
so having the task speak to its
+   parent process and having all API traffic go through that means that the 
number of HTTP connections is
+   "halved". (Not every task will make API calls, so it's not always halved, 
but it is reduced.)
+
+2. This means that the user Task code doesn't ever directly see the task 
identity JWT token.
+
+   This is a short lived token tied to one specific task instance try, so it 
being leaked/exfiltrated is not a
+   large risk, but it's easy to not give it to the user code, so lets do that.
+"""  # noqa: D400, D205
+
+from __future__ import annotations
+
+from typing import Annotated, Any, Literal, Union
+
+from pydantic import BaseModel, ConfigDict, Field
+
+from airflow.sdk.api.datamodels._generated import TaskInstanceState  # noqa: 
TCH001
+from airflow.sdk.api.datamodels.ti import TaskInstance  # noqa: TCH001
+
+
+class StartupDetails(BaseModel):
+    model_config = ConfigDict(arbitrary_types_allowed=True)
+
+    ti: TaskInstance
+    file: str
+    requests_fd: int
+    """
+    The channel for the task to send requests over.
+
+    Responses will come back on stdin
+    """
+    type: Literal["StartupDetails"] = "StartupDetails"
+
+
+class XComResponse(BaseModel):
+    """Response to ReadXCom request."""
+
+    key: str
+    value: Any
+
+    type: Literal["XComResponse"] = "XComResponse"
+
+
+class ConnectionResponse(BaseModel):
+    conn: Any
+
+    type: Literal["ConnectionResponse"] = "ConnectionResponse"
+
+
+ToTask = Annotated[
+    Union[StartupDetails, XComResponse, ConnectionResponse],
+    Field(discriminator="type"),
+]
+
+
+class TaskState(BaseModel):
+    """
+    Update a task's state.
+
+    If a process exits without sending one of these the state will be derived 
from the exit code:
+    - 0 = SUCCESS
+    - anything else = FAILED
+    """
+
+    state: TaskInstanceState
+    type: Literal["TaskState"] = "TaskState"
+
+
+class ReadXCom(BaseModel):
+    key: str
+    type: Literal["ReadXCom"] = "ReadXCom"
+
+
+class GetConnection(BaseModel):
+    id: str
+    type: Literal["GetConnection"] = "GetConnection"
+
+
+class GetVariable(BaseModel):
+    id: str
+    type: Literal["GetVariable"] = "GetVariable"
+
+
+ToSupervisor = Annotated[
+    Union[TaskState, ReadXCom, GetConnection, GetVariable],
+    Field(discriminator="type"),
+]
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
new file mode 100644
index 00000000000..3c0623ba1b0
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -0,0 +1,599 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Supervise and run Tasks in a subprocess."""
+
+from __future__ import annotations
+
+import atexit
+import io
+import logging
+import os
+import selectors
+import signal
+import sys
+import time
+import weakref
+from collections.abc import Generator
+from contextlib import suppress
+from datetime import datetime, timezone
+from socket import socket, socketpair
+from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, 
NoReturn, cast, overload
+from uuid import UUID
+
+import attrs
+import httpx
+import msgspec
+import psutil
+import structlog
+from pydantic import TypeAdapter
+
+from airflow.sdk.api.client import Client
+from airflow.sdk.api.datamodels._generated import TaskInstanceState
+from airflow.sdk.execution_time.comms import (
+    ConnectionResponse,
+    GetConnection,
+    StartupDetails,
+    ToSupervisor,
+)
+
+if TYPE_CHECKING:
+    from structlog.typing import FilteringBoundLogger
+
+    from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
+    from airflow.sdk.api.datamodels.ti import TaskInstance
+
+
+__all__ = ["WatchedSubprocess", "supervise"]
+
+log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor")
+
+# TODO: Pull this from config
+SLOWEST_HEARTBEAT_INTERVAL: int = 30
+# Don't heartbeat more often than this
+FASTEST_HEARTBEAT_INTERVAL: int = 5
+
+
+@overload
+def mkpipe() -> tuple[socket, socket]: ...
+
+
+@overload
+def mkpipe(remote_read: Literal[True]) -> tuple[socket, BinaryIO]: ...
+
+
+def mkpipe(
+    remote_read: bool = False,
+) -> tuple[socket, socket | BinaryIO]:
+    """
+    Create a pair of connected sockets.
+
+    The inheritable flag will be set correctly so that the end destined for 
the subprocess is kept open but
+    the end for this process is closed automatically by the OS.
+    """
+    rsock, wsock = socketpair()
+    local, remote = (wsock, rsock) if remote_read else (rsock, wsock)
+
+    remote.set_inheritable(True)
+    local.setblocking(False)
+
+    io: BinaryIO | socket
+    if remote_read:
+        # If _we_ are writing, we don't want to buffer
+        io = cast(BinaryIO, local.makefile("wb", buffering=0))
+    else:
+        io = local
+
+    return remote, io
+
+
+def _subprocess_main():
+    from airflow.sdk.execution_time.task_runner import main
+
+    main()
+
+
+def _reset_signals():
+    # Uninstall the rich etc. exception handler
+    sys.excepthook = sys.__excepthook__
+    signal.signal(signal.SIGINT, signal.SIG_DFL)
+    signal.signal(signal.SIGUSR2, signal.SIG_DFL)
+
+
+def _configure_logs_over_json_channel(log_fd: int):
+    # A channel that the task can send JSON-formated logs over.
+    #
+    # JSON logs sent this way will be handled nicely
+    from airflow.sdk.log import configure_logging
+
+    log_io = os.fdopen(log_fd, "wb", buffering=0)
+    configure_logging(enable_pretty_log=False, output=log_io)
+
+
+def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr):
+    if "PYTEST_CURRENT_TEST" in os.environ:
+        # When we are running in pytest, it's output capturing messes us up. 
This works around it
+        sys.stdout = sys.__stdout__
+        sys.stderr = sys.__stderr__
+
+    # Ensure that sys.stdout et al (and the underlying filehandles for C 
libraries etc) are connected to the
+    # pipes form the supervisor
+
+    for handle_name, sock, mode, close in (
+        ("stdin", child_stdin, "r", True),
+        ("stdout", child_stdout, "w", True),
+        ("stderr", child_stderr, "w", False),
+    ):
+        handle = getattr(sys, handle_name)
+        try:
+            fd = handle.fileno()
+            os.dup2(sock.fileno(), fd)
+            if close:
+                handle.close()
+        except io.UnsupportedOperation:
+            if "PYTEST_CURRENT_TEST" in os.environ:
+                # When we're running under pytest, the stdin is not a real 
filehandle with an fd, so we need
+                # to handle that differently
+                fd = sock.fileno()
+            else:
+                raise
+
+        setattr(sys, handle_name, os.fdopen(fd, mode))
+
+
+def _fork_main(
+    child_stdin: socket,
+    child_stdout: socket,
+    child_stderr: socket,
+    log_fd: int,
+    target: Callable[[], None],
+) -> NoReturn:
+    """
+    "Entrypoint" of the child process.
+
+    Ultimately this process will be running the user's code in the operators 
``execute()`` function.
+
+    The responsibility of this function is to:
+
+    - Reset any signals handlers we inherited from the parent process (so they 
don't fire twice - once in
+      parent, and once in child)
+    - Set up the out/err handles to the streams created in the parent (to 
capture stdout and stderr for
+      logging)
+    - Configure the loggers in the child (both stdlib logging and Structlog) 
to send JSON logs back to the
+      supervisor for processing/output.
+    - Catch un-handled exceptions and attempt to show _something_ in case of 
error
+    - Finally, run the actual task runner code (``target`` argument, defaults 
to ``.task_runner:main`)
+    """
+    # TODO: Make this process a session leader
+
+    # Store original stderr for last-chance exception handling
+    last_chance_stderr = sys.__stderr__ or sys.stderr
+
+    _reset_signals()
+    if log_fd:
+        _configure_logs_over_json_channel(log_fd)
+    _reopen_std_io_handles(child_stdin, child_stdout, child_stderr)
+
+    def exit(n: int) -> NoReturn:
+        with suppress(ValueError, OSError):
+            sys.stdout.flush()
+        with suppress(ValueError, OSError):
+            sys.stderr.flush()
+        with suppress(ValueError, OSError):
+            last_chance_stderr.flush()
+        os._exit(n)
+
+    if hasattr(atexit, "_clear"):
+        # Since we're in a fork we want to try and clear them. If we can't do 
it cleanly, then we won't try
+        # and run new atexit handlers.
+        with suppress(Exception):
+            atexit._clear()
+            base_exit = exit
+
+            def exit(n: int) -> NoReturn:
+                # This will only run any atexit funcs registered after we've 
forked.
+                atexit._run_exitfuncs()
+                base_exit(n)
+
+    try:
+        target()
+        exit(0)
+    except SystemExit as e:
+        code = 1
+        if isinstance(e.code, int):
+            code = e.code
+        elif e.code:
+            print(e.code, file=sys.stderr)
+        exit(code)
+    except Exception:
+        # Last ditch log attempt
+        exc, v, tb = sys.exc_info()
+
+        import traceback
+
+        try:
+            last_chance_stderr.write("--- Last chance exception handler ---\n")
+            traceback.print_exception(exc, value=v, tb=tb, 
file=last_chance_stderr)
+            # Exit code 126 and 125 don't have any "special" meaning, they are 
only meant to serve as an
+            # identifier that the task process died in a really odd way.
+            exit(126)
+        except Exception as e:
+            with suppress(Exception):
+                print(
+                    f"--- Last chance exception handler failed --- 
{repr(str(e))}\n", file=last_chance_stderr
+                )
+            exit(125)
+
+
[email protected]()
+class WatchedSubprocess:
+    ti_id: UUID
+    pid: int
+
+    stdin: BinaryIO
+    stdout: socket
+    stderr: socket
+
+    client: Client
+
+    _process: psutil.Process
+    _exit_code: int | None = None
+    _terminal_state: str | None = None
+
+    _last_heartbeat: float = 0
+
+    selector: selectors.BaseSelector = 
attrs.field(factory=selectors.DefaultSelector)
+
+    procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = 
weakref.WeakValueDictionary()
+
+    def __attrs_post_init__(self):
+        self.procs[self.pid] = self
+
+    @classmethod
+    def start(
+        cls,
+        path: str | os.PathLike[str],
+        ti: TaskInstance,
+        client: Client,
+        target: Callable[[], None] = _subprocess_main,
+    ) -> WatchedSubprocess:
+        """Fork and start a new subprocess to execute the given task."""
+        # Create socketpairs/"pipes" to connect to the stdin and out from the 
subprocess
+        child_stdin, feed_stdin = mkpipe(remote_read=True)
+        child_stdout, read_stdout = mkpipe()
+        child_stderr, read_stderr = mkpipe()
+
+        # Open these socketpair before forking off the child, so that it is 
open when we fork.
+        child_comms, read_msgs = mkpipe()
+        child_logs, read_logs = mkpipe()
+
+        pid = os.fork()
+        if pid == 0:
+            # Parent ends of the sockets are closed by the OS as they are set 
as non-inheritable
+
+            # Run the child entryoint
+            _fork_main(child_stdin, child_stdout, child_stderr, 
child_logs.fileno(), target)
+
+        proc = cls(
+            ti_id=ti.id,
+            pid=pid,
+            stdin=feed_stdin,
+            stdout=read_stdout,
+            stderr=read_stderr,
+            process=psutil.Process(pid),
+            client=client,
+        )
+
+        # We've forked, but the task won't start until we send it the 
StartupDetails message. But before we do
+        # that, we need to tell the server it's started (so it has the chance 
to tell us "no, stop!" for any
+        # reason)
+        try:
+            client.task_instances.start(ti.id, pid, 
datetime.now(tz=timezone.utc))
+            proc._last_heartbeat = time.monotonic()
+        except Exception:
+            # On any error kill that subprocess!
+            proc.kill(signal.SIGKILL)
+            raise
+
+        # TODO: Use logging providers to handle the chunked upload for us
+        task_logger: FilteringBoundLogger = 
structlog.get_logger(logger_name="task").bind()
+
+        # proc.selector is a way of registering a handler/callback to be 
called when the given IO channel has
+        # activity to read on 
(https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better
+        # alternatives are used automatically) -- this is a way of having 
"event-based" code, but without
+        # needing full async, to read and process output from each socket as 
it is received.
+
+        cb = 
make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stdout"), 
level=logging.INFO))
+        proc.selector.register(read_stdout, selectors.EVENT_READ, cb)
+
+        cb = 
make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stderr"), 
level=logging.ERROR))
+        proc.selector.register(read_stderr, selectors.EVENT_READ, cb)
+
+        proc.selector.register(
+            read_logs,
+            selectors.EVENT_READ,
+            
make_buffered_socket_reader(process_log_messages_from_subprocess(task_logger)),
+        )
+        proc.selector.register(
+            read_msgs,
+            selectors.EVENT_READ,
+            make_buffered_socket_reader(proc.handle_requests(log=log)),
+        )
+
+        # Close the remaining parent-end of the sockets we've passed to the 
child via fork. We still have the
+        # other end of the pair open
+        child_stdout.close()
+        child_stdin.close()
+        child_comms.close()
+        child_logs.close()
+
+        # Tell the task process what it needs to do!
+
+        msg = StartupDetails(
+            ti=ti,
+            file=str(path),
+            requests_fd=child_comms.fileno(),
+        )
+
+        # Send the message to tell the process what it needs to execute
+        log.debug("Sending", msg=msg)
+        feed_stdin.write(msg.model_dump_json().encode())
+        feed_stdin.write(b"\n")
+
+        return proc
+
+    def kill(self, signal: signal.Signals = signal.SIGINT):
+        if self._exit_code is not None:
+            return
+
+        with suppress(ProcessLookupError):
+            os.kill(self.pid, signal)
+
+    def wait(self) -> int:
+        if self._exit_code is not None:
+            return self._exit_code
+
+        # Until we have a selector for the process, don't poll for more than 
10s, just in case it exists but
+        # doesn't produce any output
+        max_poll_interval = 10
+
+        try:
+            while self._exit_code is None or len(self.selector.get_map()):
+                last_heartbeat_ago = time.monotonic() - self._last_heartbeat
+                # Monitor the task to see if it's done. Wait in a syscall 
(`select`) for as long as possible
+                # so we notice the subprocess finishing as quick as we can.
+                max_wait_time = max(
+                    0,  # Make sure this value is never negative,
+                    min(
+                        # Ensure we heartbeat _at most_ 75% through time the 
zombie threshold time
+                        SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75,
+                        max_poll_interval,
+                    ),
+                )
+                events = self.selector.select(timeout=max_wait_time)
+                for key, _ in events:
+                    socket_handler = key.data
+                    need_more = socket_handler(key.fileobj)
+
+                    if not need_more:
+                        self.selector.unregister(key.fileobj)
+                        key.fileobj.close()  # type: ignore[union-attr]
+
+                if self._exit_code is None:
+                    try:
+                        self._exit_code = self._process.wait(timeout=0)
+                        log.debug("Task process exited", 
exit_code=self._exit_code)
+                    except psutil.TimeoutExpired:
+                        pass
+
+                if last_heartbeat_ago < FASTEST_HEARTBEAT_INTERVAL:
+                    # Avoid heartbeating too frequently
+                    continue
+
+                try:
+                    self.client.task_instances.heartbeat(self.ti_id)
+                    self._last_heartbeat = time.monotonic()
+                except Exception:
+                    log.warning("Couldn't heartbeat", exc_info=True)
+                    # TODO: If we couldn't heartbeat for X times the interval, 
kill ourselves
+                    pass
+        finally:
+            self.selector.close()
+
+        self.client.task_instances.finish(
+            id=self.ti_id, state=self.final_state, 
when=datetime.now(tz=timezone.utc)
+        )
+        return self._exit_code
+
+    @property
+    def final_state(self):
+        """
+        The final state of the TaskInstance.
+
+        By default this will be derived from the exit code of the task
+        (0=success, failed otherwise) but can be changed by the subprocess
+        sending a TaskState message, as long as the process exits with 0
+
+        Not valid before the process has finished.
+        """
+        if self._exit_code == 0:
+            return self._terminal_state or TaskInstanceState.SUCCESS
+        return TaskInstanceState.FAILED
+
+    def __rich_repr__(self):
+        yield "pid", self.pid
+        yield "exit_code", self._exit_code, None
+
+    __rich_repr__.angular = True  # type: ignore[attr-defined]
+
+    def __repr__(self) -> str:
+        rep = f"<WatchedSubprocess pid={self.pid}"
+        if self._exit_code is not None:
+            rep += f" exit_code={self._exit_code}"
+        return rep + " >"
+
+    def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, 
bytes, None]:
+        encoder = ConnectionResponse.model_dump_json
+        # Use a buffer to avoid small allocations
+        buffer = bytearray(64)
+
+        decoder = TypeAdapter[ToSupervisor](ToSupervisor)
+
+        while True:
+            line = yield
+
+            try:
+                msg = decoder.validate_json(line)
+            except Exception:
+                log.exception("Unable to decode message", line=line)
+                continue
+
+            # if isinstnace(msg, TaskState):
+            #     self._terminal_state = msg.state
+            # elif isinstance(msg, ReadXCom):
+            #     resp = XComResponse(key="secret", value=True)
+            #     encoder.encode_into(resp, buffer)
+            #     self.stdin.write(buffer + b"\n")
+            if isinstance(msg, GetConnection):
+                conn = self.client.connections.get(msg.id)
+                resp = ConnectionResponse(conn=conn)
+                encoded_resp = encoder(resp)
+                buffer.extend(encoded_resp.encode())
+            else:
+                log.error("Unhandled request", msg=msg)
+                continue
+
+            buffer.extend(b"\n")
+            self.stdin.write(buffer)
+
+            # Ensure the buffer doesn't grow and stay large if a large payload 
is used. This won't grow it
+            # larger than it is, but it will shrink it
+            if len(buffer) > 1024:
+                buffer = buffer[:1024]
+
+
+# Sockets, even the `.makefile()` function don't correctly do line buffering 
on reading. If a chunk is read
+# and it doesn't contain a new line character, `.readline()` will just return 
the chunk as is.
+#
+# This returns a callback suitable for attaching to a `selector` that reads in 
to a buffer, and yields lines
+# to a (sync) generator
+def make_buffered_socket_reader(
+    gen: Generator[None, bytes, None], buffer_size: int = 4096
+) -> Callable[[socket], bool]:
+    buffer = bytearray()  # This will hold our accumulated binary data
+    read_buffer = bytearray(buffer_size)  # Temporary buffer for each read
+
+    # We need to start up the generator to get it to the point it's at waiting 
on the yield
+    next(gen)
+
+    def cb(sock: socket):
+        nonlocal buffer, read_buffer
+        # Read up to `buffer_size` bytes of data from the socket
+        n_received = sock.recv_into(read_buffer)
+
+        if not n_received:
+            # If no data is returned, the connection is closed. Return 
whatever is left in the buffer
+            if len(buffer):
+                gen.send(buffer)
+            # Tell loop to close this selector
+            return False
+
+        buffer.extend(read_buffer[:n_received])
+
+        # We could have read multiple lines in one go, yield them all
+        while (newline_pos := buffer.find(b"\n")) != -1:
+            if TYPE_CHECKING:
+                # We send in a memoryvuew, but pretend it's a bytes, as Buffer 
is only in 3.12+
+                line = buffer[: newline_pos + 1]
+            else:
+                line = memoryview(buffer)[: newline_pos + 1]  # Include the 
newline character
+            gen.send(line)
+            buffer = buffer[newline_pos + 1 :]  # Update the buffer with 
remaining data
+
+        return True
+
+    return cb
+
+
+def process_log_messages_from_subprocess(log: FilteringBoundLogger) -> 
Generator[None, bytes, None]:
+    from structlog.stdlib import NAME_TO_LEVEL
+
+    while True:
+        # Generator receive syntax, values are "sent" in  by the 
`make_buffered_socket_reader` and returned to
+        # the yield.
+        line = yield
+
+        try:
+            event = msgspec.json.decode(line)
+        except Exception:
+            log.exception("Malformed json log line", line=line)
+            continue
+
+        if ts := event.get("timestamp"):
+            # We use msgspec to decode the timestamp as it does it orders of 
magnitude quicker than
+            # datetime.strptime cn
+            #
+            # We remove the timezone info here, as the json encoding has 
`+00:00`, and since the log came
+            # from a subprocess we know that the timezone of the log message 
is the same, so having some
+            # messages include tz (from subprocess) but others not (ones from 
supervisor process) is
+            # confusing.
+            event["timestamp"] = msgspec.json.decode(f'"{ts}"', 
type=datetime).replace(tzinfo=None)
+
+        if exc := event.pop("exception", None):
+            # TODO: convert the dict back to a pretty stack trace
+            event["error_detail"] = exc
+        log.log(NAME_TO_LEVEL[event.pop("level")], event.pop("event", None), 
**event)
+
+
+def forward_to_log(target_log: FilteringBoundLogger, level: int) -> 
Generator[None, bytes, None]:
+    while True:
+        buf = yield
+        line = bytes(buf)
+        # Strip off new line
+        line = line.rstrip()
+        try:
+            msg = line.decode("utf-8", errors="replace")
+            target_log.log(level, msg)
+        except UnicodeDecodeError:
+            msg = line.decode("ascii", errors="replace")
+            target_log.log(level, msg)
+
+
+def supervise(activity: ExecuteTaskActivity, server: str | None = None, 
dry_run: bool = False) -> int:
+    """
+    Run a single task execution to completion.
+
+    Returns the exit code of the process
+    """
+    # One or the other
+    if (server == "") ^ dry_run:
+        raise ValueError(f"Can only specify one of {server=} or {dry_run=}")
+
+    if not activity.path:
+        raise ValueError("path filed of activity missing")
+
+    limits = httpx.Limits(max_keepalive_connections=1, max_connections=10)
+    client = Client(base_url=server or "", limits=limits, dry_run=dry_run, 
token=activity.token)
+
+    start = time.monotonic()
+
+    process = WatchedSubprocess.start(activity.path, activity.ti, 
client=client)
+
+    exit_code = process.wait()
+    end = time.monotonic()
+    log.debug("Task finished", exit_code=exit_code, duration=end - start)
+    return exit_code
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
new file mode 100644
index 00000000000..382e29c59b6
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -0,0 +1,191 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The entrypoint for the actual task execution process."""
+
+from __future__ import annotations
+
+import os
+import sys
+from io import FileIO
+from typing import TYPE_CHECKING, TextIO
+
+import attrs
+import structlog
+from pydantic import ConfigDict, TypeAdapter
+
+from airflow.sdk import BaseOperator
+from airflow.sdk.execution_time.comms import StartupDetails, TaskInstance, 
ToSupervisor, ToTask
+
+if TYPE_CHECKING:
+    from structlog.typing import FilteringBoundLogger as Logger
+
+
+class RuntimeTaskInstance(TaskInstance):
+    model_config = ConfigDict(arbitrary_types_allowed=True)
+
+    task: BaseOperator
+
+
+def parse(what: StartupDetails) -> RuntimeTaskInstance:
+    # TODO: Task-SDK:
+    # Using DagBag here is aoubt 98% wrong, but it'll do for now
+
+    from airflow.models.dagbag import DagBag
+
+    bag = DagBag(
+        dag_folder=what.file,
+        include_examples=False,
+        safe_mode=False,
+        load_op_links=False,
+    )
+    if TYPE_CHECKING:
+        assert what.ti.dag_id
+
+    dag = bag.dags[what.ti.dag_id]
+
+    # install_loader()
+
+    # TODO: Handle task not found
+    task = dag.task_dict[what.ti.task_id]
+    if not isinstance(task, BaseOperator):
+        raise TypeError(f"task is of the wrong type, got {type(task)}, wanted 
{BaseOperator}")
+    return RuntimeTaskInstance(**what.ti.model_dump(exclude_unset=True), 
task=task)
+
+
[email protected]()
+class CommsDecoder:
+    """Handle communication between the task in this process and the 
supervisor parent process."""
+
+    input: TextIO = sys.stdin
+
+    request_socket: FileIO = attrs.field(init=False, default=None)
+
+    decoder: TypeAdapter[ToTask] = attrs.field(init=False, factory=lambda: 
TypeAdapter(ToTask))
+
+    def get_message(self) -> ToTask:
+        """
+        Get a message from the parent.
+
+        This will block until the message has been received.
+        """
+        line = self.input.readline()
+        try:
+            msg = self.decoder.validate_json(line)
+        except Exception:
+            structlog.get_logger(logger_name="CommsDecoder").exception("Unable 
to decode message", line=line)
+            raise
+
+        if isinstance(msg, StartupDetails):
+            # If we read a startup message, pull out the FDs we care about!
+            if msg.requests_fd > 0:
+                self.request_socket = os.fdopen(msg.requests_fd, "wb", 
buffering=0)
+        return msg
+
+    def send_request(self, log: Logger, msg: ToSupervisor):
+        encoded_msg = msg.model_dump_json().encode() + b"\n"
+
+        log.debug("Sending request", json=encoded_msg)
+        self.request_socket.write(encoded_msg)
+
+
+# This global variable will be used by Connection/Variable classes etc to send 
requests to
+SUPERVISOR_COMMS: CommsDecoder
+
+# State machine!
+# 1. Start up (receive details from supervisor)
+# 2. Execution (run task code, possibly send requests)
+# 3. Shutdown and report status
+
+
+def startup() -> tuple[RuntimeTaskInstance, Logger]:
+    msg = SUPERVISOR_COMMS.get_message()
+
+    if isinstance(msg, StartupDetails):
+        log = structlog.get_logger(logger_name="task")
+        # TODO: set the "magic loop" context vars for parsing
+        ti = parse(msg)
+        log.debug("DAG file parsed", file=msg.file)
+        return ti, log
+    else:
+        raise RuntimeError(f"Unhandled  startup message {type(msg)} {msg}")
+
+    # TODO: Render fields here
+
+
+def run(ti: RuntimeTaskInstance, log: Logger):
+    """Run the task in this process."""
+    from airflow.exceptions import (
+        AirflowException,
+        AirflowFailException,
+        AirflowRescheduleException,
+        AirflowSensorTimeout,
+        AirflowSkipException,
+        AirflowTaskTerminated,
+        AirflowTaskTimeout,
+        TaskDeferred,
+    )
+
+    if TYPE_CHECKING:
+        assert ti.task is not None
+        assert isinstance(ti.task, BaseOperator)
+    try:
+        # TODO: pre execute etc.
+        # TODO next_method to support resuming from deferred
+        # TODO: Get a real context object
+        ti.task.execute({"task_instance": ti})  # type: ignore[attr-defined]
+    except TaskDeferred:
+        ...
+    except AirflowSkipException:
+        ...
+    except AirflowRescheduleException:
+        ...
+    except (AirflowFailException, AirflowSensorTimeout):
+        # If AirflowFailException is raised, task should not retry.
+        ...
+    except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated):
+        ...
+    except SystemExit:
+        ...
+    except BaseException:
+        ...
+
+
+def finalize(log: Logger): ...
+
+
+def main():
+    # TODO: add an exception here, it causes an oof of a stack trace!
+
+    global SUPERVISOR_COMMS
+    SUPERVISOR_COMMS = CommsDecoder()
+    try:
+        ti, log = startup()
+        run(ti, log)
+        finalize(log)
+    except KeyboardInterrupt:
+        log = structlog.get_logger(logger_name="task")
+        log.exception("Ctrl-c hit")
+        exit(2)
+    except Exception:
+        log = structlog.get_logger(logger_name="task")
+        log.exception("Top level error")
+        exit(1)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py
new file mode 100644
index 00000000000..f8e06eda4a6
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/log.py
@@ -0,0 +1,372 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import itertools
+import logging.config
+import os
+import sys
+import warnings
+from functools import cache
+from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Generic, TextIO, 
TypeVar
+
+import msgspec
+import structlog
+
+if TYPE_CHECKING:
+    from structlog.typing import EventDict, ExcInfo, Processor
+
+
+__all__ = [
+    "configure_logging",
+    "reset_logging",
+]
+
+
+def exception_group_tracebacks(format_exception: Callable[[ExcInfo], 
list[dict[str, Any]]]) -> Processor:
+    # Make mypy happy
+    if not hasattr(__builtins__, "BaseExceptionGroup"):
+        T = TypeVar("T")
+
+        class BaseExceptionGroup(Generic[T]):
+            exceptions: list[T]
+
+    def _exception_group_tracebacks(logger: Any, method_name: Any, event_dict: 
EventDict) -> EventDict:
+        if exc_info := event_dict.get("exc_info", None):
+            group: BaseExceptionGroup[Exception] | None = None
+            if exc_info is True:
+                # `log.exception('mesg")` case
+                exc_info = sys.exc_info()
+                if exc_info[0] is None:
+                    exc_info = None
+
+            if (
+                isinstance(exc_info, tuple)
+                and len(exc_info) == 3
+                and isinstance(exc_info[1], BaseExceptionGroup)
+            ):
+                group = exc_info[1]
+            elif isinstance(exc_info, BaseExceptionGroup):
+                group = exc_info
+
+            if group:
+                # Only remove it from event_dict if we handle it
+                del event_dict["exc_info"]
+                event_dict["exception"] = list(
+                    itertools.chain.from_iterable(
+                        format_exception((type(exc), exc, exc.__traceback__))  
# type: ignore[attr-defined,arg-type]
+                        for exc in (*group.exceptions, group)
+                    )
+                )
+
+        return event_dict
+
+    return _exception_group_tracebacks
+
+
+def logger_name(logger: Any, method_name: Any, event_dict: EventDict) -> 
EventDict:
+    if logger_name := event_dict.pop("logger_name", None):
+        event_dict.setdefault("logger", logger_name)
+    return event_dict
+
+
+def redact_jwt(logger: Any, method_name: str, event_dict: EventDict) -> 
EventDict:
+    for k, v in event_dict.items():
+        if isinstance(v, str) and v.startswith("eyJ"):
+            event_dict[k] = "eyJ***"
+    return event_dict
+
+
+def drop_positional_args(logger: Any, method_name: Any, event_dict: EventDict) 
-> EventDict:
+    event_dict.pop("positional_args", None)
+    return event_dict
+
+
+class StdBinaryStreamHandler(logging.StreamHandler):
+    """A logging.StreamHandler that sends logs as binary JSON over the given 
stream."""
+
+    stream: BinaryIO
+
+    def __init__(self, stream: BinaryIO):
+        super().__init__(stream)
+
+    def emit(self, record: logging.LogRecord):
+        try:
+            msg = self.format(record)
+            buffer = bytearray(msg, "ascii", "backslashreplace")
+
+            buffer += b"\n"
+
+            stream = self.stream
+            stream.write(buffer)
+            self.flush()
+        except RecursionError:  # See issue 36272
+            raise
+        except Exception:
+            self.handleError(record)
+
+
+@cache
+def logging_processors(
+    enable_pretty_log: bool,
+):
+    if enable_pretty_log:
+        timestamper = structlog.processors.MaybeTimeStamper(fmt="%Y-%m-%d 
%H:%M:%S.%f")
+    else:
+        timestamper = structlog.processors.MaybeTimeStamper(fmt="iso")
+
+    processors: list[structlog.typing.Processor] = [
+        timestamper,
+        structlog.contextvars.merge_contextvars,
+        structlog.processors.add_log_level,
+        structlog.stdlib.PositionalArgumentsFormatter(),
+        logger_name,
+        redact_jwt,
+        structlog.processors.StackInfoRenderer(),
+    ]
+
+    # Imports to suppress showing code from these modules. We need the import 
to get the filepath for
+    # structlog to ignore.
+    import contextlib
+
+    import click
+    import httpcore
+    import httpx
+
+    suppress = (
+        click,
+        contextlib,
+        httpx,
+        httpcore,
+        httpx,
+    )
+
+    if enable_pretty_log:
+        rich_exc_formatter = structlog.dev.RichTracebackFormatter(
+            # These values are picked somewhat arbitrarily to produce 
useful-but-compact tracebacks. If
+            # we ever need to change these then they should be configurable.
+            extra_lines=0,
+            max_frames=30,
+            indent_guides=False,
+            suppress=suppress,
+        )
+        my_styles = structlog.dev.ConsoleRenderer.get_default_level_styles()
+        my_styles["debug"] = structlog.dev.CYAN
+
+        console = structlog.dev.ConsoleRenderer(
+            exception_formatter=rich_exc_formatter, level_styles=my_styles
+        )
+        processors.append(console)
+        return processors, {
+            "timestamper": timestamper,
+            "console": console,
+        }
+    else:
+        # Imports to suppress showing code from these modules
+        import contextlib
+
+        import click
+        import httpcore
+        import httpx
+
+        dict_exc_formatter = structlog.tracebacks.ExceptionDictTransformer(
+            use_rich=False, show_locals=False, suppress=suppress
+        )
+
+        dict_tracebacks = 
structlog.processors.ExceptionRenderer(dict_exc_formatter)
+        if hasattr(__builtins__, "BaseExceptionGroup"):
+            exc_group_processor = 
exception_group_tracebacks(dict_exc_formatter)
+            processors.append(exc_group_processor)
+        else:
+            exc_group_processor = None
+
+        encoder = msgspec.json.Encoder()
+
+        def json_dumps(msg, default):
+            return encoder.encode(msg)
+
+        def json_processor(logger: Any, method_name: Any, event_dict: 
EventDict) -> str:
+            return encoder.encode(event_dict).decode("ascii")
+
+        json = structlog.processors.JSONRenderer(serializer=json_dumps)
+
+        processors.extend(
+            (
+                dict_tracebacks,
+                structlog.processors.UnicodeDecoder(),
+                json,
+            ),
+        )
+
+        return processors, {
+            "timestamper": timestamper,
+            "exc_group_processor": exc_group_processor,
+            "dict_tracebacks": dict_tracebacks,
+            "json": json_processor,
+        }
+
+
+@cache
+def configure_logging(
+    enable_pretty_log: bool = True,
+    log_level: str = "DEBUG",
+    output: BinaryIO | None = None,
+    cache_logger_on_first_use: bool = True,
+):
+    """Set up struct logging and stdlib logging config."""
+    if enable_pretty_log and output is not None:
+        raise ValueError("output can only be set if enable_pretty_log is not")
+
+    lvl = structlog.stdlib.NAME_TO_LEVEL[log_level.lower()]
+
+    if enable_pretty_log:
+        formatter = "colored"
+    else:
+        formatter = "plain"
+    processors, named = logging_processors(enable_pretty_log)
+    timestamper = named["timestamper"]
+
+    pre_chain: list[structlog.typing.Processor] = [
+        # Add the log level and a timestamp to the event_dict if the log entry
+        # is not from structlog.
+        structlog.stdlib.add_log_level,
+        structlog.stdlib.add_logger_name,
+        timestamper,
+    ]
+
+    # Don't cache the loggers during tests, it make it hard to capture them
+    if "PYTEST_CURRENT_TEST" in os.environ:
+        cache_logger_on_first_use = False
+
+    color_formatter: list[structlog.typing.Processor] = [
+        structlog.stdlib.ProcessorFormatter.remove_processors_meta,
+        drop_positional_args,
+    ]
+    std_lib_formatter: list[structlog.typing.Processor] = [
+        structlog.stdlib.ProcessorFormatter.remove_processors_meta,
+        drop_positional_args,
+    ]
+
+    wrapper_class = structlog.make_filtering_bound_logger(lvl)
+    if enable_pretty_log:
+        structlog.configure(
+            processors=processors,
+            cache_logger_on_first_use=cache_logger_on_first_use,
+            wrapper_class=wrapper_class,
+        )
+        color_formatter.append(named["console"])
+    else:
+        structlog.configure(
+            processors=processors,
+            cache_logger_on_first_use=cache_logger_on_first_use,
+            wrapper_class=wrapper_class,
+            logger_factory=structlog.BytesLoggerFactory(output),
+        )
+
+        if processor := named["exc_group_processor"]:
+            pre_chain.append(processor)
+        pre_chain.append(named["dict_tracebacks"])
+        color_formatter.append(named["json"])
+        std_lib_formatter.append(named["json"])
+
+    global _warnings_showwarning
+    _warnings_showwarning = warnings.showwarning
+    # Capture warnings and show them via structlog
+    warnings.showwarning = _showwarning
+
+    logging.config.dictConfig(
+        {
+            "version": 1,
+            "disable_existing_loggers": False,
+            "formatters": {
+                "plain": {
+                    "()": structlog.stdlib.ProcessorFormatter,
+                    "processors": std_lib_formatter,
+                    "foreign_pre_chain": pre_chain,
+                    "pass_foreign_args": True,
+                },
+                "colored": {
+                    "()": structlog.stdlib.ProcessorFormatter,
+                    "processors": color_formatter,
+                    "foreign_pre_chain": pre_chain,
+                    "pass_foreign_args": True,
+                },
+            },
+            "handlers": {
+                "default": {
+                    "level": log_level.upper(),
+                    "class": "logging.StreamHandler",
+                    "formatter": formatter,
+                },
+                "to_supervisor": {
+                    "level": log_level.upper(),
+                    "()": StdBinaryStreamHandler,
+                    "formatter": formatter,
+                    "stream": output,
+                },
+            },
+            "loggers": {
+                "": {
+                    "handlers": ["to_supervisor" if output else "default"],
+                    "level": log_level.upper(),
+                    "propagate": True,
+                },
+                # Some modules we _never_ want at debug level
+                "asyncio": {"level": "INFO"},
+                "alembic": {"level": "INFO"},
+                "httpcore": {"level": "INFO"},
+                "httpx": {"level": "WARN"},
+                "psycopg.pq": {"level": "INFO"},
+                "sqlalchemy.engine": {"level": "WARN"},
+            },
+        }
+    )
+
+
+def reset_logging():
+    global _warnings_showwarning
+    warnings.showwarning = _warnings_showwarning
+    configure_logging.cache_clear()
+
+
+_warnings_showwarning = None
+
+
+def _showwarning(
+    message: str | Warning,
+    category: type[Warning],
+    filename: str,
+    lineno: int,
+    file: TextIO | None = None,
+    line: str | None = None,
+):
+    """
+    Redirects warnings to structlog so they appear in task logs etc.
+
+    Implementation of showwarnings which redirects to logging, which will first
+    check to see if the file parameter is None. If a file is specified, it will
+    delegate to the original warnings implementation of showwarning. Otherwise,
+    it will call warnings.formatwarning and will log the resulting string to a
+    warnings logger named "py.warnings" with level logging.WARNING.
+    """
+    if file is not None:
+        if _warnings_showwarning is not None:
+            _warnings_showwarning(message, category, filename, lineno, file, 
line)
+    else:
+        log = structlog.get_logger(logger_name="py.warnings")
+        log.warning(str(message), category=category.__name__, 
filename=filename, lineno=lineno)
diff --git a/task_sdk/src/airflow/sdk/types.py 
b/task_sdk/src/airflow/sdk/types.py
index 232d08e27f9..ffde2170b17 100644
--- a/task_sdk/src/airflow/sdk/types.py
+++ b/task_sdk/src/airflow/sdk/types.py
@@ -57,7 +57,7 @@ if TYPE_CHECKING:
     Logger = logging.Logger
 else:
 
-    class Logger: ...  # noqa: D101
+    class Logger: ...
 
 
 def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, 
Any]) -> None:
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/api/__init__.py
similarity index 69%
copy from task_sdk/tests/conftest.py
copy to task_sdk/tests/api/__init__.py
index ddc7c61656a..13a83393a91 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/api/__init__.py
@@ -14,18 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from __future__ import annotations
-
-import os
-
-import pytest
-
-pytest_plugins = "tests_common.pytest_plugin"
-
-# Task SDK does not need access to the Airflow database
-os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
-
-
[email protected](tryfirst=True)
-def pytest_configure(config: pytest.Config) -> None:
-    config.inicfg["airflow_deprecations_ignore"] = []
diff --git a/task_sdk/tests/api/test_client.py 
b/task_sdk/tests/api/test_client.py
new file mode 100644
index 00000000000..a32b321545d
--- /dev/null
+++ b/task_sdk/tests/api/test_client.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import httpx
+import pytest
+
+from airflow.sdk.api.client import Client, ErrorBody, RemoteValidationError, 
ServerResponseError
+
+
+class TestClient:
+    def test_error_parsing(self):
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            """
+            A transport handle that always returns errors
+            """
+
+            return httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": 
"err", "type": "required"}]})
+
+        client = Client(
+            base_url=None, dry_run=True, token="", mounts={"'http://": 
httpx.MockTransport(handle_request)}
+        )
+
+        with pytest.raises(ServerResponseError) as err:
+            client.get("http://error";)
+
+        assert isinstance(err.value, ServerResponseError)
+        assert isinstance(err.value.detail, ErrorBody)
+        assert err.value.detail.detail == [
+            RemoteValidationError(loc=["#0"], msg="err", type="required"),
+        ]
+
+    def test_error_parsing_plain_text(self):
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            """
+            A transport handle that always returns errors
+            """
+
+            return httpx.Response(422, content=b"Internal Server Error")
+
+        client = Client(
+            base_url=None, dry_run=True, token="", mounts={"'http://": 
httpx.MockTransport(handle_request)}
+        )
+
+        with pytest.raises(httpx.HTTPStatusError) as err:
+            client.get("http://error";)
+        assert not isinstance(err.value, ServerResponseError)
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py
index ddc7c61656a..dffd1370f4e 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/conftest.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import os
+from typing import TYPE_CHECKING, NoReturn
 
 import pytest
 
@@ -25,7 +26,64 @@ pytest_plugins = "tests_common.pytest_plugin"
 # Task SDK does not need access to the Airflow database
 os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
 
+if TYPE_CHECKING:
+    from structlog.typing import EventDict, WrappedLogger
+
+
[email protected]()
+def pytest_addhooks(pluginmanager: pytest.PytestPluginManager):
+    # Python 3.12 starts warning about mixing os.fork + Threads, and the 
pytest-rerunfailures plugin uses
+    # threads internally. Since this is new code, and it should be flake free, 
we disable the re-run failures
+    # plugin early (so that it doesn't run it's pytest_configure which is 
where the thread starts up if xdist
+    # is discovered).
+    pluginmanager.set_blocked("rerunfailures")
+
 
 @pytest.hookimpl(tryfirst=True)
 def pytest_configure(config: pytest.Config) -> None:
     config.inicfg["airflow_deprecations_ignore"] = []
+
+
+class LogCapture:
+    # Like structlog.typing.LogCapture, but that doesn't add log_level in to 
the event dict
+    entries: list[EventDict]
+
+    def __init__(self) -> None:
+        self.entries = []
+
+    def __call__(self, _: WrappedLogger, method_name: str, event_dict: 
EventDict) -> NoReturn:
+        from structlog.exceptions import DropEvent
+
+        if "level" not in event_dict:
+            event_dict["_log_level"] = method_name
+
+        self.entries.append(event_dict)
+
+        raise DropEvent
+
+
[email protected]
+def captured_logs():
+    import structlog
+
+    from airflow.sdk.log import configure_logging, reset_logging
+
+    # Use our real log config
+    reset_logging()
+    configure_logging(enable_pretty_log=False)
+
+    # But we need to replace remove the last processor (the one that turns 
JSON into text, as we want the
+    # event dict for tests)
+    cur_processors = structlog.get_config()["processors"]
+    processors = cur_processors.copy()
+    proc = processors.pop()
+    assert isinstance(
+        proc, (structlog.dev.ConsoleRenderer, 
structlog.processors.JSONRenderer)
+    ), "Pre-condition"
+    try:
+        cap = LogCapture()
+        processors.append(cap)
+        structlog.configure(processors=processors)
+        yield cap.entries
+    finally:
+        structlog.configure(processors=cur_processors)
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/defintions/__init__.py
similarity index 69%
copy from task_sdk/tests/conftest.py
copy to task_sdk/tests/defintions/__init__.py
index ddc7c61656a..13a83393a91 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/defintions/__init__.py
@@ -14,18 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from __future__ import annotations
-
-import os
-
-import pytest
-
-pytest_plugins = "tests_common.pytest_plugin"
-
-# Task SDK does not need access to the Airflow database
-os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
-
-
[email protected](tryfirst=True)
-def pytest_configure(config: pytest.Config) -> None:
-    config.inicfg["airflow_deprecations_ignore"] = []
diff --git a/task_sdk/tests/defintions/test_baseoperator.py 
b/task_sdk/tests/defintions/test_baseoperator.py
index 427d1ee0e3e..19035319cdc 100644
--- a/task_sdk/tests/defintions/test_baseoperator.py
+++ b/task_sdk/tests/defintions/test_baseoperator.py
@@ -29,6 +29,25 @@ from airflow.task.priority_strategy import 
_DownstreamPriorityWeightStrategy, _U
 DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
 
 
[email protected](autouse=True, scope="module")
+def _disable_ol_plugin():
+    # The OpenLineage plugin imports setproctitle, and that now causes (C) 
level thread calls, which on Py
+    # 3.12+ issues a warning when os.fork happens. So for this plugin we 
disable it
+
+    # And we load plugins when setting the priorty_weight field
+    import airflow.plugins_manager
+
+    old = airflow.plugins_manager.plugins
+
+    assert old is None, "Plugins already loaded, too late to stop them being 
loaded!"
+
+    airflow.plugins_manager.plugins = []
+
+    yield
+
+    airflow.plugins_manager.plugins = None
+
+
 # Essentially similar to airflow.models.baseoperator.BaseOperator
 class FakeOperator(metaclass=BaseOperatorMeta):
     def __init__(self, test_param, params=None, default_args=None):
diff --git a/task_sdk/tests/conftest.py 
b/task_sdk/tests/execution_time/__init__.py
similarity index 69%
copy from task_sdk/tests/conftest.py
copy to task_sdk/tests/execution_time/__init__.py
index ddc7c61656a..13a83393a91 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/execution_time/__init__.py
@@ -14,18 +14,3 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from __future__ import annotations
-
-import os
-
-import pytest
-
-pytest_plugins = "tests_common.pytest_plugin"
-
-# Task SDK does not need access to the Airflow database
-os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
-
-
[email protected](tryfirst=True)
-def pytest_configure(config: pytest.Config) -> None:
-    config.inicfg["airflow_deprecations_ignore"] = []
diff --git a/task_sdk/tests/conftest.py 
b/task_sdk/tests/execution_time/conftest.py
similarity index 73%
copy from task_sdk/tests/conftest.py
copy to task_sdk/tests/execution_time/conftest.py
index ddc7c61656a..4a537373363 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/execution_time/conftest.py
@@ -14,18 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 from __future__ import annotations
 
-import os
+import sys
 
 import pytest
 
-pytest_plugins = "tests_common.pytest_plugin"
-
-# Task SDK does not need access to the Airflow database
-os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
 
[email protected]
+def disable_capturing():
+    old_in, old_out, old_err = sys.stdin, sys.stdout, sys.stderr
 
[email protected](tryfirst=True)
-def pytest_configure(config: pytest.Config) -> None:
-    config.inicfg["airflow_deprecations_ignore"] = []
+    sys.stdin = sys.__stdin__
+    sys.stdout = sys.__stdout__
+    sys.stderr = sys.__stderr__
+    yield
+    sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
new file mode 100644
index 00000000000..f1bf287cd22
--- /dev/null
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import inspect
+import logging
+import os
+import signal
+import sys
+from unittest.mock import MagicMock
+
+import pytest
+import structlog
+import structlog.testing
+
+from airflow.sdk.api import client as sdk_client
+from airflow.sdk.api.datamodels.ti import TaskInstance
+from airflow.sdk.execution_time.supervisor import WatchedSubprocess
+from airflow.utils import timezone as tz
+
+
+def lineno():
+    """Returns the current line number in our program."""
+    return inspect.currentframe().f_back.f_lineno
+
+
[email protected]("disable_capturing")
+class TestWatchedSubprocess:
+    def test_reading_from_pipes(self, captured_logs, time_machine):
+        # Ignore anything lower than INFO for this test. Captured_logs resets 
things for us afterwards
+        
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))
+
+        line = lineno()
+
+        def subprocess_main():
+            # This is run in the subprocess!
+
+            # Flush calls are to ensure ordering of output for predictable 
tests
+            import logging
+            import warnings
+
+            print("I'm a short message")
+            sys.stdout.write("Message ")
+            sys.stdout.write("split across two writes\n")
+            sys.stdout.flush()
+
+            print("stderr message", file=sys.stderr)
+            sys.stderr.flush()
+
+            logging.getLogger("airflow.foobar").error("An error message")
+
+            warnings.warn("Warning should be captured too", stacklevel=1)
+
+        instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901)
+        time_machine.move_to(instant, tick=False)
+
+        proc = WatchedSubprocess.start(
+            path=os.devnull,
+            ti=TaskInstance(
+                id="4d828a62-a417-4936-a7a6-2b3fabacecab",
+                task_id="b",
+                dag_id="c",
+                run_id="d",
+                try_number=1,
+            ),
+            client=MagicMock(spec=sdk_client.Client),
+            target=subprocess_main,
+        )
+
+        rc = proc.wait()
+
+        assert rc == 0
+        assert captured_logs == [
+            {
+                "chan": "stdout",
+                "event": "I'm a short message",
+                "level": "info",
+                "logger": "task",
+                "timestamp": "2024-11-07T12:34:56.078901Z",
+            },
+            {
+                "chan": "stdout",
+                "event": "Message split across two writes",
+                "level": "info",
+                "logger": "task",
+                "timestamp": "2024-11-07T12:34:56.078901Z",
+            },
+            {
+                "chan": "stderr",
+                "event": "stderr message",
+                "level": "error",
+                "logger": "task",
+                "timestamp": "2024-11-07T12:34:56.078901Z",
+            },
+            {
+                "event": "An error message",
+                "level": "error",
+                "logger": "airflow.foobar",
+                "timestamp": instant.replace(tzinfo=None),
+            },
+            {
+                "category": "UserWarning",
+                "event": "Warning should be captured too",
+                "filename": __file__,
+                "level": "warning",
+                "lineno": line + 19,
+                "logger": "py.warnings",
+                "timestamp": instant.replace(tzinfo=None),
+            },
+        ]
+
+    def test_subprocess_sigkilled(self):
+        main_pid = os.getpid()
+
+        def subprocess_main():
+            # This is run in the subprocess!
+            assert os.getpid() != main_pid
+            os.kill(os.getpid(), signal.SIGKILL)
+
+        proc = WatchedSubprocess.start(
+            path=os.devnull,
+            ti=TaskInstance(
+                id="4d828a62-a417-4936-a7a6-2b3fabacecab",
+                task_id="b",
+                dag_id="c",
+                run_id="d",
+                try_number=1,
+            ),
+            client=MagicMock(spec=sdk_client.Client),
+            target=subprocess_main,
+        )
+
+        rc = proc.wait()
+
+        assert rc == -9
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
new file mode 100644
index 00000000000..5a90701cb2c
--- /dev/null
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import uuid
+from socket import socketpair
+
+import pytest
+
+from airflow.sdk.execution_time.comms import StartupDetails
+from airflow.sdk.execution_time.task_runner import CommsDecoder
+
+
+class TestCommsDecoder:
+    """Test the communication between the subprocess and the "supervisor"."""
+
+    @pytest.mark.usefixtures("disable_capturing")
+    def test_recv_StartupDetails(self):
+        r, w = socketpair()
+        # Create a valid FD for the decoder to open
+        _, w2 = socketpair()
+
+        w.makefile("wb").write(
+            b'{"type":"StartupDetails", "ti": {'
+            b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", 
"try_number": 1, "run_id": "b", "dag_id": "c" }, '
+            b'"file": "/dev/null", "requests_fd": ' + 
str(w2.fileno()).encode("ascii") + b"}\n"
+        )
+
+        decoder = CommsDecoder(input=r.makefile("r"))
+
+        msg = decoder.get_message()
+        assert isinstance(msg, StartupDetails)
+        assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab")
+        assert msg.ti.task_id == "a"
+        assert msg.ti.dag_id == "c"
+        assert msg.file == "/dev/null"
+
+        # Since this was a StartupDetails message, the decoder should open the 
other socket
+        assert decoder.request_socket is not None
+        assert decoder.request_socket.writable()
+        assert decoder.request_socket.fileno() == w2.fileno()
diff --git a/tests/cli/commands/test_celery_command.py 
b/tests/cli/commands/test_celery_command.py
index ae2b71a5909..c417b6eb61c 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -276,6 +276,7 @@ class TestFlowerCommand:
     @mock.patch("airflow.cli.commands.daemon_utils.setup_locations")
     @mock.patch("airflow.cli.commands.daemon_utils.daemon")
     @mock.patch("airflow.providers.celery.executors.celery_executor.app")
+    @pytest.mark.usefixtures("capfd")  # This test needs fd capturing to work
     def test_run_command_daemon(self, mock_celery_app, mock_daemon, 
mock_setup_locations, mock_pid_file):
         mock_setup_locations.return_value = (
             mock.MagicMock(name="pidfile"),


Reply via email to