gemini-code-assist[bot] commented on code in PR #39074:
URL: https://github.com/apache/beam/pull/39074#discussion_r3461188359


##########
sdks/python/apache_beam/ml/inference/agent_development_kit.py:
##########
@@ -165,11 +193,54 @@ def load_model(self) -> "Runner":
     Returns:
       A fully initialised :class:`~google.adk.runners.Runner`.
     """
+    local_model = None
+    underlying_model = None
+    
+    if self._underlying_model_handler is not None:
+      underlying_model = self._underlying_model_handler.load_model()
+      self._current_port = 
self._underlying_model_handler.get_port(underlying_model)
+      model_name = self._underlying_model_handler.get_model_name()
+
+      from google.adk.models.lite_llm import LiteLlm
+      local_model = LiteLlm(
+          model=model_name,
+          api_base=f"http://localhost:{self._current_port}/v1";
+      )
+
+    # Resolve agent and inject model
     if callable(self._agent_or_factory) and not isinstance(
         self._agent_or_factory, Agent):
-      agent = self._agent_or_factory()
+      import inspect
+      sig = inspect.signature(self._agent_or_factory)
+      params = list(sig.parameters.values())
+      
+      if len(params) == 1:
+        if local_model is None:
+          raise ValueError("Agent factory expects 1 argument but no local 
model was configured.")
+        agent = self._agent_or_factory(local_model)
+      elif len(params) == 0:
+        agent = self._agent_or_factory()
+        if local_model is not None:
+          if not isinstance(agent.model, BeamPlaceholderModel) and agent.model 
is not None:
+            raise ValueError(
+                f"Agent model must be BeamPlaceholderModel or None when using 
local model. "
+                f"Found: {agent.model}")
+          self._set_agent_model(agent, local_model, is_root=True)
+      else:
+        raise ValueError("Agent factory must take 0 or 1 argument.")

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The current factory inspection logic strictly expects either exactly 0 or 1 
parameter. This raises a `ValueError` if the factory has optional parameters 
(e.g., `def factory(model=None)` or `def factory(model, config=None)`), even 
though they are perfectly valid to call. Inspecting only the required 
parameters (those without default values) makes the factory invocation much 
more robust.
   
   ```python
         import inspect
         sig = inspect.signature(self._agent_or_factory)
         params = list(sig.parameters.values())
         required_params = [
             p for p in params
             if p.default is inspect.Parameter.empty and p.kind not in (
                 inspect.Parameter.VAR_POSITIONAL, 
inspect.Parameter.VAR_KEYWORD)
         ]
         
         if len(required_params) == 1:
           if local_model is None:
             raise ValueError("Agent factory expects 1 argument but no local 
model was configured.")
           agent = self._agent_or_factory(local_model)
         elif len(required_params) == 0:
           if local_model is not None and len(params) > 0:
             agent = self._agent_or_factory(local_model)
           else:
             agent = self._agent_or_factory()
             if local_model is not None:
               if not isinstance(agent.model, BeamPlaceholderModel) and 
agent.model is not None:
                 raise ValueError(
                     f"Agent model must be BeamPlaceholderModel or None when 
using local model. "
                     f"Found: {agent.model}")
               self._set_agent_model(agent, local_model, is_root=True)
         else:
           raise ValueError("Agent factory must take 0 or 1 required argument.")
   ```



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,219 @@ def should_garbage_collect_on_timeout(self) -> bool:
     return self.share_model_across_processes()
 
 
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+  """Base class for model handlers that spin up a subprocess server."""
+  @abstractmethod
+  def get_port(self, model: ModelT) -> int:
+    """Returns the port the subprocess server is listening on."""
+    pass
+
+  @abstractmethod
+  def get_model_name(self) -> str:
+    """Returns the model name."""
+    pass
+
+  @abstractmethod
+  def check_connectivity(self, model: ModelT) -> None:
+    """Checks connectivity to the server and attempts to recover/mark for 
restart."""
+    pass
+
+
+class SubProcessModelServer:
+  """Manages the lifecycle of a generic subprocess model server."""
+  def __init__(self, handler_path: str, model_name: str, port: int = None, 
temp_dir: tempfile.TemporaryDirectory = None):
+    self._handler_path = handler_path
+    self._model_name = model_name
+    self._port = port
+    self._temp_dir = temp_dir
+    self._process = None
+    self._server_started = False
+    self._server_process_lock = threading.RLock()
+    self.start_server()
+
+  def start_server(self, retries=3):
+    with self._server_process_lock:
+      if not self._server_started:
+        if self._process:
+          logging.info("Terminating existing generic subprocess model server 
before restart")
+          try:
+            self._process.terminate()
+            self._process.wait(timeout=5)
+          except Exception:
+            try:
+              self._process.kill()
+            except Exception:
+              pass
+          self._process = None
+          self._port = None
+
+        from apache_beam.utils import subprocess_server
+        if self._port is None:
+          self._port, = subprocess_server.pick_port(None)
+        
+        cmd = [
+            sys.executable,
+            '-m',
+            'apache_beam.ml.inference.subprocess_server',
+            '--handler_path',
+            self._handler_path,
+            '--port',
+            str(self._port),
+        ]
+        logging.info("Starting generic model server with %s", cmd)
+        self._process = subprocess.Popen(
+            cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+        
+        # Emit the output of this command as info level logging.
+        def log_stdout():
+          line = self._process.stdout.readline()
+          while line:
+            logging.info(line.decode(errors='backslashreplace').rstrip())
+            line = self._process.stdout.readline()
+
+        t = threading.Thread(target=log_stdout)
+        t.daemon = True
+        t.start()
+
+      self.check_connectivity(retries)
+
+  def get_server_port(self) -> int:
+    if not self._server_started:
+      self.start_server()
+    return self._port
+
+  def check_connectivity(self, retries=3):
+    import urllib.request
+    import urllib.error
+    
+    url = f"http://localhost:{self._port}/v1/models";
+    attempts = 0
+    max_attempts = 12  # 12 * 5s = 60s timeout
+    while self._process.poll() is None and attempts < max_attempts:
+      try:
+        # Use standard library to check connectivity to avoid extra 
dependencies
+        req = urllib.request.Request(url, method="GET")
+        with urllib.request.urlopen(req, timeout=5) as response:
+          if response.status == 200:
+            self._server_started = True
+            return
+      except urllib.error.URLError:
+        pass
+      except Exception as e:
+        logging.warning("Error checking connectivity: %s", e)
+      attempts += 1
+      time.sleep(5)
+
+    if retries == 0:
+      self._server_started = False
+      raise Exception(
+          "Failed to start generic subprocess server, polling process exited 
with code " +
+          f"{self._process.poll()}. Next time a request is tried, the server 
will be restarted"
+      )
+    else:
+      self.start_server(retries - 1)
+
+  def __del__(self):
+    if self._process:
+      logging.info("Terminating generic subprocess model server")
+      try:
+        self._process.terminate()
+        self._process.wait(timeout=5)
+      except Exception:
+        try:
+          self._process.kill()
+        except Exception:
+          pass
+    if self._temp_dir:
+      try:
+        self._temp_dir.cleanup()
+      except Exception:
+        pass
+
+
+class SubProcessModel(SubprocessModelHandler[ExampleT, PredictionT, Any]):
+  """Wrapper to adapt any ModelHandler to SubprocessModelHandler."""
+  def __init__(
+      self,
+      handler: ModelHandler[ExampleT, PredictionT, ModelT],
+      model_name: str):
+    super().__init__()
+    self._handler = handler
+    self._model_name = model_name
+    self._handler_path = None
+
+  def load_model(self) -> Any:
+    if isinstance(self._handler, SubprocessModelHandler):
+      return self._handler.load_model()
+      
+    import tempfile
+    temp_dir = tempfile.TemporaryDirectory(prefix="beam-subprocess-")
+    self._handler_path = os.path.join(temp_dir.name, "handler.pickle")
+    with open(self._handler_path, "wb") as f:
+      pickle.dump(self._handler, f)
+      
+    return SubProcessModelServer(self._handler_path, self._model_name, 
temp_dir=temp_dir)
+
+  def run_inference(self, batch, model, inference_args=None):
+    if isinstance(model, SubProcessModelServer):
+      return self._run_inference_via_http(batch, model.get_server_port(), 
inference_args)
+    else:
+      return self._handler.run_inference(batch, model, inference_args)
+
+  def _run_inference_via_http(self, batch, port, inference_args):
+    import urllib.request
+    import urllib.error
+    
+    url = f"http://localhost:{port}/v1/beam/inference";
+    payload = {
+        "batch": batch,
+        "inference_args": inference_args
+    }
+    data = pickle.dumps(payload)
+    
+    req = urllib.request.Request(url, data=data, method="POST")
+    req.add_header('Content-Type', 'application/octet-stream')
+    
+    try:
+      with urllib.request.urlopen(req, timeout=30) as response:
+        resp_data = response.read()
+        results = pickle.loads(resp_data)
+        return results
+    except urllib.error.HTTPError as e:
+      try:
+        err_details = e.read().decode('utf-8')
+        logging.error("Subprocess server returned error: %s", err_details)
+      except Exception:
+        pass
+      logging.exception("Failed to run inference via HTTP raw endpoint 
(HTTPError)")
+      raise e
+    except Exception as e:
+      logging.exception("Failed to run inference via HTTP raw endpoint")
+      raise e
+
+  def get_port(self, model: Any) -> int:
+    if hasattr(model, 'get_server_port'):
+      return model.get_server_port()
+    elif hasattr(model, 'port'):
+      return model.port
+    raise ValueError(f"Could not determine port from model of type 
{type(model)}")
+
+  def get_model_name(self) -> str:
+    return self._model_name
+
+  def check_connectivity(self, model: Any) -> None:
+    if hasattr(model, 'check_connectivity'):
+      model.check_connectivity()
+    elif hasattr(self._handler, 'check_connectivity'):
+      self._handler.check_connectivity(model)
+
+  def share_model_across_processes(self) -> bool:
+    return self._handler.share_model_across_processes()
+
+  def __getattr__(self, name):
+    return getattr(self._handler, name)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   During unpickling or deepcopying, `__getattr__` can be called before 
`_handler` is fully restored in `self.__dict__`. This leads to infinite 
recursion and a `RecursionError` because `getattr(self, '_handler')` will 
recursively trigger `__getattr__`. Checking if `_handler` is present in 
`self.__dict__` avoids this issue.
   
   ```suggestion
     def __getattr__(self, name):
       if name == '_handler':
         raise AttributeError()
       if '_handler' not in self.__dict__:
         raise AttributeError(f"'{type(self).__name__}' object has no attribute 
'{name}'")
       return getattr(self._handler, name)
   ```



##########
sdks/python/apache_beam/ml/inference/agent_development_kit.py:
##########
@@ -73,6 +74,30 @@
   ADK_AVAILABLE = True
 except ImportError:
   ADK_AVAILABLE = False
+
+if ADK_AVAILABLE:
+  try:
+    from google.adk.models.base_llm import BaseLlm
+  except ImportError:
+    try:
+      from google.adk.models import BaseLlm
+    except ImportError:
+      BaseLlm = object
+      
+  class BeamPlaceholderModel(BaseLlm):
+    """Placeholder model to be used when the model will be injected by 
ADKAgentModelHandler."""
+    def __init__(self):
+      pass
+    async def generate_content_async(self, *args, **kwargs):
+      raise NotImplementedError("Placeholder model cannot be used for 
inference.")
+else:
+  class BeamPlaceholderModel(str):
+    """Placeholder model to be used when the model will be injected by 
ADKAgentModelHandler.
+    
+    Fallback when ADK is not available.
+    """
+    def __new__(cls):
+      return super().__new__(cls, "beam-placeholder-model")
   genai_Content = Any  # type: ignore[assignment, misc]
   genai_Part = Any  # type: ignore[assignment, misc]

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   When `ADK_AVAILABLE` is `False`, the `Agent` and `Runner` classes are not 
imported or defined. This will cause a `NameError` at runtime when 
`isinstance(tool, Agent)` is evaluated in `_set_agent_model` and 
`_update_agent_port`. Defining dummy fallback classes in the `else` block 
safely avoids this issue.
   
   ```suggestion
   else:
     class BeamPlaceholderModel(str):
       """Placeholder model to be used when the model will be injected by 
