This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/batching-doc-string in repository https://gitbox.apache.org/repos/asf/beam.git
commit 4e7c86ec07cf1d7d100c6adfa80dbf7aa6dabc71 Author: Danny McCormick <[email protected]> AuthorDate: Thu Feb 9 11:33:45 2023 -0500 Add batching args to ModelHandlers docs --- sdks/python/apache_beam/ml/inference/pytorch_inference.py | 9 +++++++++ sdks/python/apache_beam/ml/inference/sklearn_inference.py | 13 +++++++++++++ 2 files changed, 22 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 693f5995627..520f133be5d 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -176,6 +176,10 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, Otherwise, it will be CPU. inference_fn: the inference function to use during RunInference. default=_default_tensor_inference_fn + min_batch_size: the minimum batch size to use when batching inputs. This + batch will be fed into the inference_fn as a Sequence of Tensors. + max_batch_size: the maximum batch size to use when batching inputs. This + batch will be fed into the inference_fn as a Sequence of Tensors. **Supported Versions:** RunInference APIs in Apache Beam have been tested with PyTorch 1.9 and 1.10. @@ -364,6 +368,11 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], Otherwise, it will be CPU. inference_fn: the function to invoke on run_inference. default = default_keyed_tensor_inference_fn + min_batch_size: the minimum batch size to use when batching inputs. This + batch will be fed into the inference_fn as a Sequence of Keyed Tensors. + max_batch_size: the maximum batch size to use when batching inputs. This + batch will be fed into the inference_fn as a Sequence of Keyed Tensors. + **Supported Versions:** RunInference APIs in Apache Beam have been tested on torch>=1.9.0,<1.14.0. diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index 64dde0b945c..b1c0ea7d7e3 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -104,6 +104,12 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, default=pickle inference_fn: The inference function to use. default=_default_numpy_inference_fn + min_batch_size: the minimum batch size to use when batching inputs. This + batch will be fed into the inference_fn as a Sequence of Numpy + ndarrays. + max_batch_size: the maximum batch size to use when batching inputs. This + batch will be fed into the inference_fn as a Sequence of Numpy + ndarrays. """ self._model_uri = model_uri self._model_file_type = model_file_type @@ -211,6 +217,13 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame, default=pickle inference_fn: The inference function to use. default=_default_pandas_inference_fn + min_batch_size: the minimum batch size to use when batching inputs. This + batch will be fed into the inference_fn as a Sequence of Pandas + Dataframes. + max_batch_size: the maximum batch size to use when batching inputs. This + batch will be fed into the inference_fn as a Sequence of Pandas + Dataframes. + """ self._model_uri = model_uri self._model_file_type = model_file_type
