damccorm commented on code in PR #39074:
URL: https://github.com/apache/beam/pull/39074#discussion_r3499997220


##########
sdks/python/apache_beam/ml/inference/agent_development_kit.py:
##########
@@ -259,6 +389,26 @@ async def _run_concurrently():
 
     return results
 
+  def _update_agent_port(self, agent: "Agent", port: int):
+    if ADK_AVAILABLE:
+      from google.adk.models.lite_llm import LiteLlm
+      if hasattr(agent, 'model') and isinstance(agent.model, LiteLlm):
+        agent.model = LiteLlm(
+            model=agent.model.model,
+            api_base=f"http://localhost:{port}/v1";
+        )
+    if hasattr(agent, 'tools'):
+      for tool in agent.tools:
+        if hasattr(tool, 'agent'):
+          self._update_agent_port(tool.agent, port)
+        elif isinstance(tool, Agent):
+          self._update_agent_port(tool, port)

Review Comment:
   Fixed. Safely checking if tools is not None before iterating. Added 
test_local_model_injection_with_none_tools to verify.



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,225 @@ def should_garbage_collect_on_timeout(self) -> bool:
     return self.share_model_across_processes()
 
 
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+  """Base class for model handlers that spin up a subprocess server."""
+  @abstractmethod
+  def get_port(self, model: ModelT) -> int:
+    """Returns the port the subprocess server is listening on."""
+    pass
+
+  @abstractmethod
+  def get_model_name(self) -> str:
+    """Returns the model name."""
+    pass
+
+  @abstractmethod
+  def check_connectivity(self, model: ModelT) -> None:
+    """Checks connectivity to the server and attempts to recover/mark for 
restart."""
+    pass
+
+
+class SubProcessModelServer:
+  """Manages the lifecycle of a generic subprocess model server."""
+  def __init__(self, handler_path: str, model_name: str, port: int = None, 
temp_dir: tempfile.TemporaryDirectory = None):
+    self._handler_path = handler_path
+    self._model_name = model_name
+    self._port = port
+    self._temp_dir = temp_dir
+    self._process = None
+    self._server_started = False
+    self._server_process_lock = threading.RLock()
+    self.start_server()
+
+  def start_server(self, retries=3):
+    with self._server_process_lock:
+      if not self._server_started:

Review Comment:
   Fixed. Resetting _server_started to False if the process has exited, which 
triggers a restart on the next start_server call. Added 
test_subprocess_server_recovery to verify.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to