damccorm commented on code in PR #33364:
URL: https://github.com/apache/beam/pull/33364#discussion_r1884541855


##########
sdks/python/apache_beam/ml/rag/embeddings/huggingface.py:
##########
@@ -0,0 +1,63 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""RAG-specific embedding implementations using HuggingFace models."""
+
+from typing import Optional
+
+import apache_beam as beam
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.rag.embeddings.base import create_rag_adapter
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.transforms.base import EmbeddingsManager
+from apache_beam.ml.transforms.base import _TextEmbeddingHandler
+from apache_beam.ml.transforms.embeddings.huggingface import 
SentenceTransformer
+from apache_beam.ml.transforms.embeddings.huggingface import 
_SentenceTransformerModelHandler
+
+
+class HuggingfaceTextEmbeddings(EmbeddingsManager):
+  def __init__(
+      self, model_name: str, *, max_seq_length: Optional[int] = None, 
**kwargs):
+    """Utilizes huggingface SentenceTransformer embeddings for RAG pipeline.
+
+        Args:
+            model_name: Name of the sentence-transformers model to use
+            max_seq_length: Maximum sequence length for the model
+            **kwargs: Additional arguments including ModelHandlers arguments

Review Comment:
   Should we link to the base class as well since these will be passed through?



##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -62,36 +68,31 @@
 # Output of the apply() method of BaseOperation.
 OperationOutputT = TypeVar('OperationOutputT')
 
+# Input to the EmbeddingTypeAdapter input_fn
+EmbeddingTypeAdapterInputT = TypeVar(
+    'EmbeddingTypeAdapterInputT')  # e.g., Chunk
+# Output of the EmbeddingTypeAdapter output_fn
+EmbeddingTypeAdapterOutputT = TypeVar(
+    'EmbeddingTypeAdapterOutputT')  # e.g., Embedding
 
-def _convert_list_of_dicts_to_dict_of_lists(
-    list_of_dicts: Sequence[dict[str, Any]]) -> dict[str, list[Any]]:
-  keys_to_element_list = collections.defaultdict(list)
-  input_keys = list_of_dicts[0].keys()
-  for d in list_of_dicts:
-    if set(d.keys()) != set(input_keys):
-      extra_keys = set(d.keys()) - set(input_keys) if len(
-          d.keys()) > len(input_keys) else set(input_keys) - set(d.keys())
-      raise RuntimeError(
-          f'All the dicts in the input data should have the same keys. '
-          f'Got: {extra_keys} instead.')
-    for key, value in d.items():
-      keys_to_element_list[key].append(value)
-  return keys_to_element_list
-
-
-def _convert_dict_of_lists_to_lists_of_dict(
-    dict_of_lists: dict[str, list[Any]]) -> list[dict[str, Any]]:
-  batch_length = len(next(iter(dict_of_lists.values())))
-  result: list[dict[str, Any]] = [{} for _ in range(batch_length)]
-  # all the values in the dict_of_lists should have same length
-  for key, values in dict_of_lists.items():
-    assert len(values) == batch_length, (
-        "This function expects all the values "
-        "in the dict_of_lists to have same length."
-        )
-    for i in range(len(values)):
-      result[i][key] = values[i]
-  return result
+
+@dataclass
+class EmbeddingTypeAdapter(Generic[EmbeddingTypeAdapterInputT,
+                                   EmbeddingTypeAdapterOutputT]):
+  """Adapts input types to text for embedding and converts output embeddings.
+
+    Args:
+        input_fn: Function to extract text for embedding from input type
+        output_fn: Function to create output type from input and embeddings
+    """
+  input_fn: Callable[[Sequence[EmbeddingTypeAdapterInputT]], List[str]]
+  output_fn: Callable[[Sequence[EmbeddingTypeAdapterInputT], Sequence[Any]],
+                      List[EmbeddingTypeAdapterOutputT]]
+
+  def __reduce__(self):
+    """Custom serialization that preserves type information during
+    jsonpickle."""

Review Comment:
   Oh interesting - thanks for clarifying. I don't see any problem with this, 
was mostly just curious



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to