ryanthompson591 commented on code in PR #21806:
URL: https://github.com/apache/beam/pull/21806#discussion_r897040957


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -125,11 +130,14 @@ def load_model(self) -> ModelT:
     return self._unkeyed.load_model()
 
   def run_inference(
-      self, batch: Sequence[Tuple[KeyT, ExampleT]], model: ModelT,
-      **kwargs) -> Iterable[Tuple[KeyT, PredictionT]]:
+      self,
+      batch: Sequence[Tuple[KeyT, ExampleT]],
+      model: ModelT,
+      extra_kwargs: Optional[Dict[str, Any]] = None
+  ) -> Iterable[Tuple[KeyT, PredictionT]]:
     keys, unkeyed_batch = zip(*batch)
     return zip(
-        keys, self._unkeyed.run_inference(unkeyed_batch, model, **kwargs))
+        keys, self._unkeyed.run_inference(unkeyed_batch, model, extra_kwargs))

Review Comment:
   These arguments should be passed through if they exist and not otherwise.



##########
sdks/python/apache_beam/ml/inference/pytorch_inference_test.py:
##########
@@ -237,10 +237,8 @@ def test_run_inference_kwargs_prediction_params(self):
     inference_runner = TestPytorchModelHandlerForInferenceOnly(
         torch.device('cpu'))
     predictions = inference_runner.run_inference(
-        batch=KWARGS_TORCH_EXAMPLES,
-        model=model,
-        prediction_params=prediction_params)
-    for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS):
+        batch=KEYED_TORCH_EXAMPLES, model=model, extra_kwargs=extra_kwargs)

Review Comment:
   I think we should keep these args as anonymous 



##########
sdks/python/apache_beam/ml/inference/sklearn_inference.py:
##########
@@ -74,8 +77,11 @@ def load_model(self) -> BaseEstimator:
     return _load_model(self._model_uri, self._model_file_type)
 
   def run_inference(
-      self, batch: Sequence[numpy.ndarray], model: BaseEstimator,
-      **kwargs) -> Iterable[PredictionResult]:
+      self,
+      batch: Sequence[numpy.ndarray],
+      model: BaseEstimator,
+      extra_kwargs: Optional[Dict[str,

Review Comment:
   No. Let's not put these extra args into sklearn.  The model doesn't want 
them, they don't need to be there.
   
   This will be extra unused parameters that don't need to be there. They don't 
need to be here and they don't need to be in tensorflow implementations.
   
   Don't put them here. If anything the only thing this PR should do is remove 
**kwargs from this interface, it shouldn't be here.
   
   If someone puts extra args into this interface that it isn't expecting it 
should fail instead of work silently.



##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -84,8 +86,11 @@ def load_model(self) -> ModelT:
     """Loads and initializes a model for processing."""
     raise NotImplementedError(type(self))
 
-  def run_inference(self, batch: Sequence[ExampleT], model: ModelT,
-                    **kwargs) -> Iterable[PredictionT]:
+  def run_inference(
+      self,
+      batch: Sequence[ExampleT],
+      model: ModelT,
+      extra_kwargs: Optional[Dict[str, Any]] = None) -> Iterable[PredictionT]:

Review Comment:
   this is a bad name. Something like inference_arguments should be the right 
name since these are not just generic kwargs but rather inference specific 
arguments.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to