ryanthompson591 commented on code in PR #21868:
URL: https://github.com/apache/beam/pull/21868#discussion_r898091963
##########
sdks/python/apache_beam/ml/inference/sklearn_inference.py:
##########
@@ -76,29 +85,47 @@ def load_model(self) -> BaseEstimator:
def run_inference(
self, batch: Sequence[numpy.ndarray], model: BaseEstimator,
**kwargs) -> Iterable[PredictionResult]:
+ """Runs inferences on a batch of numpy arrays.
+
+ Returns:
+ An Iterable of type PredictionResult.
+ """
# vectorize data for better performance
vectorized_batch = numpy.stack(batch, axis=0)
predictions = model.predict(vectorized_batch)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
def get_num_bytes(self, batch: Sequence[pandas.DataFrame]) -> int:
- """Returns the number of bytes of data for a batch."""
+ """
+ Returns:
+ The number of bytes of data for a batch.
+ """
return sum(sys.getsizeof(element) for element in batch)
class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
PredictionResult,
BaseEstimator]):
- """ Implementation of the ModelHandler interface for scikit-learn that
- supports pandas dataframes.
-
- NOTE: This API and its implementation are under development and
- do not provide backward compatibility guarantees.
- """
def __init__(
self,
model_uri: str,
model_file_type: ModelFileType = ModelFileType.PICKLE):
+ """Implementation of the ModelHandler interface for scikit-learn that
+ supports pandas dataframes.
+
+ Example Usage:
+ pcol | RunInference(SklearnModelHandlerPandas(model_uri="my_uri"))
+
+ NOTE::
+ This API and its implementation are under development and
+ do not provide backward compatibility guarantees.
+
+
Review Comment:
Done.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]