This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0dac9b4455d Update run inference documentation (#21921)
0dac9b4455d is described below

commit 0dac9b4455df7ba60636e4ec98db82d74ff60fa9
Author: Ryan Thompson <[email protected]>
AuthorDate: Fri Jun 17 11:37:51 2022 -0400

    Update run inference documentation (#21921)
    
    * Note on batching
    
    * fixed docs
---
 sdks/python/apache_beam/ml/inference/base.py              |  5 +++++
 sdks/python/apache_beam/ml/inference/pytorch_inference.py | 10 ++++++----
 sdks/python/apache_beam/ml/inference/sklearn_inference.py | 12 ++++++------
 3 files changed, 17 insertions(+), 10 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index c0da0837170..371dc341647 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -93,6 +93,8 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
     Args:
       batch: A sequence of examples or features.
       model: The model used to make inferences.
+      inference_args: Extra arguments for models whose inference call requires
+        extra parameters.
 
     Returns:
       An Iterable of Predictions.
@@ -250,6 +252,9 @@ class 
RunInference(beam.PTransform[beam.PCollection[ExampleT],
     Models for supported frameworks can be loaded via a URI. Supported services
     can also be used.
 
+    This transform attempts to batch examples using the beam.BatchElements
+    transform. Batching may be configured using the ModelHandler.
+
     Args:
         model_handler: An implementation of ModelHandler.
         clock: A clock implementing time_ns.
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py 
b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index f313fc17227..0847f8d96c4 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -45,7 +45,7 @@ def _convert_to_device(examples: torch.Tensor, device) -> 
torch.Tensor:
   """
   Converts samples to a style matching given device.
 
-  Note: A user may pass in device='GPU' but if GPU is not detected in the
+  **NOTE:** A user may pass in device='GPU' but if GPU is not detected in the
   environment it must be converted back to CPU.
   """
   if examples.device != device:
@@ -64,7 +64,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
       device: str = 'CPU'):
     """Implementation of the ModelHandler interface for PyTorch.
 
-    Example Usage:
+    Example Usage::
+
       pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri"))
 
     Args:
@@ -153,11 +154,12 @@ class 
PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
       device: str = 'CPU'):
     """Implementation of the ModelHandler interface for PyTorch.
 
-    Example Usage:
+    Example Usage::
+
       pcoll | RunInference(
       PytorchModelHandlerKeyedTensor(state_dict_path="my_uri"))
 
-    NOTE: This API and its implementation are under development and
+    **NOTE:** This API and its implementation are under development and
     do not provide backward compatibility guarantees.
 
     See https://pytorch.org/tutorials/beginner/saving_loading_models.html
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py 
b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index fcfe252804d..e9d8aa65e92 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -84,7 +84,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
     """ Implementation of the ModelHandler interface for scikit-learn
     using numpy arrays as input.
 
-    Example Usage:
+    Example Usage::
+
       pcoll | RunInference(SklearnModelHandlerNumpy(model_uri="my_uri"))
 
     Args:
@@ -141,13 +142,12 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
     """Implementation of the ModelHandler interface for scikit-learn that
     supports pandas dataframes.
 
-    Example Usage:
-      pcoll | RunInference(SklearnModelHandlerPandas(model_uri="my_uri"))
+    Example Usage::
 
-    NOTE::
-      This API and its implementation are under development and
-      do not provide backward compatibility guarantees.
+      pcoll | RunInference(SklearnModelHandlerPandas(model_uri="my_uri"))
 
+    **NOTE:** This API and its implementation are under development and
+    do not provide backward compatibility guarantees.
 
     Args:
       model_uri: The URI to where the model is saved.

Reply via email to