gemini-code-assist[bot] commented on code in PR #35774:
URL: https://github.com/apache/beam/pull/35774#discussion_r2267988918


##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -512,13 +503,36 @@ def ml_transform(
     raise ValueError(
         'tensorflow-transform must be installed to use this MLTransform')
   options.YamlOptions.check_enabled(pcoll.pipeline, 'ML')
-  # TODO(robertwb): Perhaps _config_to_obj could be pushed into MLTransform
-  # itself for better cross-language support?
-  return pcoll | MLTransform(
+  result_ml_transform = MLTransform(
       write_artifact_location=write_artifact_location,
       read_artifact_location=read_artifact_location,
       transforms=[_config_to_obj(t) for t in transforms] if transforms else [])
 
-
-if tft is not None:
+  if transforms and any(t.get('type', '').endswith('Embeddings')
+                        for t in transforms):
+    from apache_beam.typehints import List
+    try:
+      if pcoll.element_type:
+        new_fields = named_fields_from_element_type(pcoll.element_type)
+        columns_to_change = set()
+        for t_spec in transforms:
+          if t_spec.get('type', '').endswith('Embeddings'):
+            columns_to_change.update(
+                t_spec.get('config', {}).get('columns', []))
+
+        final_fields = []
+        for name, typ in new_fields:
+          if name in columns_to_change:
+            final_fields.append((name, List[float]))
+          else:
+            final_fields.append((name, typ))
+        output_schema = RowTypeConstraint.from_fields(final_fields)
+        return pcoll | result_ml_transform.with_output_types(output_schema)
+    except TypeError:
+      # If we can't get a schema, just return the result.
+      pass
+  return pcoll | result_ml_transform

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This block of code can be refactored for better readability and efficiency. 
The current implementation iterates over `transforms` twice: once in `any()` 
and again in the `for` loop. This can be optimized by iterating only once to 
find all embedding transforms and then processing them. Using comprehensions 
can also make the code more concise and Pythonic.
   
   ```python
     if transforms:
       embedding_transforms = [
           t for t in transforms if t.get('type', '').endswith('Embeddings')
       ]
       if embedding_transforms:
         from apache_beam.typehints import List
         try:
           if pcoll.element_type:
             columns_to_change = {
                 col
                 for t_spec in embedding_transforms
                 for col in t_spec.get('config', {}).get('columns', [])
             }
             new_fields = named_fields_from_element_type(pcoll.element_type)
             final_fields = [
                 (name, List[float] if name in columns_to_change else typ)
                 for name, typ in new_fields
             ]
             output_schema = RowTypeConstraint.from_fields(final_fields)
             return pcoll | result_ml_transform.with_output_types(output_schema)
         except TypeError:
           # If we can't get a schema, just return the result.
           pass
     return pcoll | result_ml_transform
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to