gemini-code-assist[bot] commented on code in PR #36979:
URL: https://github.com/apache/beam/pull/36979#discussion_r2636217694


##########
sdks/python/apache_beam/coders/coder_impl.py:
##########
@@ -497,20 +510,35 @@ def encode_special_deterministic(self, value, stream):
       self.encode_type(type(value), stream)
       stream.write(value.SerializePartialToString(deterministic=True), True)
     elif dataclasses and dataclasses.is_dataclass(value):
-      stream.write_byte(DATACLASS_TYPE)
       if not type(value).__dataclass_params__.frozen:
         raise TypeError(
             "Unable to deterministically encode non-frozen '%s' of type '%s' "
             "for the input of '%s'" %
             (value, type(value), self.requires_deterministic_step_label))
-      self.encode_type(type(value), stream)
-      values = [
-          getattr(value, field.name) for field in dataclasses.fields(value)
-      ]
-      try:
-        self.iterable_coder_impl.encode_to_stream(values, stream, True)
-      except Exception as e:
-        raise TypeError(self._deterministic_encoding_error_msg(value)) from e
+      if dataclass_uses_kw_only(type(value)):
+        stream.write_byte(DATACLASS_KW_ONLY_TYPE)
+        self.encode_type(type(value), stream)
+        init_field_names = [
+            field.name for field in dataclasses.fields(value) if field.init
+        ]
+        stream.write_var_int64(len(init_field_names))
+        try:
+          for field_name in init_field_names:
+            stream.write(field_name.encode("utf-8"), True)
+            self.encode_to_stream(getattr(value, field_name), stream, True)
+        except Exception as e:
+          raise TypeError(self._deterministic_encoding_error_msg(value)) from e
+      else:  # Not using kw_only, we can pass parameters by position.
+        stream.write_byte(DATACLASS_TYPE)
+        self.encode_type(type(value), stream)
+        values = [
+            getattr(value, field.name) for field in dataclasses.fields(value)
+            if field.init
+        ]
+        try:
+          self.iterable_coder_impl.encode_to_stream(values, stream, True)
+        except Exception as e:
+          raise TypeError(self._deterministic_encoding_error_msg(value)) from e

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The logic for encoding dataclasses with and without keyword-only fields has 
some repeated code, specifically fetching the `init` fields and the 
`try...except` block. This can be refactored to reduce duplication and improve 
maintainability.
   
   ```suggestion
         init_fields = [field for field in dataclasses.fields(value) if 
field.init]
         try:
           if dataclass_uses_kw_only(type(value)):
             stream.write_byte(DATACLASS_KW_ONLY_TYPE)
             self.encode_type(type(value), stream)
             stream.write_var_int64(len(init_fields))
             for field in init_fields:
               stream.write(field.name.encode("utf-8"), True)
               self.encode_to_stream(getattr(value, field.name), stream, True)
           else:  # Not using kw_only, we can pass parameters by position.
             stream.write_byte(DATACLASS_TYPE)
             self.encode_type(type(value), stream)
             values = [getattr(value, field.name) for field in init_fields]
             self.iterable_coder_impl.encode_to_stream(values, stream, True)
         except Exception as e:
           raise TypeError(self._deterministic_encoding_error_msg(value)) from e
   ```



##########
sdks/python/apache_beam/coders/coder_impl.py:
##########
@@ -616,6 +644,14 @@ def decode_from_stream(self, stream, nested):
       msg = cls()
       msg.ParseFromString(stream.read_all(True))
       return msg
+    elif t == DATACLASS_KW_ONLY_TYPE:
+      cls = self.decode_type(stream)
+      vlen = stream.read_var_int64()
+      fields = {}
+      for _ in range(vlen):
+        field_name = stream.read_all(True).decode('utf-8')
+        fields[field_name] = self.decode_from_stream(stream, True)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This loop to build the `fields` dictionary can be expressed more concisely 
using a dictionary comprehension.
   
   ```suggestion
         fields = {
             stream.read_all(True).decode('utf-8'):
             self.decode_from_stream(stream, True)
             for _ in range(vlen)
         }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to