This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch localexecutor-uses-task-sdk
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 5a7aca4d69f6baa9e821bb30c2d795bbf04a9ca1
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Wed Nov 20 17:49:45 2024 +0000

    Convert the LocalExecutor to run tasks using new Task SDK supervisor code
    
    This also lays the groundwork for a more general purpose "workload" 
execution
    system, make a single interface for executors to run tasks and callbacks.
    
    Also in this PR we set up the supervise function to send Task logs to a 
file,
    and handle the task log template rendering in the scheduler before queueing
    the workload.
    
    Additionally we don't pass the activity directly to `supervise()` but 
instead
    the properties/fields of it to reduce the coupling between SDK and Executor.
    (More separation will appear in PRs over the next few weeks.)
    
    The big change of note here is that rather than sending an airflow command
    line to execute (`["airflow", "tasks", "run", ...]`) and going back in via 
the
    CLI parser we go directly to a special purpose function. Much simpler.
    
    It doesn't remove any of the old behaviour (CeleryExecutor still uses
    LocalTaskJob via the CLI parser etc.), nor does anything currently send
    callback requests via this new workload mechanism.
    
    The `airflow.executors.workloads` module currently needs to be shared 
between
    the Scheduler (or more specifically the Executor) and the "worker" side of
    things. In the future these will be separate python dists and this module 
will
    need to live somewhere else.
    
    Right now we check the if `executor.queue_workload` is different from the
    BaseExecutor version (which just raises an error right now) to see which
    executors support this new version. That check will be removed as soon as 
all
    the in-tree executors have been migrated.
---
 airflow/cli/commands/task_command.py               |  32 +++--
 airflow/config_templates/config.yml                |   2 +-
 airflow/executors/base_executor.py                 |   7 ++
 airflow/executors/local_executor.py                | 140 ++++++---------------
 airflow/executors/workloads.py                     |  98 +++++++++++++++
 airflow/jobs/scheduler_job_runner.py               |  11 +-
 airflow/utils/helpers.py                           |  25 +++-
 task_sdk/pyproject.toml                            |   1 +
 .../src/airflow/sdk/execution_time/supervisor.py   |  71 ++++++++---
 .../src/airflow/sdk/execution_time/task_runner.py  |   4 +
 task_sdk/src/airflow/sdk/log.py                    |  81 ++++++++++--
 task_sdk/tests/execution_time/test_supervisor.py   |  19 ++-
 tests/executors/test_local_executor.py             |  77 ++++++------
 tests/jobs/test_scheduler_job.py                   |   6 +
 14 files changed, 383 insertions(+), 191 deletions(-)

diff --git a/airflow/cli/commands/task_command.py 
b/airflow/cli/commands/task_command.py
index 55c1662d35d..3962592e9b2 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -300,6 +300,8 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) 
-> None:
 
     This can result in the task being started by another host if the executor 
