gemini-code-assist[bot] commented on code in PR #39136:
URL: https://github.com/apache/beam/pull/39136#discussion_r3486741923


##########
runners/kafka-streams/src/main/java/org/apache/beam/runners/kafka/streams/translation/GroupByKeyProcessor.java:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.kafka.streams.translation;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.sdk.values.WindowedValues;
+import org.apache.kafka.streams.processor.api.Processor;
+import org.apache.kafka.streams.processor.api.ProcessorContext;
+import org.apache.kafka.streams.processor.api.Record;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.KeyValueStore;
+
+/** Processor for GroupByKey. */
+class GroupByKeyProcessor
+    implements Processor<
+        byte[],
+        KStreamsPayload<KV<Object, Object>>,
+        byte[],
+        KStreamsPayload<KV<Object, Iterable<Object>>>> {
+
+  private final String stateStoreName;
+  private final String transformId;
+  private final Coder<List<WindowedValue<KV<Object, Object>>>> listCoder;
+
+  private ProcessorContext<byte[], KStreamsPayload<KV<Object, 
Iterable<Object>>>> context;
+  private KeyValueStore<byte[], byte[]> stateStore;
+
+  GroupByKeyProcessor(
+      String stateStoreName,
+      String transformId,
+      Coder<WindowedValue<KV<Object, Object>>> inputCoder) {
+    this.stateStoreName = stateStoreName;
+    this.transformId = transformId;
+    this.listCoder = ListCoder.of(inputCoder);
+  }

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   Refactor the processor to store the individual element coder (`inputCoder`) 
instead of wrapping it in a `ListCoder`. This enables a highly optimized 
append-only state storage pattern, avoiding the need to decode and re-encode 
the entire list of elements on every new record.
   
   ```suggestion
     private final Coder<WindowedValue<KV<Object, Object>>> inputCoder;
   
     private ProcessorContext<byte[], KStreamsPayload<KV<Object, 
Iterable<Object>>>> context;
     private KeyValueStore<byte[], byte[]> stateStore;
   
     GroupByKeyProcessor(
         String stateStoreName,
         String transformId,
         Coder<WindowedValue<KV<Object, Object>>> inputCoder) {
       this.stateStoreName = stateStoreName;
       this.transformId = transformId;
       this.inputCoder = inputCoder;
     }
   ```



##########
runners/kafka-streams/src/main/java/org/apache/beam/runners/kafka/streams/translation/GroupByKeyProcessor.java:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.kafka.streams.translation;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.sdk.values.WindowedValues;
+import org.apache.kafka.streams.processor.api.Processor;
+import org.apache.kafka.streams.processor.api.ProcessorContext;
+import org.apache.kafka.streams.processor.api.Record;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.KeyValueStore;
+
+/** Processor for GroupByKey. */
+class GroupByKeyProcessor
+    implements Processor<
+        byte[],
+        KStreamsPayload<KV<Object, Object>>,
+        byte[],
+        KStreamsPayload<KV<Object, Iterable<Object>>>> {
+
+  private final String stateStoreName;
+  private final String transformId;
+  private final Coder<List<WindowedValue<KV<Object, Object>>>> listCoder;
+
+  private ProcessorContext<byte[], KStreamsPayload<KV<Object, 
Iterable<Object>>>> context;
+  private KeyValueStore<byte[], byte[]> stateStore;
+
+  GroupByKeyProcessor(
+      String stateStoreName,
+      String transformId,
+      Coder<WindowedValue<KV<Object, Object>>> inputCoder) {
+    this.stateStoreName = stateStoreName;
+    this.transformId = transformId;
+    this.listCoder = ListCoder.of(inputCoder);
+  }
+
+  @Override
+  public void init(
+      ProcessorContext<byte[], KStreamsPayload<KV<Object, Iterable<Object>>>> 
context) {
+    this.context = context;
+    this.stateStore = context.getStateStore(stateStoreName);
+  }
+
+  @Override
+  public void process(Record<byte[], KStreamsPayload<KV<Object, Object>>> 
record) {
+    KStreamsPayload<KV<Object, Object>> payload = record.value();
+
+    if (payload.isData()) {
+      byte[] keyBytes = record.key();
+      byte[] existingBytes = stateStore.get(keyBytes);
+      List<WindowedValue<KV<Object, Object>>> list;
+      if (existingBytes == null) {
+        list = new ArrayList<>();
+      } else {
+        try {
+          list = listCoder.decode(new ByteArrayInputStream(existingBytes));
+        } catch (IOException e) {
+          throw new RuntimeException("Failed to decode buffered GroupByKey 
state", e);
+        }
+      }
+      list.add(payload.getData());
+      ByteArrayOutputStream os = new ByteArrayOutputStream();
+      try {
+        listCoder.encode(list, os);
+      } catch (IOException e) {
+        throw new RuntimeException("Failed to encode buffered GroupByKey 
state", e);
+      }
+      stateStore.put(keyBytes, os.toByteArray());
+    } else {
+      WatermarkPayload watermark = payload.asWatermark();
+      if (watermark.getWatermarkMillis() == 
BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()) {
+        try (KeyValueIterator<byte[], byte[]> iterator = stateStore.all()) {
+          while (iterator.hasNext()) {
+            org.apache.kafka.streams.KeyValue<byte[], byte[]> kv = 
iterator.next();
+            List<WindowedValue<KV<Object, Object>>> buffered;
+            try {
+              buffered = listCoder.decode(new ByteArrayInputStream(kv.value));
+            } catch (IOException e) {
+              throw new RuntimeException("Failed to decode buffered GroupByKey 
state on emit", e);
+            }
+            if (!buffered.isEmpty()) {
+              List<Object> values = new ArrayList<>();
+              for (WindowedValue<KV<Object, Object>> wv : buffered) {
+                values.add(wv.getValue().getValue());
+              }
+              Object key = buffered.get(0).getValue().getKey();
+              WindowedValue<KV<Object, Iterable<Object>>> outWv =
+                  WindowedValues.valueInGlobalWindow(KV.of(key, values));
+              context.forward(
+                  new Record<>(kv.key, KStreamsPayload.data(outWv), 
record.timestamp()));
+            }
+          }
+        }
+        // Since we fired everything for the global window, we can optionally 
clear the store here.
+        // But the pipeline is finishing.
+
+        // Forward the watermark downstream
+        context.forward(
+            new Record<>(
+                record.key(),
+                KStreamsPayload.watermark(
+                    watermark.getWatermarkMillis(),
+                    watermark.getSourcePartition(),
+                    watermark.getTotalSourcePartitions()),
+                record.timestamp()));
+      }
+    }
+  }

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   This refactoring addresses several critical issues in the `process` method:
   
   1. **Performance ($O(N)$ vs $O(N^2)$)**: Instead of decoding and re-encoding 
the entire list of elements for every incoming record, we append the serialized 
bytes of the new element directly to the existing byte array in the state 
store. When emitting, we decode the elements sequentially from the stream.
   2. **State Leak**: We collect the keys and delete them from the state store 
after forwarding the grouped elements, preventing persistent state leaks.
   3. **Watermark Forwarding**: We forward all watermarks downstream (not just 
the terminal one) to allow downstream watermark progression and prevent 
pipeline stalls.
   4. **Correct Output Timestamp**: We compute the minimum timestamp of the 
buffered elements (following `TimestampCombiner.EARLIEST`) and use 
`WindowedValue.timestampedValueInGlobalWindow` instead of defaulting to 
`TIMESTAMP_MIN_VALUE`, which would violate event-time semantics and cause 
downstream data loss.
   
   ```java
     @Override
     public void process(Record<byte[], KStreamsPayload<KV<Object, Object>>> 
record) {
       KStreamsPayload<KV<Object, Object>> payload = record.value();
   
       if (payload.isData()) {
         byte[] keyBytes = record.key();
         byte[] existingBytes = stateStore.get(keyBytes);
         ByteArrayOutputStream os = new ByteArrayOutputStream();
         if (existingBytes != null) {
           try {
             os.write(existingBytes);
           } catch (IOException e) {
             throw new RuntimeException("Failed to append to buffered 
GroupByKey state", e);
           }
         }
         try {
           inputCoder.encode(payload.getData(), os);
         } catch (IOException e) {
           throw new RuntimeException("Failed to encode buffered GroupByKey 
state", e);
         }
         stateStore.put(keyBytes, os.toByteArray());
       } else {
         WatermarkPayload watermark = payload.asWatermark();
         if (watermark.getWatermarkMillis() == 
BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()) {
           List<byte[]> keysToDelete = new ArrayList<>();
           try (KeyValueIterator<byte[], byte[]> iterator = stateStore.all()) {
             while (iterator.hasNext()) {
               org.apache.kafka.streams.KeyValue<byte[], byte[]> kv = 
iterator.next();
               keysToDelete.add(kv.key);
               List<WindowedValue<KV<Object, Object>>> buffered = new 
ArrayList<>();
               try (ByteArrayInputStream bis = new 
ByteArrayInputStream(kv.value)) {
                 while (bis.available() > 0) {
                   buffered.add(inputCoder.decode(bis));
                 }
               } catch (IOException e) {
                 throw new RuntimeException("Failed to decode buffered 
GroupByKey state on emit", e);
               }
               if (!buffered.isEmpty()) {
                 org.joda.time.Instant minTimestamp = 
BoundedWindow.TIMESTAMP_MAX_VALUE;
                 List<Object> values = new ArrayList<>();
                 for (WindowedValue<KV<Object, Object>> wv : buffered) {
                   values.add(wv.getValue().getValue());
                   if (wv.getTimestamp().isBefore(minTimestamp)) {
                     minTimestamp = wv.getTimestamp();
                   }
                 }
                 Object key = buffered.get(0).getValue().getKey();
                 WindowedValue<KV<Object, Iterable<Object>>> outWv =
                     WindowedValue.timestampedValueInGlobalWindow(KV.of(key, 
values), minTimestamp);
                 context.forward(
                     new Record<>(kv.key, KStreamsPayload.data(outWv), 
record.timestamp()));
               }
             }
           }
           for (byte[] key : keysToDelete) {
             stateStore.delete(key);
           }
         }
   
         // Forward the watermark downstream
         context.forward(
             new Record<>(
                 record.key(),
                 KStreamsPayload.watermark(
                     watermark.getWatermarkMillis(),
                     watermark.getSourcePartition(),
                     watermark.getTotalSourcePartitions()),
                 record.timestamp()));
       }
     }
   ```



##########
runners/kafka-streams/src/test/java/org/apache/beam/runners/kafka/streams/translation/GroupByKeyTranslatorTest.java:
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.kafka.streams.translation;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Properties;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.construction.PipelineTranslation;
+import org.apache.beam.sdk.util.construction.RehydratedComponents;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.sdk.values.WindowedValues;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TestInputTopic;
+import org.apache.kafka.streams.Topology;
+import org.apache.kafka.streams.TopologyTestDriver;
+import org.apache.kafka.streams.processor.api.Processor;
+import org.apache.kafka.streams.processor.api.ProcessorContext;
+import org.apache.kafka.streams.processor.api.ProcessorSupplier;
+import org.apache.kafka.streams.processor.api.Record;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.junit.Test;
+
+/** Tests for {@link GroupByKeyTranslator}. */
+public class GroupByKeyTranslatorTest {
+
+  @Test
+  public void groupByKeyBuffersAndFiresAtTerminalWatermark() throws Exception {
+    KafkaStreamsTranslationContext context = 
KafkaStreamsPipelineTranslatorTest.newContext();
+
+    Pipeline sdkPipeline = Pipeline.create(PipelineOptionsFactory.create());
+    sdkPipeline
+        .apply("create", Create.empty(KvCoder.of(StringUtf8Coder.of(), 
StringUtf8Coder.of())))
+        .apply("gbk", GroupByKey.create());
+
+    RunnerApi.Pipeline pipeline = PipelineTranslation.toProto(sdkPipeline);
+    String gbkTransformId = null;
+    for (String id : pipeline.getComponents().getTransformsMap().keySet()) {
+      if (pipeline
+          .getComponents()
+          .getTransformsOrThrow(id)
+          .getSpec()
+          .getUrn()
+          .equals("beam:transform:group_by_key:v1")) {
+        gbkTransformId = id;
+      }
+    }
+
+    RunnerApi.PTransform gbkTransform =
+        pipeline.getComponents().getTransformsOrThrow(gbkTransformId);
+    String inputPCollId = 
Iterables.getOnlyElement(gbkTransform.getInputsMap().values());
+    context.registerPCollectionProducer(inputPCollId, "mock-parent");
+
+    RehydratedComponents components = 
RehydratedComponents.forComponents(pipeline.getComponents());
+    Coder<?> rawInputCoder =
+        components.getCoder(
+            
pipeline.getComponents().getPcollectionsOrThrow(inputPCollId).getCoderId());
+    @SuppressWarnings("unchecked")
+    Coder<WindowedValue<KV<String, String>>> inputCoder =
+        (Coder<WindowedValue<KV<String, String>>>) rawInputCoder;
+    KStreamsPayloadSerde<KV<String, String>> payloadSerde = new 
KStreamsPayloadSerde<>(inputCoder);
+
+    Topology topology = context.getTopology();
+    topology.addSource(
+        "mock-source",
+        Serdes.ByteArray().deserializer(),
+        payloadSerde.deserializer(),
+        "mock-topic");
+    // Just a pass-through to satisfy the parent reference
+    topology.addProcessor("mock-parent", PassThroughProcessor::new, 
"mock-source");
+
+    new GroupByKeyTranslator().translate(gbkTransformId, pipeline, context);
+
+    CapturingProcessor capture = new CapturingProcessor();
+    topology.addProcessor("capture", capture, gbkTransformId);
+
+    try (TopologyTestDriver driver = new TopologyTestDriver(topology, 
baseProps())) {
+      TestInputTopic<byte[], KStreamsPayload<KV<String, String>>> inputTopic =
+          driver.createInputTopic(
+              "mock-topic", Serdes.ByteArray().serializer(), 
payloadSerde.serializer());
+
+      // Send elements
+      inputTopic.pipeInput(
+          new byte[0], 
KStreamsPayload.data(WindowedValues.valueInGlobalWindow(KV.of("k1", "v1"))));
+      inputTopic.pipeInput(
+          new byte[0], 
KStreamsPayload.data(WindowedValues.valueInGlobalWindow(KV.of("k1", "v2"))));
+      inputTopic.pipeInput(
+          new byte[0], 
KStreamsPayload.data(WindowedValues.valueInGlobalWindow(KV.of("k2", "v3"))));
+
+      // No output yet
+      assertThat(capture.received.size(), is(0));
+
+      // Send terminal watermark
+      inputTopic.pipeInput(
+          new byte[0],
+          
KStreamsPayload.watermark(BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis(), 0, 1));
+    }
+
+    // Now it should have fired. It fires 1 for each key, plus the watermark
+    assertThat(
+        capture.received.size(),
+        is(3)); // 2 data + 1 watermark (wait, maybe order is deterministic?)
+
+    int dataCount = 0;
+    int watermarkCount = 0;
+    for (KStreamsPayload<?> payload : capture.received) {
+      if (payload.isData()) {
+        dataCount++;
+        KV<?, ?> kv = (KV<?, ?>) payload.getData().getValue();
+        Iterable<?> iter = (Iterable<?>) kv.getValue();
+        if ("k1".equals(kv.getKey())) {
+          assertThat(Iterables.size(iter), is(2));
+        } else if ("k2".equals(kv.getKey())) {
+          assertThat(Iterables.size(iter), is(1));
+        }
+      } else {
+        watermarkCount++;
+        assertThat(
+            payload.asWatermark().getWatermarkMillis(),
+            is(BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()));
+      }
+    }
+    assertThat(dataCount, is(2));
+    assertThat(watermarkCount, is(1));

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Update the test to verify that non-terminal watermarks are correctly 
forwarded downstream without triggering early emission, and that the terminal 
watermark is also forwarded alongside the emitted data.
   
   ```java
         // No output yet
         assertThat(capture.received.size(), is(0));
   
         // Send non-terminal watermark
         inputTopic.pipeInput(
             new byte[0],
             KStreamsPayload.watermark(1000L, 0, 1));
   
         // Watermark should be forwarded, but no data emitted yet
         assertThat(capture.received.size(), is(1));
         assertThat(capture.received.get(0).isData(), is(false));
         assertThat(capture.received.get(0).asWatermark().getWatermarkMillis(), 
is(1000L));
   
         // Send terminal watermark
         inputTopic.pipeInput(
             new byte[0],
             
KStreamsPayload.watermark(BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis(), 0, 1));
       }
   
       // Now it should have fired. It fires 1 for each key, plus the watermarks
       assertThat(
           capture.received.size(),
           is(4)); // 2 data + 2 watermarks
   
       int dataCount = 0;
       int watermarkCount = 0;
       for (KStreamsPayload<?> payload : capture.received) {
         if (payload.isData()) {
           dataCount++;
           KV<?, ?> kv = (KV<?, ?>) payload.getData().getValue();
           Iterable<?> iter = (Iterable<?>) kv.getValue();
           if ("k1".equals(kv.getKey())) {
             assertThat(Iterables.size(iter), is(2));
           } else if ("k2".equals(kv.getKey())) {
             assertThat(Iterables.size(iter), is(1));
           }
         } else {
           watermarkCount++;
           long wm = payload.asWatermark().getWatermarkMillis();
           assertThat(
               wm == 1000L || wm == 
BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis(),
               is(true));
         }
       }
       assertThat(dataCount, is(2));
       assertThat(watermarkCount, is(2));
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to