This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 128d5989121 [BEAM-14218] Add resource location hints to base inference 
runner. (#17448)
128d5989121 is described below

commit 128d5989121155549e9bef823529153a3de6520e
Author: Ryan Thompson <[email protected]>
AuthorDate: Fri May 27 16:55:12 2022 -0400

    [BEAM-14218] Add resource location hints to base inference runner. (#17448)
    
    * added docs and location hints
    
    * simplified model hints to come from the loader
    
    * yapf
    
    * added optional clock
    
    * moved clock
    
    * fixed typo'
    
    * keep clock in right space
    
    * fixed bad hints
---
 sdks/python/apache_beam/ml/inference/base.py | 25 ++++++++++++++++++++-----
 1 file changed, 20 insertions(+), 5 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 49753c4e7a3..7c7ef50e234 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -39,6 +39,7 @@ from typing import Generic
 from typing import Iterable
 from typing import List
 from typing import Mapping
+from typing import Optional
 from typing import TypeVar
 
 import apache_beam as beam
@@ -57,7 +58,7 @@ _SECOND_TO_MICROSECOND = 1_000_000
 T = TypeVar('T')
 
 
-class InferenceRunner():
+class InferenceRunner:
   """Implements running inferences for a framework."""
   def run_inference(self, batch: List[Any], model: Any) -> Iterable[Any]:
     """Runs inferences on a batch of examples and
@@ -83,25 +84,38 @@ class ModelLoader(Generic[T]):
     """Returns an implementation of InferenceRunner for this model."""
     raise NotImplementedError(type(self))
 
+  def get_resource_hints(self) -> dict:
+    """Returns resource hints for the transform."""
+    return {}
+
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     """Returns kwargs suitable for beam.BatchElements."""
     return {}
 
 
 class RunInference(beam.PTransform):
-  """An extensible transform for running inferences."""
-  def __init__(self, model_loader: ModelLoader, clock=None):
+  """An extensible transform for running inferences.
+  Args:
+      model_loader: An implementation of ModelLoader.
+      clock: A clock implementing get_current_time_in_microseconds.
+  """
+  def __init__(
+      self, model_loader: ModelLoader, clock: Optional["_Clock"] = None):
     self._model_loader = model_loader
     self._clock = clock
 
   # TODO(BEAM-14208): Add batch_size back off in the case there
   # are functional reasons large batch sizes cannot be handled.
   def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
+    resource_hints = self._model_loader.get_resource_hints()
     return (
         pcoll
         # TODO(BEAM-14044): Hook into the batching DoFn APIs.
         | beam.BatchElements(**self._model_loader.batch_elements_kwargs())
-        | beam.ParDo(_RunInferenceDoFn(self._model_loader, self._clock)))
+        | (
+            beam.ParDo(_RunInferenceDoFn(
+                self._model_loader,
+                self._clock)).with_resource_hints(**resource_hints)))
 
 
 class _MetricsCollector:
@@ -155,7 +169,8 @@ class _MetricsCollector:
 
 class _RunInferenceDoFn(beam.DoFn):
   """A DoFn implementation generic to frameworks."""
-  def __init__(self, model_loader: ModelLoader, clock=None):
+  def __init__(
+      self, model_loader: ModelLoader, clock: Optional["_Clock"] = None):
     self._model_loader = model_loader
     self._inference_runner = model_loader.get_inference_runner()
     self._shared_model_handle = shared.Shared()

Reply via email to