This is an automated email from the ASF dual-hosted git repository.
jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 9a921d51054 Deduplicate Base Embedding Handler Code (#31534)
9a921d51054 is described below
commit 9a921d51054b7cf1be8119c76dabca80d26c9bfc
Author: Jack McCluskey <[email protected]>
AuthorDate: Fri Jun 7 09:26:36 2024 -0400
Deduplicate Base Embedding Handler Code (#31534)
* Deduplicate Base Embedding Handler Code
* linting
---
sdks/python/apache_beam/ml/transforms/base.py | 131 +++++++++-----------------
1 file changed, 43 insertions(+), 88 deletions(-)
diff --git a/sdks/python/apache_beam/ml/transforms/base.py
b/sdks/python/apache_beam/ml/transforms/base.py
index 0d7db282bb8..678ab0882d2 100644
--- a/sdks/python/apache_beam/ml/transforms/base.py
+++ b/sdks/python/apache_beam/ml/transforms/base.py
@@ -587,9 +587,9 @@ class _MLTransformToPTransformMapper:
return _transform_attribute_manager.load_attributes(artifact_location)
-class _TextEmbeddingHandler(ModelHandler):
+class _EmbeddingHandler(ModelHandler):
"""
- A ModelHandler intended to be work on list[dict[str, str]] inputs.
+ A ModelHandler intended to be work on list[dict[str, Any]] inputs.
The inputs to the model handler are expected to be a list of dicts.
@@ -597,12 +597,10 @@ class _TextEmbeddingHandler(ModelHandler):
PCollection[E] to a PCollection[P], this ModelHandler would take a
PCollection[Dict[str, E]] to a PCollection[Dict[str, P]].
- _TextEmbeddingHandler will accept an EmbeddingsManager instance, which
+ _EmbeddingHandler will accept an EmbeddingsManager instance, which
contains the details of the model to be loaded and the inference_fn to be
- used. The purpose of _TextEmbeddingHandler is to generate embeddings for
- text inputs using the EmbeddingsManager instance.
-
- If the input is not a text column, a RuntimeError will be raised.
+ used. The purpose of _EmbeddingHandler is to generate embeddings for
+ general inputs using the EmbeddingsManager instance.
This is an internal class and offers no backwards compatibility guarantees.
@@ -619,12 +617,9 @@ class _TextEmbeddingHandler(ModelHandler):
return model
def _validate_column_data(self, batch):
- if not isinstance(batch[0], (str, bytes)):
- raise TypeError(
- 'Embeddings can only be generated on Dict[str, str].'
- f'Got Dict[str, {type(batch[0])}] instead.')
+ pass
- def _validate_batch(self, batch: Sequence[Dict[str, List[str]]]):
+ def _validate_batch(self, batch: Sequence[Dict[str, Any]]):
if not batch or not isinstance(batch[0], dict):
raise TypeError(
'Expected data to be dicts, got '
@@ -676,8 +671,7 @@ class _TextEmbeddingHandler(ModelHandler):
def get_metrics_namespace(self) -> str:
return (
- self._underlying.get_metrics_namespace() or
- 'BeamML_TextEmbeddingHandler')
+ self._underlying.get_metrics_namespace() or 'BeamML_EmbeddingHandler')
def batch_elements_kwargs(self) -> Mapping[str, Any]:
batch_sizes_map = {}
@@ -694,7 +688,41 @@ class _TextEmbeddingHandler(ModelHandler):
pass
-class _ImageEmbeddingHandler(ModelHandler):
+class _TextEmbeddingHandler(_EmbeddingHandler):
+ """
+ A ModelHandler intended to be work on list[dict[str, str]] inputs.
+
+ The inputs to the model handler are expected to be a list of dicts.
+
+ For example, if the original mode is used with RunInference to take a
+ PCollection[E] to a PCollection[P], this ModelHandler would take a
+ PCollection[Dict[str, E]] to a PCollection[Dict[str, P]].
+
+ _TextEmbeddingHandler will accept an EmbeddingsManager instance, which
+ contains the details of the model to be loaded and the inference_fn to be
+ used. The purpose of _TextEmbeddingHandler is to generate embeddings for
+ text inputs using the EmbeddingsManager instance.
+
+ If the input is not a text column, a RuntimeError will be raised.
+
+ This is an internal class and offers no backwards compatibility guarantees.
+
+ Args:
+ embeddings_manager: An EmbeddingsManager instance.
+ """
+ def _validate_column_data(self, batch):
+ if not isinstance(batch[0], (str, bytes)):
+ raise TypeError(
+ 'Embeddings can only be generated on Dict[str, str].'
+ f'Got Dict[str, {type(batch[0])}] instead.')
+
+ def get_metrics_namespace(self) -> str:
+ return (
+ self._underlying.get_metrics_namespace() or
+ 'BeamML_TextEmbeddingHandler')
+
+
+class _ImageEmbeddingHandler(_EmbeddingHandler):
"""
A ModelHandler intended to be work on list[dict[str, Image]] inputs.
@@ -717,15 +745,6 @@ class _ImageEmbeddingHandler(ModelHandler):
Args:
embeddings_manager: An EmbeddingsManager instance.
"""
- def __init__(self, embeddings_manager: EmbeddingsManager):
- self.embedding_config = embeddings_manager
- self._underlying = self.embedding_config.get_model_handler()
- self.columns = self.embedding_config.get_columns_to_apply()
-
- def load_model(self):
- model = self._underlying.load_model()
- return model
-
def _validate_column_data(self, batch):
# Don't want to require framework-specific imports
# here, so just catch columns of primatives for now.
@@ -734,71 +753,7 @@ class _ImageEmbeddingHandler(ModelHandler):
'Embeddings can only be generated on Dict[str, Image].'
f'Got Dict[str, {type(batch[0])}] instead.')
- def _validate_batch(self, batch: Sequence[Dict[str, List[Any]]]):
- if not batch or not isinstance(batch[0], dict):
- raise TypeError(
- 'Expected data to be dicts, got '
- f'{type(batch[0])} instead.')
-
- def _process_batch(
- self,
- dict_batch: Dict[str, List[Any]],
- model: ModelT,
- inference_args: Optional[Dict[str, Any]]) -> Dict[str, List[Any]]:
- result: Dict[str, List[Any]] = collections.defaultdict(list)
- input_keys = dict_batch.keys()
- missing_columns_in_data = set(self.columns) - set(input_keys)
- if missing_columns_in_data:
- raise RuntimeError(
- f'Data does not contain the following columns '
- f': {missing_columns_in_data}.')
- for key, batch in dict_batch.items():
- if key in self.columns:
- self._validate_column_data(batch)
- prediction = self._underlying.run_inference(
- batch, model, inference_args)
- if isinstance(prediction, np.ndarray):
- prediction = prediction.tolist()
- result[key] = prediction # type: ignore[assignment]
- else:
- result[key] = prediction # type: ignore[assignment]
- else:
- result[key] = batch
- return result
-
- def run_inference(
- self,
- batch: Sequence[Dict[str, List[str]]],
- model: ModelT,
- inference_args: Optional[Dict[str, Any]] = None,
- ) -> List[Dict[str, Union[List[float], List[str]]]]:
- """
- Runs inference on a batch of text inputs. The inputs are expected to be
- a list of dicts. Each dict should have the same keys, and the shape
- should be of the same size for a single key across the batch.
- """
- self._validate_batch(batch)
- dict_batch = _convert_list_of_dicts_to_dict_of_lists(list_of_dicts=batch)
- transformed_batch = self._process_batch(dict_batch, model, inference_args)
- return _convert_dict_of_lists_to_lists_of_dict(
- dict_of_lists=transformed_batch,
- )
-
def get_metrics_namespace(self) -> str:
return (
self._underlying.get_metrics_namespace() or
'BeamML_ImageEmbeddingHandler')
-
- def batch_elements_kwargs(self) -> Mapping[str, Any]:
- batch_sizes_map = {}
- if self.embedding_config.max_batch_size:
- batch_sizes_map['max_batch_size'] = self.embedding_config.max_batch_size
- if self.embedding_config.min_batch_size:
- batch_sizes_map['min_batch_size'] = self.embedding_config.min_batch_size
- return (self._underlying.batch_elements_kwargs() or batch_sizes_map)
-
- def __repr__(self):
- return self._underlying.__repr__()
-
- def validate_inference_args(self, _):
- pass