implementation does.
     """
+    from airflow.executors.base_executor import BaseExecutor
+
     if ti.executor:
         executor = ExecutorLoader.load_executor(ti.executor)
     else:
@@ -307,17 +309,25 @@ def _run_task_by_executor(args, dag: DAG, ti: 
TaskInstance) -> None:
     executor.job_id = None
     executor.start()
     print("Sending to executor.")
-    executor.queue_task_instance(
-        ti,
-        mark_success=args.mark_success,
-        ignore_all_deps=args.ignore_all_dependencies,
-        ignore_depends_on_past=should_ignore_depends_on_past(args),
-        wait_for_past_depends_before_skipping=(args.depends_on_past == "wait"),
-        ignore_task_deps=args.ignore_dependencies,
-        ignore_ti_state=args.force,
-        pool=args.pool,
-    )
-    executor.heartbeat()
+
+    # TODO: Task-SDK: this is temporary while we migrate the other executors 
over
+    if executor.queue_workload.__func__ is not BaseExecutor.queue_workload:  # 
type: ignore[attr-defined]
+        from airflow.executors import workloads
+
+        workload = workloads.ExecuteTask.make(ti, 
dag_path=dag.relative_fileloc)
+        executor.queue_workload(workload)
+    else:
+        executor.queue_task_instance(
+            ti,
+            mark_success=args.mark_success,
+            ignore_all_deps=args.ignore_all_dependencies,
+            ignore_depends_on_past=should_ignore_depends_on_past(args),
+            wait_for_past_depends_before_skipping=(args.depends_on_past == 
"wait"),
+            ignore_task_deps=args.ignore_dependencies,
+            ignore_ti_state=args.force,
+            pool=args.pool,
+        )
+        executor.heartbeat()
     executor.end()
 
 
diff --git a/airflow/config_templates/config.yml 
b/airflow/config_templates/config.yml
index 7c50f3b7cd3..5ca17e63ee2 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -908,7 +908,7 @@ logging:
       is_template: true
       default: "dag_id={{ ti.dag_id }}/run_id={{ ti.run_id }}/task_id={{ 
ti.task_id }}/\
                {%% if ti.map_index >= 0 %%}map_index={{ ti.map_index }}/{%% 
endif %%}\
-               attempt={{ try_number }}.log"
+               attempt={{ try_number|default(ti.try_number) }}.log"
     log_processor_filename_template:
       description: |
         Formatting for how airflow generates file names for log
diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index 3e406e3d74d..b57f7458bbb 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -49,6 +49,7 @@ if TYPE_CHECKING:
     from airflow.callbacks.base_callback_sink import BaseCallbackSink
     from airflow.callbacks.callback_requests import CallbackRequest
     from airflow.cli.cli_config import GroupCommand
+    from airflow.executors import workloads
     from airflow.executors.executor_utils import ExecutorName
     from airflow.models.taskinstance import TaskInstance
     from airflow.models.taskinstancekey import TaskInstanceKey
@@ -170,6 +171,9 @@ class BaseExecutor(LoggingMixin):
         else:
             self.log.error("could not queue task %s", task_instance.key)
 
+    def queue_workload(self, workload: workloads.All) -> None:
+        raise ValueError(f"Un-handled workload kind 
{type(workload).__name__!r} in {type(self).__name__}")
+
     def queue_task_instance(
         self,
         task_instance: TaskInstance,
@@ -409,6 +413,9 @@ class BaseExecutor(LoggingMixin):
                 self.execute_async(key=key, command=command, queue=queue, 
executor_config=executor_config)
                 self.running.add(key)
 
+    # TODO: This should not be using `TaskInstanceState` here, this is just 
"did the process complete, or did
+    # it die". It is possible for the task itself to finish with success, but 
the state of the task to be set
+    # to FAILED. By using TaskInstanceState enum here it confuses matters!
     def change_state(
         self, key: TaskInstanceKey, state: TaskInstanceState, info=None, 
remove_running=True
     ) -> None:
diff --git a/airflow/executors/local_executor.py 
b/airflow/executors/local_executor.py
index 02e201f7b5b..397901a8ff4 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -30,32 +30,24 @@ import logging
 import multiprocessing
 import multiprocessing.sharedctypes
 import os
-import subprocess
 from multiprocessing import Queue, SimpleQueue
-from typing import TYPE_CHECKING, Any, Optional
+from typing import TYPE_CHECKING, Optional
 
 from setproctitle import setproctitle
 
 from airflow import settings
 from airflow.executors.base_executor import PARALLELISM, BaseExecutor
-from airflow.traces.tracer import add_span
-from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
 from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
-    from airflow.executors.base_executor import CommandType
-    from airflow.models.taskinstancekey import TaskInstanceKey
+    from airflow.executors import workloads
 
-    # This is a work to be executed by a worker.
-    # It can Key and Command - but it can also be None, None which is actually 
a
-    # "Poison Pill" - worker seeing Poison Pill should take the pill and ... 
die instantly.
-    ExecutorWorkType = Optional[tuple[TaskInstanceKey, CommandType]]
-    TaskInstanceStateType = tuple[TaskInstanceKey, TaskInstanceState, 
Optional[Exception]]
+    TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState, 
Optional[Exception]]
 
 
 def _run_worker(
     logger_name: str,
-    input: SimpleQueue[ExecutorWorkType],
+    input: SimpleQueue[workloads.All | None],
     output: Queue[TaskInstanceStateType],
     unread_messages: multiprocessing.sharedctypes.Synchronized[int],
 ):
@@ -65,16 +57,16 @@ def _run_worker(
     signal.signal(signal.SIGINT, signal.SIG_IGN)
 
     log = logging.getLogger(logger_name)
+    log.info("Worker starting up pid=%d", os.getpid())
 
     # We know we've just started a new process, so lets disconnect from the 
metadata db now
     settings.engine.pool.dispose()
     settings.engine.dispose()
 
-    setproctitle("airflow worker -- LocalExecutor: <idle>")
-
     while True:
+        setproctitle("airflow worker -- LocalExecutor: <idle>")
         try:
-            item = input.get()
+            workload = input.get()
         except EOFError:
             log.info(
                 "Failed to read tasks from the task queue because the other "
@@ -83,88 +75,49 @@ def _run_worker(
             )
             break
 
-        if item is None:
+        if workload is None:
             # Received poison pill, no more tasks to run
             return
 
         # Decrement this as soon as we pick up a message off the queue
         with unread_messages:
             unread_messages.value -= 1
+        key = None
+        if ti := getattr(workload, "ti", None):
+            key = ti.key
+        else:
+            raise TypeError(f"Don't know how to get ti key from 
{type(workload).__name__}")
 
-        (key, command) = item
         try:
-            state = _execute_work(log, key, command)
+            _execute_work(log, workload)
 
-            output.put((key, state, None))
+            output.put((key, TaskInstanceState.SUCCESS, None))
         except Exception as e:
+            log.exception("uhoh")
             output.put((key, TaskInstanceState.FAILED, e))
 
 
-def _execute_work(log: logging.Logger, key: TaskInstanceKey, command: 
CommandType) -> TaskInstanceState:
+def _execute_work(log: logging.Logger, workload: workloads.ExecuteTask) -> 
None:
     """
     Execute command received and stores result state in queue.
 
     :param key: the key to identify the task instance
     :param command: the command to execute
     """
-    setproctitle(f"airflow worker -- LocalExecutor: {command}")
-    dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command)
-    try:
-        with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id):
-            if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:
-                return _execute_work_in_subprocess(log, command)
-            else:
-                return _execute_work_in_fork(log, command)
-    finally:
-        # Remove the command since the worker is done executing the task
-        setproctitle("airflow worker -- LocalExecutor: <idle>")
-
-
-def _execute_work_in_subprocess(log: logging.Logger, command: CommandType) -> 
TaskInstanceState:
-    try:
-        subprocess.check_call(command, close_fds=True)
-        return TaskInstanceState.SUCCESS
-    except subprocess.CalledProcessError as e:
-        log.error("Failed to execute task %s.", e)
-        return TaskInstanceState.FAILED
-
-
-def _execute_work_in_fork(log: logging.Logger, command: CommandType) -> 
TaskInstanceState:
-    pid = os.fork()
-    if pid:
-        # In parent, wait for the child
-        pid, ret = os.waitpid(pid, 0)
-        return TaskInstanceState.SUCCESS if ret == 0 else 
TaskInstanceState.FAILED
-
-    from airflow.sentry import Sentry
-
-    ret = 1
-    try:
-        import signal
-
-        from airflow.cli.cli_parser import get_parser
-
-        signal.signal(signal.SIGINT, signal.SIG_IGN)
-        signal.signal(signal.SIGTERM, signal.SIG_DFL)
-        signal.signal(signal.SIGUSR2, signal.SIG_DFL)
-
-        parser = get_parser()
-        # [1:] - remove "airflow" from the start of the command
-        args = parser.parse_args(command[1:])
-        args.shut_down_logging = False
-
-        setproctitle(f"airflow task supervisor: {command}")
-
-        args.func(args)
-        ret = 0
-        return TaskInstanceState.SUCCESS
-    except Exception as e:
-        log.exception("Failed to execute task %s.", e)
-        return TaskInstanceState.FAILED
-    finally:
-        Sentry.flush()
-        logging.shutdown()
-        os._exit(ret)
+    from airflow.configuration import conf
+    from airflow.sdk.execution_time.supervisor import supervise
+
+    setproctitle(f"airflow worker -- LocalExecutor: {workload.ti.id}")
+    # This will return the exit code of the task process, but we don't care 
about that, just if the
+    # _supervisor_ had an error reporting the state back (which will result in 
an exception.)
+    supervise(
+        # This is the "wrong" ti type, but it duck types the same. TODO: 
Create a protocol for this.
+        ti=workload.ti,  # type: ignore[arg-type]
+        dag_path=workload.dag_path,
+        token=workload.token,
+        server=conf.get("workers", "execution_api_server_url", 
fallback="http://localhost:9091/execution/";),
+        log_path=workload.log_path,
+    )
 
 
 class LocalExecutor(BaseExecutor):
@@ -180,7 +133,7 @@ class LocalExecutor(BaseExecutor):
 
     serve_logs: bool = True
 
-    activity_queue: SimpleQueue[ExecutorWorkType]
+    activity_queue: SimpleQueue[workloads.All | None]
     result_queue: SimpleQueue[TaskInstanceStateType]
     workers: dict[int, multiprocessing.Process]
     _unread_messages: multiprocessing.sharedctypes.Synchronized[int]
@@ -203,22 +156,7 @@ class LocalExecutor(BaseExecutor):
         # (it looks like an int to python)
         self._unread_messages = multiprocessing.Value(ctypes.c_uint)  # type: 
ignore[assignment]
 
-    @add_span
-    def execute_async(
-        self,
-        key: TaskInstanceKey,
-        command: CommandType,
-        queue: str | None = None,
-        executor_config: Any | None = None,
-    ) -> None:
-        """Execute asynchronously."""
-        self.validate_airflow_tasks_run_command(command)
-        self.activity_queue.put((key, command))
-        with self._unread_messages:
-            self._unread_messages.value += 1
-        self._check_workers(can_start=True)
-
-    def _check_workers(self, can_start: bool = True):
+    def _check_workers(self):
         # Reap any dead workers
         to_remove = set()
         for pid, proc in self.workers.items():
@@ -270,12 +208,6 @@ class LocalExecutor(BaseExecutor):
         while not self.result_queue.empty():
             key, state, exc = self.result_queue.get()
 
-            if exc:
-                # TODO: This needs a better stacktrace, it appears from here
-                if hasattr(exc, "add_note"):
-                    exc.add_note("(This stacktrace is incorrect -- the 
exception came from a subprocess)")
-                raise exc
-
             self.change_state(key, state)
 
     def end(self) -> None:
@@ -306,3 +238,9 @@ class LocalExecutor(BaseExecutor):
 
     def terminate(self):
         """Terminate the executor is not doing anything."""
+
+    def queue_workload(self, workload: workloads.All):
+        self.activity_queue.put(workload)
+        with self._unread_messages:
+            self._unread_messages.value += 1
+        self._check_workers()
diff --git a/airflow/executors/workloads.py b/airflow/executors/workloads.py
new file mode 100644
index 00000000000..0adb54cd6da
--- /dev/null
+++ b/airflow/executors/workloads.py
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+import uuid
+from pathlib import Path
+from typing import TYPE_CHECKING, Literal, Union
+
+from pydantic import BaseModel, Field
+
+if TYPE_CHECKING:
+    from airflow.models.taskinstance import TaskInstance as TIModel
+    from airflow.models.taskinstancekey import TaskInstanceKey
+
+
+__all__ = [
+    "All",
+    "ExecuteTask",
+]
+
+
+class BaseActivity(BaseModel):
+    token: str
+    """The identity token for this workload"""
+
+
+class TaskInstance(BaseModel):
+    """Schema for TaskInstance with minimal required fields needed for 
Executors and Task SDK."""
+
+    id: uuid.UUID
+
+    task_id: str
+    dag_id: str
+    run_id: str
+    try_number: int
+    map_index: int | None = None
+
+    # TODO: Task-SDK: Can we replace TastInstanceKey with just the uuid across 
the codebase?
+    @property
+    def key(self) -> TaskInstanceKey:
+        from airflow.models.taskinstancekey import TaskInstanceKey
+
+        return TaskInstanceKey(
+            dag_id=self.dag_id,
+            task_id=self.task_id,
+            run_id=self.run_id,
+            try_number=self.try_number,
+            map_index=-1 if self.map_index is None else self.map_index,
+        )
+
+
+class ExecuteTask(BaseActivity):
+    """Execute the given Task."""
+
+    ti: TaskInstance
+    """The TaskInstance to execute"""
+    dag_path: os.PathLike[str]
+    """The filepath where the DAG can be found (likely prefixed with 
`DAG_FOLDER/`)"""
+
+    log_path: str | None
+    """The rendered relative log filename template the task logs should be 
written to"""
+
+    kind: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask")
+
+    @classmethod
+    def make(cls, ti: TIModel, dag_path: Path | None = None) -> ExecuteTask:
+        from pathlib import Path
+
+        from airflow.utils.helpers import log_filename_template_renderer
+
+        ser_ti = TaskInstance.model_validate(ti, from_attributes=True)
+
+        dag_path = dag_path or Path(ti.dag_run.dag_model.relative_fileloc)
+
+        if dag_path and not dag_path.is_absolute():
+            # TODO: What about multiple dag sub folders
+            dag_path = "DAGS_FOLDER" / dag_path
+
+        fname = log_filename_template_renderer()(ti=ti)
+        return cls(ti=ser_ti, dag_path=dag_path, token="", log_path=fname)
+
+
+All = Union[ExecuteTask]
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index da50deb31c9..56a65009e2b 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -44,6 +44,8 @@ from airflow.callbacks.callback_requests import 
DagCallbackRequest, TaskCallback
 from airflow.callbacks.pipe_callback_sink import PipeCallbackSink
 from airflow.configuration import conf
 from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.executors import workloads
+from airflow.executors.base_executor import BaseExecutor
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.jobs.base_job_runner import BaseJobRunner
 from airflow.jobs.job import Job, perform_heartbeat
@@ -91,7 +93,6 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Query, Session
 
     from airflow.dag_processing.manager import DagFileProcessorAgent
-    from airflow.executors.base_executor import BaseExecutor
     from airflow.executors.executor_utils import ExecutorName
     from airflow.models.taskinstance import TaskInstanceKey
     from airflow.utils.sqlalchemy import (
@@ -654,6 +655,14 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             if ti.dag_run.state in State.finished_dr_states:
                 ti.set_state(None, session=session)
                 continue
+
+            # TODO: Task-SDK: This check is transitionary. Remove once all 
executors are ported over.
+            # Has a real queue_activity implemented
+            if executor.queue_workload.__func__ is not 
BaseExecutor.queue_workload:  # type: ignore[attr-defined]
+                workload = workloads.ExecuteTask.make(ti)
+                executor.queue_workload(workload)
+                continue
+
             command = ti.command_as_list(
                 local=True,
             )
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index 4cfc62acdd2..30f8dde41af 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -23,7 +23,7 @@ import re
 import signal
 from collections.abc import Generator, Iterable, Mapping, MutableMapping
 from datetime import datetime
-from functools import reduce
+from functools import cache, reduce
 from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
 
 from lazy_object_proxy import Proxy
@@ -171,7 +171,7 @@ def as_flattened_list(iterable: Iterable[Iterable[T]]) -> 
list[T]:
     return [e for i in iterable for e in i]
 
 
-def parse_template_string(template_string: str) -> tuple[str | None, 
jinja2.Template | None]:
+def parse_template_string(template_string: str) -> tuple[str, None] | 
tuple[None, jinja2.Template]:
     """Parse Jinja template string."""
     import jinja2
 
@@ -181,6 +181,27 @@ def parse_template_string(template_string: str) -> 
tuple[str | None, jinja2.Temp
         return template_string, None
 
 
+@cache
+def log_filename_template_renderer() -> Callable[..., str]:
+    template = conf.get("logging", "log_filename_template")
+
+    if "{{" in template:
+        import jinja2
+
+        return jinja2.Template(template).render
+    else:
+
+        def f_str_format(ti: TaskInstance, try_number: int | None = None):
+            return template.format(
+                dag_id=ti.dag_id,
+                task_id=ti.task_id,
+                logical_date=ti.logical_date.isoformat(),
+                try_number=try_number or ti.try_number,
+            )
+
+        return f_str_format
+
+
 def render_log_filename(ti: TaskInstance, try_number, filename_template) -> 
str:
     """
     Given task instance, try_number, filename_template, return the rendered 
