yeandy commented on code in PR #21868:
URL: https://github.com/apache/beam/pull/21868#discussion_r898505276


##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -112,41 +128,48 @@ def run_inference(
     return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
 
   def get_num_bytes(self, batch: Sequence[torch.Tensor]) -> int:
-    """Returns the number of bytes of data for a batch of Tensors."""
+    """
+    Returns:
+      The number of bytes of data for a batch of Tensors.
+    """
     return sum((el.element_size() for tensor in batch for el in tensor))
 
   def get_metrics_namespace(self) -> str:
     """
-    Returns a namespace for metrics collected by the RunInference transform.
+    Returns:
+       A namespace for metrics collected by the RunInference transform.
     """
     return 'RunInferencePytorch'
 
 
 class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
                                                   PredictionResult,
                                                   torch.nn.Module]):
-  """ Implementation of the ModelHandler interface for PyTorch.
-
-      NOTE: This API and its implementation are under development and
-      do not provide backward compatibility guarantees.
-  """
   def __init__(
       self,
       state_dict_path: str,
       model_class: Callable[..., torch.nn.Module],
       model_params: Dict[str, Any],
       device: str = 'CPU'):
-    """
-    Initializes a PytorchModelHandlerKeyedTensor
-    :param state_dict_path: path to the saved dictionary of the model state.
-    :param model_class: class of the Pytorch model that defines the model
-    structure.
-    :param device: the device on which you wish to run the model. If
-    ``device = GPU`` then a GPU device will be used if it is available.
-    Otherwise, it will be CPU.
+    """Implementation of the ModelHandler interface for PyTorch.
+
+    Example Usage:
+      pcoll | RunInference(

Review Comment:
   ```suggestion
       Example Usage::
   
         pcoll | RunInference(
   ```



##########
sdks/python/apache_beam/ml/inference/sklearn_inference.py:
##########
@@ -76,13 +77,21 @@ def _validate_inference_args(inference_args):
 class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
                                             PredictionResult,
                                             BaseEstimator]):
-  """ Implementation of the ModelHandler interface for scikit-learn
-      using numpy arrays as input.
-  """
   def __init__(
       self,
       model_uri: str,
       model_file_type: ModelFileType = ModelFileType.PICKLE):
+    """ Implementation of the ModelHandler interface for scikit-learn
+    using numpy arrays as input.
+
+    Example Usage:
+      pcoll | RunInference(SklearnModelHandlerNumpy(model_uri="my_uri"))

Review Comment:
   ```suggestion
       Example Usage::
   
         pcoll | RunInference(SklearnModelHandlerNumpy(model_uri="my_uri"))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to