This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/batching-doc-string
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 4e7c86ec07cf1d7d100c6adfa80dbf7aa6dabc71
Author: Danny McCormick <[email protected]>
AuthorDate: Thu Feb 9 11:33:45 2023 -0500

    Add batching args to ModelHandlers docs
---
 sdks/python/apache_beam/ml/inference/pytorch_inference.py |  9 +++++++++
 sdks/python/apache_beam/ml/inference/sklearn_inference.py | 13 +++++++++++++
 2 files changed, 22 insertions(+)

diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py 
b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index 693f5995627..520f133be5d 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -176,6 +176,10 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
         Otherwise, it will be CPU.
       inference_fn: the inference function to use during RunInference.
         default=_default_tensor_inference_fn
+      min_batch_size: the minimum batch size to use when batching inputs. This
+        batch will be fed into the inference_fn as a Sequence of Tensors.
+      max_batch_size: the maximum batch size to use when batching inputs. This
+        batch will be fed into the inference_fn as a Sequence of Tensors.
 
     **Supported Versions:** RunInference APIs in Apache Beam have been tested
     with PyTorch 1.9 and 1.10.
@@ -364,6 +368,11 @@ class 
PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
         Otherwise, it will be CPU.
       inference_fn: the function to invoke on run_inference.
         default = default_keyed_tensor_inference_fn
+      min_batch_size: the minimum batch size to use when batching inputs. This
+        batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
+      max_batch_size: the maximum batch size to use when batching inputs. This
+        batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
+
 
     **Supported Versions:** RunInference APIs in Apache Beam have been tested
     on torch>=1.9.0,<1.14.0.
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py 
b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index 64dde0b945c..b1c0ea7d7e3 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -104,6 +104,12 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
         default=pickle
       inference_fn: The inference function to use.
         default=_default_numpy_inference_fn
+      min_batch_size: the minimum batch size to use when batching inputs. This
+        batch will be fed into the inference_fn as a Sequence of Numpy
+        ndarrays.
+      max_batch_size: the maximum batch size to use when batching inputs. This
+        batch will be fed into the inference_fn as a Sequence of Numpy
+        ndarrays.
     """
     self._model_uri = model_uri
     self._model_file_type = model_file_type
@@ -211,6 +217,13 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
         default=pickle
       inference_fn: The inference function to use.
         default=_default_pandas_inference_fn
+      min_batch_size: the minimum batch size to use when batching inputs. This
+        batch will be fed into the inference_fn as a Sequence of Pandas
+        Dataframes.
+      max_batch_size: the maximum batch size to use when batching inputs. This
+        batch will be fed into the inference_fn as a Sequence of Pandas
+        Dataframes.
+
     """
     self._model_uri = model_uri
     self._model_file_type = model_file_type

Reply via email to