log filename.
diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml
index 170aff5ec24..77061faa37c 100644
--- a/task_sdk/pyproject.toml
+++ b/task_sdk/pyproject.toml
@@ -25,6 +25,7 @@ dependencies = [
     "attrs>=24.2.0",
     "google-re2>=1.1.20240702",
     "httpx>=0.27.0",
+    "jinja2>=3.1.4",
     "methodtools>=0.4.7",
     "msgspec>=0.18.6",
     "psutil>=6.1.0",
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 6dfbe415058..4058e46513a 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -54,9 +54,8 @@ from airflow.sdk.execution_time.comms import (
 )
 
 if TYPE_CHECKING:
-    from structlog.typing import FilteringBoundLogger
+    from structlog.typing import FilteringBoundLogger, WrappedLogger
 
-    from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
 
 __all__ = ["WatchedSubprocess", "supervise"]
 
@@ -282,6 +281,7 @@ class WatchedSubprocess:
         ti: TaskInstance,
         client: Client,
         target: Callable[[], None] = _subprocess_main,
+        logger: FilteringBoundLogger | None = None,
     ) -> WatchedSubprocess:
         """Fork and start a new subprocess to execute the given task."""
         # Create socketpairs/"pipes" to connect to the stdin and out from the 
subprocess
@@ -297,6 +297,13 @@ class WatchedSubprocess:
         if pid == 0:
             # Parent ends of the sockets are closed by the OS as they are set 
