gemini-code-assist[bot] commented on code in PR #39074:
URL: https://github.com/apache/beam/pull/39074#discussion_r3461219111


##########
sdks/python/apache_beam/ml/inference/agent_development_kit.py:
##########
@@ -259,6 +389,26 @@ async def _run_concurrently():
 
     return results
 
+  def _update_agent_port(self, agent: "Agent", port: int):
+    if ADK_AVAILABLE:
+      from google.adk.models.lite_llm import LiteLlm
+      if hasattr(agent, 'model') and isinstance(agent.model, LiteLlm):
+        agent.model = LiteLlm(
+            model=agent.model.model,
+            api_base=f"http://localhost:{port}/v1";
+        )
+    if hasattr(agent, 'tools'):
+      for tool in agent.tools:
+        if hasattr(tool, 'agent'):
+          self._update_agent_port(tool.agent, port)
+        elif isinstance(tool, Agent):
+          self._update_agent_port(tool, port)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   Similarly, if `agent.tools` is `None`, iterating over it here will raise a 
`TypeError`. We should safely check if `agent.tools` is not `None` before 
iterating.
   
   ```suggestion
       if getattr(agent, 'tools', None) is not None:
         for tool in agent.tools:
           if hasattr(tool, 'agent'):
             self._update_agent_port(tool.agent, port)
           elif isinstance(tool, Agent):
             self._update_agent_port(tool, port)
   ```



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,225 @@ def should_garbage_collect_on_timeout(self) -> bool:
     return self.share_model_across_processes()
 
 
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+  """Base class for model handlers that spin up a subprocess server."""
+  @abstractmethod
+  def get_port(self, model: ModelT) -> int:
+    """Returns the port the subprocess server is listening on."""
+    pass
+
+  @abstractmethod
+  def get_model_name(self) -> str:
+    """Returns the model name."""
+    pass
+
+  @abstractmethod
+  def check_connectivity(self, model: ModelT) -> None:
+    """Checks connectivity to the server and attempts to recover/mark for 
restart."""
+    pass
+
+
+class SubProcessModelServer:
+  """Manages the lifecycle of a generic subprocess model server."""
+  def __init__(self, handler_path: str, model_name: str, port: int = None, 
temp_dir: tempfile.TemporaryDirectory = None):
+    self._handler_path = handler_path
+    self._model_name = model_name
+    self._port = port
+    self._temp_dir = temp_dir
+    self._process = None
+    self._server_started = False
+    self._server_process_lock = threading.RLock()
+    self.start_server()
+
+  def start_server(self, retries=3):
+    with self._server_process_lock:
+      if not self._server_started:

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   If the subprocess server crashes, `self._server_started` remains `True`. If 
`get_server_port()` is called subsequently, it will return the port without 
restarting the server because `self._server_started` is still `True`. We should 
check if the process has exited and reset `self._server_started = False` at the 
beginning of `start_server` to ensure recovery.
   
   ```python
     def start_server(self, retries=3):
       with self._server_process_lock:
         if self._process and self._process.poll() is not None:
           self._server_started = False
         if not self._server_started:
   ```



##########
sdks/python/apache_beam/ml/inference/agent_development_kit.py:
##########
@@ -181,13 +264,33 @@ def load_model(self) -> "Runner":
         app_name=self._app_name,
         session_service=session_service,
     )
+    
+    if underlying_model is not None:
+      runner._underlying_model = underlying_model
+
     LOGGER.info(
         "Loaded ADK Runner for agent '%s' (app_name='%s')",
         agent.name,
         self._app_name,
     )
     return runner
 
