damccorm commented on code in PR #25321:
URL: https://github.com/apache/beam/pull/25321#discussion_r1101716682


##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -201,20 +207,30 @@ def __init__(
       logging.info("Device is set to CPU")
       self._device = torch.device('cpu')
     self._model_class = model_class
-    self._model_params = model_params
+    self._model_params = model_params if model_params else {}
     self._inference_fn = inference_fn
-    self._use_torch_script_format = use_torch_script_format
+    self._batching_kwargs = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
+    self._torch_script_model_path = torch_script_model_path
 
-    self._validate_func_args()
+    self.validate_constructor_args()
 
-  def _validate_func_args(self):
-    if not self._use_torch_script_format and (self._model_class is None or
-                                              self._model_params is None):
+  def validate_constructor_args(self):
+    if self._state_dict_path and not self._model_class:

Review Comment:
   ```suggestion
       if bool(self._state_dict_path) != bool(self._model_class):
   ```
   
   or alternatively:
   
   ```suggestion
       if (self._state_dict_path and not self._model_class) or (not 
self._state_dict_path and self._model_class):
   ```
   
   We just need to catch the case where they provide the model class but not 
the state_dict_path



##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -400,11 +430,30 @@ def __init__(
       logging.info("Device is set to CPU")
       self._device = torch.device('cpu')
     self._model_class = model_class
-    self._model_params = model_params
+    self._model_params = model_params if model_params else {}
     self._inference_fn = inference_fn
-    self._use_torch_script_format = use_torch_script_format
+    self._batching_kwargs = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
+    self._torch_script_model_path = torch_script_model_path
+
+    self.validate_constructor_args()
+
+  def validate_constructor_args(self):

Review Comment:
   Can we share this function as `_validate_constructor_args` (like 
`_load_model`)



##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -201,20 +207,30 @@ def __init__(
       logging.info("Device is set to CPU")
       self._device = torch.device('cpu')
     self._model_class = model_class
-    self._model_params = model_params
+    self._model_params = model_params if model_params else {}
     self._inference_fn = inference_fn
-    self._use_torch_script_format = use_torch_script_format
+    self._batching_kwargs = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
+    self._torch_script_model_path = torch_script_model_path
 
-    self._validate_func_args()
+    self.validate_constructor_args()
 
-  def _validate_func_args(self):
-    if not self._use_torch_script_format and (self._model_class is None or
-                                              self._model_params is None):
+  def validate_constructor_args(self):
+    if self._state_dict_path and not self._model_class:
       raise RuntimeError(
-          "Please pass both `model_class` and `model_params` to the torch "
-          "model handler when using it with PyTorch. "
-          "If you opt to load the entire that was saved using TorchScript, "
-          "set `use_torch_script_format` to True.")
+          "A state_dict_path has been supplied to the model "
+          "handler, but the required model_class is missing. "
+          "Please provide the model_class in order to "
+          "successfully load the state_dict_path.")
+
+    if self._torch_script_model_path:
+      if self._state_dict_path and self._model_class:

Review Comment:
   ```suggestion
         if self._state_dict_path or self._model_class:
   ```



##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -400,11 +430,30 @@ def __init__(
       logging.info("Device is set to CPU")
       self._device = torch.device('cpu')
     self._model_class = model_class
-    self._model_params = model_params
+    self._model_params = model_params if model_params else {}
     self._inference_fn = inference_fn
-    self._use_torch_script_format = use_torch_script_format
+    self._batching_kwargs = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
+    self._torch_script_model_path = torch_script_model_path
+
+    self.validate_constructor_args()
+
+  def validate_constructor_args(self):

Review Comment:
   If you're planning on doing something like that in your follow up refactor 
pr that you mentioned that is fine as well.



##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -201,20 +207,30 @@ def __init__(
       logging.info("Device is set to CPU")
       self._device = torch.device('cpu')
     self._model_class = model_class
-    self._model_params = model_params
+    self._model_params = model_params if model_params else {}
     self._inference_fn = inference_fn
-    self._use_torch_script_format = use_torch_script_format
+    self._batching_kwargs = {}
+    if min_batch_size is not None:
+      self._batching_kwargs['min_batch_size'] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs['max_batch_size'] = max_batch_size
+    self._torch_script_model_path = torch_script_model_path
 
-    self._validate_func_args()
+    self.validate_constructor_args()
 
-  def _validate_func_args(self):
-    if not self._use_torch_script_format and (self._model_class is None or
-                                              self._model_params is None):
+  def validate_constructor_args(self):
+    if self._state_dict_path and not self._model_class:
       raise RuntimeError(
-          "Please pass both `model_class` and `model_params` to the torch "
-          "model handler when using it with PyTorch. "
-          "If you opt to load the entire that was saved using TorchScript, "
-          "set `use_torch_script_format` to True.")
+          "A state_dict_path has been supplied to the model "
+          "handler, but the required model_class is missing. "
+          "Please provide the model_class in order to "
+          "successfully load the state_dict_path.")
+
+    if self._torch_script_model_path:
+      if self._state_dict_path and self._model_class:

Review Comment:
   (shouldn't really matter if we check the above case)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to