yeandy commented on code in PR #21810:
URL: https://github.com/apache/beam/pull/21810#discussion_r895621888
##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -134,3 +135,185 @@ def get_metrics_namespace(self) -> str:
Returns a namespace for metrics collected by the RunInference transform.
"""
return 'RunInferencePytorch'
+
+
+@experimental()
+class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
+ PredictionResult,
+ torch.nn.Module]):
+ """ Implementation of the ModelHandler interface for PyTorch.
+
+ NOTE: This API and its implementation are under development and
+ do not provide backward compatibility guarantees.
+ """
+ def __init__(
+ self,
+ state_dict_path: str,
+ model_class: Callable[..., torch.nn.Module],
+ model_params: Dict[str, Any],
+ device: str = 'CPU'):
+ """
+ Initializes a PytorchModelHandler
+ :param state_dict_path: path to the saved dictionary of the model state.
+ :param model_class: class of the Pytorch model that defines the model
+ structure.
+ :param device: the device on which you wish to run the model. If
+ ``device = GPU`` then a GPU device will be used if it is available.
+ Otherwise, it will be CPU.
+
+ See https://pytorch.org/tutorials/beginner/saving_loading_models.html
+ for details
+ """
+ self._state_dict_path = state_dict_path
+ if device == 'GPU' and torch.cuda.is_available():
+ self._device = torch.device('cuda')
+ else:
+ self._device = torch.device('cpu')
+ self._model_class = model_class
+ self._model_params = model_params
+
+ def load_model(self) -> torch.nn.Module:
+ """Loads and initializes a Pytorch model for processing."""
+ model = self._model_class(**self._model_params)
+ model.to(self._device)
+ file = FileSystems.open(self._state_dict_path, 'rb')
+ model.load_state_dict(torch.load(file))
+ model.eval()
+ return model
+
+ def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor:
+ """
+ Converts samples to a style matching given device.
+
+ Note: A user may pass in device='GPU' but if GPU is not detected in the
+ environment it must be converted back to CPU.
+ """
+ if examples.device != self._device:
+ examples = examples.to(self._device)
+ return examples
+
+ def run_inference(
+ self, batch: List[torch.Tensor], model: torch.nn.Module,
+ **kwargs) -> Iterable[PredictionResult]:
+ """
+ Runs inferences on a batch of Tensors and returns an Iterable of
+ Tensor Predictions.
+
+ This method stacks the list of Tensors in a vectorized format to optimize
+ the inference call.
+ """
+ prediction_params = kwargs.get('prediction_params', {})
+ batched_tensors = torch.stack(batch)
+ batched_tensors = self._convert_to_device(batched_tensors)
+ predictions = model(batched_tensors, **prediction_params)
+ return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+ def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
+ """Returns the number of bytes of data for a batch of Tensors."""
+ return sum((el.element_size() for tensor in batch for el in tensor))
+
+ def get_metrics_namespace(self) -> str:
+ """
+ Returns a namespace for metrics collected by the RunInference transform.
+ """
+ return 'RunInferencePytorch'
+
+
+@experimental()
+class PytorchModelHandlerKeyedTensor(ModelHandler[torch.Tensor,
+ PredictionResult,
+ torch.nn.Module]):
+ """ Implementation of the ModelHandler interface for PyTorch.
+
+ NOTE: This API and its implementation are under development and
+ do not provide backward compatibility guarantees.
+ """
+ def __init__(
+ self,
+ state_dict_path: str,
+ model_class: Callable[..., torch.nn.Module],
+ model_params: Dict[str, Any],
+ device: str = 'CPU'):
+ """
+ Initializes a PytorchModelHandler
+ :param state_dict_path: path to the saved dictionary of the model state.
+ :param model_class: class of the Pytorch model that defines the model
+ structure.
+ :param device: the device on which you wish to run the model. If
+ ``device = GPU`` then a GPU device will be used if it is available.
+ Otherwise, it will be CPU.
+
+ See https://pytorch.org/tutorials/beginner/saving_loading_models.html
+ for details
+ """
+ self._state_dict_path = state_dict_path
+ if device == 'GPU' and torch.cuda.is_available():
+ self._device = torch.device('cuda')
+ else:
+ self._device = torch.device('cpu')
+ self._model_class = model_class
+ self._model_params = model_params
+
+ def load_model(self) -> torch.nn.Module:
+ """Loads and initializes a Pytorch model for processing."""
+ model = self._model_class(**self._model_params)
+ model.to(self._device)
+ file = FileSystems.open(self._state_dict_path, 'rb')
+ model.load_state_dict(torch.load(file))
+ model.eval()
+ return model
Review Comment:
Done.
##########
sdks/python/apache_beam/ml/inference/pytorch_inference.py:
##########
@@ -134,3 +135,185 @@ def get_metrics_namespace(self) -> str:
Returns a namespace for metrics collected by the RunInference transform.
"""
return 'RunInferencePytorch'
+
+
+@experimental()
+class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
+ PredictionResult,
+ torch.nn.Module]):
+ """ Implementation of the ModelHandler interface for PyTorch.
+
+ NOTE: This API and its implementation are under development and
+ do not provide backward compatibility guarantees.
+ """
+ def __init__(
+ self,
+ state_dict_path: str,
+ model_class: Callable[..., torch.nn.Module],
+ model_params: Dict[str, Any],
+ device: str = 'CPU'):
+ """
+ Initializes a PytorchModelHandler
+ :param state_dict_path: path to the saved dictionary of the model state.
+ :param model_class: class of the Pytorch model that defines the model
+ structure.
+ :param device: the device on which you wish to run the model. If
+ ``device = GPU`` then a GPU device will be used if it is available.
+ Otherwise, it will be CPU.
+
+ See https://pytorch.org/tutorials/beginner/saving_loading_models.html
+ for details
+ """
+ self._state_dict_path = state_dict_path
+ if device == 'GPU' and torch.cuda.is_available():
+ self._device = torch.device('cuda')
+ else:
+ self._device = torch.device('cpu')
+ self._model_class = model_class
+ self._model_params = model_params
+
+ def load_model(self) -> torch.nn.Module:
+ """Loads and initializes a Pytorch model for processing."""
+ model = self._model_class(**self._model_params)
+ model.to(self._device)
+ file = FileSystems.open(self._state_dict_path, 'rb')
+ model.load_state_dict(torch.load(file))
+ model.eval()
+ return model
+
+ def _convert_to_device(self, examples: torch.Tensor) -> torch.Tensor:
+ """
+ Converts samples to a style matching given device.
+
+ Note: A user may pass in device='GPU' but if GPU is not detected in the
+ environment it must be converted back to CPU.
+ """
+ if examples.device != self._device:
+ examples = examples.to(self._device)
+ return examples
+
+ def run_inference(
+ self, batch: List[torch.Tensor], model: torch.nn.Module,
+ **kwargs) -> Iterable[PredictionResult]:
+ """
+ Runs inferences on a batch of Tensors and returns an Iterable of
+ Tensor Predictions.
+
+ This method stacks the list of Tensors in a vectorized format to optimize
+ the inference call.
+ """
+ prediction_params = kwargs.get('prediction_params', {})
+ batched_tensors = torch.stack(batch)
+ batched_tensors = self._convert_to_device(batched_tensors)
+ predictions = model(batched_tensors, **prediction_params)
+ return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+ def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
+ """Returns the number of bytes of data for a batch of Tensors."""
+ return sum((el.element_size() for tensor in batch for el in tensor))
+
+ def get_metrics_namespace(self) -> str:
+ """
+ Returns a namespace for metrics collected by the RunInference transform.
+ """
+ return 'RunInferencePytorch'
+
+
+@experimental()
+class PytorchModelHandlerKeyedTensor(ModelHandler[torch.Tensor,
Review Comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]