as non-inheritable
 
+            # Python GC should delete these for us, but lets make double sure 
that we don't keep anything
+            # around in the forked processes, especially things that might 
involve open files or sockets!
+            del path
+            del client
+            del ti
+            del logger
+
             # Run the child entrypoint
             _fork_main(child_stdin, child_stdout, child_stderr, 
child_logs.fileno(), target)
 
@@ -319,8 +326,13 @@ class WatchedSubprocess:
             proc.kill(signal.SIGKILL)
             raise
 
+        logger = logger or cast("FilteringBoundLogger", 
structlog.get_logger(logger_name="task").bind())
         proc._register_pipe_readers(
-            stdout=read_stdout, stderr=read_stderr, requests=read_msgs, 
logs=read_logs
+            logger=logger,
+            stdout=read_stdout,
+            stderr=read_stderr,
+            requests=read_msgs,
+            logs=read_logs,
         )
 
         # Close the remaining parent-end of the sockets we've passed to the 
child via fork. We still have the
@@ -331,16 +343,15 @@ class WatchedSubprocess:
         proc._send_startup_message(ti, path, child_comms)
         return proc
 
-    def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: 
socket, logs: socket):
+    def _register_pipe_readers(
+        self, logger: FilteringBoundLogger, stdout: socket, stderr: socket, 
requests: socket, logs: socket
+    ):
         """Register handlers for subprocess communication channels."""
         # self.selector is a way of registering a handler/callback to be 
