claudevdm commented on code in PR #33364:
URL: https://github.com/apache/beam/pull/33364#discussion_r1884211127


##########
sdks/python/apache_beam/ml/transforms/base.py:
##########
@@ -62,36 +68,31 @@
 # Output of the apply() method of BaseOperation.
 OperationOutputT = TypeVar('OperationOutputT')
 
+# Input to the EmbeddingTypeAdapter input_fn
+EmbeddingTypeAdapterInputT = TypeVar(
+    'EmbeddingTypeAdapterInputT')  # e.g., Chunk
+# Output of the EmbeddingTypeAdapter output_fn
+EmbeddingTypeAdapterOutputT = TypeVar(
+    'EmbeddingTypeAdapterOutputT')  # e.g., Embedding
 
-def _convert_list_of_dicts_to_dict_of_lists(
-    list_of_dicts: Sequence[dict[str, Any]]) -> dict[str, list[Any]]:
-  keys_to_element_list = collections.defaultdict(list)
-  input_keys = list_of_dicts[0].keys()
-  for d in list_of_dicts:
-    if set(d.keys()) != set(input_keys):
-      extra_keys = set(d.keys()) - set(input_keys) if len(
-          d.keys()) > len(input_keys) else set(input_keys) - set(d.keys())
-      raise RuntimeError(
-          f'All the dicts in the input data should have the same keys. '
-          f'Got: {extra_keys} instead.')
-    for key, value in d.items():
-      keys_to_element_list[key].append(value)
-  return keys_to_element_list
-
-
-def _convert_dict_of_lists_to_lists_of_dict(
-    dict_of_lists: dict[str, list[Any]]) -> list[dict[str, Any]]:
-  batch_length = len(next(iter(dict_of_lists.values())))
-  result: list[dict[str, Any]] = [{} for _ in range(batch_length)]
-  # all the values in the dict_of_lists should have same length
-  for key, values in dict_of_lists.items():
-    assert len(values) == batch_length, (
-        "This function expects all the values "
-        "in the dict_of_lists to have same length."
-        )
-    for i in range(len(values)):
-      result[i][key] = values[i]
-  return result
+
+@dataclass
+class EmbeddingTypeAdapter(Generic[EmbeddingTypeAdapterInputT,
+                                   EmbeddingTypeAdapterOutputT]):
+  """Adapts input types to text for embedding and converts output embeddings.
+
+    Args:
+        input_fn: Function to extract text for embedding from input type
+        output_fn: Function to create output type from input and embeddings
+    """
+  input_fn: Callable[[Sequence[EmbeddingTypeAdapterInputT]], List[str]]
+  output_fn: Callable[[Sequence[EmbeddingTypeAdapterInputT], Sequence[Any]],
+                      List[EmbeddingTypeAdapterOutputT]]
+
+  def __reduce__(self):
+    """Custom serialization that preserves type information during
+    jsonpickle."""

Review Comment:
   The generic type params somehow get lost  during the default json 
pickling/decoding.
   
   It causes errors when loading transforms from artifacts e.g. 
apache_beam/ml/transforms/embeddings/huggingface_test.py::SentenceTransformerEmbeddingsTest::test_mltransform_to_ptransform_with_sentence_transformer
 - TypeError: Plain typing.Any is not valid as type argument
   
   ```
   ../../../../.pyenv/versions/3.9.4/lib/python3.9/typing.py:820: in <genexpr>
       params = tuple(_type_check(p, msg) for p in params)
   _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ 
   
   arg = typing.Any, msg = 'Parameters to generic types must be types.', 
is_argument = True
   
       def _type_check(arg, msg, is_argument=True):
           """Check that the argument is a type, and return it (internal 
helper).
       
           As a special case, accept None and return type(None) instead. Also 
wrap strings
           into ForwardRef instances. Consider several corner cases, for 
example plain
           special forms like Union are not valid, while Union[int, str] is OK, 
etc.
           The msg argument is a human-readable error message, e.g::
       
               "Union[arg, ...]: arg should be a type."
       
           We append the repr() of the actual value (truncated to 100 chars).
           """
           invalid_generic_forms = (Generic, Protocol)
           if is_argument:
               invalid_generic_forms = invalid_generic_forms + (ClassVar, Final)
       
           arg = _type_convert(arg)
           if (isinstance(arg, _GenericAlias) and
                   arg.__origin__ in invalid_generic_forms):
               raise TypeError(f"{arg} is not valid as type argument")
           if arg in (Any, NoReturn):
               return arg
           if isinstance(arg, _SpecialForm) or arg in (Generic, Protocol):
   >           raise TypeError(f"Plain {arg} is not valid as type argument")
   E           TypeError: Plain typing.Any is not valid as type argument
   
   ../../../../.pyenv/versions/3.9.4/lib/python3.9/typing.py:153: TypeError
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to