AnandInguva commented on code in PR #26309:
URL: https://github.com/apache/beam/pull/26309#discussion_r1175905168
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -300,6 +359,116 @@ def validate_inference_args(self, inference_args:
Optional[Dict[str, Any]]):
def update_model_path(self, model_path: Optional[str] = None):
return self._unkeyed.update_model_path(model_path=model_path)
+ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._unkeyed.get_preprocess_fns()
+
+ def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+ return self._unkeyed.get_postprocess_fns()
+
+
+class _PreProcessingModelHandler(Generic[ExampleT, PredictionT, ModelT, PreT],
+ ModelHandler[PreT, PredictionT, ModelT]):
+ def __init__(
+ self,
+ base: ModelHandler[ExampleT, PredictionT, ModelT],
+ preprocess_fn: Callable[[PreT], ExampleT]):
+ """A ModelHandler that has a preprocessing function associated with it.
+
+ Args:
+ base: An implementation of the underlying model handler.
+ preprocess_fn: the preprocessing function to use.
+ """
+ self._base = base
+ self._preprocess_fn = preprocess_fn
+
+ def load_model(self) -> ModelT:
+ return self._base.load_model()
+
+ def run_inference(
+ self,
+ batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
+ model: ModelT,
+ inference_args: Optional[Dict[str, Any]] = None
+ ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
+ return self._base.run_inference(batch, model, inference_args)
+
+ def get_num_bytes(
+ self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
+ return self._base.get_num_bytes(batch)
+
+ def get_metrics_namespace(self) -> str:
+ return self._base.get_metrics_namespace()
+
+ def get_resource_hints(self):
+ return self._base.get_resource_hints()
+
+ def batch_elements_kwargs(self):
+ return self._base.batch_elements_kwargs()
+
+ def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+ return self._base.validate_inference_args(inference_args)
+
+ def update_model_path(self, model_path: Optional[str] = None):
+ return self._base.update_model_path(model_path=model_path)
+
+ def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
Review Comment:
Thanks. that makes sense.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]