+  def _set_agent_model(self, agent: "Agent", model: Any, is_root: bool = 
False):
+    if is_root:
+      if isinstance(agent.model, BeamPlaceholderModel) or agent.model is None:
+        agent.model = model
+    else:
+      if isinstance(agent.model, BeamPlaceholderModel):
+        agent.model = model
+
+    # Speculative propagation to subagents/tools
+    if hasattr(agent, 'tools'):
+      for tool in agent.tools:
+        if hasattr(tool, 'agent'):
+          self._set_agent_model(tool.agent, model, is_root=False)
+        elif isinstance(tool, Agent):
+          self._set_agent_model(tool, model, is_root=False)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   If `agent.tools` is `None` (which is the default in ADK when no tools are 
provided), calling `hasattr(agent, 'tools')` will return `True`, but attempting 
to iterate over it will raise a `TypeError: 'NoneType' object is not iterable`. 
We should use `getattr(agent, 'tools', None) is not None` to safely guard 
against this.
   
   ```suggestion
       if getattr(agent, 'tools', None) is not None:
         for tool in agent.tools:
           if hasattr(tool, 'agent'):
             self._set_agent_model(tool.agent, model, is_root=False)
           elif isinstance(tool, Agent):
             self._set_agent_model(tool, model, is_root=False)
   ```



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,225 @@ def should_garbage_collect_on_timeout(self) -> bool:
     return self.share_model_across_processes()
 
 
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+  """Base class for model handlers that spin up a subprocess server."""
+  @abstractmethod
+  def get_port(self, model: ModelT) -> int:
+    """Returns the port the subprocess server is listening on."""
+    pass
+
+  @abstractmethod
+  def get_model_name(self) -> str:
+    """Returns the model name."""
+    pass
+
+  @abstractmethod
+  def check_connectivity(self, model: ModelT) -> None:
+    """Checks connectivity to the server and attempts to recover/mark for 
restart."""
+    pass
+
+
+class SubProcessModelServer:
+  """Manages the lifecycle of a generic subprocess model server."""
+  def __init__(self, handler_path: str, model_name: str, port: int = None, 
temp_dir: tempfile.TemporaryDirectory = None):
+    self._handler_path = handler_path
+    self._model_name = model_name
+    self._port = port
+    self._temp_dir = temp_dir
+    self._process = None
+    self._server_started = False
+    self._server_process_lock = threading.RLock()
+    self.start_server()
+
+  def start_server(self, retries=3):
+    with self._server_process_lock:
+      if not self._server_started:
+        if self._process:
+          logging.info("Terminating existing generic subprocess model server 
before restart")
+          try:
+            self._process.terminate()
+            self._process.wait(timeout=5)
+          except Exception:
+            try:
+              self._process.kill()
+            except Exception:
+              pass
+          self._process = None
+          self._port = None
+
+        from apache_beam.utils import subprocess_server
+        if self._port is None:
+          self._port, = subprocess_server.pick_port(None)
+        
+        cmd = [
+            sys.executable,
+            '-m',
+            'apache_beam.ml.inference.subprocess_server',
+            '--handler_path',
+            self._handler_path,
+            '--port',
+            str(self._port),
+        ]
+        logging.info("Starting generic model server with %s", cmd)
+        self._process = subprocess.Popen(
+            cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+        
+        # Emit the output of this command as info level logging.
+        def log_stdout():
+          line = self._process.stdout.readline()
+          while line:
+            logging.info(line.decode(errors='backslashreplace').rstrip())
+            line = self._process.stdout.readline()
+
+        t = threading.Thread(target=log_stdout)
+        t.daemon = True
+        t.start()
+
+      self.check_connectivity(retries)
+
+  def get_server_port(self) -> int:
+    if not self._server_started:
+      self.start_server()
+    return self._port
+
+  def check_connectivity(self, retries=3):
+    import urllib.request
+    import urllib.error
+    
+    url = f"http://localhost:{self._port}/v1/models";
+    attempts = 0
+    max_attempts = 12  # 12 * 5s = 60s timeout
+    while self._process.poll() is None and attempts < max_attempts:
+      try:
+        # Use standard library to check connectivity to avoid extra 
dependencies
+        req = urllib.request.Request(url, method="GET")
+        with urllib.request.urlopen(req, timeout=5) as response:
+          if response.status == 200:
+            self._server_started = True
+            return
+      except urllib.error.URLError:
+        pass
+      except Exception as e:
+        logging.warning("Error checking connectivity: %s", e)
+      attempts += 1
+      time.sleep(5)
+
+    if retries == 0:
+      self._server_started = False
+      raise Exception(
+          "Failed to start generic subprocess server, polling process exited 
with code " +
+          f"{self._process.poll()}. Next time a request is tried, the server 
will be restarted"
+      )
+    else:
+      self.start_server(retries - 1)

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   There are two critical issues in `check_connectivity` and `start_server`:
   1. **Race Condition / Redundant Restarts**: If multiple threads concurrently 
call `check_connectivity` due to an inference failure, they will all call 
`start_server`. While `start_server` is protected by 
`self._server_process_lock`, once the first thread finishes restarting the 
server and releases the lock, the waiting threads will acquire the lock and 
restart the server again, terminating the newly started healthy server. Since 
`self._server_process_lock` is an `RLock`, we should wrap the entire 
`check_connectivity` method with this lock so that subsequent threads 
immediately see the healthy server and return without redundant restarts.
   2. **No-op Restart Loop**: If the server crashes, `self._server_started` 
remains `True`. When `check_connectivity` is called, the polling loop is 
skipped because `self._process.poll()` is not `None`. It then calls 
`start_server(retries - 1)`. However, because `self._server_started` is still 
`True`, `start_server` skips the spawning block entirely and just calls 
`check_connectivity` again. This results in a loop of no-op retries that never 
actually restarts the process. We must reset `self._server_started = False` 
before calling `start_server` to ensure the server is actually restarted.
   
   ```python
     def check_connectivity(self, retries=3):
       with self._server_process_lock:
         import urllib.request
         import urllib.error
         
         url = f"http://localhost:{self._port}/v1/models";
         attempts = 0
         max_attempts = 12  # 12 * 5s = 60s timeout
         while self._process.poll() is None and attempts < max_attempts:
           try:
             # Use standard library to check connectivity to avoid extra 
dependencies
             req = urllib.request.Request(url, method="GET")
             with urllib.request.urlopen(req, timeout=5) as response:
               if response.status == 200:
                 self._server_started = True
                 return
           except urllib.error.URLError:
             pass
           except Exception as e:
             logging.warning("Error checking connectivity: %s", e)
           attempts += 1
           time.sleep(5)
   
         if retries == 0:
           self._server_started = False
           raise Exception(
               "Failed to start generic subprocess server, polling process 
exited with code " +
               f"{self._process.poll()}. Next time a request is tried, the 
server will be restarted"
           )
         else:
           self._server_started = False
           self.start_server(retries - 1)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to