This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 0dac9b4455d Update run inference documentation (#21921)
0dac9b4455d is described below
commit 0dac9b4455df7ba60636e4ec98db82d74ff60fa9
Author: Ryan Thompson <[email protected]>
AuthorDate: Fri Jun 17 11:37:51 2022 -0400
Update run inference documentation (#21921)
* Note on batching
* fixed docs
---
sdks/python/apache_beam/ml/inference/base.py | 5 +++++
sdks/python/apache_beam/ml/inference/pytorch_inference.py | 10 ++++++----
sdks/python/apache_beam/ml/inference/sklearn_inference.py | 12 ++++++------
3 files changed, 17 insertions(+), 10 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/base.py
b/sdks/python/apache_beam/ml/inference/base.py
index c0da0837170..371dc341647 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -93,6 +93,8 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
Args:
batch: A sequence of examples or features.
model: The model used to make inferences.
+ inference_args: Extra arguments for models whose inference call requires
+ extra parameters.
Returns:
An Iterable of Predictions.
@@ -250,6 +252,9 @@ class
RunInference(beam.PTransform[beam.PCollection[ExampleT],
Models for supported frameworks can be loaded via a URI. Supported services
can also be used.
+ This transform attempts to batch examples using the beam.BatchElements
+ transform. Batching may be configured using the ModelHandler.
+
Args:
model_handler: An implementation of ModelHandler.
clock: A clock implementing time_ns.
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index f313fc17227..0847f8d96c4 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -45,7 +45,7 @@ def _convert_to_device(examples: torch.Tensor, device) ->
torch.Tensor:
"""
Converts samples to a style matching given device.
- Note: A user may pass in device='GPU' but if GPU is not detected in the
+ **NOTE:** A user may pass in device='GPU' but if GPU is not detected in the
environment it must be converted back to CPU.
"""
if examples.device != device:
@@ -64,7 +64,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
device: str = 'CPU'):
"""Implementation of the ModelHandler interface for PyTorch.
- Example Usage:
+ Example Usage::
+
pcoll | RunInference(PytorchModelHandlerTensor(state_dict_path="my_uri"))
Args:
@@ -153,11 +154,12 @@ class
PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
device: str = 'CPU'):
"""Implementation of the ModelHandler interface for PyTorch.
- Example Usage:
+ Example Usage::
+
pcoll | RunInference(
PytorchModelHandlerKeyedTensor(state_dict_path="my_uri"))
- NOTE: This API and its implementation are under development and
+ **NOTE:** This API and its implementation are under development and
do not provide backward compatibility guarantees.
See https://pytorch.org/tutorials/beginner/saving_loading_models.html
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index fcfe252804d..e9d8aa65e92 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -84,7 +84,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
""" Implementation of the ModelHandler interface for scikit-learn
using numpy arrays as input.
- Example Usage:
+ Example Usage::
+
pcoll | RunInference(SklearnModelHandlerNumpy(model_uri="my_uri"))
Args:
@@ -141,13 +142,12 @@ class
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
"""Implementation of the ModelHandler interface for scikit-learn that
supports pandas dataframes.
- Example Usage:
- pcoll | RunInference(SklearnModelHandlerPandas(model_uri="my_uri"))
+ Example Usage::
- NOTE::
- This API and its implementation are under development and
- do not provide backward compatibility guarantees.
+ pcoll | RunInference(SklearnModelHandlerPandas(model_uri="my_uri"))
+ **NOTE:** This API and its implementation are under development and
+ do not provide backward compatibility guarantees.
Args:
model_uri: The URI to where the model is saved.