AnandInguva commented on code in PR #26309:
URL: https://github.com/apache/beam/pull/26309#discussion_r1175694787
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -300,6 +359,116 @@ def validate_inference_args(self, inference_args:
Optional[Dict[str, Any]]):
def update_model_path(self, model_path: Optional[str] = None):
return self._unkeyed.update_model_path(model_path=model_path)
+ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._unkeyed.get_preprocess_fns()
+
+ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._unkeyed.get_postprocess_fns()
+
+
+class _PreProcessingModelHandler(Generic[ExampleT, PredictionT, ModelT, PreT],
+ ModelHandler[PreT, PredictionT, ModelT]):
+ def __init__(
+ self,
+ base: ModelHandler[ExampleT, PredictionT, ModelT],
+ preprocess_fn: Callable[[PreT], ExampleT]):
+ """A ModelHandler that has a preprocessing function associated with it.
+
+ Args:
+ base: An implementation of the underlying model handler.
+ preprocess_fn: the preprocessing function to use.
+ """
+ self._base = base
+ self._preprocess_fn = preprocess_fn
+
+ def load_model(self) -> ModelT:
+ return self._base.load_model()
+
+ def run_inference(
+ self,
+ batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
+ model: ModelT,
+ inference_args: Optional[Dict[str, Any]] = None
+ ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
+ return self._base.run_inference(batch, model, inference_args)
+
+ def get_num_bytes(
+ self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
+ return self._base.get_num_bytes(batch)
+
+ def get_metrics_namespace(self) -> str:
+ return self._base.get_metrics_namespace()
+
+ def get_resource_hints(self):
+ return self._base.get_resource_hints()
+
+ def batch_elements_kwargs(self):
+ return self._base.batch_elements_kwargs()
+
+ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+ return self._base.validate_inference_args(inference_args)
+
+ def update_model_path(self, model_path: Optional[str] = None):
+ return self._base.update_model_path(model_path=model_path)
+
+ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
Review Comment:
So you are saying now the RunInference will get the input as `PreProcessT`?
`ExampleT` -> `PreProcessT` and then this is passed to RunInference? What
happens if the user chains two preprocessing methods?
`ExampleT` -> `PreProcessT` -> ?
For the `?` stage, does the `_PreProcessModelHandler` takes `ExampleT` to be
a type of `PreProcessT` and output `PreProcessT`?
Sorry if the question doesn't make sense.
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -174,6 +183,38 @@ def update_model_path(self, model_path: Optional[str] =
None):
"""Update the model paths produced by side inputs."""
pass
+ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ """Gets all preprocessing functions to be run before batching/inference.
+ Functions are in order that they should be applied."""
+ return []
+
+ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ """Gets all postprocessing functions to be run after inference.
+ Functions are in order that they should be applied."""
+ return []
+
+ def with_preprocess_fn(
+ self, fn: Callable[[PreProcessT], ExampleT]
+ ) -> 'ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT]':
Review Comment:
```suggestion
) -> ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT]:
```
similar change to the `with_postprocess_fn`. Was it intentional to have
single quotes?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]