yeandy commented on code in PR #17470:
URL: https://github.com/apache/beam/pull/17470#discussion_r878083266
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -39,7 +41,12 @@ class FakeInferenceRunner(base.InferenceRunner):
def __init__(self, clock=None):
self._mock_clock = clock
- def run_inference(self, batch: Any, model: Any) -> Iterable[Any]:
+ def run_inference(
+ self,
+ batch: Any,
+ model: Any,
+ prediction_params: Optional[Dict[str, Any]] = None,
Review Comment:
If I take out `prediction_params`, I get
```
result_generator = self._inference_runner.run_inference(
> examples, self._model, self._prediction_params)
E TypeError: run_inference() takes 3 positional arguments but 4 were
given [while running 'RunInference/ParDo(_RunInferenceDoFn)']
apache_beam/ml/inference/base.py:216: TypeError
```
If we allow this not to be a required param, then maybe we can check if
`_prediction_params` is `None` like so?
```
if not self._prediction_params:
result_generator = self._inference_runner.run_inference(
examples, self._model, self._prediction_params)
else:
result_generator = self._inference_runner.run_inference(
examples, self._model)
```
That feels a bit messy though, and breaks our `run_inference` interface,
assuming we want to use this modification:
```
class InferenceRunner():
def run_inference(
self,
batch: List[Any],
model: Any,
prediction_params: Optional[Dict[str, Any]] = None) -> Iterable[Any]:
"""Runs inferences on a batch of examples and
returns an Iterable of Predictions."""
raise NotImplementedError(type(self))
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]