TheNeuralBit commented on code in PR #17470:
URL: https://github.com/apache/beam/pull/17470#discussion_r870606313


##########
sdks/python/apache_beam/ml/inference/pytorch.py:
##########
@@ -45,11 +46,22 @@ def run_inference(self, batch: List[torch.Tensor],
     This method stacks the list of Tensors in a vectorized format to optimize
     the inference call.
     """
-
-    batch = torch.stack(batch)
-    if batch.device != self._device:
-      batch = batch.to(self._device)
-    predictions = model(batch)
+    if isinstance(batch[0], dict):

Review Comment:
   Fair point - another option that @robertwb brought up offline would be to 
have multiple implementations, e.g. one for the single-argument special-case, 
and another for the **kwargs case (I'm not asking you to do this now, it's just 
something to think about)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to