gemini-code-assist[bot] commented on code in PR #35692:
URL: https://github.com/apache/beam/pull/35692#discussion_r2244019119


##########
sdks/python/apache_beam/transforms/core_test.py:
##########
@@ -322,6 +324,33 @@ def test_typecheck_with_default(self):
             | beam.Map(lambda s: s.upper()).with_input_types(str))
 
 
+class CreateInferOutputSchemaTest(unittest.TestCase):
+  def test_multiple_types_for_field(self):
+    output_type = beam.Create([beam.Row(a=1),
+                               beam.Row(a='foo')]).infer_output_type(None)
+    self.assertEqual(
+        output_type,
+        row_type.RowTypeConstraint.from_fields([('a', typing.Union[int, 
str])]))
+
+  def test_single_type_for_field(self):
+    output_type = beam.Create([beam.Row(a=1),
+                               beam.Row(a=2)]).infer_output_type(None)
+    self.assertEqual(
+        output_type, row_type.RowTypeConstraint.from_fields([('a', int)]))
+
+  def test_optional_type_for_field(self):
+    output_type = beam.Create([beam.Row(a=1),
+                               beam.Row(a=None)]).infer_output_type(None)
+    self.assertEqual(
+        output_type,
+        row_type.RowTypeConstraint.from_fields([('a', typing.Optional[int])]))
+
+  def test_none_type_for_field_raises_error(self):
+    with self.assertRaisesRegex(TypeError,
+                                "('No types found for field %s', 'a')"):
+      beam.Create([beam.Row(a=None), beam.Row(a=None)]).infer_output_type(None)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The test should be updated to assert the new error message after the 
suggested change in `core.py` is applied.
   
   ```suggestion
       with self.assertRaisesRegex(TypeError, "No types found for field a"):
         beam.Create([beam.Row(a=None), 
beam.Row(a=None)]).infer_output_type(None)
   ```



##########
sdks/python/apache_beam/yaml/extended_tests/databases/bigquery.yaml:
##########
@@ -74,4 +83,58 @@ pipelines:
               - {label: "389a"}
     options:
       project: "apache-beam-testing"
-      temp_location: "{TEMP_DIR}"
+      temp_location: "{TEMP_DIR_0}"
+
+  # 
----------------------------------------------------------------------------
+
+  # New write to verify row restriction based on Timestamp and nullability
+  - pipeline:
+      type: chain
+      transforms:
+        - type: Create
+          config:
+            elements:
+              - {label: "4a", rank: 3, timestamp: "2024-07-14 00:00:00 UTC"}
+              - {label: "5a", rank: 4}
+              - {label: "6a", rank: 5, timestamp: "2024-07-14T02:00:00.123Z"}  
  

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Consider using the same datetime format for all timestamp values for 
consistency.



##########
sdks/python/apache_beam/transforms/core.py:
##########
@@ -3962,9 +3963,41 @@ def to_runner_api_parameter(self, context):
   def infer_output_type(self, unused_input_type):
     if not self.values:
       return typehints.Any
-    return typehints.Union[[
-        trivial_inference.instance_to_type(v) for v in self.values
-    ]]
+
+    # No field data - just use default Union.
+    if not hasattr(self.values[0], 'as_dict'):
+      return typehints.Union[[
+          trivial_inference.instance_to_type(v) for v in self.values
+      ]]
+
+    first_fields = self.values[0].as_dict().keys()
+
+    # Save field types for each field
+    field_types_by_field = defaultdict(set)
+    for row in self.values:
+      row_dict = row.as_dict()
+      for field in first_fields:
+        field_types_by_field[field].add(
+            trivial_inference.instance_to_type(row_dict.get(field)))
+
+    # Determine the appropriate type for each field
+    final_fields = []
+    for field in first_fields:
+      field_types = field_types_by_field[field]
+      non_none_types = {t for t in field_types if t is not type(None)}
+
+      if len(non_none_types) > 1:
+        final_type = typing.Union[tuple(non_none_types)]
+      elif len(non_none_types) == 1 and len(field_types) == 1:
+        final_type = non_none_types.pop()
+      elif len(non_none_types) == 1 and len(field_types) == 2:
+        final_type = typing.Optional[non_none_types.pop()]
+      else:
+        raise TypeError("No types found for field %s", field)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The exception is raised with a format string and arguments. For better 
readability and adherence to standard Python practices, it's recommended to 
format the string before raising the exception. An f-string is a clean way to 
do this.
   
   ```suggestion
           raise TypeError(f"No types found for field {field}")
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to