potiuk commented on code in PR #68533: URL: https://github.com/apache/airflow/pull/68533#discussion_r3411154296
########## devel-common/src/tests_common/test_utils/in_process_taskrun.py: ########## @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DB-free, xdist-safe execution of a task through a *real* supervisor socket. + +`run_task` (in ``pytest_plugin``) mocks supervisor comms entirely in-process and +has **no real socket**, so operators that spawn a subprocess which re-connects to +the supervisor — ``PythonVirtualenvOperator``, ``ExternalPythonOperator``, +``run_as_user`` — fail there with ``OSError: Socket operation on non-socket``. + +This helper drives the *real* ``InProcessTestSupervisor`` socketpair machinery +(created explicitly for VirtualEnv operators) but serves every Execution-API call +from an in-memory stub instead of the DB-backed ``InProcessExecutionAPI`` — so the +subprocess gets a working supervisor socket without touching the metadata DB. The +result: such tests need no ``@pytest.mark.db_test`` and run under xdist. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest import mock + +if TYPE_CHECKING: + from collections.abc import Callable + + from airflow.sdk.types import Operator + + +class _StubXComs: + """Dict-backed stand-in for ``client.xcoms`` (the only resource that must round-trip).""" + + def __init__(self) -> None: + self.store: dict[tuple, Any] = {} + + def set(self, dag_id, run_id, task_id, key, value, map_index, **kwargs): + self.store[(dag_id, run_id, task_id, key, map_index)] = value + + def get(self, dag_id, run_id, task_id, key, map_index, include_prior_dates=False): + from airflow.sdk.api.datamodels._generated import XComResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if (dag_id, run_id, task_id, key, map_index) in self.store: + return XComResponse(key=key, value=self.store[(dag_id, run_id, task_id, key, map_index)]) + return ErrorResponse(error=ErrorType.XCOM_NOT_FOUND) + + def delete(self, *args, **kwargs): + return None + + +class _StubVariables: + def __init__(self, values: dict[str, Any] | None = None) -> None: + self.store = dict(values or {}) + + def get(self, key): + from airflow.sdk.api.datamodels._generated import VariableResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if key in self.store: + return VariableResponse(key=key, value=self.store[key]) + return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND) + + def set(self, key, value, description=None): + self.store[key] = value + + def delete(self, key): + self.store.pop(key, None) + return None + + +class _StubConnections: + def __init__(self, conns: dict[str, Any] | None = None) -> None: + self.store = dict(conns or {}) + + def get(self, conn_id): + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if conn_id in self.store: + return self.store[conn_id] + return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + + +class _InMemoryExecutionClient: Review Comment: Done — collapsed into a real `Client(dry_run=True)` whose `MockTransport` remembers XCom writes in a dict (replaying the run-context, since `noop_handler`'s fake `TIRunContext` no longer validates against the live schema). The three `_Stub*` classes are gone, and with a real client there's no `MagicMock`/`__getattr__` fallback either, so the "passes against a mock" footgun is removed. --- Drafted-by: Claude Code (Opus 4.8); reviewed by @potiuk before posting ########## devel-common/src/tests_common/test_utils/in_process_taskrun.py: ########## @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DB-free, xdist-safe execution of a task through a *real* supervisor socket. + +`run_task` (in ``pytest_plugin``) mocks supervisor comms entirely in-process and +has **no real socket**, so operators that spawn a subprocess which re-connects to +the supervisor — ``PythonVirtualenvOperator``, ``ExternalPythonOperator``, +``run_as_user`` — fail there with ``OSError: Socket operation on non-socket``. + +This helper drives the *real* ``InProcessTestSupervisor`` socketpair machinery +(created explicitly for VirtualEnv operators) but serves every Execution-API call +from an in-memory stub instead of the DB-backed ``InProcessExecutionAPI`` — so the +subprocess gets a working supervisor socket without touching the metadata DB. The +result: such tests need no ``@pytest.mark.db_test`` and run under xdist. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest import mock + +if TYPE_CHECKING: + from collections.abc import Callable + + from airflow.sdk.types import Operator + + +class _StubXComs: + """Dict-backed stand-in for ``client.xcoms`` (the only resource that must round-trip).""" + + def __init__(self) -> None: + self.store: dict[tuple, Any] = {} + + def set(self, dag_id, run_id, task_id, key, value, map_index, **kwargs): + self.store[(dag_id, run_id, task_id, key, map_index)] = value + + def get(self, dag_id, run_id, task_id, key, map_index, include_prior_dates=False): + from airflow.sdk.api.datamodels._generated import XComResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if (dag_id, run_id, task_id, key, map_index) in self.store: + return XComResponse(key=key, value=self.store[(dag_id, run_id, task_id, key, map_index)]) + return ErrorResponse(error=ErrorType.XCOM_NOT_FOUND) + + def delete(self, *args, **kwargs): + return None + + +class _StubVariables: + def __init__(self, values: dict[str, Any] | None = None) -> None: + self.store = dict(values or {}) + + def get(self, key): + from airflow.sdk.api.datamodels._generated import VariableResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if key in self.store: + return VariableResponse(key=key, value=self.store[key]) + return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND) + + def set(self, key, value, description=None): + self.store[key] = value + + def delete(self, key): + self.store.pop(key, None) + return None + + +class _StubConnections: + def __init__(self, conns: dict[str, Any] | None = None) -> None: + self.store = dict(conns or {}) + + def get(self, conn_id): + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if conn_id in self.store: + return self.store[conn_id] + return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + + +class _InMemoryExecutionClient: + """In-memory stand-in for the Task SDK execution-API ``Client`` (no metadata DB).""" + + def __init__(self, ti_context, variables=None, connections=None) -> None: + self.task_instances = mock.MagicMock(name="stub.task_instances") + self.task_instances.start.return_value = ti_context + self.xcoms = _StubXComs() + self.variables = _StubVariables(variables) + self.connections = _StubConnections(connections) + + def __getattr__(self, name): + # Resources we don't model (assets, dag_runs, hitl, task_store, ...) are + # absorbed — venv operator tests don't exercise them. + if name.startswith("__"): + raise AttributeError(name) + return mock.MagicMock(name=f"stub_client.{name}") + + +class TaskRunResultNoDB: Review Comment: Done — `run_task_no_db` now returns the stock `TaskRunResult` (plus the pushed-XCom dict for assertions); `TaskRunResultNoDB` is removed. --- Drafted-by: Claude Code (Opus 4.8); reviewed by @potiuk before posting ########## devel-common/src/tests_common/test_utils/in_process_taskrun.py: ########## @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DB-free, xdist-safe execution of a task through a *real* supervisor socket. + +`run_task` (in ``pytest_plugin``) mocks supervisor comms entirely in-process and +has **no real socket**, so operators that spawn a subprocess which re-connects to +the supervisor — ``PythonVirtualenvOperator``, ``ExternalPythonOperator``, +``run_as_user`` — fail there with ``OSError: Socket operation on non-socket``. + +This helper drives the *real* ``InProcessTestSupervisor`` socketpair machinery +(created explicitly for VirtualEnv operators) but serves every Execution-API call +from an in-memory stub instead of the DB-backed ``InProcessExecutionAPI`` — so the +subprocess gets a working supervisor socket without touching the metadata DB. The +result: such tests need no ``@pytest.mark.db_test`` and run under xdist. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest import mock + +if TYPE_CHECKING: + from collections.abc import Callable + + from airflow.sdk.types import Operator + + +class _StubXComs: + """Dict-backed stand-in for ``client.xcoms`` (the only resource that must round-trip).""" + + def __init__(self) -> None: + self.store: dict[tuple, Any] = {} + + def set(self, dag_id, run_id, task_id, key, value, map_index, **kwargs): + self.store[(dag_id, run_id, task_id, key, map_index)] = value + + def get(self, dag_id, run_id, task_id, key, map_index, include_prior_dates=False): + from airflow.sdk.api.datamodels._generated import XComResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if (dag_id, run_id, task_id, key, map_index) in self.store: + return XComResponse(key=key, value=self.store[(dag_id, run_id, task_id, key, map_index)]) + return ErrorResponse(error=ErrorType.XCOM_NOT_FOUND) + + def delete(self, *args, **kwargs): + return None + + +class _StubVariables: + def __init__(self, values: dict[str, Any] | None = None) -> None: + self.store = dict(values or {}) + + def get(self, key): + from airflow.sdk.api.datamodels._generated import VariableResponse + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if key in self.store: + return VariableResponse(key=key, value=self.store[key]) + return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND) + + def set(self, key, value, description=None): + self.store[key] = value + + def delete(self, key): + self.store.pop(key, None) + return None + + +class _StubConnections: + def __init__(self, conns: dict[str, Any] | None = None) -> None: + self.store = dict(conns or {}) + + def get(self, conn_id): + from airflow.sdk.exceptions import ErrorType + from airflow.sdk.execution_time.comms import ErrorResponse + + if conn_id in self.store: + return self.store[conn_id] + return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + + +class _InMemoryExecutionClient: + """In-memory stand-in for the Task SDK execution-API ``Client`` (no metadata DB).""" + + def __init__(self, ti_context, variables=None, connections=None) -> None: + self.task_instances = mock.MagicMock(name="stub.task_instances") + self.task_instances.start.return_value = ti_context + self.xcoms = _StubXComs() + self.variables = _StubVariables(variables) + self.connections = _StubConnections(connections) + + def __getattr__(self, name): + # Resources we don't model (assets, dag_runs, hitl, task_store, ...) are + # absorbed — venv operator tests don't exercise them. + if name.startswith("__"): + raise AttributeError(name) + return mock.MagicMock(name=f"stub_client.{name}") + + +class TaskRunResultNoDB: + """Result of :func:`run_task_no_db`, mirroring the ``run_task`` fixture surface.""" + + def __init__(self, result, client: _InMemoryExecutionClient, ti) -> None: + self._result = result + self.client = client + self._ti = ti + + @property + def state(self): + return self._result.state + + @property + def error(self): + return self._result.error + + @property + def msg(self): + return self._result.msg + + def xcom_get( + self, + key: str = "return_value", + task_id: str | None = None, + dag_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, + ) -> Any: + task_id = task_id or self._ti.task_id + dag_id = dag_id or self._ti.dag_id + run_id = run_id or self._ti.run_id + map_index = map_index if map_index is not None else self._ti.map_index + return self.client.xcoms.store.get((dag_id, run_id, task_id, key, map_index)) + + +def run_task_no_db( + task: Operator, + create_runtime_ti: Callable[..., Any], + *, + logical_date: Any | None = None, + variables: dict[str, Any] | None = None, + connections: dict[str, Any] | None = None, +) -> TaskRunResultNoDB: + """Run *task* DB-free through the real-socket in-process supervisor.""" + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance as TaskInstanceDTO + from airflow.sdk.execution_time.supervisor import InProcessTestSupervisor + + ti_kwargs = {} if logical_date is None else {"logical_date": logical_date} + rti = create_runtime_ti(task, **ti_kwargs) + ti_context = rti._ti_context_from_server + + # `start()` model_dumps `what`; the plain DTO dumps cleanly, whereas the + # operator-laden RuntimeTaskInstance trips forward refs (RetryPolicy/WeightRuleParam). + what = TaskInstanceDTO( + id=rti.id, + task_id=rti.task_id, + dag_id=rti.dag_id, + run_id=rti.run_id, + try_number=rti.try_number, + map_index=rti.map_index, + dag_version_id=uuid7(), + queue="default", + ) + + client = _InMemoryExecutionClient(ti_context, variables=variables, connections=connections) + + class _StubBackendSupervisor(InProcessTestSupervisor): Review Comment: Done — added an optional `client=` to `start()`/`run_task_in_process`; `run_task_no_db` injects through it, so the per-call subclass is gone on `main` (kept only as a compat shim for Task SDKs that have `InProcessTestSupervisor` but not the param yet). --- Drafted-by: Claude Code (Opus 4.8); reviewed by @potiuk before posting -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
