gemini-code-assist[bot] commented on code in PR #38451:
URL: https://github.com/apache/beam/pull/38451#discussion_r3226646982


##########
sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml:
##########
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+pipelines:
+  - pipeline:
+      type: chain
+      transforms:
+        - type: Create
+          config:
+            elements:
+              - text: "I love Apache Beam!"
+              - text: "I hate this error."
+        - type: RunInference
+          config:
+            model_handler:
+              type: "HuggingFacePipeline"
+              config:
+                task: "text-classification"
+                inference_fn:
+                  callable: |
+                    def real_inference(batch, pipeline, inference_args):
+                      predictions = pipeline(batch, **inference_args) 
+                      
+                      # If it's a single dictionary (batch size of 1), wrap it 
in a list
+                      if isinstance(predictions, dict):
+                        predictions = [predictions]
+                      
+                      return {
+                        'label': [p['label'] for p in predictions],
+                        'score': [p['score'] for p in predictions]
+                      }
+                preprocess:
+                  callable: 'lambda x: x.text'
+        - type: MapToFields
+          config:
+            language: python
+            fields:
+              text: text
+              sentiment:
+                callable: 'lambda x: x.inference.inference["label"]'

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   Based on the `run_inference` implementation in `yaml_ml.py`, the output Row 
has an `inference` field containing the result from the model handler. If the 
handler returns a dictionary (which is the default for HuggingFace 
text-classification), then `x.inference` is that dictionary. Accessing 
`x.inference.inference["label"]` is incorrect and will result in an error. It 
should be `x.inference["label"]`.
   
   ```yaml
                   callable: 'lambda x: x.inference["label"]'
   ```



##########
sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml:
##########
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+pipelines:
+  - pipeline:
+      type: chain
+      transforms:
+        - type: Create
+          config:
+            elements:
+              - text: "I love Apache Beam!"
+              - text: "I hate this error."
+        - type: RunInference
+          config:
+            model_handler:
+              type: "HuggingFacePipeline"
+              config:
+                task: "text-classification"
+                inference_fn:
+                  callable: |
+                    def real_inference(batch, pipeline, inference_args):
+                      predictions = pipeline(batch, **inference_args) 
+                      
+                      # If it's a single dictionary (batch size of 1), wrap it 
in a list
+                      if isinstance(predictions, dict):
+                        predictions = [predictions]
+                      
+                      return {
+                        'label': [p['label'] for p in predictions],
+                        'score': [p['score'] for p in predictions]
+                      }

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The `real_inference` function must return an iterable (such as a list) of 
predictions where each element corresponds to an input in the batch. Currently, 
it returns a single dictionary containing lists of labels and scores. This will 
cause `RunInference` to either raise a `ValueError` (due to a length mismatch 
with the input batch) or yield the dictionary keys (`'label'`, `'score'`) as 
individual results, which will break downstream transforms. Since the 
HuggingFace pipeline already returns a list of dictionaries for the 
`text-classification` task, you can simply return the `predictions` list 
directly.
   
   ```yaml
                       def real_inference(batch, pipeline, inference_args):
                         predictions = pipeline(batch, **inference_args)
                         if isinstance(predictions, dict):
                           predictions = [predictions]
                         return predictions
   ```



##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -282,6 +282,55 @@ def inference_output_type(self):
                                           ('model_id', Optional[str])])
 
 
[email protected]_handler_type('HuggingFacePipeline')
+class HuggingFacePipelineProvider(ModelHandlerProvider):
+  def __init__(
+      self,
+      task: Optional[str] = None,
+      model: Optional[str] = None,
+      preprocess: Optional[dict[str, str]] = None,
+      postprocess: Optional[dict[str, str]] = None,
+      device: Optional[Any] = None,
+      inference_fn: Optional[dict[str, str]] = None,
+      load_pipeline_args: Optional[dict[str, Any]] = None,

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Using `dict[...]` for type hints requires Python 3.9+ or `from __future__ 
import annotations`. Since Apache Beam supports Python 3.8, this will cause a 
`TypeError` at runtime if the import is missing. It is recommended to use 
`typing.Dict` instead. Additionally, `preprocess` and `postprocess` 
configurations in YAML can be complex dictionaries (e.g., containing `type` and 
`config` keys), so `Optional[Dict[str, Any]]` is a more appropriate type hint 
than `Optional[Dict[str, str]]`.
   
   ```suggestion
         preprocess: Optional[Dict[str, Any]] = None,
         postprocess: Optional[Dict[str, Any]] = None,
         device: Optional[Any] = None,
         inference_fn: Optional[Dict[str, Any]] = None,
         load_pipeline_args: Optional[Dict[str, Any]] = None,
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to