gemini-code-assist[bot] commented on code in PR #38255:
URL: https://github.com/apache/beam/pull/38255#discussion_r3116591950


##########
runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java:
##########
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.io;
+
+import static java.util.stream.Collectors.toList;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static 
org.apache.beam.sdk.values.WindowedValues.timestampedValueInGlobalWindow;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import static scala.collection.JavaConverters.asScalaIterator;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.IntSupplier;
+import java.util.function.Supplier;
+import javax.annotation.CheckForNull;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.BoundedSource.BoundedReader;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.values.WindowedValue;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import org.apache.spark.InterruptibleIterator;
+import org.apache.spark.Partition;
+import org.apache.spark.SparkContext;
+import org.apache.spark.TaskContext;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.apache.spark.sql.classic.Dataset$;
+import org.apache.spark.sql.connector.catalog.SupportsRead;
+import org.apache.spark.sql.connector.catalog.Table;
+import org.apache.spark.sql.connector.catalog.TableCapability;
+import org.apache.spark.sql.connector.read.Batch;
+import org.apache.spark.sql.connector.read.InputPartition;
+import org.apache.spark.sql.connector.read.PartitionReader;
+import org.apache.spark.sql.connector.read.PartitionReaderFactory;
+import org.apache.spark.sql.connector.read.Scan;
+import org.apache.spark.sql.connector.read.ScanBuilder;
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+import scala.Option;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+
+public class BoundedDatasetFactory {
+  private BoundedDatasetFactory() {}
+
+  /**
+   * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link 
Table}.
+   *
+   * <p>Unfortunately tables are expected to return an {@link InternalRow}, 
requiring serialization.
+   * This makes this approach at the time being significantly less performant 
than creating a
+   * dataset from an RDD.
+   */
+  public static <T> Dataset<WindowedValue<T>> createDatasetFromRows(
+      SparkSession session,
+      BoundedSource<T> source,
+      Supplier<PipelineOptions> options,
+      Encoder<WindowedValue<T>> encoder) {
+    Params<T> params = new Params<>(encoder, options, 
session.sparkContext().defaultParallelism());
+    BeamTable<T> table = new BeamTable<>(source, params);
+    LogicalPlan logicalPlan = DataSourceV2Relation.create(table, 
Option.empty(), Option.empty());
+    // In Spark 4.0+, Dataset$ moved to org.apache.spark.sql.classic and its 
ofRows() now
+    // takes the classic SparkSession subclass. The runtime instance returned 
by
+    // SparkSession.builder() is always a classic.SparkSession, so the 
downcast is safe and
+    // avoids reflection.
+    return (Dataset<WindowedValue<T>>)
+        Dataset$.MODULE$
+            .ofRows((org.apache.spark.sql.classic.SparkSession) session, 
logicalPlan)
+            .as(encoder);
+  }
+
+  /**
+   * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link 
RDD}.
+   *
+   * <p>This is currently the most efficient approach as it avoid any 
serialization overhead.
+   */
+  public static <T> Dataset<WindowedValue<T>> createDatasetFromRDD(
+      SparkSession session,
+      BoundedSource<T> source,
+      Supplier<PipelineOptions> options,
+      Encoder<WindowedValue<T>> encoder) {
+    Params<T> params = new Params<>(encoder, options, 
session.sparkContext().defaultParallelism());
+    RDD<WindowedValue<T>> rdd = new BoundedRDD<>(session.sparkContext(), 
source, params);
+    return session.createDataset(rdd, encoder);
+  }
+
+  /** An {@link RDD} for a bounded Beam source. */
+  private static class BoundedRDD<T> extends RDD<WindowedValue<T>> {
+    final BoundedSource<T> source;
+    final Params<T> params;
+
+    public BoundedRDD(SparkContext sc, BoundedSource<T> source, Params<T> 
params) {
+      super(sc, emptyList(), ClassTag.apply(WindowedValue.class));
+      this.source = source;
+      this.params = params;
+    }
+
+    @Override
+    public Iterator<WindowedValue<T>> compute(Partition split, TaskContext 
context) {
+      return new InterruptibleIterator<>(
+          context,
+          asScalaIterator(new SourcePartitionIterator<>((SourcePartition<T>) 
split, params)));
+    }
+
+    @Override
+    public Partition[] getPartitions() {
+      return SourcePartition.partitionsOf(source, params).toArray(new 
Partition[0]);
+    }
+  }
+
+  /** A Spark {@link Table} for a bounded Beam source supporting batch reads 
only. */
+  private static class BeamTable<T> implements Table, SupportsRead {
+    final BoundedSource<T> source;
+    final Params<T> params;
+
+    BeamTable(BoundedSource<T> source, Params<T> params) {
+      this.source = source;
+      this.params = params;
+    }
+
+    public Encoder<WindowedValue<T>> getEncoder() {
+      return params.encoder;
+    }
+
+    @Override
+    public ScanBuilder newScanBuilder(CaseInsensitiveStringMap ignored) {
+      return () ->
+          new Scan() {
+            @Override
+            public StructType readSchema() {
+              return params.encoder.schema();
+            }
+
+            @Override
+            public Batch toBatch() {
+              return new BeamBatch<>(source, params);
+            }
+          };
+    }
+
+    @Override
+    public String name() {
+      return "BeamSource<" + source.getClass().getName() + ">";
+    }
+
+    @Override
+    public StructType schema() {
+      return params.encoder.schema();
+    }
+
+    @Override
+    public Set<TableCapability> capabilities() {
+      return ImmutableSet.of(TableCapability.BATCH_READ);
+    }
+
+    private static class BeamBatch<T> implements Batch, Serializable {
+      final BoundedSource<T> source;
+      final Params<T> params;
+
+      private BeamBatch(BoundedSource<T> source, Params<T> params) {
+        this.source = source;
+        this.params = params;
+      }
+
+      @Override
+      public InputPartition[] planInputPartitions() {
+        return SourcePartition.partitionsOf(source, params).toArray(new 
InputPartition[0]);
+      }
+
+      @Override
+      public PartitionReaderFactory createReaderFactory() {
+        return p -> new BeamPartitionReader<>(((SourcePartition<T>) p), 
params);
+      }
+    }
+
+    private static class BeamPartitionReader<T> implements 
PartitionReader<InternalRow> {
+      final SourcePartitionIterator<T> iterator;
+      final Serializer<WindowedValue<T>> serializer;
+      transient @Nullable InternalRow next;
+
+      BeamPartitionReader(SourcePartition<T> partition, Params<T> params) {
+        iterator = new SourcePartitionIterator<>(partition, params);
+        serializer = ((ExpressionEncoder<WindowedValue<T>>) 
params.encoder).createSerializer();
+      }
+
+      @Override
+      public boolean next() throws IOException {
+        if (iterator.hasNext()) {
+          next = serializer.apply(iterator.next());
+          return true;
+        }
+        return false;
+      }
+
+      @Override
+      public InternalRow get() {
+        if (next == null) {
+          throw new IllegalStateException("Next not available");
+        }
+        return next;
+      }
+
+      @Override
+      public void close() throws IOException {
+        next = null;
+        iterator.close();
+      }
+    }
+  }
+
+  /** A Spark partition wrapping the partitioned Beam {@link BoundedSource}. */
+  private static class SourcePartition<T> implements Partition, InputPartition 
{
+    final BoundedSource<T> source;
+    final int index;
+
+    SourcePartition(BoundedSource<T> source, IntSupplier idxSupplier) {
+      this.source = source;
+      this.index = idxSupplier.getAsInt();
+    }
+
+    static <T> List<SourcePartition<T>> partitionsOf(BoundedSource<T> source, 
Params<T> params) {
+      try {
+        PipelineOptions options = params.options.get();
+        long desiredSize = source.getEstimatedSizeBytes(options) / 
params.numPartitions;
+        List<BoundedSource<T>> split = (List<BoundedSource<T>>) 
source.split(desiredSize, options);

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The cast to `List<BoundedSource<T>>` is unchecked and unnecessary. Since 
`source.split` returns `List<? extends BoundedSource<T>>`, you can use that 
type directly. This is compatible with the subsequent stream processing and 
avoids potential heap pollution warnings.
   
   ```suggestion
           List<? extends BoundedSource<T>> split = source.split(desiredSize, 
options);
   ```



##########
runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java:
##########
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+
+import java.lang.reflect.Constructor;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder;
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders;
+import org.apache.spark.sql.catalyst.encoders.AgnosticExpressionPathEncoder;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.expressions.BoundReference;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke;
+import org.apache.spark.sql.catalyst.expressions.objects.NewInstance;
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import scala.Option;
+import scala.collection.Iterator;
+import scala.collection.immutable.Seq;
+import scala.reflect.ClassTag;
+
+public class EncoderFactory {
+  // default constructor to reflectively create static invoke expressions
+  private static final Constructor<StaticInvoke> STATIC_INVOKE_CONSTRUCTOR =
+      (Constructor<StaticInvoke>) StaticInvoke.class.getConstructors()[0];

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Relying on `getConstructors()[0]` to find the `StaticInvoke` constructor is 
fragile because the order of constructors is not guaranteed by the JVM. It is 
safer to look up the constructor by its parameter types or to iterate through 
all constructors and find the one with the expected number of parameters.



##########
runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java:
##########
@@ -0,0 +1,610 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invoke;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invokeIfNotNull;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.match;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
+import static org.apache.spark.sql.types.DataTypes.BinaryType;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.LongType;
+
+import java.math.BigDecimal;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.concurrent.ConcurrentHashMap;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import 
org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.sdk.values.WindowedValues;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.catalyst.SerializerBuildHelper;
+import 
org.apache.spark.sql.catalyst.SerializerBuildHelper.MapElementInformation;
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.expressions.BoundReference;
+import org.apache.spark.sql.catalyst.expressions.Coalesce;
+import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
+import org.apache.spark.sql.catalyst.expressions.EqualTo;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.GetStructField;
+import org.apache.spark.sql.catalyst.expressions.If;
+import org.apache.spark.sql.catalyst.expressions.IsNotNull;
+import org.apache.spark.sql.catalyst.expressions.IsNull;
+import org.apache.spark.sql.catalyst.expressions.Literal;
+import org.apache.spark.sql.catalyst.expressions.Literal$;
+import org.apache.spark.sql.catalyst.expressions.MapKeys;
+import org.apache.spark.sql.catalyst.expressions.MapValues;
+import org.apache.spark.sql.catalyst.expressions.objects.MapObjects$;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.types.ArrayType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.MapType;
+import org.apache.spark.sql.types.ObjectType;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.MutablePair;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+import scala.Option;
+import scala.Some;
+import scala.Tuple2;
+import scala.collection.IndexedSeq;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+
+/** {@link Encoders} utility class. */
+public class EncoderHelpers {
+  private static final DataType OBJECT_TYPE = new ObjectType(Object.class);
+  private static final DataType TUPLE2_TYPE = new ObjectType(Tuple2.class);
+  private static final DataType WINDOWED_VALUE = new 
ObjectType(WindowedValue.class);
+  private static final DataType KV_TYPE = new ObjectType(KV.class);
+  private static final DataType MUTABLE_PAIR_TYPE = new 
ObjectType(MutablePair.class);
+  private static final DataType LIST_TYPE = new ObjectType(List.class);
+
+  // Collections / maps of these types can be (de)serialized without 
(de)serializing each member
+  private static final Set<Class<?>> PRIMITIV_TYPES =
+      ImmutableSet.of(
+          Boolean.class,
+          Byte.class,
+          Short.class,
+          Integer.class,
+          Long.class,
+          Float.class,
+          Double.class);
+
+  // Default encoders by class
+  private static final Map<Class<?>, Encoder<?>> DEFAULT_ENCODERS = new 
ConcurrentHashMap<>();
+
+  // Factory for default encoders by class
+  private static @Nullable Encoder<?> encoderFactory(Class<?> cls) {
+    if (cls.equals(PaneInfo.class)) {
+      return paneInfoEncoder();
+    } else if (cls.equals(GlobalWindow.class)) {
+      return binaryEncoder(GlobalWindow.Coder.INSTANCE, false);
+    } else if (cls.equals(IntervalWindow.class)) {
+      return binaryEncoder(IntervalWindowCoder.of(), false);
+    } else if (cls.equals(Instant.class)) {
+      return instantEncoder();
+    } else if (cls.equals(String.class)) {
+      return Encoders.STRING();
+    } else if (cls.equals(Boolean.class)) {
+      return Encoders.BOOLEAN();
+    } else if (cls.equals(Integer.class)) {
+      return Encoders.INT();
+    } else if (cls.equals(Long.class)) {
+      return Encoders.LONG();
+    } else if (cls.equals(Float.class)) {
+      return Encoders.FLOAT();
+    } else if (cls.equals(Double.class)) {
+      return Encoders.DOUBLE();
+    } else if (cls.equals(BigDecimal.class)) {
+      return Encoders.DECIMAL();
+    } else if (cls.equals(byte[].class)) {
+      return Encoders.BINARY();
+    } else if (cls.equals(Byte.class)) {
+      return Encoders.BYTE();
+    } else if (cls.equals(Short.class)) {
+      return Encoders.SHORT();
+    }
+    return null;
+  }
+
+  @SuppressWarnings({"nullness", "methodref.return"}) // computeIfAbsent 
allows null returns
+  private static <T> @Nullable Encoder<T> getOrCreateDefaultEncoder(Class<? 
super T> cls) {
+    return (Encoder<T>) DEFAULT_ENCODERS.computeIfAbsent(cls, 
EncoderHelpers::encoderFactory);
+  }
+
+  /** Gets or creates a default {@link Encoder} for {@link T}. */
+  public static <T> Encoder<T> encoderOf(Class<? super T> cls) {
+    Encoder<T> enc = getOrCreateDefaultEncoder(cls);
+    if (enc == null) {
+      throw new IllegalArgumentException("No default coder available for class 
" + cls);
+    }
+    return enc;
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for {@link T} of {@link 
DataTypes#BinaryType BinaryType}
+   * delegating to a Beam {@link Coder} underneath.
+   *
+   * <p>Note: For common types, if available, default Spark {@link Encoder}s 
are used instead.
+   *
+   * @param coder Beam {@link Coder}
+   */
+  public static <T> Encoder<T> encoderFor(Coder<T> coder) {
+    Encoder<T> enc = 
getOrCreateDefaultEncoder(coder.getEncodedTypeDescriptor().getRawType());
+    return enc != null ? enc : binaryEncoder(coder, true);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for {@link T} of {@link StructType} with 
fields {@code value},
+   * {@code timestamp}, {@code window} and {@code pane}.
+   *
+   * @param value {@link Encoder} to encode field `{@code value}`.
+   * @param window {@link Encoder} to encode individual windows in field 
`{@code window}`
+   */
+  public static <T, W extends BoundedWindow> Encoder<WindowedValue<T>> 
windowedValueEncoder(
+      Encoder<T> value, Encoder<W> window) {
+    Encoder<Instant> timestamp = encoderOf(Instant.class);
+    Encoder<PaneInfo> paneInfo = encoderOf(PaneInfo.class);
+    Encoder<Collection<W>> windows = collectionEncoder(window);
+    Expression serializer =
+        serializeWindowedValue(rootRef(WINDOWED_VALUE, true), value, 
timestamp, windows, paneInfo);
+    Expression deserializer =
+        deserializeWindowedValue(
+            rootCol(serializer.dataType()), value, timestamp, windows, 
paneInfo);
+    return EncoderFactory.create(serializer, deserializer, 
WindowedValue.class);
+  }
+
+  /**
+   * Creates a one-of Spark {@link Encoder} of {@link StructType} where each 
alternative is
+   * represented as colum / field named by its index with a separate {@link 
Encoder} each.
+   *
+   * <p>Externally this is represented as tuple {@code (index, data)} where an 
index corresponds to
+   * an {@link Encoder} in the provided list.
+   *
+   * @param encoders {@link Encoder}s for each alternative.
+   */
+  public static <T> Encoder<Tuple2<Integer, T>> oneOfEncoder(List<Encoder<T>> 
encoders) {
+    Expression serializer = serializeOneOf(rootRef(TUPLE2_TYPE, true), 
encoders);
+    Expression deserializer = deserializeOneOf(rootCol(serializer.dataType()), 
encoders);
+    return EncoderFactory.create(serializer, deserializer, Tuple2.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for {@link KV} of {@link StructType} with 
fields {@code key}
+   * and {@code value}.
+   *
+   * @param key {@link Encoder} to encode field `{@code key}`.
+   * @param value {@link Encoder} to encode field `{@code value}`
+   */
+  public static <K, V> Encoder<KV<K, V>> kvEncoder(Encoder<K> key, Encoder<V> 
value) {
+    Expression serializer = serializeKV(rootRef(KV_TYPE, true), key, value);
+    Expression deserializer = deserializeKV(rootCol(serializer.dataType()), 
key, value);
+    return EncoderFactory.create(serializer, deserializer, KV.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link 
Collection}s with nullable
+   * elements.
+   *
+   * @param enc {@link Encoder} to encode collection elements
+   */
+  public static <T> Encoder<Collection<T>> collectionEncoder(Encoder<T> enc) {
+    return collectionEncoder(enc, true);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link 
Collection}s.
+   *
+   * @param enc {@link Encoder} to encode collection elements
+   * @param nullable Allow nullable collection elements
+   */
+  public static <T> Encoder<Collection<T>> collectionEncoder(Encoder<T> enc, 
boolean nullable) {
+    DataType type = new ObjectType(Collection.class);
+    Expression serializer = serializeSeq(rootRef(type, true), enc, nullable);
+    Expression deserializer = deserializeSeq(rootCol(serializer.dataType()), 
enc, nullable, true);
+    return EncoderFactory.create(serializer, deserializer, Collection.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} of {@link MapType} that deserializes to 
{@link MapT}.
+   *
+   * @param key {@link Encoder} to encode keys
+   * @param value {@link Encoder} to encode values
+   * @param cls Specific class to use, supported are {@link HashMap} and 
{@link TreeMap}
+   */
+  public static <MapT extends Map<K, V>, K, V> Encoder<MapT> mapEncoder(
+      Encoder<K> key, Encoder<V> value, Class<MapT> cls) {
+    Expression serializer = mapSerializer(rootRef(new ObjectType(cls), true), 
key, value);
+    Expression deserializer = mapDeserializer(rootCol(serializer.dataType()), 
key, value, cls);
+    return EncoderFactory.create(serializer, deserializer, cls);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for Spark's {@link MutablePair} of {@link 
StructType} with
+   * fields `{@code _1}` and `{@code _2}`.
+   *
+   * <p>This is intended to be used in places such as aggregators.
+   *
+   * @param enc1 {@link Encoder} to encode `{@code _1}`
+   * @param enc2 {@link Encoder} to encode `{@code _2}`
+   */
+  public static <T1, T2> Encoder<MutablePair<T1, T2>> mutablePairEncoder(
+      Encoder<T1> enc1, Encoder<T2> enc2) {
+    Expression serializer = serializeMutablePair(rootRef(MUTABLE_PAIR_TYPE, 
true), enc1, enc2);
+    Expression deserializer = 
deserializeMutablePair(rootCol(serializer.dataType()), enc1, enc2);
+    return EncoderFactory.create(serializer, deserializer, MutablePair.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for {@link PaneInfo} of {@link 
DataTypes#BinaryType
+   * BinaryType}.
+   */
+  private static Encoder<PaneInfo> paneInfoEncoder() {
+    DataType type = new ObjectType(PaneInfo.class);
+    return EncoderFactory.create(
+        invokeIfNotNull(Utils.class, "paneInfoToBytes", BinaryType, 
rootRef(type, false)),
+        invokeIfNotNull(Utils.class, "paneInfoFromBytes", type, 
rootCol(BinaryType)),
+        PaneInfo.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for Joda {@link Instant} of {@link 
DataTypes#LongType
+   * LongType}.
+   */
+  private static Encoder<Instant> instantEncoder() {
+    DataType type = new ObjectType(Instant.class);
+    Expression instant = rootRef(type, true);
+    Expression millis = rootCol(LongType);
+    return EncoderFactory.create(
+        nullSafe(instant, invoke(instant, "getMillis", LongType, false)),
+        nullSafe(millis, invoke(Instant.class, "ofEpochMilli", type, millis)),
+        Instant.class);
+  }
+
+  /**
+   * Creates a Spark {@link Encoder} for {@link T} of {@link 
DataTypes#BinaryType BinaryType}
+   * delegating to a Beam {@link Coder} underneath.
+   *
+   * @param coder Beam {@link Coder}
+   * @param nullable If to allow nullable items
+   */
+  private static <T> Encoder<T> binaryEncoder(Coder<T> coder, boolean 
nullable) {
+    Literal litCoder = lit(coder, Coder.class);
+    // T could be private, use OBJECT_TYPE for code generation to not risk an 
IllegalAccessError
+    return EncoderFactory.create(
+        invokeIfNotNull(
+            CoderHelpers.class,
+            "toByteArray",
+            BinaryType,
+            rootRef(OBJECT_TYPE, nullable),
+            litCoder),
+        invokeIfNotNull(
+            CoderHelpers.class, "fromByteArray", OBJECT_TYPE, 
rootCol(BinaryType), litCoder),
+        coder.getEncodedTypeDescriptor().getRawType());
+  }
+
+  private static <T, W extends BoundedWindow> Expression 
serializeWindowedValue(
+      Expression in,
+      Encoder<T> valueEnc,
+      Encoder<Instant> timestampEnc,
+      Encoder<Collection<W>> windowsEnc,
+      Encoder<PaneInfo> paneEnc) {
+    return serializerObject(
+        in,
+        tuple("value", serializeField(in, valueEnc, "getValue")),
+        tuple("timestamp", serializeField(in, timestampEnc, "getTimestamp")),
+        tuple("windows", serializeField(in, windowsEnc, "getWindows")),
+        tuple("paneInfo", serializeField(in, paneEnc, "getPaneInfo")));
+  }
+
+  private static Expression serializerObject(Expression in, Tuple2<String, 
Expression>... fields) {
+    return SerializerBuildHelper.createSerializerForObject(in, seqOf(fields));
+  }
+
+  private static <T, W extends BoundedWindow> Expression 
deserializeWindowedValue(
+      Expression in,
+      Encoder<T> valueEnc,
+      Encoder<Instant> timestampEnc,
+      Encoder<Collection<W>> windowsEnc,
+      Encoder<PaneInfo> paneEnc) {
+    Expression value = deserializeField(in, valueEnc, 0, "value");
+    Expression windows = deserializeField(in, windowsEnc, 2, "windows");
+    Expression timestamp = deserializeField(in, timestampEnc, 1, "timestamp");
+    Expression paneInfo = deserializeField(in, paneEnc, 3, "paneInfo");
+    // set timestamp to end of window (maxTimestamp) if null
+    timestamp =
+        ifNotNull(timestamp, invoke(Utils.class, "maxTimestamp", 
timestamp.dataType(), windows));
+    Expression[] fields = new Expression[] {value, timestamp, windows, 
paneInfo};
+
+    return nullSafe(paneInfo, invoke(WindowedValues.class, "of", 
WINDOWED_VALUE, fields));
+  }
+
+  private static <K, V> Expression serializeMutablePair(
+      Expression in, Encoder<K> enc1, Encoder<V> enc2) {
+    return serializerObject(
+        in,
+        tuple("_1", serializeField(in, enc1, "_1")),
+        tuple("_2", serializeField(in, enc2, "_2")));
+  }
+
+  private static <K, V> Expression deserializeMutablePair(
+      Expression in, Encoder<K> enc1, Encoder<V> enc2) {
+    Expression field1 = deserializeField(in, enc1, 0, "_1");
+    Expression field2 = deserializeField(in, enc2, 1, "_2");
+    return invoke(MutablePair.class, "apply", MUTABLE_PAIR_TYPE, field1, 
field2);
+  }
+
+  private static <K, V> Expression serializeKV(
+      Expression in, Encoder<K> keyEnc, Encoder<V> valueEnc) {
+    return serializerObject(
+        in,
+        tuple("key", serializeField(in, keyEnc, "getKey")),
+        tuple("value", serializeField(in, valueEnc, "getValue")));
+  }
+
+  private static <K, V> Expression deserializeKV(
+      Expression in, Encoder<K> keyEnc, Encoder<V> valueEnc) {
+    Expression key = deserializeField(in, keyEnc, 0, "key");
+    Expression value = deserializeField(in, valueEnc, 1, "value");
+    return invoke(KV.class, "of", KV_TYPE, key, value);
+  }
+
+  public static <T> Expression serializeOneOf(Expression in, List<Encoder<T>> 
encoders) {
+    Expression type = invoke(in, "_1", IntegerType, false);
+    Expression[] args = new Expression[encoders.size() * 2];
+    for (int i = 0; i < encoders.size(); i++) {
+      args[i * 2] = lit(String.valueOf(i));
+      args[i * 2 + 1] = serializeOneOfField(in, type, encoders.get(i), i);
+    }
+    return new CreateNamedStruct(seqOf(args));
+  }
+
+  public static <T> Expression deserializeOneOf(Expression in, 
List<Encoder<T>> encoders) {
+    Expression[] args = new Expression[encoders.size()];
+    for (int i = 0; i < encoders.size(); i++) {
+      args[i] = deserializeOneOfField(in, encoders.get(i), i);
+    }
+    return new Coalesce(seqOf(args));
+  }
+
+  private static <T> Expression serializeOneOfField(
+      Expression in, Expression type, Encoder<T> enc, int typeIdx) {
+    Expression litNull = lit(null, serializedType(enc));
+    Expression value = invoke(in, "_2", deserializedType(enc), false);
+    return new If(new EqualTo(type, lit(typeIdx)), serialize(value, enc), 
litNull);
+  }
+
+  private static <T> Expression deserializeOneOfField(Expression in, 
Encoder<T> enc, int idx) {
+    GetStructField field = new GetStructField(in, idx, Option.empty());
+    Expression litNull = lit(null, TUPLE2_TYPE);
+    Expression newTuple =
+        EncoderFactory.newInstance(Tuple2.class, TUPLE2_TYPE, lit(idx), 
deserialize(field, enc));
+    return new If(new IsNull(field), litNull, newTuple);
+  }
+
+  private static <T> Expression serializeField(Expression in, Encoder<T> enc, 
String getterName) {
+    Expression ref = 
serializer(enc).collect(match(BoundReference.class)).head();
+    return serialize(invoke(in, getterName, ref.dataType(), ref.nullable()), 
enc);
+  }
+
+  private static <T> Expression deserializeField(
+      Expression in, Encoder<T> enc, int idx, String name) {
+    return deserialize(new GetStructField(in, idx, new Some<>(name)), enc);
+  }
+
+  // Note: Currently this doesn't support nullable primitive values
+  private static <K, V> Expression mapSerializer(Expression map, Encoder<K> 
key, Encoder<V> value) {
+    DataType keyType = deserializedType(key);
+    DataType valueType = deserializedType(value);
+    return SerializerBuildHelper.createSerializerForMap(
+        map,
+        new MapElementInformation(keyType, false, e -> serialize(e, key)),
+        new MapElementInformation(valueType, false, e -> serialize(e, value)));
+  }
+
+  private static <MapT extends Map<K, V>, K, V> Expression mapDeserializer(
+      Expression in, Encoder<K> key, Encoder<V> value, Class<MapT> cls) {
+    Preconditions.checkArgument(cls.isAssignableFrom(HashMap.class) || 
cls.equals(TreeMap.class));
+    Expression keys = deserializeSeq(new MapKeys(in), key, false, false);
+    Expression values = deserializeSeq(new MapValues(in), value, false, false);
+    String fn = cls.equals(TreeMap.class) ? "toTreeMap" : "toMap";
+    return invoke(
+        Utils.class, fn, new ObjectType(cls), keys, values, mapItemType(key), 
mapItemType(value));
+  }
+
+  // serialized type for primitive types (avoid boxing!), otherwise the 
deserialized type
+  private static Literal mapItemType(Encoder<?> enc) {
+    return lit(isPrimitiveEnc(enc) ? serializedType(enc) : 
deserializedType(enc), DataType.class);
+  }
+
+  private static <T> Expression serializeSeq(Expression in, Encoder<T> enc, 
boolean nullable) {
+    if (isPrimitiveEnc(enc)) {
+      Expression array = invoke(in, "toArray", new ObjectType(Object[].class), 
false);
+      return SerializerBuildHelper.createSerializerForGenericArray(
+          array, serializedType(enc), nullable);
+    }
+    Expression seq = invoke(Utils.class, "toSeq", new ObjectType(Seq.class), 
in);
+    return MapObjects$.MODULE$.apply(
+        exp -> serialize(exp, enc), seq, deserializedType(enc), nullable, 
Option.empty());
+  }
+
+  private static <T> Expression deserializeSeq(
+      Expression in, Encoder<T> enc, boolean nullable, boolean exposeAsJava) {
+    DataType type = serializedType(enc); // input type is the serializer 
result type
+    if (isPrimitiveEnc(enc)) {
+      // Spark may reuse unsafe array data, if directly exposed it must be 
copied before
+      return exposeAsJava
+          ? invoke(Utils.class, "copyToList", LIST_TYPE, in, lit(type, 
DataType.class))
+          : in;
+    }
+    Option<Class<?>> optCls = exposeAsJava ? Option.apply(List.class) : 
Option.empty();
+    // MapObjects will always copy
+    return MapObjects$.MODULE$.apply(exp -> deserialize(exp, enc), in, type, 
nullable, optCls);
+  }
+
+  private static <T> boolean isPrimitiveEnc(Encoder<T> enc) {
+    return PRIMITIV_TYPES.contains(enc.clsTag().runtimeClass());
+  }
+
+  private static <T> Expression serialize(Expression input, Encoder<T> enc) {
+    return serializer(enc).transformUp(replace(BoundReference.class, input));
+  }
+
+  private static <T> Expression deserialize(Expression input, Encoder<T> enc) {
+    return deserializer(enc).transformUp(replace(GetColumnByOrdinal.class, 
input));
+  }
+
+  /**
+   * Wraps an {@link Encoder} as an {@link ExpressionEncoder}. In Spark 4.x, 
built-in encoders (e.g.
+   * {@code Encoders.INT()}) are {@link AgnosticEncoder} subclasses rather 
than {@link
+   * ExpressionEncoder}s, so we convert them on demand.
+   */
+  @SuppressWarnings("unchecked")
+  private static <T> ExpressionEncoder<T> toExpressionEncoder(Encoder<T> enc) {
+    if (enc instanceof ExpressionEncoder) {
+      return (ExpressionEncoder<T>) enc;
+    } else if (enc instanceof AgnosticEncoder) {
+      return ExpressionEncoder.apply((AgnosticEncoder<T>) enc);
+    }
+    throw new IllegalArgumentException("Unsupported encoder type: " + 
enc.getClass());
+  }
+
+  private static <T> Expression serializer(Encoder<T> enc) {
+    return toExpressionEncoder(enc).objSerializer();
+  }
+
+  private static <T> Expression deserializer(Encoder<T> enc) {
+    return toExpressionEncoder(enc).objDeserializer();
+  }
+
+  private static <T> DataType serializedType(Encoder<T> enc) {
+    return toExpressionEncoder(enc).objSerializer().dataType();
+  }
+
+  private static <T> DataType deserializedType(Encoder<T> enc) {
+    return toExpressionEncoder(enc).objDeserializer().dataType();
+  }
+
+  private static Expression rootRef(DataType dt, boolean nullable) {
+    return new BoundReference(0, dt, nullable);
+  }
+
+  private static Expression rootCol(DataType dt) {
+    return new GetColumnByOrdinal(0, dt);
+  }
+
+  private static Expression nullSafe(Expression in, Expression out) {
+    return new If(new IsNull(in), lit(null, out.dataType()), out);
+  }
+
+  private static Expression ifNotNull(Expression expr, Expression otherwise) {
+    return new If(new IsNotNull(expr), expr, otherwise);
+  }
+
+  private static <T extends @NonNull Object> Expression lit(T t) {
+    return Literal$.MODULE$.apply(t);
+  }
+
+  @SuppressWarnings("nullness") // literal NULL is allowed
+  private static <T> Expression lit(@Nullable T t, DataType dataType) {
+    return new Literal(t, dataType);
+  }
+
+  private static <T extends @NonNull Object> Literal lit(T obj, Class<? 
extends T> cls) {
+    return Literal.fromObject(obj, new ObjectType(cls));
+  }
+
+  /** Encoder / expression utils that are called from generated code. */
+  public static class Utils {
+
+    public static PaneInfo paneInfoFromBytes(byte[] bytes) {
+      return CoderHelpers.fromByteArray(bytes, PaneInfoCoder.of());
+    }
+
+    public static byte[] paneInfoToBytes(PaneInfo paneInfo) {
+      return CoderHelpers.toByteArray(paneInfo, PaneInfoCoder.of());
+    }
+
+    /** The end of the only window (max timestamp). */
+    public static Instant maxTimestamp(Iterable<BoundedWindow> windows) {
+      return Iterables.getOnlyElement(windows).maxTimestamp();
+    }

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   `Iterables.getOnlyElement(windows)` will throw an `IllegalArgumentException` 
if the `WindowedValue` is associated with multiple windows (e.g., after a 
sliding window assignment). This will cause the pipeline to crash at runtime. 
The method should instead return the maximum timestamp across all associated 
windows to safely handle multi-windowed data.
   
   ```suggestion
       public static Instant maxTimestamp(Iterable<? extends BoundedWindow> 
windows) {
         Instant maxTimestamp = null;
         for (BoundedWindow window : windows) {
           Instant timestamp = window.maxTimestamp();
           if (maxTimestamp == null || timestamp.isAfter(maxTimestamp)) {
             maxTimestamp = timestamp;
           }
         }
         return maxTimestamp;
       }
   ```



##########
runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java:
##########
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static 
org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.collect_list;
+import static org.apache.spark.sql.functions.explode;
+import static org.apache.spark.sql.functions.max;
+import static org.apache.spark.sql.functions.min;
+import static org.apache.spark.sql.functions.struct;
+
+import java.io.Serializable;
+import org.apache.beam.runners.core.InMemoryStateInternals;
+import org.apache.beam.runners.core.ReduceFnRunner;
+import org.apache.beam.runners.core.StateInternalsFactory;
+import org.apache.beam.runners.core.SystemReduceFn;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import 
org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
+import 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.WindowedValue;
+import org.apache.beam.sdk.values.WindowedValues;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import scala.Tuple2;
+import scala.collection.Iterator;
+import scala.collection.JavaConverters;
+import scala.collection.immutable.List;
+
+/**
+ * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the 
build-in aggregation
+ * function {@code collect_list} when applicable.
+ *
+ * <p>Note: Using {@code collect_list} isn't any worse than using {@link 
ReduceFnRunner}. In the
+ * latter case the entire group (iterator) has to be loaded into memory as 
well. Either way there's
+ * a risk of OOM errors. When enabling {@link
+ * SparkCommonPipelineOptions#getPreferGroupByKeyToHandleHugeValues()}, a more 
memory sensitive
+ * iterable is used that can be traversed just once. Attempting to traverse 
the iterable again will
+ * throw.
+ *
+ * <ul>
+ *   <li>When using the default global window, window information is dropped 
and restored after the
+ *       aggregation.
+ *   <li>For non-merging windows, windows are exploded and moved into a 
composite key for better
+ *       distribution. Though, to keep the amount of shuffled data low, this 
is only done if values
+ *       are assigned to a single window or if there are only few keys and 
distributing data is
+ *       important. After the aggregation, windowed values are restored from 
the composite key.
+ *   <li>All other cases are implemented using the SDK {@link ReduceFnRunner}.
+ * </ul>
+ */
+class GroupByKeyTranslatorBatch<K, V>
+    extends TransformTranslator<
+        PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, 
V>> {
+
+  /** Literal of binary encoded Pane info. */
+  private static final Column PANE_NO_FIRING = lit(toByteArray(NO_FIRING, 
PaneInfoCoder.of()));
+
+  /** Defaults for value in single global window. */
+  private static final List<Column> GLOBAL_WINDOW_DETAILS =
+      windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY}));
+
+  GroupByKeyTranslatorBatch() {
+    super(0.2f);
+  }
+
+  @Override
+  public void translate(GroupByKey<K, V> transform, Context cxt) {
+    WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
+    TimestampCombiner tsCombiner = windowing.getTimestampCombiner();
+
+    Dataset<WindowedValue<KV<K, V>>> input = cxt.getDataset(cxt.getInput());
+
+    KvCoder<K, V> inputCoder = (KvCoder<K, V>) cxt.getInput().getCoder();
+    KvCoder<K, Iterable<V>> outputCoder = (KvCoder<K, Iterable<V>>) 
cxt.getOutput().getCoder();
+
+    Encoder<V> valueEnc = cxt.valueEncoderOf(inputCoder);
+    Encoder<K> keyEnc = cxt.keyEncoderOf(inputCoder);
+
+    // In batch we can ignore triggering and allowed lateness parameters
+    final Dataset<WindowedValue<KV<K, Iterable<V>>>> result;
+
+    boolean useCollectList =
+        !cxt.getOptions()
+            .as(SparkCommonPipelineOptions.class)
+            .getPreferGroupByKeyToHandleHugeValues();
+    if (useCollectList && eligibleForGlobalGroupBy(windowing, false)) {
+      // Collects all values per key in memory. This might be problematic if 
there's
+      // few keys only
+      // or some highly skewed distribution.
+      result =
+          input
+              .groupBy(col("value.key").as("key"))
+              .agg(collect_list(col("value.value")).as("values"), 
timestampAggregator(tsCombiner))
+              .select(
+                  inGlobalWindow(
+                      keyValue(col("key").as(keyEnc), 
col("values").as(iterableEnc(valueEnc))),
+                      windowTimestamp(tsCombiner)));
+
+    } else if (eligibleForGlobalGroupBy(windowing, true)) {
+      // Produces an iterable that can be traversed exactly once. However, on 
the plus
+      // side, data is
+      // not collected in memory until serialized or done by the user.
+      result =
+          cxt.getDataset(cxt.getInput())
+              .groupByKey(valueKey(), keyEnc)
+              .mapValues(valueValue(), cxt.valueEncoderOf(inputCoder))
+              .mapGroups(fun2((k, it) -> KV.of(k, iterableOnce(it))), 
cxt.kvEncoderOf(outputCoder))
+              .map(fun1(WindowedValues::valueInGlobalWindow), 
cxt.windowedEncoder(outputCoder));
+
+    } else if (useCollectList
+        && eligibleForGroupByWindow(windowing, false)
+        && (windowing.getWindowFn().assignsToOneWindow() || 
transform.fewKeys())) {
+      // Using the window as part of the key should help to better distribute 
the
+      // data. However, if
+      // values are assigned to multiple windows, more data would be shuffled 
around.
+      // If there's few
+      // keys only, this is still valuable.
+      // Collects all values per key & window in memory.
+      result =
+          input
+              .select(explode(col("windows")).as("window"), col("value"), 
col("timestamp"))
+              .groupBy(col("value.key").as("key"), col("window"))
+              .agg(collect_list(col("value.value")).as("values"), 
timestampAggregator(tsCombiner))
+              .select(
+                  inSingleWindow(
+                      keyValue(col("key").as(keyEnc), 
col("values").as(iterableEnc(valueEnc))),
+                      col("window").as(cxt.windowEncoder()),
+                      windowTimestamp(tsCombiner)));
+
+    } else if (eligibleForGroupByWindow(windowing, true)
+        && (windowing.getWindowFn().assignsToOneWindow() || 
transform.fewKeys())) {
+      // Using the window as part of the key should help to better distribute 
the
+      // data. However, if
+      // values are assigned to multiple windows, more data would be shuffled 
around.
+      // If there's few
+      // keys only, this is still valuable.
+      // Produces an iterable that can be traversed exactly once. However, on 
the plus
+      // side, data is
+      // not collected in memory until serialized or done by the user.
+      Encoder<Tuple2<BoundedWindow, K>> windowedKeyEnc =
+          cxt.tupleEncoder(cxt.windowEncoder(), keyEnc);
+      result =
+          cxt.getDataset(cxt.getInput())
+              .flatMap(explodeWindowedKey(valueValue()), 
cxt.tupleEncoder(windowedKeyEnc, valueEnc))
+              .groupByKey(fun1(t -> t._1()), windowedKeyEnc)
+              .mapValues(fun1(t -> t._2()), valueEnc)
+              .mapGroups(
+                  fun2((wKey, it) -> windowedKV(wKey, 
iterableOnce((Iterator<V>) it))),

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The cast `(Iterator<V>)` is redundant as the `it` parameter is already 
inferred as `Iterator<V>` from the `mapGroups` method signature.
   
   ```suggestion
                     fun2((wKey, it) -> windowedKV(wKey, iterableOnce(it))),
   ```



##########
runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java:
##########
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace;
+import static 
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+
+import java.lang.reflect.Constructor;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder;
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders;
+import org.apache.spark.sql.catalyst.encoders.AgnosticExpressionPathEncoder;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.expressions.BoundReference;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke;
+import org.apache.spark.sql.catalyst.expressions.objects.NewInstance;
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import scala.Option;
+import scala.collection.Iterator;
+import scala.collection.immutable.Seq;
+import scala.reflect.ClassTag;
+
+public class EncoderFactory {
+  // default constructor to reflectively create static invoke expressions
+  private static final Constructor<StaticInvoke> STATIC_INVOKE_CONSTRUCTOR =
+      (Constructor<StaticInvoke>) StaticInvoke.class.getConstructors()[0];
+
+  private static final Constructor<Invoke> INVOKE_CONSTRUCTOR =
+      (Constructor<Invoke>) Invoke.class.getConstructors()[0];
+
+  private static final Constructor<NewInstance> NEW_INSTANCE_CONSTRUCTOR =
+      (Constructor<NewInstance>) NewInstance.class.getConstructors()[0];
+
+  @SuppressWarnings({"nullness", "unchecked"})
+  static <T> ExpressionEncoder<T> create(
+      Expression serializer, Expression deserializer, Class<? super T> clazz) {
+    AgnosticEncoder<T> agnosticEncoder = new BeamAgnosticEncoder<>(serializer, 
deserializer, clazz);
+    return ExpressionEncoder.apply(agnosticEncoder, serializer, deserializer);
+  }
+
+  /**
+   * An {@link AgnosticEncoder} that implements both {@link 
AgnosticExpressionPathEncoder} (so that
+   * {@code SerializerBuildHelper} / {@code DeserializerBuildHelper} delegate 
to our pre-built
+   * expressions) and {@link AgnosticEncoders.StructEncoder} (so that {@code
+   * Dataset.select(TypedColumn)} creates an N-attribute plan instead of a 
1-attribute wrapped plan,
+   * preventing {@code FIELD_NUMBER_MISMATCH} errors).
+   *
+   * <p>The {@code toCatalyst} / {@code fromCatalyst} methods substitute the 
{@code input}
+   * expression into the pre-built serializer / deserializer via {@code 
transformUp}, so that when
+   * this encoder is nested inside a composite encoder (e.g. {@code 
Encoders.tuple}) the correct
+   * field-level expression is used in place of the root {@code 
BoundReference} / {@code
+   * GetColumnByOrdinal}.
+   */
+  @SuppressWarnings({"nullness", "unchecked", "deprecation"})
+  private static final class BeamAgnosticEncoder<T>
+      implements AgnosticExpressionPathEncoder<T>, 
AgnosticEncoders.StructEncoder<T> {
+
+    private final Expression serializer;
+    private final Expression deserializer;
+    private final Class<? super T> clazz;
+    private final Seq<AgnosticEncoders.EncoderField> encoderFields;
+
+    BeamAgnosticEncoder(Expression serializer, Expression deserializer, 
Class<? super T> clazz) {
+      this.serializer = serializer;
+      this.deserializer = deserializer;
+      this.clazz = clazz;
+      this.encoderFields = buildFields(serializer.dataType());
+    }
+
+    private static Seq<AgnosticEncoders.EncoderField> buildFields(DataType dt) 
{
+      if (dt instanceof StructType) {
+        StructField[] structFields = ((StructType) dt).fields();
+        List<AgnosticEncoders.EncoderField> fields = new 
ArrayList<>(structFields.length);
+        for (StructField sf : structFields) {
+          fields.add(
+              new AgnosticEncoders.EncoderField(
+                  sf.name(),
+                  new FieldEncoder<>(sf.dataType(), sf.nullable()),
+                  sf.nullable(),
+                  sf.metadata(),
+                  Option.empty(),
+                  Option.empty()));
+        }
+        return seqOf(fields.toArray(new AgnosticEncoders.EncoderField[0]));
+      } else {
+        // Non-struct: wrap in a single "value" field so StructEncoder sees 
one field.
+        return seqOf(
+            new AgnosticEncoders.EncoderField(
+                "value",
+                new FieldEncoder<>(dt, true),
+                true,
+                Metadata.empty(),
+                Option.empty(),
+                Option.empty()));
+      }
+    }
+
+    // --- AgnosticExpressionPathEncoder ---
+
+    @Override
+    public Expression toCatalyst(Expression input) {
+      return serializer.transformUp(replace(BoundReference.class, input));
+    }
+
+    @Override
+    public Expression fromCatalyst(Expression input) {
+      return deserializer.transformUp(replace(GetColumnByOrdinal.class, 
input));
+    }
+
+    // --- AgnosticEncoders.StructEncoder ---
+
+    @Override
+    public Seq<AgnosticEncoders.EncoderField> fields() {
+      return encoderFields;
+    }
+
+    @Override
+    public boolean isStruct() {
+      return true;
+    }
+
+    @Override
+    public void
+        
org$apache$spark$sql$catalyst$encoders$AgnosticEncoders$StructEncoder$_setter_$isStruct_$eq(
+            boolean v) {
+      // no-op: isStruct() is implemented directly above

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The use of the Scala-mangled setter name 
`org$apache$spark$sql$catalyst$encoders$AgnosticEncoders$StructEncoder$_setter_$isStruct_$eq`
 is extremely fragile as it depends on internal Scala compiler implementation 
details for trait field setters. While this is a known workaround when 
implementing Scala traits with `var` fields in Java, it should be clearly 
documented as a dependency on Spark's internal Scala structure.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to