yeandy commented on code in PR #17470:
URL: https://github.com/apache/beam/pull/17470#discussion_r878255796


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -96,7 +107,9 @@ def expand(self, pcoll: beam.PCollection) -> 
beam.PCollection:
         pcoll
         # TODO(BEAM-14044): Hook into the batching DoFn APIs.
         | beam.BatchElements()
-        | beam.ParDo(_RunInferenceDoFn(self._model_loader, self._clock)))
+        | beam.ParDo(
+            _RunInferenceDoFn(
+                self._model_loader, self._prediction_params, self._clock)))

Review Comment:
   Note: The code has been changed slightly to `**kwargs` instead of 
`self._prediction_params`.
   
   Passing in these extra prediction-related parameters as side inputs makes 
more conceptual sense because both the `batch` examples and `prediction_params` 
kwargs are different types of inputs that are both passed at the same time into 
Pytorch model's predict call (which are both passed into its `forward` function)
   ```
   # Inside PytorchInferenceRunner()
   def run_inference(batch, model, **kwargs):
      prediction_params = kwargs.get('prediction_params')
      model(batch, **prediction_params)
   
   # this call goes inside Pytorch model ...
   class PytorchModel(nn.Module):
       def forward(self, batch, **prediction_params):
           ...
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to