gemini-code-assist[bot] commented on code in PR #39074:
URL: https://github.com/apache/beam/pull/39074#discussion_r3461188359
##########
sdks/python/apache_beam/ml/inference/agent_development_kit.py:
##########
@@ -165,11 +193,54 @@ def load_model(self) -> "Runner":
Returns:
A fully initialised :class:`~google.adk.runners.Runner`.
"""
+ local_model = None
+ underlying_model = None
+
+ if self._underlying_model_handler is not None:
+ underlying_model = self._underlying_model_handler.load_model()
+ self._current_port =
self._underlying_model_handler.get_port(underlying_model)
+ model_name = self._underlying_model_handler.get_model_name()
+
+ from google.adk.models.lite_llm import LiteLlm
+ local_model = LiteLlm(
+ model=model_name,
+ api_base=f"http://localhost:{self._current_port}/v1"
+ )
+
+ # Resolve agent and inject model
if callable(self._agent_or_factory) and not isinstance(
self._agent_or_factory, Agent):
- agent = self._agent_or_factory()
+ import inspect
+ sig = inspect.signature(self._agent_or_factory)
+ params = list(sig.parameters.values())
+
+ if len(params) == 1:
+ if local_model is None:
+ raise ValueError("Agent factory expects 1 argument but no local
model was configured.")
+ agent = self._agent_or_factory(local_model)
+ elif len(params) == 0:
+ agent = self._agent_or_factory()
+ if local_model is not None:
+ if not isinstance(agent.model, BeamPlaceholderModel) and agent.model
is not None:
+ raise ValueError(
+ f"Agent model must be BeamPlaceholderModel or None when using
local model. "
+ f"Found: {agent.model}")
+ self._set_agent_model(agent, local_model, is_root=True)
+ else:
+ raise ValueError("Agent factory must take 0 or 1 argument.")
Review Comment:

The current factory inspection logic strictly expects either exactly 0 or 1
parameter. This raises a `ValueError` if the factory has optional parameters
(e.g., `def factory(model=None)` or `def factory(model, config=None)`), even
though they are perfectly valid to call. Inspecting only the required
parameters (those without default values) makes the factory invocation much
more robust.
```python
import inspect
sig = inspect.signature(self._agent_or_factory)
params = list(sig.parameters.values())
required_params = [
p for p in params
if p.default is inspect.Parameter.empty and p.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD)
]
if len(required_params) == 1:
if local_model is None:
raise ValueError("Agent factory expects 1 argument but no local
model was configured.")
agent = self._agent_or_factory(local_model)
elif len(required_params) == 0:
if local_model is not None and len(params) > 0:
agent = self._agent_or_factory(local_model)
else:
agent = self._agent_or_factory()
if local_model is not None:
if not isinstance(agent.model, BeamPlaceholderModel) and
agent.model is not None:
raise ValueError(
f"Agent model must be BeamPlaceholderModel or None when
using local model. "
f"Found: {agent.model}")
self._set_agent_model(agent, local_model, is_root=True)
else:
raise ValueError("Agent factory must take 0 or 1 required argument.")
```
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,219 @@ def should_garbage_collect_on_timeout(self) -> bool:
return self.share_model_across_processes()
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+ """Base class for model handlers that spin up a subprocess server."""
+ @abstractmethod
+ def get_port(self, model: ModelT) -> int:
+ """Returns the port the subprocess server is listening on."""
+ pass
+
+ @abstractmethod
+ def get_model_name(self) -> str:
+ """Returns the model name."""
+ pass
+
+ @abstractmethod
+ def check_connectivity(self, model: ModelT) -> None:
+ """Checks connectivity to the server and attempts to recover/mark for
restart."""
+ pass
+
+
+class SubProcessModelServer:
+ """Manages the lifecycle of a generic subprocess model server."""
+ def __init__(self, handler_path: str, model_name: str, port: int = None,
temp_dir: tempfile.TemporaryDirectory = None):
+ self._handler_path = handler_path
+ self._model_name = model_name
+ self._port = port
+ self._temp_dir = temp_dir
+ self._process = None
+ self._server_started = False
+ self._server_process_lock = threading.RLock()
+ self.start_server()
+
+ def start_server(self, retries=3):
+ with self._server_process_lock:
+ if not self._server_started:
+ if self._process:
+ logging.info("Terminating existing generic subprocess model server
before restart")
+ try:
+ self._process.terminate()
+ self._process.wait(timeout=5)
+ except Exception:
+ try:
+ self._process.kill()
+ except Exception:
+ pass
+ self._process = None
+ self._port = None
+
+ from apache_beam.utils import subprocess_server
+ if self._port is None:
+ self._port, = subprocess_server.pick_port(None)
+
+ cmd = [
+ sys.executable,
+ '-m',
+ 'apache_beam.ml.inference.subprocess_server',
+ '--handler_path',
+ self._handler_path,
+ '--port',
+ str(self._port),
+ ]
+ logging.info("Starting generic model server with %s", cmd)
+ self._process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+ # Emit the output of this command as info level logging.
+ def log_stdout():
+ line = self._process.stdout.readline()
+ while line:
+ logging.info(line.decode(errors='backslashreplace').rstrip())
+ line = self._process.stdout.readline()
+
+ t = threading.Thread(target=log_stdout)
+ t.daemon = True
+ t.start()
+
+ self.check_connectivity(retries)
+
+ def get_server_port(self) -> int:
+ if not self._server_started:
+ self.start_server()
+ return self._port
+
+ def check_connectivity(self, retries=3):
+ import urllib.request
+ import urllib.error
+
+ url = f"http://localhost:{self._port}/v1/models"
+ attempts = 0
+ max_attempts = 12 # 12 * 5s = 60s timeout
+ while self._process.poll() is None and attempts < max_attempts:
+ try:
+ # Use standard library to check connectivity to avoid extra
dependencies
+ req = urllib.request.Request(url, method="GET")
+ with urllib.request.urlopen(req, timeout=5) as response:
+ if response.status == 200:
+ self._server_started = True
+ return
+ except urllib.error.URLError:
+ pass
+ except Exception as e:
+ logging.warning("Error checking connectivity: %s", e)
+ attempts += 1
+ time.sleep(5)
+
+ if retries == 0:
+ self._server_started = False
+ raise Exception(
+ "Failed to start generic subprocess server, polling process exited
with code " +
+ f"{self._process.poll()}. Next time a request is tried, the server
will be restarted"
+ )
+ else:
+ self.start_server(retries - 1)
+
+ def __del__(self):
+ if self._process:
+ logging.info("Terminating generic subprocess model server")
+ try:
+ self._process.terminate()
+ self._process.wait(timeout=5)
+ except Exception:
+ try:
+ self._process.kill()
+ except Exception:
+ pass
+ if self._temp_dir:
+ try:
+ self._temp_dir.cleanup()
+ except Exception:
+ pass
+
+
+class SubProcessModel(SubprocessModelHandler[ExampleT, PredictionT, Any]):
+ """Wrapper to adapt any ModelHandler to SubprocessModelHandler."""
+ def __init__(
+ self,
+ handler: ModelHandler[ExampleT, PredictionT, ModelT],
+ model_name: str):
+ super().__init__()
+ self._handler = handler
+ self._model_name = model_name
+ self._handler_path = None
+
+ def load_model(self) -> Any:
+ if isinstance(self._handler, SubprocessModelHandler):
+ return self._handler.load_model()
+
+ import tempfile
+ temp_dir = tempfile.TemporaryDirectory(prefix="beam-subprocess-")
+ self._handler_path = os.path.join(temp_dir.name, "handler.pickle")
+ with open(self._handler_path, "wb") as f:
+ pickle.dump(self._handler, f)
+
+ return SubProcessModelServer(self._handler_path, self._model_name,
temp_dir=temp_dir)
+
+ def run_inference(self, batch, model, inference_args=None):
+ if isinstance(model, SubProcessModelServer):
+ return self._run_inference_via_http(batch, model.get_server_port(),
inference_args)
+ else:
+ return self._handler.run_inference(batch, model, inference_args)
+
+ def _run_inference_via_http(self, batch, port, inference_args):
+ import urllib.request
+ import urllib.error
+
+ url = f"http://localhost:{port}/v1/beam/inference"
+ payload = {
+ "batch": batch,
+ "inference_args": inference_args
+ }
+ data = pickle.dumps(payload)
+
+ req = urllib.request.Request(url, data=data, method="POST")
+ req.add_header('Content-Type', 'application/octet-stream')
+
+ try:
+ with urllib.request.urlopen(req, timeout=30) as response:
+ resp_data = response.read()
+ results = pickle.loads(resp_data)
+ return results
+ except urllib.error.HTTPError as e:
+ try:
+ err_details = e.read().decode('utf-8')
+ logging.error("Subprocess server returned error: %s", err_details)
+ except Exception:
+ pass
+ logging.exception("Failed to run inference via HTTP raw endpoint
(HTTPError)")
+ raise e
+ except Exception as e:
+ logging.exception("Failed to run inference via HTTP raw endpoint")
+ raise e
+
+ def get_port(self, model: Any) -> int:
+ if hasattr(model, 'get_server_port'):
+ return model.get_server_port()
+ elif hasattr(model, 'port'):
+ return model.port
+ raise ValueError(f"Could not determine port from model of type
{type(model)}")
+
+ def get_model_name(self) -> str:
+ return self._model_name
+
+ def check_connectivity(self, model: Any) -> None:
+ if hasattr(model, 'check_connectivity'):
+ model.check_connectivity()
+ elif hasattr(self._handler, 'check_connectivity'):
+ self._handler.check_connectivity(model)
+
+ def share_model_across_processes(self) -> bool:
+ return self._handler.share_model_across_processes()
+
+ def __getattr__(self, name):
+ return getattr(self._handler, name)
Review Comment:

During unpickling or deepcopying, `__getattr__` can be called before
`_handler` is fully restored in `self.__dict__`. This leads to infinite
recursion and a `RecursionError` because `getattr(self, '_handler')` will
recursively trigger `__getattr__`. Checking if `_handler` is present in
`self.__dict__` avoids this issue.
```suggestion
def __getattr__(self, name):
if name == '_handler':
raise AttributeError()
if '_handler' not in self.__dict__:
raise AttributeError(f"'{type(self).__name__}' object has no attribute
'{name}'")
return getattr(self._handler, name)
```
##########
sdks/python/apache_beam/ml/inference/agent_development_kit.py:
##########
@@ -73,6 +74,30 @@
ADK_AVAILABLE = True
except ImportError:
ADK_AVAILABLE = False
+
+if ADK_AVAILABLE:
+ try:
+ from google.adk.models.base_llm import BaseLlm
+ except ImportError:
+ try:
+ from google.adk.models import BaseLlm
+ except ImportError:
+ BaseLlm = object
+
+ class BeamPlaceholderModel(BaseLlm):
+ """Placeholder model to be used when the model will be injected by
ADKAgentModelHandler."""
+ def __init__(self):
+ pass
+ async def generate_content_async(self, *args, **kwargs):
+ raise NotImplementedError("Placeholder model cannot be used for
inference.")
+else:
+ class BeamPlaceholderModel(str):
+ """Placeholder model to be used when the model will be injected by
ADKAgentModelHandler.
+
+ Fallback when ADK is not available.
+ """
+ def __new__(cls):
+ return super().__new__(cls, "beam-placeholder-model")
genai_Content = Any # type: ignore[assignment, misc]
genai_Part = Any # type: ignore[assignment, misc]
Review Comment:

When `ADK_AVAILABLE` is `False`, the `Agent` and `Runner` classes are not
imported or defined. This will cause a `NameError` at runtime when
`isinstance(tool, Agent)` is evaluated in `_set_agent_model` and
`_update_agent_port`. Defining dummy fallback classes in the `else` block
safely avoids this issue.
```suggestion
else:
class BeamPlaceholderModel(str):
"""Placeholder model to be used when the model will be injected by
ADKAgentModelHandler.
Fallback when ADK is not available.
"""
def __new__(cls):
return super().__new__(cls, "beam-placeholder-model")
class Agent:
pass
class Runner:
pass
genai_Content = Any # type: ignore[assignment, misc]
genai_Part = Any # type: ignore[assignment, misc]
```
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,219 @@ def should_garbage_collect_on_timeout(self) -> bool:
return self.share_model_across_processes()
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+ """Base class for model handlers that spin up a subprocess server."""
+ @abstractmethod
+ def get_port(self, model: ModelT) -> int:
+ """Returns the port the subprocess server is listening on."""
+ pass
+
+ @abstractmethod
+ def get_model_name(self) -> str:
+ """Returns the model name."""
+ pass
+
+ @abstractmethod
+ def check_connectivity(self, model: ModelT) -> None:
+ """Checks connectivity to the server and attempts to recover/mark for
restart."""
+ pass
+
+
+class SubProcessModelServer:
+ """Manages the lifecycle of a generic subprocess model server."""
+ def __init__(self, handler_path: str, model_name: str, port: int = None,
temp_dir: tempfile.TemporaryDirectory = None):
+ self._handler_path = handler_path
+ self._model_name = model_name
+ self._port = port
+ self._temp_dir = temp_dir
+ self._process = None
+ self._server_started = False
+ self._server_process_lock = threading.RLock()
+ self.start_server()
+
+ def start_server(self, retries=3):
+ with self._server_process_lock:
+ if not self._server_started:
+ if self._process:
+ logging.info("Terminating existing generic subprocess model server
before restart")
+ try:
+ self._process.terminate()
+ self._process.wait(timeout=5)
+ except Exception:
+ try:
+ self._process.kill()
+ except Exception:
+ pass
+ self._process = None
+ self._port = None
+
+ from apache_beam.utils import subprocess_server
+ if self._port is None:
+ self._port, = subprocess_server.pick_port(None)
+
+ cmd = [
+ sys.executable,
+ '-m',
+ 'apache_beam.ml.inference.subprocess_server',
+ '--handler_path',
+ self._handler_path,
+ '--port',
+ str(self._port),
+ ]
+ logging.info("Starting generic model server with %s", cmd)
+ self._process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+ # Emit the output of this command as info level logging.
+ def log_stdout():
+ line = self._process.stdout.readline()
+ while line:
+ logging.info(line.decode(errors='backslashreplace').rstrip())
+ line = self._process.stdout.readline()
+
+ t = threading.Thread(target=log_stdout)
+ t.daemon = True
+ t.start()
+
+ self.check_connectivity(retries)
+
+ def get_server_port(self) -> int:
+ if not self._server_started:
+ self.start_server()
+ return self._port
+
+ def check_connectivity(self, retries=3):
+ import urllib.request
+ import urllib.error
+
+ url = f"http://localhost:{self._port}/v1/models"
+ attempts = 0
+ max_attempts = 12 # 12 * 5s = 60s timeout
+ while self._process.poll() is None and attempts < max_attempts:
+ try:
+ # Use standard library to check connectivity to avoid extra
dependencies
+ req = urllib.request.Request(url, method="GET")
+ with urllib.request.urlopen(req, timeout=5) as response:
+ if response.status == 200:
+ self._server_started = True
+ return
+ except urllib.error.URLError:
+ pass
+ except Exception as e:
+ logging.warning("Error checking connectivity: %s", e)
+ attempts += 1
+ time.sleep(5)
+
+ if retries == 0:
+ self._server_started = False
+ raise Exception(
+ "Failed to start generic subprocess server, polling process exited
with code " +
+ f"{self._process.poll()}. Next time a request is tried, the server
will be restarted"
+ )
+ else:
+ self.start_server(retries - 1)
+
+ def __del__(self):
+ if self._process:
+ logging.info("Terminating generic subprocess model server")
+ try:
+ self._process.terminate()
+ self._process.wait(timeout=5)
+ except Exception:
+ try:
+ self._process.kill()
+ except Exception:
+ pass
+ if self._temp_dir:
+ try:
+ self._temp_dir.cleanup()
+ except Exception:
+ pass
+
+
+class SubProcessModel(SubprocessModelHandler[ExampleT, PredictionT, Any]):
+ """Wrapper to adapt any ModelHandler to SubprocessModelHandler."""
+ def __init__(
+ self,
+ handler: ModelHandler[ExampleT, PredictionT, ModelT],
+ model_name: str):
+ super().__init__()
+ self._handler = handler
+ self._model_name = model_name
+ self._handler_path = None
Review Comment:

The HTTP inference timeout is currently hardcoded to 30 seconds. For heavy
ML models or large batches, this can easily lead to premature timeouts. Adding
a configurable `timeout` parameter (defaulting to a safer 300 seconds) provides
better flexibility and production reliability.
```suggestion
def __init__(
self,
handler: ModelHandler[ExampleT, PredictionT, ModelT],
model_name: str,
timeout: int = 300):
super().__init__()
self._handler = handler
self._model_name = model_name
self._timeout = timeout
self._handler_path = None
```
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,219 @@ def should_garbage_collect_on_timeout(self) -> bool:
return self.share_model_across_processes()
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+ """Base class for model handlers that spin up a subprocess server."""
+ @abstractmethod
+ def get_port(self, model: ModelT) -> int:
+ """Returns the port the subprocess server is listening on."""
+ pass
+
+ @abstractmethod
+ def get_model_name(self) -> str:
+ """Returns the model name."""
+ pass
+
+ @abstractmethod
+ def check_connectivity(self, model: ModelT) -> None:
+ """Checks connectivity to the server and attempts to recover/mark for
restart."""
+ pass
+
+
+class SubProcessModelServer:
+ """Manages the lifecycle of a generic subprocess model server."""
+ def __init__(self, handler_path: str, model_name: str, port: int = None,
temp_dir: tempfile.TemporaryDirectory = None):
+ self._handler_path = handler_path
+ self._model_name = model_name
+ self._port = port
+ self._temp_dir = temp_dir
+ self._process = None
+ self._server_started = False
+ self._server_process_lock = threading.RLock()
+ self.start_server()
+
+ def start_server(self, retries=3):
+ with self._server_process_lock:
+ if not self._server_started:
+ if self._process:
+ logging.info("Terminating existing generic subprocess model server
before restart")
+ try:
+ self._process.terminate()
+ self._process.wait(timeout=5)
+ except Exception:
+ try:
+ self._process.kill()
+ except Exception:
+ pass
+ self._process = None
+ self._port = None
+
+ from apache_beam.utils import subprocess_server
+ if self._port is None:
+ self._port, = subprocess_server.pick_port(None)
+
+ cmd = [
+ sys.executable,
+ '-m',
+ 'apache_beam.ml.inference.subprocess_server',
+ '--handler_path',
+ self._handler_path,
+ '--port',
+ str(self._port),
+ ]
+ logging.info("Starting generic model server with %s", cmd)
+ self._process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+ # Emit the output of this command as info level logging.
+ def log_stdout():
+ line = self._process.stdout.readline()
+ while line:
+ logging.info(line.decode(errors='backslashreplace').rstrip())
+ line = self._process.stdout.readline()
+
+ t = threading.Thread(target=log_stdout)
+ t.daemon = True
+ t.start()
+
+ self.check_connectivity(retries)
+
+ def get_server_port(self) -> int:
+ if not self._server_started:
+ self.start_server()
+ return self._port
+
+ def check_connectivity(self, retries=3):
+ import urllib.request
+ import urllib.error
+
+ url = f"http://localhost:{self._port}/v1/models"
+ attempts = 0
+ max_attempts = 12 # 12 * 5s = 60s timeout
+ while self._process.poll() is None and attempts < max_attempts:
+ try:
+ # Use standard library to check connectivity to avoid extra
dependencies
+ req = urllib.request.Request(url, method="GET")
+ with urllib.request.urlopen(req, timeout=5) as response:
+ if response.status == 200:
+ self._server_started = True
+ return
+ except urllib.error.URLError:
+ pass
+ except Exception as e:
+ logging.warning("Error checking connectivity: %s", e)
+ attempts += 1
+ time.sleep(5)
+
+ if retries == 0:
+ self._server_started = False
+ raise Exception(
+ "Failed to start generic subprocess server, polling process exited
with code " +
+ f"{self._process.poll()}. Next time a request is tried, the server
will be restarted"
+ )
+ else:
+ self.start_server(retries - 1)
+
+ def __del__(self):
+ if self._process:
+ logging.info("Terminating generic subprocess model server")
+ try:
+ self._process.terminate()
+ self._process.wait(timeout=5)
+ except Exception:
+ try:
+ self._process.kill()
+ except Exception:
+ pass
+ if self._temp_dir:
+ try:
+ self._temp_dir.cleanup()
+ except Exception:
+ pass
+
+
+class SubProcessModel(SubprocessModelHandler[ExampleT, PredictionT, Any]):
+ """Wrapper to adapt any ModelHandler to SubprocessModelHandler."""
+ def __init__(
+ self,
+ handler: ModelHandler[ExampleT, PredictionT, ModelT],
+ model_name: str):
+ super().__init__()
+ self._handler = handler
+ self._model_name = model_name
+ self._handler_path = None
+
+ def load_model(self) -> Any:
+ if isinstance(self._handler, SubprocessModelHandler):
+ return self._handler.load_model()
+
+ import tempfile
+ temp_dir = tempfile.TemporaryDirectory(prefix="beam-subprocess-")
+ self._handler_path = os.path.join(temp_dir.name, "handler.pickle")
+ with open(self._handler_path, "wb") as f:
+ pickle.dump(self._handler, f)
+
+ return SubProcessModelServer(self._handler_path, self._model_name,
temp_dir=temp_dir)
+
+ def run_inference(self, batch, model, inference_args=None):
+ if isinstance(model, SubProcessModelServer):
+ return self._run_inference_via_http(batch, model.get_server_port(),
inference_args)
+ else:
+ return self._handler.run_inference(batch, model, inference_args)
+
+ def _run_inference_via_http(self, batch, port, inference_args):
+ import urllib.request
+ import urllib.error
+
+ url = f"http://localhost:{port}/v1/beam/inference"
+ payload = {
+ "batch": batch,
+ "inference_args": inference_args
+ }
+ data = pickle.dumps(payload)
+
+ req = urllib.request.Request(url, data=data, method="POST")
+ req.add_header('Content-Type', 'application/octet-stream')
+
+ try:
+ with urllib.request.urlopen(req, timeout=30) as response:
+ resp_data = response.read()
+ results = pickle.loads(resp_data)
+ return results
Review Comment:

Use the configurable `self._timeout` instead of the hardcoded 30-second
timeout to prevent timeout failures on larger batches or slower workers.
```suggestion
try:
with urllib.request.urlopen(req, timeout=self._timeout) as response:
resp_data = response.read()
results = pickle.loads(resp_data)
return results
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]