AnandInguva commented on code in PR #25321:
URL: https://github.com/apache/beam/pull/25321#discussion_r1101906346


##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -400,11 +430,30 @@ def __init__(
       logging.info("Device is set to CPU")
       self._device = torch.device('cpu')
     self._model_class = model_class
-    self._model_params = model_params
+    self._model_params = model_params if model_params else {}
     self._inference_fn = inference_fn
-    self._use_torch_script_format = use_torch_script_format
+    self._batching_kwargs = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
+    self._torch_script_model_path = torch_script_model_path
+
+    self.validate_constructor_args()
+
+  def validate_constructor_args(self):

Review Comment:
   Yes, I will take care of refactoring in the next PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to