This is an automated email from the ASF dual-hosted git repository.

riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 498c5c1d973 [Python] Added load model option args for PyTorch Model 
Handler (#26683)
498c5c1d973 is described below

commit 498c5c1d973958e509f7af2e182999089abffe9e
Author: Ritesh Ghorse <[email protected]>
AuthorDate: Tue May 16 16:47:22 2023 -0400

    [Python] Added load model option args for PyTorch Model Handler (#26683)
    
    * add load model option args for pytorch
    
    * change import order
    
    * update _load_model
    
    * update doc comments
    
    * rm .
    
    * pass with kwargs
    
    * Update pytorch_inference.py
---
 .../apache_beam/ml/inference/pytorch_inference.py  | 42 +++++++++++-------
 .../ml/inference/pytorch_inference_test.py         | 51 ++++++++++++++++++++++
 2 files changed, 78 insertions(+), 15 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py 
b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index 14abc202bdf..38a404c7a3a 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -84,7 +84,8 @@ def _load_model(
     state_dict_path: Optional[str],
     device: torch.device,
     model_params: Optional[Dict[str, Any]],
-    torch_script_model_path: Optional[str]):
+    torch_script_model_path: Optional[str],
+    load_model_args: Optional[Dict[str, Any]]):
   if device == torch.device('cuda') and not torch.cuda.is_available():
     logging.warning(
         "Model handler specified a 'GPU' device, but GPUs are not available. "
@@ -97,11 +98,11 @@ def _load_model(
     if not torch_script_model_path:
       file = FileSystems.open(state_dict_path, 'rb')
       model = model_class(**model_params)  # type: ignore[arg-type,misc]
-      state_dict = torch.load(file, map_location=device)
+      state_dict = torch.load(file, map_location=device, **load_model_args)
       model.load_state_dict(state_dict)
     else:
       file = FileSystems.open(torch_script_model_path, 'rb')
-      model = torch.jit.load(file, map_location=device)
+      model = torch.jit.load(file, map_location=device, **load_model_args)
   except RuntimeError as e:
     if device == torch.device('cuda'):
       message = "Loading the model onto a GPU device failed due to an " \
@@ -112,7 +113,8 @@ def _load_model(
           state_dict_path,
           torch.device('cpu'),
           model_params,
-          torch_script_model_path)
+          torch_script_model_path,
+          load_model_args)
     else:
       raise e
 
@@ -190,6 +192,7 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
       torch_script_model_path: Optional[str] = None,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      load_model_args: Optional[Dict[str, Any]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for PyTorch.
 
@@ -222,6 +225,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
         batch will be fed into the inference_fn as a Sequence of Tensors.
       max_batch_size: the maximum batch size to use when batching inputs. This
         batch will be fed into the inference_fn as a Sequence of Tensors.
+      load_model_args: a dictionary of parameters passed to the torch.load
+        function to specify custom config for loading models.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
@@ -244,6 +249,7 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
     if max_batch_size is not None:
       self._batching_kwargs['max_batch_size'] = max_batch_size
     self._torch_script_model_path = torch_script_model_path
+    self._load_model_args = load_model_args if load_model_args else {}
     self._env_vars = kwargs.get('env_vars', {})
 
     _validate_constructor_args(
@@ -254,11 +260,12 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
   def load_model(self) -> torch.nn.Module:
     """Loads and initializes a Pytorch model for processing."""
     model, device = _load_model(
-        self._model_class,
-        self._state_dict_path,
-        self._device,
-        self._model_params,
-        self._torch_script_model_path
+        model_class=self._model_class,
+        state_dict_path=self._state_dict_path,
+        device=self._device,
+        model_params=self._model_params,
+        torch_script_model_path=self._torch_script_model_path,
+        load_model_args=self._load_model_args
     )
     self._device = device
     return model
@@ -404,6 +411,7 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, 
torch.Tensor],
       torch_script_model_path: Optional[str] = None,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      load_model_args: Optional[Dict[str, Any]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for PyTorch.
 
@@ -436,11 +444,13 @@ class 
PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
       torch_script_model_path: Path to the torch script model.
          the model will be loaded using `torch.jit.load()`.
         `state_dict_path`, `model_class` and `model_params`
-         arguments will be disregarded..
+         arguments will be disregarded.
       min_batch_size: the minimum batch size to use when batching inputs. This
         batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
       max_batch_size: the maximum batch size to use when batching inputs. This
         batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
+      load_model_args: a dictionary of parameters passed to the torch.load
+        function to specify custom config for loading models.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
@@ -463,6 +473,7 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, 
torch.Tensor],
     if max_batch_size is not None:
       self._batching_kwargs['max_batch_size'] = max_batch_size
     self._torch_script_model_path = torch_script_model_path
+    self._load_model_args = load_model_args if load_model_args else {}
     self._env_vars = kwargs.get('env_vars', {})
     _validate_constructor_args(
         state_dict_path=self._state_dict_path,
@@ -472,11 +483,12 @@ class 
PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
   def load_model(self) -> torch.nn.Module:
     """Loads and initializes a Pytorch model for processing."""
     model, device = _load_model(
-        self._model_class,
-        self._state_dict_path,
-        self._device,
-        self._model_params,
-        self._torch_script_model_path
+        model_class=self._model_class,
+        state_dict_path=self._state_dict_path,
+        device=self._device,
+        model_params=self._model_params,
+        torch_script_model_path=self._torch_script_model_path,
+        load_model_args=self._load_model_args
     )
     self._device = device
     return model
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py 
b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
index 76eea44fe55..2bcd56dbdf9 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
@@ -37,6 +37,7 @@ try:
   import torch
   from apache_beam.ml.inference.base import PredictionResult
   from apache_beam.ml.inference.base import RunInference
+  from apache_beam.ml.inference import pytorch_inference
   from apache_beam.ml.inference.pytorch_inference import 
default_keyed_tensor_inference_fn
   from apache_beam.ml.inference.pytorch_inference import 
default_tensor_inference_fn
   from apache_beam.ml.inference.pytorch_inference import 
make_keyed_tensor_model_fn
@@ -893,5 +894,55 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
       self.assertTrue((os.environ['FOO']) == 'bar')
 
 
[email protected]_pytorch
+class PytorchInferenceTestWithMocks(unittest.TestCase):
+  def setUp(self):
+    self._load_model = pytorch_inference._load_model
+    pytorch_inference._load_model = unittest.mock.MagicMock(
+        return_value=("model", "device"))
+    self.tmpdir = tempfile.mkdtemp()
+    self.state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
+                                   ('linear.bias', torch.Tensor([0.5]))])
+    self.torch_path = os.path.join(self.tmpdir, 'torch_model.pt')
+    torch.save(self.state_dict, self.torch_path)
+    self.model_params = {'input_dim': 2, 'output_dim': 1}
+
+  def tearDown(self):
+    pytorch_inference._load_model = self._load_model
+    shutil.rmtree(self.tmpdir)
+
+  def test_load_model_args_tensor(self):
+    load_model_args = {'weights_only': True}
+    model_handler = PytorchModelHandlerTensor(
+        state_dict_path=self.torch_path,
+        model_class=PytorchLinearRegression,
+        model_params=self.model_params,
+        load_model_args=load_model_args)
+    model_handler.load_model()
+    pytorch_inference._load_model.assert_called_with(
+        model_class=PytorchLinearRegression,
+        state_dict_path=self.torch_path,
+        device=torch.device('cpu'),
+        model_params=self.model_params,
+        torch_script_model_path=None,
+        load_model_args=load_model_args)
+
+  def test_load_model_args_keyed_tensor(self):
+    load_model_args = {'weights_only': True}
+    model_handler = PytorchModelHandlerKeyedTensor(
+        state_dict_path=self.torch_path,
+        model_class=PytorchLinearRegression,
+        model_params=self.model_params,
+        load_model_args=load_model_args)
+    model_handler.load_model()
+    pytorch_inference._load_model.assert_called_with(
+        model_class=PytorchLinearRegression,
+        state_dict_path=self.torch_path,
+        device=torch.device('cpu'),
+        model_params=self.model_params,
+        torch_script_model_path=None,
+        load_model_args=load_model_args)
+
+
 if __name__ == '__main__':
   unittest.main()

Reply via email to