This is an automated email from the ASF dual-hosted git repository.
riteshghorse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 498c5c1d973 [Python] Added load model option args for PyTorch Model
Handler (#26683)
498c5c1d973 is described below
commit 498c5c1d973958e509f7af2e182999089abffe9e
Author: Ritesh Ghorse <[email protected]>
AuthorDate: Tue May 16 16:47:22 2023 -0400
[Python] Added load model option args for PyTorch Model Handler (#26683)
* add load model option args for pytorch
* change import order
* update _load_model
* update doc comments
* rm .
* pass with kwargs
* Update pytorch_inference.py
---
.../apache_beam/ml/inference/pytorch_inference.py | 42 +++++++++++-------
.../ml/inference/pytorch_inference_test.py | 51 ++++++++++++++++++++++
2 files changed, 78 insertions(+), 15 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index 14abc202bdf..38a404c7a3a 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -84,7 +84,8 @@ def _load_model(
state_dict_path: Optional[str],
device: torch.device,
model_params: Optional[Dict[str, Any]],
- torch_script_model_path: Optional[str]):
+ torch_script_model_path: Optional[str],
+ load_model_args: Optional[Dict[str, Any]]):
if device == torch.device('cuda') and not torch.cuda.is_available():
logging.warning(
"Model handler specified a 'GPU' device, but GPUs are not available. "
@@ -97,11 +98,11 @@ def _load_model(
if not torch_script_model_path:
file = FileSystems.open(state_dict_path, 'rb')
model = model_class(**model_params) # type: ignore[arg-type,misc]
- state_dict = torch.load(file, map_location=device)
+ state_dict = torch.load(file, map_location=device, **load_model_args)
model.load_state_dict(state_dict)
else:
file = FileSystems.open(torch_script_model_path, 'rb')
- model = torch.jit.load(file, map_location=device)
+ model = torch.jit.load(file, map_location=device, **load_model_args)
except RuntimeError as e:
if device == torch.device('cuda'):
message = "Loading the model onto a GPU device failed due to an " \
@@ -112,7 +113,8 @@ def _load_model(
state_dict_path,
torch.device('cpu'),
model_params,
- torch_script_model_path)
+ torch_script_model_path,
+ load_model_args)
else:
raise e
@@ -190,6 +192,7 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
torch_script_model_path: Optional[str] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
+ load_model_args: Optional[Dict[str, Any]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for PyTorch.
@@ -222,6 +225,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
batch will be fed into the inference_fn as a Sequence of Tensors.
max_batch_size: the maximum batch size to use when batching inputs. This
batch will be fed into the inference_fn as a Sequence of Tensors.
+ load_model_args: a dictionary of parameters passed to the torch.load
+ function to specify custom config for loading models.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.
@@ -244,6 +249,7 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
self._torch_script_model_path = torch_script_model_path
+ self._load_model_args = load_model_args if load_model_args else {}
self._env_vars = kwargs.get('env_vars', {})
_validate_constructor_args(
@@ -254,11 +260,12 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model, device = _load_model(
- self._model_class,
- self._state_dict_path,
- self._device,
- self._model_params,
- self._torch_script_model_path
+ model_class=self._model_class,
+ state_dict_path=self._state_dict_path,
+ device=self._device,
+ model_params=self._model_params,
+ torch_script_model_path=self._torch_script_model_path,
+ load_model_args=self._load_model_args
)
self._device = device
return model
@@ -404,6 +411,7 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str,
torch.Tensor],
torch_script_model_path: Optional[str] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
+ load_model_args: Optional[Dict[str, Any]] = None,
**kwargs):
"""Implementation of the ModelHandler interface for PyTorch.
@@ -436,11 +444,13 @@ class
PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
torch_script_model_path: Path to the torch script model.
the model will be loaded using `torch.jit.load()`.
`state_dict_path`, `model_class` and `model_params`
- arguments will be disregarded..
+ arguments will be disregarded.
min_batch_size: the minimum batch size to use when batching inputs. This
batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
max_batch_size: the maximum batch size to use when batching inputs. This
batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
+ load_model_args: a dictionary of parameters passed to the torch.load
+ function to specify custom config for loading models.
kwargs: 'env_vars' can be used to set environment variables
before loading the model.
@@ -463,6 +473,7 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str,
torch.Tensor],
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
self._torch_script_model_path = torch_script_model_path
+ self._load_model_args = load_model_args if load_model_args else {}
self._env_vars = kwargs.get('env_vars', {})
_validate_constructor_args(
state_dict_path=self._state_dict_path,
@@ -472,11 +483,12 @@ class
PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor],
def load_model(self) -> torch.nn.Module:
"""Loads and initializes a Pytorch model for processing."""
model, device = _load_model(
- self._model_class,
- self._state_dict_path,
- self._device,
- self._model_params,
- self._torch_script_model_path
+ model_class=self._model_class,
+ state_dict_path=self._state_dict_path,
+ device=self._device,
+ model_params=self._model_params,
+ torch_script_model_path=self._torch_script_model_path,
+ load_model_args=self._load_model_args
)
self._device = device
return model
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
index 76eea44fe55..2bcd56dbdf9 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference_test.py
@@ -37,6 +37,7 @@ try:
import torch
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
+ from apache_beam.ml.inference import pytorch_inference
from apache_beam.ml.inference.pytorch_inference import
default_keyed_tensor_inference_fn
from apache_beam.ml.inference.pytorch_inference import
default_tensor_inference_fn
from apache_beam.ml.inference.pytorch_inference import
make_keyed_tensor_model_fn
@@ -893,5 +894,55 @@ class PytorchRunInferencePipelineTest(unittest.TestCase):
self.assertTrue((os.environ['FOO']) == 'bar')
[email protected]_pytorch
+class PytorchInferenceTestWithMocks(unittest.TestCase):
+ def setUp(self):
+ self._load_model = pytorch_inference._load_model
+ pytorch_inference._load_model = unittest.mock.MagicMock(
+ return_value=("model", "device"))
+ self.tmpdir = tempfile.mkdtemp()
+ self.state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
+ ('linear.bias', torch.Tensor([0.5]))])
+ self.torch_path = os.path.join(self.tmpdir, 'torch_model.pt')
+ torch.save(self.state_dict, self.torch_path)
+ self.model_params = {'input_dim': 2, 'output_dim': 1}
+
+ def tearDown(self):
+ pytorch_inference._load_model = self._load_model
+ shutil.rmtree(self.tmpdir)
+
+ def test_load_model_args_tensor(self):
+ load_model_args = {'weights_only': True}
+ model_handler = PytorchModelHandlerTensor(
+ state_dict_path=self.torch_path,
+ model_class=PytorchLinearRegression,
+ model_params=self.model_params,
+ load_model_args=load_model_args)
+ model_handler.load_model()
+ pytorch_inference._load_model.assert_called_with(
+ model_class=PytorchLinearRegression,
+ state_dict_path=self.torch_path,
+ device=torch.device('cpu'),
+ model_params=self.model_params,
+ torch_script_model_path=None,
+ load_model_args=load_model_args)
+
+ def test_load_model_args_keyed_tensor(self):
+ load_model_args = {'weights_only': True}
+ model_handler = PytorchModelHandlerKeyedTensor(
+ state_dict_path=self.torch_path,
+ model_class=PytorchLinearRegression,
+ model_params=self.model_params,
+ load_model_args=load_model_args)
+ model_handler.load_model()
+ pytorch_inference._load_model.assert_called_with(
+ model_class=PytorchLinearRegression,
+ state_dict_path=self.torch_path,
+ device=torch.device('cpu'),
+ model_params=self.model_params,
+ torch_script_model_path=None,
+ load_model_args=load_model_args)
+
+
if __name__ == '__main__':
unittest.main()