ADKAgentModelHandler.
       
       Fallback when ADK is not available.
       """
       def __new__(cls):
         return super().__new__(cls, "beam-placeholder-model")
     class Agent:
       pass
     class Runner:
       pass
     genai_Content = Any  # type: ignore[assignment, misc]
     genai_Part = Any  # type: ignore[assignment, misc]
   ```



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,219 @@ def should_garbage_collect_on_timeout(self) -> bool:
     return self.share_model_across_processes()
 
 
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+  """Base class for model handlers that spin up a subprocess server."""
+  @abstractmethod
+  def get_port(self, model: ModelT) -> int:
+    """Returns the port the subprocess server is listening on."""
+    pass
+
+  @abstractmethod
+  def get_model_name(self) -> str:
+    """Returns the model name."""
+    pass
+
+  @abstractmethod
+  def check_connectivity(self, model: ModelT) -> None:
+    """Checks connectivity to the server and attempts to recover/mark for 
restart."""
+    pass
+
+
+class SubProcessModelServer:
+  """Manages the lifecycle of a generic subprocess model server."""
+  def __init__(self, handler_path: str, model_name: str, port: int = None, 
temp_dir: tempfile.TemporaryDirectory = None):
+    self._handler_path = handler_path
+    self._model_name = model_name
+    self._port = port
+    self._temp_dir = temp_dir
+    self._process = None
+    self._server_started = False
+    self._server_process_lock = threading.RLock()
+    self.start_server()
+
+  def start_server(self, retries=3):
+    with self._server_process_lock:
+      if not self._server_started:
+        if self._process:
+          logging.info("Terminating existing generic subprocess model server 
before restart")
+          try:
+            self._process.terminate()
+            self._process.wait(timeout=5)
+          except Exception:
+            try:
+              self._process.kill()
+            except Exception:
+              pass
+          self._process = None
+          self._port = None
+
+        from apache_beam.utils import subprocess_server
+        if self._port is None:
+          self._port, = subprocess_server.pick_port(None)
+        
+        cmd = [
+            sys.executable,
+            '-m',
+            'apache_beam.ml.inference.subprocess_server',
+            '--handler_path',
+            self._handler_path,
+            '--port',
+            str(self._port),
+        ]
+        logging.info("Starting generic model server with %s", cmd)
+        self._process = subprocess.Popen(
+            cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+        
+        # Emit the output of this command as info level logging.
+        def log_stdout():
+          line = self._process.stdout.readline()
+          while line:
+            logging.info(line.decode(errors='backslashreplace').rstrip())
+            line = self._process.stdout.readline()
+
+        t = threading.Thread(target=log_stdout)
+        t.daemon = True
+        t.start()
+
+      self.check_connectivity(retries)
+
+  def get_server_port(self) -> int:
+    if not self._server_started:
+      self.start_server()
+    return self._port
+
+  def check_connectivity(self, retries=3):
+    import urllib.request
+    import urllib.error
+    
+    url = f"http://localhost:{self._port}/v1/models";
+    attempts = 0
+    max_attempts = 12  # 12 * 5s = 60s timeout
+    while self._process.poll() is None and attempts < max_attempts:
+      try:
+        # Use standard library to check connectivity to avoid extra 
dependencies
+        req = urllib.request.Request(url, method="GET")
+        with urllib.request.urlopen(req, timeout=5) as response:
+          if response.status == 200:
+            self._server_started = True
+            return
+      except urllib.error.URLError:
+        pass
+      except Exception as e:
+        logging.warning("Error checking connectivity: %s", e)
+      attempts += 1
+      time.sleep(5)
+
+    if retries == 0:
+      self._server_started = False
+      raise Exception(
+          "Failed to start generic subprocess server, polling process exited 
with code " +
+          f"{self._process.poll()}. Next time a request is tried, the server 
will be restarted"
+      )
+    else:
+      self.start_server(retries - 1)
+
+  def __del__(self):
+    if self._process:
+      logging.info("Terminating generic subprocess model server")
+      try:
+        self._process.terminate()
+        self._process.wait(timeout=5)
+      except Exception:
+        try:
+          self._process.kill()
+        except Exception:
+          pass
+    if self._temp_dir:
+      try:
+        self._temp_dir.cleanup()
+      except Exception:
+        pass
+
+
+class SubProcessModel(SubprocessModelHandler[ExampleT, PredictionT, Any]):
+  """Wrapper to adapt any ModelHandler to SubprocessModelHandler."""
+  def __init__(
+      self,
+      handler: ModelHandler[ExampleT, PredictionT, ModelT],
+      model_name: str):
+    super().__init__()
+    self._handler = handler
+    self._model_name = model_name
+    self._handler_path = None

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The HTTP inference timeout is currently hardcoded to 30 seconds. For heavy 
ML models or large batches, this can easily lead to premature timeouts. Adding 
a configurable `timeout` parameter (defaulting to a safer 300 seconds) provides 
better flexibility and production reliability.
   
   ```suggestion
     def __init__(
         self,
         handler: ModelHandler[ExampleT, PredictionT, ModelT],
         model_name: str,
         timeout: int = 300):
       super().__init__()
       self._handler = handler
       self._model_name = model_name
       self._timeout = timeout
       self._handler_path = None
   ```



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,219 @@ def should_garbage_collect_on_timeout(self) -> bool:
     return self.share_model_across_processes()
 
 
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+  """Base class for model handlers that spin up a subprocess server."""
+  @abstractmethod
+  def get_port(self, model: ModelT) -> int:
+    """Returns the port the subprocess server is listening on."""
+    pass
+
+  @abstractmethod
+  def get_model_name(self) -> str:
+    """Returns the model name."""
+    pass
+
+  @abstractmethod
+  def check_connectivity(self, model: ModelT) -> None:
+    """Checks connectivity to the server and attempts to recover/mark for 
restart."""
+    pass
+
+
+class SubProcessModelServer:
+  """Manages the lifecycle of a generic subprocess model server."""
+  def __init__(self, handler_path: str, model_name: str, port: int = None, 
temp_dir: tempfile.TemporaryDirectory = None):
+    self._handler_path = handler_path
+    self._model_name = model_name
+    self._port = port
+    self._temp_dir = temp_dir
+    self._process = None
+    self._server_started = False
+    self._server_process_lock = threading.RLock()
+    self.start_server()
+
+  def start_server(self, retries=3):
+    with self._server_process_lock:
+      if not self._server_started:
+        if self._process:
+          logging.info("Terminating existing generic subprocess model server 
before restart")
+          try:
+            self._process.terminate()
+            self._process.wait(timeout=5)
+          except Exception:
+            try:
+              self._process.kill()
+            except Exception:
+              pass
+          self._process = None
+          self._port = None
+
+        from apache_beam.utils import subprocess_server
+        if self._port is None:
+          self._port, = subprocess_server.pick_port(None)
+        
+        cmd = [
+            sys.executable,
+            '-m',
+            'apache_beam.ml.inference.subprocess_server',
+            '--handler_path',
+            self._handler_path,
+            '--port',
+            str(self._port),
+        ]
+        logging.info("Starting generic model server with %s", cmd)
+        self._process = subprocess.Popen(
+            cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+        
+        # Emit the output of this command as info level logging.
+        def log_stdout():
+          line = self._process.stdout.readline()
+          while line:
+            logging.info(line.decode(errors='backslashreplace').rstrip())
+            line = self._process.stdout.readline()
+
+        t = threading.Thread(target=log_stdout)
+        t.daemon = True
+        t.start()
+
+      self.check_connectivity(retries)
+
+  def get_server_port(self) -> int:
+    if not self._server_started:
+      self.start_server()
+    return self._port
+
+  def check_connectivity(self, retries=3):
+    import urllib.request
+    import urllib.error
+    
+    url = f"http://localhost:{self._port}/v1/models";
+    attempts = 0
+    max_attempts = 12  # 12 * 5s = 60s timeout
+    while self._process.poll() is None and attempts < max_attempts:
+      try:
+        # Use standard library to check connectivity to avoid extra 
dependencies
+        req = urllib.request.Request(url, method="GET")
+        with urllib.request.urlopen(req, timeout=5) as response:
+          if response.status == 200:
+            self._server_started = True
+            return
+      except urllib.error.URLError:
+        pass
+      except Exception as e:
+        logging.warning("Error checking connectivity: %s", e)
+      attempts += 1
+      time.sleep(5)
+
+    if retries == 0:
+      self._server_started = False
+      raise Exception(
+          "Failed to start generic subprocess server, polling process exited 
with code " +
+          f"{self._process.poll()}. Next time a request is tried, the server 
will be restarted"
+      )
+    else:
+      self.start_server(retries - 1)
+
+  def __del__(self):
+    if self._process:
+      logging.info("Terminating generic subprocess model server")
+      try:
+        self._process.terminate()
+        self._process.wait(timeout=5)
+      except Exception:
+        try:
+          self._process.kill()
+        except Exception:
+          pass
+    if self._temp_dir:
+      try:
+        self._temp_dir.cleanup()
+      except Exception:
+        pass
+
+
+class SubProcessModel(SubprocessModelHandler[ExampleT, PredictionT, Any]):
+  """Wrapper to adapt any ModelHandler to SubprocessModelHandler."""
+  def __init__(
+      self,
+      handler: ModelHandler[ExampleT, PredictionT, ModelT],
+      model_name: str):
+    super().__init__()
+    self._handler = handler
+    self._model_name = model_name
+    self._handler_path = None
+
+  def load_model(self) -> Any:
+    if isinstance(self._handler, SubprocessModelHandler):
+      return self._handler.load_model()
+      
+    import tempfile
+    temp_dir = tempfile.TemporaryDirectory(prefix="beam-subprocess-")
+    self._handler_path = os.path.join(temp_dir.name, "handler.pickle")
+    with open(self._handler_path, "wb") as f:
+      pickle.dump(self._handler, f)
+      
+    return SubProcessModelServer(self._handler_path, self._model_name, 
temp_dir=temp_dir)
+
+  def run_inference(self, batch, model, inference_args=None):
+    if isinstance(model, SubProcessModelServer):
+      return self._run_inference_via_http(batch, model.get_server_port(), 
inference_args)
+    else:
+      return self._handler.run_inference(batch, model, inference_args)
+
+  def _run_inference_via_http(self, batch, port, inference_args):
+    import urllib.request
+    import urllib.error
+    
+    url = f"http://localhost:{port}/v1/beam/inference";
+    payload = {
+        "batch": batch,
+        "inference_args": inference_args
+    }
+    data = pickle.dumps(payload)
+    
+    req = urllib.request.Request(url, data=data, method="POST")
+    req.add_header('Content-Type', 'application/octet-stream')
+    
+    try:
+      with urllib.request.urlopen(req, timeout=30) as response:
+        resp_data = response.read()
+        results = pickle.loads(resp_data)
+        return results

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Use the configurable `self._timeout` instead of the hardcoded 30-second 
timeout to prevent timeout failures on larger batches or slower workers.
   
   ```suggestion
       try:
         with urllib.request.urlopen(req, timeout=self._timeout) as response:
           resp_data = response.read()
           results = pickle.loads(resp_data)
           return results
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to