damccorm commented on code in PR #26309:
URL: https://github.com/apache/beam/pull/26309#discussion_r1175775190


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -300,6 +359,116 @@ def validate_inference_args(self, inference_args: 
Optional[Dict[str, Any]]):
   def update_model_path(self, model_path: Optional[str] = None):
     return self._unkeyed.update_model_path(model_path=model_path)
 
+  def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._unkeyed.get_preprocess_fns()
+
+  def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
+    return self._unkeyed.get_postprocess_fns()
+
+
+class _PreProcessingModelHandler(Generic[ExampleT, PredictionT, ModelT, PreT],
+                                 ModelHandler[PreT, PredictionT, ModelT]):
+  def __init__(
+      self,
+      base: ModelHandler[ExampleT, PredictionT, ModelT],
+      preprocess_fn: Callable[[PreT], ExampleT]):
+    """A ModelHandler that has a preprocessing function associated with it.
+
+    Args:
+      base: An implementation of the underlying model handler.
+      preprocess_fn: the preprocessing function to use.
+    """
+    self._base = base
+    self._preprocess_fn = preprocess_fn
+
+  def load_model(self) -> ModelT:
+    return self._base.load_model()
+
+  def run_inference(
+      self,
+      batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]],
+      model: ModelT,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Union[Iterable[PredictionT], Iterable[Tuple[KeyT, PredictionT]]]:
+    return self._base.run_inference(batch, model, inference_args)
+
+  def get_num_bytes(
+      self, batch: Sequence[Union[ExampleT, Tuple[KeyT, ExampleT]]]) -> int:
+    return self._base.get_num_bytes(batch)
+
+  def get_metrics_namespace(self) -> str:
+    return self._base.get_metrics_namespace()
+
+  def get_resource_hints(self):
+    return self._base.get_resource_hints()
+
+  def batch_elements_kwargs(self):
+    return self._base.batch_elements_kwargs()
+
+  def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
+    return self._base.validate_inference_args(inference_args)
+
+  def update_model_path(self, model_path: Optional[str] = None):
+    return self._base.update_model_path(model_path=model_path)
+
+  def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:

Review Comment:
   > So you are saying now the RunInference will get the input as PreProcessT?
   
   Yes
   
   > ExampleT -> PreProcessT and then this is passed to RunInference? What 
happens if the user chains two preprocessing methods?
   
   Lets say you define:
   
   ```
   def pre1(in: X) -> Y:
      ...
   
   def pre2(in: Y) -> Z:
     ...
   
   mh = ModelHandler[Z, PredictionT, ModelT](...)
   mh2 = mh.with_preprocessing_fn(pre2)
   mh1 = mh2.with_preprocess_fn(pre1)
   ```
   
   `mh` would have types `[Z, PredictionT, ModelT]`, `mh2` would have types 
`[Y, PredictionT, ModelT]`, and `mh1` would have types `[X, PredictionT, 
ModelT]`. So these are composable; as you stack more preprocessing fns, they 
will take on the input type of the newest function.
   
   Note that `PreProcessT` doesn't mean the same thing for each of these. 
`PreProcessT` is a generic variable that takes on a type that is consistent for 
a given instantiation of an object, but not necessarily consistent across 
multiple objects (which is why you can have 2 different model handlers defined 
with different types associated with `ExampleT`)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to