gemini-code-assist[bot] commented on code in PR #35774:
URL: https://github.com/apache/beam/pull/35774#discussion_r2267988918
##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -512,13 +503,36 @@ def ml_transform(
raise ValueError(
'tensorflow-transform must be installed to use this MLTransform')
options.YamlOptions.check_enabled(pcoll.pipeline, 'ML')
- # TODO(robertwb): Perhaps _config_to_obj could be pushed into MLTransform
- # itself for better cross-language support?
- return pcoll | MLTransform(
+ result_ml_transform = MLTransform(
write_artifact_location=write_artifact_location,
read_artifact_location=read_artifact_location,
transforms=[_config_to_obj(t) for t in transforms] if transforms else [])
-
-if tft is not None:
+ if transforms and any(t.get('type', '').endswith('Embeddings')
+ for t in transforms):
+ from apache_beam.typehints import List
+ try:
+ if pcoll.element_type:
+ new_fields = named_fields_from_element_type(pcoll.element_type)
+ columns_to_change = set()
+ for t_spec in transforms:
+ if t_spec.get('type', '').endswith('Embeddings'):
+ columns_to_change.update(
+ t_spec.get('config', {}).get('columns', []))
+
+ final_fields = []
+ for name, typ in new_fields:
+ if name in columns_to_change:
+ final_fields.append((name, List[float]))
+ else:
+ final_fields.append((name, typ))
+ output_schema = RowTypeConstraint.from_fields(final_fields)
+ return pcoll | result_ml_transform.with_output_types(output_schema)
+ except TypeError:
+ # If we can't get a schema, just return the result.
+ pass
+ return pcoll | result_ml_transform
Review Comment:

This block of code can be refactored for better readability and efficiency.
The current implementation iterates over `transforms` twice: once in `any()`
and again in the `for` loop. This can be optimized by iterating only once to
find all embedding transforms and then processing them. Using comprehensions
can also make the code more concise and Pythonic.
```python
if transforms:
embedding_transforms = [
t for t in transforms if t.get('type', '').endswith('Embeddings')
]
if embedding_transforms:
from apache_beam.typehints import List
try:
if pcoll.element_type:
columns_to_change = {
col
for t_spec in embedding_transforms
for col in t_spec.get('config', {}).get('columns', [])
}
new_fields = named_fields_from_element_type(pcoll.element_type)
final_fields = [
(name, List[float] if name in columns_to_change else typ)
for name, typ in new_fields
]
output_schema = RowTypeConstraint.from_fields(final_fields)
return pcoll | result_ml_transform.with_output_types(output_schema)
except TypeError:
# If we can't get a schema, just return the result.
pass
return pcoll | result_ml_transform
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]