called when the given IO channel has
         # activity to read on 
(https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better
         # alternatives are used automatically) -- this is a way of having 
"event-based" code, but without
         # needing full async, to read and process output from each socket as 
it is received.
 
-        # TODO: Use logging providers to handle the chunked upload for us
-        logger: FilteringBoundLogger = 
structlog.get_logger(logger_name="task").bind()
-
         self.selector.register(stdout, selectors.EVENT_READ, 
self._create_socket_handler(logger, "stdout"))
         self.selector.register(
             stderr,
@@ -369,7 +380,7 @@ class WatchedSubprocess:
 
     def _send_startup_message(self, ti: TaskInstance, path: str | 
os.PathLike[str], child_comms: socket):
         """Send startup message to the subprocess."""
-        msg = StartupDetails(
+        msg = StartupDetails.model_construct(
             ti=ti,
             file=str(path),
             requests_fd=child_comms.fileno(),
@@ -630,25 +641,57 @@ def forward_to_log(target_log: FilteringBoundLogger, 
level: int) -> Generator[No
             target_log.log(level, msg)
 
 
-def supervise(activity: ExecuteTaskActivity, server: str | None = None, 
dry_run: bool = False) -> int:
+def supervise(
+    *,
+    ti: TaskInstance,
+    dag_path: str | os.PathLike[str],
+    token: str,
+    server: str | None = None,
+    dry_run: bool = False,
+    log_path: str | None = None,
+) -> int:
     """
     Run a single task execution to completion.
 
     Returns the exit code of the process
     """
     # One or the other
-    if (server == "") ^ dry_run:
+    if (not server) ^ dry_run:
         raise ValueError(f"Can only specify one of {server=} or {dry_run=}")
 
-    if not activity.path:
-        raise ValueError("path filed of activity missing")
+    if not dag_path:
+        raise ValueError("dag_path is required")
+
+    if (str_path := os.fspath(dag_path)).startswith("DAGS_FOLDER/"):
+        from airflow.settings import DAGS_FOLDER
+
+        dag_path = str_path.replace("DAGS_FOLDER/", DAGS_FOLDER + "/", 1)
 
     limits = httpx.Limits(max_keepalive_connections=1, max_connections=10)
-    client = Client(base_url=server or "", limits=limits, dry_run=dry_run, 
token=activity.token)
+    client = Client(base_url=server or "", limits=limits, dry_run=dry_run, 
token=token)
 
     start = time.monotonic()
 
-    process = WatchedSubprocess.start(activity.path, activity.ti, 
client=client)
+    # TODO: Use logging providers to handle the chunked upload for us etc.
+    logger: FilteringBoundLogger | None = None
+    if log_path:
+        # If we are told to write logs to a file, redirect the task logger to 
it.
+        from airflow.sdk.log import init_log_file, logging_processors
+
+        try:
+            log_file = init_log_file(log_path)
+        except OSError as e:
+            log.warning("OSError while changing ownership of the log file. ", 
e)
+
+        pretty_logs = False
+        if pretty_logs:
+            underlying_logger: WrappedLogger = 
structlog.WriteLogger(log_file.open("w", buffering=1))
+        else:
+            underlying_logger = structlog.BytesLogger(log_file.open("wb"))
+        processors = logging_processors(enable_pretty_log=pretty_logs)[0]
+        logger = structlog.wrap_logger(underlying_logger, 
processors=processors, logger_name="task").bind()
+
+    process = WatchedSubprocess.start(dag_path, ti, client=client, 
logger=logger)
 
     exit_code = process.wait()
     end = time.monotonic()
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 9abf3a4796b..c693a77eac2 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -127,6 +127,10 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
     msg = SUPERVISOR_COMMS.get_message()
 
     if isinstance(msg, StartupDetails):
+        from setproctitle import setproctitle
+
+        setproctitle(f"airflow worker -- {msg.ti.id}")
+
         log = structlog.get_logger(logger_name="task")
         # TODO: set the "magic loop" context vars for parsing
         ti = parse(msg)
diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py
index f8e06eda4a6..0873233f536 100644
--- a/task_sdk/src/airflow/sdk/log.py
+++ b/task_sdk/src/airflow/sdk/log.py
@@ -23,6 +23,7 @@ import os
 import sys
 import warnings
 from functools import cache
+from pathlib import Path
 from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Generic, TextIO, 
TypeVar
 
 import msgspec
@@ -285,9 +286,11 @@ def configure_logging(
         std_lib_formatter.append(named["json"])
 
     global _warnings_showwarning
-    _warnings_showwarning = warnings.showwarning
-    # Capture warnings and show them via structlog
-    warnings.showwarning = _showwarning
+
+    if _warnings_showwarning is None:
+        _warnings_showwarning = warnings.showwarning
+        # Capture warnings and show them via structlog
+        warnings.showwarning = _showwarning
 
     logging.config.dictConfig(
         {
@@ -327,11 +330,13 @@ def configure_logging(
                     "propagate": True,
                 },
                 # Some modules we _never_ want at debug level
-                "asyncio": {"level": "INFO"},
                 "alembic": {"level": "INFO"},
+                "asyncio": {"level": "INFO"},
+                "cron_descriptor.GetText": {"level": "INFO"},
                 "httpcore": {"level": "INFO"},
                 "httpx": {"level": "WARN"},
                 "psycopg.pq": {"level": "INFO"},
+                # These ones are too chatty even at info
                 "sqlalchemy.engine": {"level": "WARN"},
             },
         }
@@ -344,17 +349,17 @@ def reset_logging():
     configure_logging.cache_clear()
 
 
-_warnings_showwarning = None
+_warnings_showwarning: Any = None
 
 
 def _showwarning(
-    message: str | Warning,
+    message: Warning | str,
     category: type[Warning],
     filename: str,
     lineno: int,
     file: TextIO | None = None,
     line: str | None = None,
-):
+) -> Any:
     """
     Redirects warnings to structlog so they appear in task logs etc.
 
@@ -370,3 +375,65 @@ def _showwarning(
     else:
         log = structlog.get_logger(logger_name="py.warnings")
         log.warning(str(message), category=category.__name__, 
filename=filename, lineno=lineno)
+
+
+def _prepare_log_folder(directory: Path, mode: int):
+    """
+    Prepare the log folder and ensure its mode is as configured.
+
+    To handle log writing when tasks are impersonated, the log files need to
+    be writable by the user that runs the Airflow command and the user
+    that is impersonated. This is mainly to handle corner cases with the
+    SubDagOperator. When the SubDagOperator is run, all of the operators
+    run under the impersonated user and create appropriate log files
+    as the impersonated user. However, if the user manually runs tasks
+    of the SubDagOperator through the UI, then the log files are created
+    by the user that runs the Airflow command. For example, the Airflow
+    run command may be run by the `airflow_sudoable` user, but the Airflow
+    tasks may be run by the `airflow` user. If the log files are not
+    writable by both users, then it's possible that re-running a task
+    via the UI (or vice versa) results in a permission error as the task
+    tries to write to a log file created by the other user.
+
+    We leave it up to the user to manage their permissions by exposing 
configuration for both
+    new folders and new log files. Default is to make new log folders and 
files group-writeable
+    to handle most common impersonation use cases. The requirement in this 
case will be to make
+    sure that the same group is set as default group for both - impersonated 
user and main airflow
+    user.
+    """
+    for parent in reversed(directory.parents):
+        parent.mkdir(mode=mode, exist_ok=True)
+    directory.mkdir(mode=mode, exist_ok=True)
+
+
+def init_log_file(local_relative_path: str) -> Path:
+    """
+    Ensure log file and parent directories are created.
+
+    Any directories that are missing are created with the right permission 
bits.
+
+    See above ``_prepare_log_folder`` method for more detailed explanation.
+    """
+    # NOTE: This is duplicated from 
airflow.utils.log.file_task_handler:FileTaskHandler._init_file, but we
+    # want to remove that
+    from airflow.configuration import conf
+
+    new_file_permissions = int(
+        conf.get("logging", "file_task_handler_new_file_permissions", 
fallback="0o664"), 8
+    )
+    new_folder_permissions = int(
+        conf.get("logging", "file_task_handler_new_folder_permissions", 
fallback="0o775"), 8
+    )
+
+    base_log_folder = conf.get("logging", "base_log_folder")
+    full_path = Path(base_log_folder, local_relative_path)
+
+    _prepare_log_folder(full_path.parent, new_folder_permissions)
+
+    try:
+        full_path.touch(new_file_permissions)
+    except OSError as e:
+        log = structlog.get_logger()
+        log.warning("OSError while changing ownership of the log file. %s", e)
+
+    return full_path
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 2cb8c24af1b..3a50a73b3cf 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -36,7 +36,6 @@ from uuid6 import uuid7
 from airflow.sdk.api import client as sdk_client
 from airflow.sdk.api.client import ServerResponseError
 from airflow.sdk.api.datamodels._generated import TaskInstance
-from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
 from airflow.sdk.execution_time.comms import (
     ConnectionResult,
     DeferTask,
@@ -249,19 +248,15 @@ class TestWatchedSubprocess:
         time_machine.move_to(instant, tick=False)
 
         dagfile_path = test_dags_dir / "super_basic_run.py"
-        task_activity = ExecuteTaskActivity(
-            ti=TaskInstance(
-                id=uuid7(),
-                task_id="hello",
-                dag_id="super_basic_run",
-                run_id="c",
-                try_number=1,
-            ),
-            path=dagfile_path,
-            token="",
+        ti = TaskInstance(
+            id=uuid7(),
+            task_id="hello",
+            dag_id="super_basic_run",
+            run_id="c",
+            try_number=1,
         )
         # Assert Exit Code is 0
-        assert supervise(activity=task_activity, server="", dry_run=True) == 0
+        assert supervise(ti=ti, dag_path=dagfile_path, token="", server="", 
dry_run=True) == 0
 
         # We should have a log from the task!
         assert {
diff --git a/tests/executors/test_local_executor.py 
b/tests/executors/test_local_executor.py
index 2545ceb7e70..a64f87bf1ae 100644
--- a/tests/executors/test_local_executor.py
+++ b/tests/executors/test_local_executor.py
@@ -17,17 +17,15 @@
 # under the License.
 from __future__ import annotations
 
-import datetime
 import multiprocessing
 import os
-import subprocess
 from unittest import mock
 
 import pytest
 from kgb import spy_on
+from uuid6 import uuid7
 
-from airflow import settings
-from airflow.exceptions import AirflowException
+from airflow.executors import workloads
 from airflow.executors.local_executor import LocalExecutor
 from airflow.utils.state import State
 
@@ -52,43 +50,43 @@ class TestLocalExecutor:
     def test_serve_logs_default_value(self):
         assert LocalExecutor.serve_logs
 
-    @mock.patch("airflow.executors.local_executor.subprocess.check_call")
-    @mock.patch("airflow.cli.commands.task_command.task_run")
-    def _test_execute(self, mock_run, mock_check_call, parallelism=1):
-        success_command = ["airflow", "tasks", "run", "success", 
"some_parameter", "2020-10-07"]
-        fail_command = ["airflow", "tasks", "run", "failure", "task_id", 
"2020-10-07"]
+    @mock.patch("airflow.sdk.execution_time.supervisor.supervise")
+    def _test_execute(self, mock_supervise, parallelism=1):
+        success_tis = [
+            workloads.TaskInstance(
+                id=uuid7(),
+                task_id=f"success_{i}",
+                dag_id="mydag",
+                run_id="run1",
+                try_number=1,
+                state="queued",
+            )
+            for i in range(self.TEST_SUCCESS_COMMANDS)
+        ]
+        fail_ti = success_tis[0].model_copy(update={"id": uuid7(), "task_id": 
"failure"})
 
         # We just mock both styles here, only one will be hit though
-        def fake_execute_command(command, close_fds=True):
-            if command != success_command:
-                raise subprocess.CalledProcessError(returncode=1, cmd=command)
-            else:
-                return 0
-
-        def fake_task_run(args):
-            if args.dag_id != "success":
-                raise AirflowException("Simulate failed task")
+        def fake_supervise(ti, **kwargs):
+            if ti.id == fail_ti.id:
+                raise RuntimeError("fake failure")
+            return 0
 
-        mock_check_call.side_effect = fake_execute_command
-        mock_run.side_effect = fake_task_run
+        mock_supervise.side_effect = fake_supervise
 
         executor = LocalExecutor(parallelism=parallelism)
         executor.start()
 
-        success_key = "success {}"
         assert executor.result_queue.empty()
 
         with spy_on(executor._spawn_worker) as spawn_worker:
-            run_id = "manual_" + datetime.datetime.now().isoformat()
-            for i in range(self.TEST_SUCCESS_COMMANDS):
-                key_id, command = success_key.format(i), success_command
-                key = key_id, "fake_ti", run_id, 0
-                executor.running.add(key)
-                executor.execute_async(key=key, command=command)
+            for ti in success_tis:
+                executor.queue_workload(
+                    workloads.ExecuteTask(token="", ti=ti, 
dag_path="some/path", log_path=None)
+                )
 
-            fail_key = "fail", "fake_ti", run_id, 0
-            executor.running.add(fail_key)
-            executor.execute_async(key=fail_key, command=fail_command)
+            executor.queue_workload(
+                workloads.ExecuteTask(token="", ti=fail_ti, 
dag_path="some/path", log_path=None)
+            )
 
             executor.end()
 
@@ -100,24 +98,19 @@ class TestLocalExecutor:
         assert len(executor.running) == 0
         assert executor._unread_messages.value == 0
 
-        for i in range(self.TEST_SUCCESS_COMMANDS):
-            key_id = success_key.format(i)
-            key = key_id, "fake_ti", run_id, 0
-            assert executor.event_buffer[key][0] == State.SUCCESS
-        assert executor.event_buffer[fail_key][0] == State.FAILED
+        for ti in success_tis:
+            assert executor.event_buffer[ti.key][0] == State.SUCCESS
+        assert executor.event_buffer[fail_ti.key][0] == State.FAILED
 
     @skip_spawn_mp_start
     @pytest.mark.parametrize(
-        ("parallelism", "fork_or_subproc"),
+        ("parallelism",),
         [
-            pytest.param(0, True, id="unlimited_subprocess"),
-            pytest.param(2, True, id="limited_subprocess"),
-            pytest.param(0, False, id="unlimited_fork"),
-            pytest.param(2, False, id="limited_fork"),
+            pytest.param(0, id="unlimited"),
+            pytest.param(2, id="limited"),
         ],
     )
-    def test_execution(self, parallelism: int, fork_or_subproc: bool, 
monkeypatch: pytest.MonkeyPatch):
-        monkeypatch.setattr(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", 
fork_or_subproc)
+    def test_execution(self, parallelism: int):
         self._test_execute(parallelism=parallelism)
 
     @mock.patch("airflow.executors.local_executor.LocalExecutor.sync")
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 3b2d1a34dd6..b4039c7bc51 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -181,6 +181,12 @@ class TestSchedulerJob:
         default_executor.name = ExecutorName(alias="default_exec", 
module_path="default.exec.module.path")
         second_executor = mock.MagicMock(name="SeconadaryExecutor", 
slots_available=8, slots_occupied=0)
         second_executor.name = ExecutorName(alias="secondary_exec", 
module_path="secondary.exec.module.path")
+
+        # TODO: Task-SDK Make it look like a bound method. Needed until we 
remove the old queue_command
+        # interface from executors
+        default_executor.queue_workload.__func__ = BaseExecutor.queue_workload
+        second_executor.queue_workload.__func__ = BaseExecutor.queue_workload
+
         with mock.patch("airflow.jobs.job.Job.executors", 
new_callable=PropertyMock) as executors_mock:
             executors_mock.return_value = [default_executor, second_executor]
             yield [default_executor, second_executor]


Reply via email to