This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a921d51054 Deduplicate Base Embedding Handler Code (#31534)
9a921d51054 is described below

commit 9a921d51054b7cf1be8119c76dabca80d26c9bfc
Author: Jack McCluskey <[email protected]>
AuthorDate: Fri Jun 7 09:26:36 2024 -0400

    Deduplicate Base Embedding Handler Code (#31534)
    
    * Deduplicate Base Embedding Handler Code
    
    * linting
---
 sdks/python/apache_beam/ml/transforms/base.py | 131 +++++++++-----------------
 1 file changed, 43 insertions(+), 88 deletions(-)

diff --git a/sdks/python/apache_beam/ml/transforms/base.py 
b/sdks/python/apache_beam/ml/transforms/base.py
index 0d7db282bb8..678ab0882d2 100644
--- a/sdks/python/apache_beam/ml/transforms/base.py
+++ b/sdks/python/apache_beam/ml/transforms/base.py
@@ -587,9 +587,9 @@ class _MLTransformToPTransformMapper:
     return _transform_attribute_manager.load_attributes(artifact_location)
 
 
-class _TextEmbeddingHandler(ModelHandler):
+class _EmbeddingHandler(ModelHandler):
   """
-  A ModelHandler intended to be work on list[dict[str, str]] inputs.
+  A ModelHandler intended to be work on list[dict[str, Any]] inputs.
 
   The inputs to the model handler are expected to be a list of dicts.
 
@@ -597,12 +597,10 @@ class _TextEmbeddingHandler(ModelHandler):
   PCollection[E] to a PCollection[P], this ModelHandler would take a
   PCollection[Dict[str, E]] to a PCollection[Dict[str, P]].
 
-  _TextEmbeddingHandler will accept an EmbeddingsManager instance, which
+  _EmbeddingHandler will accept an EmbeddingsManager instance, which
   contains the details of the model to be loaded and the inference_fn to be
-  used. The purpose of _TextEmbeddingHandler is to generate embeddings for
-  text inputs using the EmbeddingsManager instance.
-
-  If the input is not a text column, a RuntimeError will be raised.
+  used. The purpose of _EmbeddingHandler is to generate embeddings for
+  general inputs using the EmbeddingsManager instance.
 
   This is an internal class and offers no backwards compatibility guarantees.
 
@@ -619,12 +617,9 @@ class _TextEmbeddingHandler(ModelHandler):
     return model
 
   def _validate_column_data(self, batch):
-    if not isinstance(batch[0], (str, bytes)):
-      raise TypeError(
-          'Embeddings can only be generated on Dict[str, str].'
-          f'Got Dict[str, {type(batch[0])}] instead.')
+    pass
 
-  def _validate_batch(self, batch: Sequence[Dict[str, List[str]]]):
+  def _validate_batch(self, batch: Sequence[Dict[str, Any]]):
     if not batch or not isinstance(batch[0], dict):
       raise TypeError(
           'Expected data to be dicts, got '
@@ -676,8 +671,7 @@ class _TextEmbeddingHandler(ModelHandler):
 
   def get_metrics_namespace(self) -> str:
     return (
-        self._underlying.get_metrics_namespace() or
-        'BeamML_TextEmbeddingHandler')
+        self._underlying.get_metrics_namespace() or 'BeamML_EmbeddingHandler')
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     batch_sizes_map = {}
@@ -694,7 +688,41 @@ class _TextEmbeddingHandler(ModelHandler):
     pass
 
 
-class _ImageEmbeddingHandler(ModelHandler):
+class _TextEmbeddingHandler(_EmbeddingHandler):
+  """
+  A ModelHandler intended to be work on list[dict[str, str]] inputs.
+
+  The inputs to the model handler are expected to be a list of dicts.
+
+  For example, if the original mode is used with RunInference to take a
+  PCollection[E] to a PCollection[P], this ModelHandler would take a
+  PCollection[Dict[str, E]] to a PCollection[Dict[str, P]].
+
+  _TextEmbeddingHandler will accept an EmbeddingsManager instance, which
+  contains the details of the model to be loaded and the inference_fn to be
+  used. The purpose of _TextEmbeddingHandler is to generate embeddings for
+  text inputs using the EmbeddingsManager instance.
+
+  If the input is not a text column, a RuntimeError will be raised.
+
+  This is an internal class and offers no backwards compatibility guarantees.
+
+  Args:
+    embeddings_manager: An EmbeddingsManager instance.
+  """
+  def _validate_column_data(self, batch):
+    if not isinstance(batch[0], (str, bytes)):
+      raise TypeError(
+          'Embeddings can only be generated on Dict[str, str].'
+          f'Got Dict[str, {type(batch[0])}] instead.')
+
+  def get_metrics_namespace(self) -> str:
+    return (
+        self._underlying.get_metrics_namespace() or
+        'BeamML_TextEmbeddingHandler')
+
+
+class _ImageEmbeddingHandler(_EmbeddingHandler):
   """
   A ModelHandler intended to be work on list[dict[str, Image]] inputs.
 
@@ -717,15 +745,6 @@ class _ImageEmbeddingHandler(ModelHandler):
   Args:
     embeddings_manager: An EmbeddingsManager instance.
   """
-  def __init__(self, embeddings_manager: EmbeddingsManager):
-    self.embedding_config = embeddings_manager
-    self._underlying = self.embedding_config.get_model_handler()
-    self.columns = self.embedding_config.get_columns_to_apply()
-
-  def load_model(self):
-    model = self._underlying.load_model()
-    return model
-
   def _validate_column_data(self, batch):
     # Don't want to require framework-specific imports
     # here, so just catch columns of primatives for now.
@@ -734,71 +753,7 @@ class _ImageEmbeddingHandler(ModelHandler):
           'Embeddings can only be generated on Dict[str, Image].'
           f'Got Dict[str, {type(batch[0])}] instead.')
 
-  def _validate_batch(self, batch: Sequence[Dict[str, List[Any]]]):
-    if not batch or not isinstance(batch[0], dict):
-      raise TypeError(
-          'Expected data to be dicts, got '
-          f'{type(batch[0])} instead.')
-
-  def _process_batch(
-      self,
-      dict_batch: Dict[str, List[Any]],
-      model: ModelT,
-      inference_args: Optional[Dict[str, Any]]) -> Dict[str, List[Any]]:
-    result: Dict[str, List[Any]] = collections.defaultdict(list)
-    input_keys = dict_batch.keys()
-    missing_columns_in_data = set(self.columns) - set(input_keys)
-    if missing_columns_in_data:
-      raise RuntimeError(
-          f'Data does not contain the following columns '
-          f': {missing_columns_in_data}.')
-    for key, batch in dict_batch.items():
-      if key in self.columns:
-        self._validate_column_data(batch)
-        prediction = self._underlying.run_inference(
-            batch, model, inference_args)
-        if isinstance(prediction, np.ndarray):
-          prediction = prediction.tolist()
-          result[key] = prediction  # type: ignore[assignment]
-        else:
-          result[key] = prediction  # type: ignore[assignment]
-      else:
-        result[key] = batch
-    return result
-
-  def run_inference(
-      self,
-      batch: Sequence[Dict[str, List[str]]],
-      model: ModelT,
-      inference_args: Optional[Dict[str, Any]] = None,
-  ) -> List[Dict[str, Union[List[float], List[str]]]]:
-    """
-    Runs inference on a batch of text inputs. The inputs are expected to be
-    a list of dicts. Each dict should have the same keys, and the shape
-    should be of the same size for a single key across the batch.
-    """
-    self._validate_batch(batch)
-    dict_batch = _convert_list_of_dicts_to_dict_of_lists(list_of_dicts=batch)
-    transformed_batch = self._process_batch(dict_batch, model, inference_args)
-    return _convert_dict_of_lists_to_lists_of_dict(
-        dict_of_lists=transformed_batch,
-    )
-
   def get_metrics_namespace(self) -> str:
     return (
         self._underlying.get_metrics_namespace() or
         'BeamML_ImageEmbeddingHandler')
-
-  def batch_elements_kwargs(self) -> Mapping[str, Any]:
-    batch_sizes_map = {}
-    if self.embedding_config.max_batch_size:
-      batch_sizes_map['max_batch_size'] = self.embedding_config.max_batch_size
-    if self.embedding_config.min_batch_size:
-      batch_sizes_map['min_batch_size'] = self.embedding_config.min_batch_size
-    return (self._underlying.batch_elements_kwargs() or batch_sizes_map)
-
-  def __repr__(self):
-    return self._underlying.__repr__()
-
-  def validate_inference_args(self, _):
-    pass

Reply via email to