tkaymak commented on code in PR #38255: URL: https://github.com/apache/beam/pull/38255#discussion_r3116723009
########## runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java: ########## @@ -0,0 +1,610 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invoke; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invokeIfNotNull; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.match; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.spark.sql.types.DataTypes.BinaryType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.SerializerBuildHelper; +import org.apache.spark.sql.catalyst.SerializerBuildHelper.MapElementInformation; +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.Coalesce; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GetStructField; +import org.apache.spark.sql.catalyst.expressions.If; +import org.apache.spark.sql.catalyst.expressions.IsNotNull; +import org.apache.spark.sql.catalyst.expressions.IsNull; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.expressions.MapKeys; +import org.apache.spark.sql.catalyst.expressions.MapValues; +import org.apache.spark.sql.catalyst.expressions.objects.MapObjects$; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ObjectType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.MutablePair; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import scala.Option; +import scala.Some; +import scala.Tuple2; +import scala.collection.IndexedSeq; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +/** {@link Encoders} utility class. */ +public class EncoderHelpers { + private static final DataType OBJECT_TYPE = new ObjectType(Object.class); + private static final DataType TUPLE2_TYPE = new ObjectType(Tuple2.class); + private static final DataType WINDOWED_VALUE = new ObjectType(WindowedValue.class); + private static final DataType KV_TYPE = new ObjectType(KV.class); + private static final DataType MUTABLE_PAIR_TYPE = new ObjectType(MutablePair.class); + private static final DataType LIST_TYPE = new ObjectType(List.class); + + // Collections / maps of these types can be (de)serialized without (de)serializing each member + private static final Set<Class<?>> PRIMITIV_TYPES = + ImmutableSet.of( + Boolean.class, + Byte.class, + Short.class, + Integer.class, + Long.class, + Float.class, + Double.class); + + // Default encoders by class + private static final Map<Class<?>, Encoder<?>> DEFAULT_ENCODERS = new ConcurrentHashMap<>(); + + // Factory for default encoders by class + private static @Nullable Encoder<?> encoderFactory(Class<?> cls) { + if (cls.equals(PaneInfo.class)) { + return paneInfoEncoder(); + } else if (cls.equals(GlobalWindow.class)) { + return binaryEncoder(GlobalWindow.Coder.INSTANCE, false); + } else if (cls.equals(IntervalWindow.class)) { + return binaryEncoder(IntervalWindowCoder.of(), false); + } else if (cls.equals(Instant.class)) { + return instantEncoder(); + } else if (cls.equals(String.class)) { + return Encoders.STRING(); + } else if (cls.equals(Boolean.class)) { + return Encoders.BOOLEAN(); + } else if (cls.equals(Integer.class)) { + return Encoders.INT(); + } else if (cls.equals(Long.class)) { + return Encoders.LONG(); + } else if (cls.equals(Float.class)) { + return Encoders.FLOAT(); + } else if (cls.equals(Double.class)) { + return Encoders.DOUBLE(); + } else if (cls.equals(BigDecimal.class)) { + return Encoders.DECIMAL(); + } else if (cls.equals(byte[].class)) { + return Encoders.BINARY(); + } else if (cls.equals(Byte.class)) { + return Encoders.BYTE(); + } else if (cls.equals(Short.class)) { + return Encoders.SHORT(); + } + return null; + } + + @SuppressWarnings({"nullness", "methodref.return"}) // computeIfAbsent allows null returns + private static <T> @Nullable Encoder<T> getOrCreateDefaultEncoder(Class<? super T> cls) { + return (Encoder<T>) DEFAULT_ENCODERS.computeIfAbsent(cls, EncoderHelpers::encoderFactory); + } + + /** Gets or creates a default {@link Encoder} for {@link T}. */ + public static <T> Encoder<T> encoderOf(Class<? super T> cls) { + Encoder<T> enc = getOrCreateDefaultEncoder(cls); + if (enc == null) { + throw new IllegalArgumentException("No default coder available for class " + cls); + } + return enc; + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType} + * delegating to a Beam {@link Coder} underneath. + * + * <p>Note: For common types, if available, default Spark {@link Encoder}s are used instead. + * + * @param coder Beam {@link Coder} + */ + public static <T> Encoder<T> encoderFor(Coder<T> coder) { + Encoder<T> enc = getOrCreateDefaultEncoder(coder.getEncodedTypeDescriptor().getRawType()); + return enc != null ? enc : binaryEncoder(coder, true); + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link StructType} with fields {@code value}, + * {@code timestamp}, {@code window} and {@code pane}. + * + * @param value {@link Encoder} to encode field `{@code value}`. + * @param window {@link Encoder} to encode individual windows in field `{@code window}` + */ + public static <T, W extends BoundedWindow> Encoder<WindowedValue<T>> windowedValueEncoder( + Encoder<T> value, Encoder<W> window) { + Encoder<Instant> timestamp = encoderOf(Instant.class); + Encoder<PaneInfo> paneInfo = encoderOf(PaneInfo.class); + Encoder<Collection<W>> windows = collectionEncoder(window); + Expression serializer = + serializeWindowedValue(rootRef(WINDOWED_VALUE, true), value, timestamp, windows, paneInfo); + Expression deserializer = + deserializeWindowedValue( + rootCol(serializer.dataType()), value, timestamp, windows, paneInfo); + return EncoderFactory.create(serializer, deserializer, WindowedValue.class); + } + + /** + * Creates a one-of Spark {@link Encoder} of {@link StructType} where each alternative is + * represented as colum / field named by its index with a separate {@link Encoder} each. + * + * <p>Externally this is represented as tuple {@code (index, data)} where an index corresponds to + * an {@link Encoder} in the provided list. + * + * @param encoders {@link Encoder}s for each alternative. + */ + public static <T> Encoder<Tuple2<Integer, T>> oneOfEncoder(List<Encoder<T>> encoders) { + Expression serializer = serializeOneOf(rootRef(TUPLE2_TYPE, true), encoders); + Expression deserializer = deserializeOneOf(rootCol(serializer.dataType()), encoders); + return EncoderFactory.create(serializer, deserializer, Tuple2.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link KV} of {@link StructType} with fields {@code key} + * and {@code value}. + * + * @param key {@link Encoder} to encode field `{@code key}`. + * @param value {@link Encoder} to encode field `{@code value}` + */ + public static <K, V> Encoder<KV<K, V>> kvEncoder(Encoder<K> key, Encoder<V> value) { + Expression serializer = serializeKV(rootRef(KV_TYPE, true), key, value); + Expression deserializer = deserializeKV(rootCol(serializer.dataType()), key, value); + return EncoderFactory.create(serializer, deserializer, KV.class); + } + + /** + * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s with nullable + * elements. + * + * @param enc {@link Encoder} to encode collection elements + */ + public static <T> Encoder<Collection<T>> collectionEncoder(Encoder<T> enc) { + return collectionEncoder(enc, true); + } + + /** + * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s. + * + * @param enc {@link Encoder} to encode collection elements + * @param nullable Allow nullable collection elements + */ + public static <T> Encoder<Collection<T>> collectionEncoder(Encoder<T> enc, boolean nullable) { + DataType type = new ObjectType(Collection.class); + Expression serializer = serializeSeq(rootRef(type, true), enc, nullable); + Expression deserializer = deserializeSeq(rootCol(serializer.dataType()), enc, nullable, true); + return EncoderFactory.create(serializer, deserializer, Collection.class); + } + + /** + * Creates a Spark {@link Encoder} of {@link MapType} that deserializes to {@link MapT}. + * + * @param key {@link Encoder} to encode keys + * @param value {@link Encoder} to encode values + * @param cls Specific class to use, supported are {@link HashMap} and {@link TreeMap} + */ + public static <MapT extends Map<K, V>, K, V> Encoder<MapT> mapEncoder( + Encoder<K> key, Encoder<V> value, Class<MapT> cls) { + Expression serializer = mapSerializer(rootRef(new ObjectType(cls), true), key, value); + Expression deserializer = mapDeserializer(rootCol(serializer.dataType()), key, value, cls); + return EncoderFactory.create(serializer, deserializer, cls); + } + + /** + * Creates a Spark {@link Encoder} for Spark's {@link MutablePair} of {@link StructType} with + * fields `{@code _1}` and `{@code _2}`. + * + * <p>This is intended to be used in places such as aggregators. + * + * @param enc1 {@link Encoder} to encode `{@code _1}` + * @param enc2 {@link Encoder} to encode `{@code _2}` + */ + public static <T1, T2> Encoder<MutablePair<T1, T2>> mutablePairEncoder( + Encoder<T1> enc1, Encoder<T2> enc2) { + Expression serializer = serializeMutablePair(rootRef(MUTABLE_PAIR_TYPE, true), enc1, enc2); + Expression deserializer = deserializeMutablePair(rootCol(serializer.dataType()), enc1, enc2); + return EncoderFactory.create(serializer, deserializer, MutablePair.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link PaneInfo} of {@link DataTypes#BinaryType + * BinaryType}. + */ + private static Encoder<PaneInfo> paneInfoEncoder() { + DataType type = new ObjectType(PaneInfo.class); + return EncoderFactory.create( + invokeIfNotNull(Utils.class, "paneInfoToBytes", BinaryType, rootRef(type, false)), + invokeIfNotNull(Utils.class, "paneInfoFromBytes", type, rootCol(BinaryType)), + PaneInfo.class); + } + + /** + * Creates a Spark {@link Encoder} for Joda {@link Instant} of {@link DataTypes#LongType + * LongType}. + */ + private static Encoder<Instant> instantEncoder() { + DataType type = new ObjectType(Instant.class); + Expression instant = rootRef(type, true); + Expression millis = rootCol(LongType); + return EncoderFactory.create( + nullSafe(instant, invoke(instant, "getMillis", LongType, false)), + nullSafe(millis, invoke(Instant.class, "ofEpochMilli", type, millis)), + Instant.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType} + * delegating to a Beam {@link Coder} underneath. + * + * @param coder Beam {@link Coder} + * @param nullable If to allow nullable items + */ + private static <T> Encoder<T> binaryEncoder(Coder<T> coder, boolean nullable) { + Literal litCoder = lit(coder, Coder.class); + // T could be private, use OBJECT_TYPE for code generation to not risk an IllegalAccessError + return EncoderFactory.create( + invokeIfNotNull( + CoderHelpers.class, + "toByteArray", + BinaryType, + rootRef(OBJECT_TYPE, nullable), + litCoder), + invokeIfNotNull( + CoderHelpers.class, "fromByteArray", OBJECT_TYPE, rootCol(BinaryType), litCoder), + coder.getEncodedTypeDescriptor().getRawType()); + } + + private static <T, W extends BoundedWindow> Expression serializeWindowedValue( + Expression in, + Encoder<T> valueEnc, + Encoder<Instant> timestampEnc, + Encoder<Collection<W>> windowsEnc, + Encoder<PaneInfo> paneEnc) { + return serializerObject( + in, + tuple("value", serializeField(in, valueEnc, "getValue")), + tuple("timestamp", serializeField(in, timestampEnc, "getTimestamp")), + tuple("windows", serializeField(in, windowsEnc, "getWindows")), + tuple("paneInfo", serializeField(in, paneEnc, "getPaneInfo"))); + } + + private static Expression serializerObject(Expression in, Tuple2<String, Expression>... fields) { + return SerializerBuildHelper.createSerializerForObject(in, seqOf(fields)); + } + + private static <T, W extends BoundedWindow> Expression deserializeWindowedValue( + Expression in, + Encoder<T> valueEnc, + Encoder<Instant> timestampEnc, + Encoder<Collection<W>> windowsEnc, + Encoder<PaneInfo> paneEnc) { + Expression value = deserializeField(in, valueEnc, 0, "value"); + Expression windows = deserializeField(in, windowsEnc, 2, "windows"); + Expression timestamp = deserializeField(in, timestampEnc, 1, "timestamp"); + Expression paneInfo = deserializeField(in, paneEnc, 3, "paneInfo"); + // set timestamp to end of window (maxTimestamp) if null + timestamp = + ifNotNull(timestamp, invoke(Utils.class, "maxTimestamp", timestamp.dataType(), windows)); + Expression[] fields = new Expression[] {value, timestamp, windows, paneInfo}; + + return nullSafe(paneInfo, invoke(WindowedValues.class, "of", WINDOWED_VALUE, fields)); + } + + private static <K, V> Expression serializeMutablePair( + Expression in, Encoder<K> enc1, Encoder<V> enc2) { + return serializerObject( + in, + tuple("_1", serializeField(in, enc1, "_1")), + tuple("_2", serializeField(in, enc2, "_2"))); + } + + private static <K, V> Expression deserializeMutablePair( + Expression in, Encoder<K> enc1, Encoder<V> enc2) { + Expression field1 = deserializeField(in, enc1, 0, "_1"); + Expression field2 = deserializeField(in, enc2, 1, "_2"); + return invoke(MutablePair.class, "apply", MUTABLE_PAIR_TYPE, field1, field2); + } + + private static <K, V> Expression serializeKV( + Expression in, Encoder<K> keyEnc, Encoder<V> valueEnc) { + return serializerObject( + in, + tuple("key", serializeField(in, keyEnc, "getKey")), + tuple("value", serializeField(in, valueEnc, "getValue"))); + } + + private static <K, V> Expression deserializeKV( + Expression in, Encoder<K> keyEnc, Encoder<V> valueEnc) { + Expression key = deserializeField(in, keyEnc, 0, "key"); + Expression value = deserializeField(in, valueEnc, 1, "value"); + return invoke(KV.class, "of", KV_TYPE, key, value); + } + + public static <T> Expression serializeOneOf(Expression in, List<Encoder<T>> encoders) { + Expression type = invoke(in, "_1", IntegerType, false); + Expression[] args = new Expression[encoders.size() * 2]; + for (int i = 0; i < encoders.size(); i++) { + args[i * 2] = lit(String.valueOf(i)); + args[i * 2 + 1] = serializeOneOfField(in, type, encoders.get(i), i); + } + return new CreateNamedStruct(seqOf(args)); + } + + public static <T> Expression deserializeOneOf(Expression in, List<Encoder<T>> encoders) { + Expression[] args = new Expression[encoders.size()]; + for (int i = 0; i < encoders.size(); i++) { + args[i] = deserializeOneOfField(in, encoders.get(i), i); + } + return new Coalesce(seqOf(args)); + } + + private static <T> Expression serializeOneOfField( + Expression in, Expression type, Encoder<T> enc, int typeIdx) { + Expression litNull = lit(null, serializedType(enc)); + Expression value = invoke(in, "_2", deserializedType(enc), false); + return new If(new EqualTo(type, lit(typeIdx)), serialize(value, enc), litNull); + } + + private static <T> Expression deserializeOneOfField(Expression in, Encoder<T> enc, int idx) { + GetStructField field = new GetStructField(in, idx, Option.empty()); + Expression litNull = lit(null, TUPLE2_TYPE); + Expression newTuple = + EncoderFactory.newInstance(Tuple2.class, TUPLE2_TYPE, lit(idx), deserialize(field, enc)); + return new If(new IsNull(field), litNull, newTuple); + } + + private static <T> Expression serializeField(Expression in, Encoder<T> enc, String getterName) { + Expression ref = serializer(enc).collect(match(BoundReference.class)).head(); + return serialize(invoke(in, getterName, ref.dataType(), ref.nullable()), enc); + } + + private static <T> Expression deserializeField( + Expression in, Encoder<T> enc, int idx, String name) { + return deserialize(new GetStructField(in, idx, new Some<>(name)), enc); + } + + // Note: Currently this doesn't support nullable primitive values + private static <K, V> Expression mapSerializer(Expression map, Encoder<K> key, Encoder<V> value) { + DataType keyType = deserializedType(key); + DataType valueType = deserializedType(value); + return SerializerBuildHelper.createSerializerForMap( + map, + new MapElementInformation(keyType, false, e -> serialize(e, key)), + new MapElementInformation(valueType, false, e -> serialize(e, value))); + } + + private static <MapT extends Map<K, V>, K, V> Expression mapDeserializer( + Expression in, Encoder<K> key, Encoder<V> value, Class<MapT> cls) { + Preconditions.checkArgument(cls.isAssignableFrom(HashMap.class) || cls.equals(TreeMap.class)); + Expression keys = deserializeSeq(new MapKeys(in), key, false, false); + Expression values = deserializeSeq(new MapValues(in), value, false, false); + String fn = cls.equals(TreeMap.class) ? "toTreeMap" : "toMap"; + return invoke( + Utils.class, fn, new ObjectType(cls), keys, values, mapItemType(key), mapItemType(value)); + } + + // serialized type for primitive types (avoid boxing!), otherwise the deserialized type + private static Literal mapItemType(Encoder<?> enc) { + return lit(isPrimitiveEnc(enc) ? serializedType(enc) : deserializedType(enc), DataType.class); + } + + private static <T> Expression serializeSeq(Expression in, Encoder<T> enc, boolean nullable) { + if (isPrimitiveEnc(enc)) { + Expression array = invoke(in, "toArray", new ObjectType(Object[].class), false); + return SerializerBuildHelper.createSerializerForGenericArray( + array, serializedType(enc), nullable); + } + Expression seq = invoke(Utils.class, "toSeq", new ObjectType(Seq.class), in); + return MapObjects$.MODULE$.apply( + exp -> serialize(exp, enc), seq, deserializedType(enc), nullable, Option.empty()); + } + + private static <T> Expression deserializeSeq( + Expression in, Encoder<T> enc, boolean nullable, boolean exposeAsJava) { + DataType type = serializedType(enc); // input type is the serializer result type + if (isPrimitiveEnc(enc)) { + // Spark may reuse unsafe array data, if directly exposed it must be copied before + return exposeAsJava + ? invoke(Utils.class, "copyToList", LIST_TYPE, in, lit(type, DataType.class)) + : in; + } + Option<Class<?>> optCls = exposeAsJava ? Option.apply(List.class) : Option.empty(); + // MapObjects will always copy + return MapObjects$.MODULE$.apply(exp -> deserialize(exp, enc), in, type, nullable, optCls); + } + + private static <T> boolean isPrimitiveEnc(Encoder<T> enc) { + return PRIMITIV_TYPES.contains(enc.clsTag().runtimeClass()); + } + + private static <T> Expression serialize(Expression input, Encoder<T> enc) { + return serializer(enc).transformUp(replace(BoundReference.class, input)); + } + + private static <T> Expression deserialize(Expression input, Encoder<T> enc) { + return deserializer(enc).transformUp(replace(GetColumnByOrdinal.class, input)); + } + + /** + * Wraps an {@link Encoder} as an {@link ExpressionEncoder}. In Spark 4.x, built-in encoders (e.g. + * {@code Encoders.INT()}) are {@link AgnosticEncoder} subclasses rather than {@link + * ExpressionEncoder}s, so we convert them on demand. + */ + @SuppressWarnings("unchecked") + private static <T> ExpressionEncoder<T> toExpressionEncoder(Encoder<T> enc) { + if (enc instanceof ExpressionEncoder) { + return (ExpressionEncoder<T>) enc; + } else if (enc instanceof AgnosticEncoder) { + return ExpressionEncoder.apply((AgnosticEncoder<T>) enc); + } + throw new IllegalArgumentException("Unsupported encoder type: " + enc.getClass()); + } + + private static <T> Expression serializer(Encoder<T> enc) { + return toExpressionEncoder(enc).objSerializer(); + } + + private static <T> Expression deserializer(Encoder<T> enc) { + return toExpressionEncoder(enc).objDeserializer(); + } + + private static <T> DataType serializedType(Encoder<T> enc) { + return toExpressionEncoder(enc).objSerializer().dataType(); + } + + private static <T> DataType deserializedType(Encoder<T> enc) { + return toExpressionEncoder(enc).objDeserializer().dataType(); + } + + private static Expression rootRef(DataType dt, boolean nullable) { + return new BoundReference(0, dt, nullable); + } + + private static Expression rootCol(DataType dt) { + return new GetColumnByOrdinal(0, dt); + } + + private static Expression nullSafe(Expression in, Expression out) { + return new If(new IsNull(in), lit(null, out.dataType()), out); + } + + private static Expression ifNotNull(Expression expr, Expression otherwise) { + return new If(new IsNotNull(expr), expr, otherwise); + } + + private static <T extends @NonNull Object> Expression lit(T t) { + return Literal$.MODULE$.apply(t); + } + + @SuppressWarnings("nullness") // literal NULL is allowed + private static <T> Expression lit(@Nullable T t, DataType dataType) { + return new Literal(t, dataType); + } + + private static <T extends @NonNull Object> Literal lit(T obj, Class<? extends T> cls) { + return Literal.fromObject(obj, new ObjectType(cls)); + } + + /** Encoder / expression utils that are called from generated code. */ + public static class Utils { + + public static PaneInfo paneInfoFromBytes(byte[] bytes) { + return CoderHelpers.fromByteArray(bytes, PaneInfoCoder.of()); + } + + public static byte[] paneInfoToBytes(PaneInfo paneInfo) { + return CoderHelpers.toByteArray(paneInfo, PaneInfoCoder.of()); + } + + /** The end of the only window (max timestamp). */ + public static Instant maxTimestamp(Iterable<BoundedWindow> windows) { + return Iterables.getOnlyElement(windows).maxTimestamp(); + } Review Comment: Done in 334dcd33 — fixed in both the shared base (`runners/spark/src/.../EncoderHelpers.java`) and the Spark 4 override. The method now iterates and returns the max `maxTimestamp()` across all windows, with a `Preconditions.checkNotNull` to fail loudly on the (not expected in practice) empty-windows case rather than silently returning `null`. ########## runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java: ########## @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.io; + +import static java.util.stream.Collectors.toList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.sdk.values.WindowedValues.timestampedValueInGlobalWindow; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static scala.collection.JavaConverters.asScalaIterator; + +import java.io.Closeable; +import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntSupplier; +import java.util.function.Supplier; +import javax.annotation.CheckForNull; +import javax.annotation.Nullable; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.spark.InterruptibleIterator; +import org.apache.spark.Partition; +import org.apache.spark.SparkContext; +import org.apache.spark.TaskContext; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.classic.Dataset$; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.Iterator; +import scala.reflect.ClassTag; + +public class BoundedDatasetFactory { + private BoundedDatasetFactory() {} + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}. + * + * <p>Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization. + * This makes this approach at the time being significantly less performant than creating a + * dataset from an RDD. + */ + public static <T> Dataset<WindowedValue<T>> createDatasetFromRows( + SparkSession session, + BoundedSource<T> source, + Supplier<PipelineOptions> options, + Encoder<WindowedValue<T>> encoder) { + Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + BeamTable<T> table = new BeamTable<>(source, params); + LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty()); + // In Spark 4.0+, Dataset$ moved to org.apache.spark.sql.classic and its ofRows() now + // takes the classic SparkSession subclass. The runtime instance returned by + // SparkSession.builder() is always a classic.SparkSession, so the downcast is safe and + // avoids reflection. + return (Dataset<WindowedValue<T>>) + Dataset$.MODULE$ + .ofRows((org.apache.spark.sql.classic.SparkSession) session, logicalPlan) + .as(encoder); + } + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link RDD}. + * + * <p>This is currently the most efficient approach as it avoid any serialization overhead. + */ + public static <T> Dataset<WindowedValue<T>> createDatasetFromRDD( + SparkSession session, + BoundedSource<T> source, + Supplier<PipelineOptions> options, + Encoder<WindowedValue<T>> encoder) { + Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + RDD<WindowedValue<T>> rdd = new BoundedRDD<>(session.sparkContext(), source, params); + return session.createDataset(rdd, encoder); + } + + /** An {@link RDD} for a bounded Beam source. */ + private static class BoundedRDD<T> extends RDD<WindowedValue<T>> { + final BoundedSource<T> source; + final Params<T> params; + + public BoundedRDD(SparkContext sc, BoundedSource<T> source, Params<T> params) { + super(sc, emptyList(), ClassTag.apply(WindowedValue.class)); + this.source = source; + this.params = params; + } + + @Override + public Iterator<WindowedValue<T>> compute(Partition split, TaskContext context) { + return new InterruptibleIterator<>( + context, + asScalaIterator(new SourcePartitionIterator<>((SourcePartition<T>) split, params))); + } + + @Override + public Partition[] getPartitions() { + return SourcePartition.partitionsOf(source, params).toArray(new Partition[0]); + } + } + + /** A Spark {@link Table} for a bounded Beam source supporting batch reads only. */ + private static class BeamTable<T> implements Table, SupportsRead { + final BoundedSource<T> source; + final Params<T> params; + + BeamTable(BoundedSource<T> source, Params<T> params) { + this.source = source; + this.params = params; + } + + public Encoder<WindowedValue<T>> getEncoder() { + return params.encoder; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap ignored) { + return () -> + new Scan() { + @Override + public StructType readSchema() { + return params.encoder.schema(); + } + + @Override + public Batch toBatch() { + return new BeamBatch<>(source, params); + } + }; + } + + @Override + public String name() { + return "BeamSource<" + source.getClass().getName() + ">"; + } + + @Override + public StructType schema() { + return params.encoder.schema(); + } + + @Override + public Set<TableCapability> capabilities() { + return ImmutableSet.of(TableCapability.BATCH_READ); + } + + private static class BeamBatch<T> implements Batch, Serializable { + final BoundedSource<T> source; + final Params<T> params; + + private BeamBatch(BoundedSource<T> source, Params<T> params) { + this.source = source; + this.params = params; + } + + @Override + public InputPartition[] planInputPartitions() { + return SourcePartition.partitionsOf(source, params).toArray(new InputPartition[0]); + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return p -> new BeamPartitionReader<>(((SourcePartition<T>) p), params); + } + } + + private static class BeamPartitionReader<T> implements PartitionReader<InternalRow> { + final SourcePartitionIterator<T> iterator; + final Serializer<WindowedValue<T>> serializer; + transient @Nullable InternalRow next; + + BeamPartitionReader(SourcePartition<T> partition, Params<T> params) { + iterator = new SourcePartitionIterator<>(partition, params); + serializer = ((ExpressionEncoder<WindowedValue<T>>) params.encoder).createSerializer(); + } + + @Override + public boolean next() throws IOException { + if (iterator.hasNext()) { + next = serializer.apply(iterator.next()); + return true; + } + return false; + } + + @Override + public InternalRow get() { + if (next == null) { + throw new IllegalStateException("Next not available"); + } + return next; + } + + @Override + public void close() throws IOException { + next = null; + iterator.close(); + } + } + } + + /** A Spark partition wrapping the partitioned Beam {@link BoundedSource}. */ + private static class SourcePartition<T> implements Partition, InputPartition { + final BoundedSource<T> source; + final int index; + + SourcePartition(BoundedSource<T> source, IntSupplier idxSupplier) { + this.source = source; + this.index = idxSupplier.getAsInt(); + } + + static <T> List<SourcePartition<T>> partitionsOf(BoundedSource<T> source, Params<T> params) { + try { + PipelineOptions options = params.options.get(); + long desiredSize = source.getEstimatedSizeBytes(options) / params.numPartitions; + List<BoundedSource<T>> split = (List<BoundedSource<T>>) source.split(desiredSize, options); Review Comment: Done in 9f53d5ab — fixed in both the shared base and the Spark 4 override. `source.split` already returns `List<? extends BoundedSource<T>>`, so the cast was both unchecked and unnecessary. ########## runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java: ########## @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.collect_list; +import static org.apache.spark.sql.functions.explode; +import static org.apache.spark.sql.functions.max; +import static org.apache.spark.sql.functions.min; +import static org.apache.spark.sql.functions.struct; + +import java.io.Serializable; +import org.apache.beam.runners.core.InMemoryStateInternals; +import org.apache.beam.runners.core.ReduceFnRunner; +import org.apache.beam.runners.core.StateInternalsFactory; +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowedValue; +import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.JavaConverters; +import scala.collection.immutable.List; + +/** + * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the build-in aggregation + * function {@code collect_list} when applicable. + * + * <p>Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the + * latter case the entire group (iterator) has to be loaded into memory as well. Either way there's + * a risk of OOM errors. When enabling {@link + * SparkCommonPipelineOptions#getPreferGroupByKeyToHandleHugeValues()}, a more memory sensitive + * iterable is used that can be traversed just once. Attempting to traverse the iterable again will + * throw. + * + * <ul> + * <li>When using the default global window, window information is dropped and restored after the + * aggregation. + * <li>For non-merging windows, windows are exploded and moved into a composite key for better + * distribution. Though, to keep the amount of shuffled data low, this is only done if values + * are assigned to a single window or if there are only few keys and distributing data is + * important. After the aggregation, windowed values are restored from the composite key. + * <li>All other cases are implemented using the SDK {@link ReduceFnRunner}. + * </ul> + */ +class GroupByKeyTranslatorBatch<K, V> + extends TransformTranslator< + PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>> { + + /** Literal of binary encoded Pane info. */ + private static final Column PANE_NO_FIRING = lit(toByteArray(NO_FIRING, PaneInfoCoder.of())); + + /** Defaults for value in single global window. */ + private static final List<Column> GLOBAL_WINDOW_DETAILS = + windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY})); + + GroupByKeyTranslatorBatch() { + super(0.2f); + } + + @Override + public void translate(GroupByKey<K, V> transform, Context cxt) { + WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy(); + TimestampCombiner tsCombiner = windowing.getTimestampCombiner(); + + Dataset<WindowedValue<KV<K, V>>> input = cxt.getDataset(cxt.getInput()); + + KvCoder<K, V> inputCoder = (KvCoder<K, V>) cxt.getInput().getCoder(); + KvCoder<K, Iterable<V>> outputCoder = (KvCoder<K, Iterable<V>>) cxt.getOutput().getCoder(); + + Encoder<V> valueEnc = cxt.valueEncoderOf(inputCoder); + Encoder<K> keyEnc = cxt.keyEncoderOf(inputCoder); + + // In batch we can ignore triggering and allowed lateness parameters + final Dataset<WindowedValue<KV<K, Iterable<V>>>> result; + + boolean useCollectList = + !cxt.getOptions() + .as(SparkCommonPipelineOptions.class) + .getPreferGroupByKeyToHandleHugeValues(); + if (useCollectList && eligibleForGlobalGroupBy(windowing, false)) { + // Collects all values per key in memory. This might be problematic if there's + // few keys only + // or some highly skewed distribution. + result = + input + .groupBy(col("value.key").as("key")) + .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner)) + .select( + inGlobalWindow( + keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))), + windowTimestamp(tsCombiner))); + + } else if (eligibleForGlobalGroupBy(windowing, true)) { + // Produces an iterable that can be traversed exactly once. However, on the plus + // side, data is + // not collected in memory until serialized or done by the user. + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .mapValues(valueValue(), cxt.valueEncoderOf(inputCoder)) + .mapGroups(fun2((k, it) -> KV.of(k, iterableOnce(it))), cxt.kvEncoderOf(outputCoder)) + .map(fun1(WindowedValues::valueInGlobalWindow), cxt.windowedEncoder(outputCoder)); + + } else if (useCollectList + && eligibleForGroupByWindow(windowing, false) + && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) { + // Using the window as part of the key should help to better distribute the + // data. However, if + // values are assigned to multiple windows, more data would be shuffled around. + // If there's few + // keys only, this is still valuable. + // Collects all values per key & window in memory. + result = + input + .select(explode(col("windows")).as("window"), col("value"), col("timestamp")) + .groupBy(col("value.key").as("key"), col("window")) + .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner)) + .select( + inSingleWindow( + keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))), + col("window").as(cxt.windowEncoder()), + windowTimestamp(tsCombiner))); + + } else if (eligibleForGroupByWindow(windowing, true) + && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) { + // Using the window as part of the key should help to better distribute the + // data. However, if + // values are assigned to multiple windows, more data would be shuffled around. + // If there's few + // keys only, this is still valuable. + // Produces an iterable that can be traversed exactly once. However, on the plus + // side, data is + // not collected in memory until serialized or done by the user. + Encoder<Tuple2<BoundedWindow, K>> windowedKeyEnc = + cxt.tupleEncoder(cxt.windowEncoder(), keyEnc); + result = + cxt.getDataset(cxt.getInput()) + .flatMap(explodeWindowedKey(valueValue()), cxt.tupleEncoder(windowedKeyEnc, valueEnc)) + .groupByKey(fun1(t -> t._1()), windowedKeyEnc) + .mapValues(fun1(t -> t._2()), valueEnc) + .mapGroups( + fun2((wKey, it) -> windowedKV(wKey, iterableOnce((Iterator<V>) it))), Review Comment: Done in 8bebceda — dropped the cast (override only; the analogous call in the base translator already does `iterableOnce(it)` without a cast). ########## runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java: ########## @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; + +import java.lang.reflect.Constructor; +import java.util.ArrayList; +import java.util.List; +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders; +import org.apache.spark.sql.catalyst.encoders.AgnosticExpressionPathEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.objects.Invoke; +import org.apache.spark.sql.catalyst.expressions.objects.NewInstance; +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.Option; +import scala.collection.Iterator; +import scala.collection.immutable.Seq; +import scala.reflect.ClassTag; + +public class EncoderFactory { + // default constructor to reflectively create static invoke expressions + private static final Constructor<StaticInvoke> STATIC_INVOKE_CONSTRUCTOR = + (Constructor<StaticInvoke>) StaticInvoke.class.getConstructors()[0]; Review Comment: Done in 9c071c5e — replaced `getConstructors()[0]` with a `primaryConstructor(...)` helper that picks the public constructor with the most parameters. The downstream `switch` on `getParameterCount()` already dispatches on the right argument shape per Spark version, so this just makes the choice deterministic across JVMs / Spark releases. Comment block above the constants explains the intent. ########## runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java: ########## @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; + +import java.lang.reflect.Constructor; +import java.util.ArrayList; +import java.util.List; +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder; +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders; +import org.apache.spark.sql.catalyst.encoders.AgnosticExpressionPathEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.objects.Invoke; +import org.apache.spark.sql.catalyst.expressions.objects.NewInstance; +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import scala.Option; +import scala.collection.Iterator; +import scala.collection.immutable.Seq; +import scala.reflect.ClassTag; + +public class EncoderFactory { + // default constructor to reflectively create static invoke expressions + private static final Constructor<StaticInvoke> STATIC_INVOKE_CONSTRUCTOR = + (Constructor<StaticInvoke>) StaticInvoke.class.getConstructors()[0]; + + private static final Constructor<Invoke> INVOKE_CONSTRUCTOR = + (Constructor<Invoke>) Invoke.class.getConstructors()[0]; + + private static final Constructor<NewInstance> NEW_INSTANCE_CONSTRUCTOR = + (Constructor<NewInstance>) NewInstance.class.getConstructors()[0]; + + @SuppressWarnings({"nullness", "unchecked"}) + static <T> ExpressionEncoder<T> create( + Expression serializer, Expression deserializer, Class<? super T> clazz) { + AgnosticEncoder<T> agnosticEncoder = new BeamAgnosticEncoder<>(serializer, deserializer, clazz); + return ExpressionEncoder.apply(agnosticEncoder, serializer, deserializer); + } + + /** + * An {@link AgnosticEncoder} that implements both {@link AgnosticExpressionPathEncoder} (so that + * {@code SerializerBuildHelper} / {@code DeserializerBuildHelper} delegate to our pre-built + * expressions) and {@link AgnosticEncoders.StructEncoder} (so that {@code + * Dataset.select(TypedColumn)} creates an N-attribute plan instead of a 1-attribute wrapped plan, + * preventing {@code FIELD_NUMBER_MISMATCH} errors). + * + * <p>The {@code toCatalyst} / {@code fromCatalyst} methods substitute the {@code input} + * expression into the pre-built serializer / deserializer via {@code transformUp}, so that when + * this encoder is nested inside a composite encoder (e.g. {@code Encoders.tuple}) the correct + * field-level expression is used in place of the root {@code BoundReference} / {@code + * GetColumnByOrdinal}. + */ + @SuppressWarnings({"nullness", "unchecked", "deprecation"}) + private static final class BeamAgnosticEncoder<T> + implements AgnosticExpressionPathEncoder<T>, AgnosticEncoders.StructEncoder<T> { + + private final Expression serializer; + private final Expression deserializer; + private final Class<? super T> clazz; + private final Seq<AgnosticEncoders.EncoderField> encoderFields; + + BeamAgnosticEncoder(Expression serializer, Expression deserializer, Class<? super T> clazz) { + this.serializer = serializer; + this.deserializer = deserializer; + this.clazz = clazz; + this.encoderFields = buildFields(serializer.dataType()); + } + + private static Seq<AgnosticEncoders.EncoderField> buildFields(DataType dt) { + if (dt instanceof StructType) { + StructField[] structFields = ((StructType) dt).fields(); + List<AgnosticEncoders.EncoderField> fields = new ArrayList<>(structFields.length); + for (StructField sf : structFields) { + fields.add( + new AgnosticEncoders.EncoderField( + sf.name(), + new FieldEncoder<>(sf.dataType(), sf.nullable()), + sf.nullable(), + sf.metadata(), + Option.empty(), + Option.empty())); + } + return seqOf(fields.toArray(new AgnosticEncoders.EncoderField[0])); + } else { + // Non-struct: wrap in a single "value" field so StructEncoder sees one field. + return seqOf( + new AgnosticEncoders.EncoderField( + "value", + new FieldEncoder<>(dt, true), + true, + Metadata.empty(), + Option.empty(), + Option.empty())); + } + } + + // --- AgnosticExpressionPathEncoder --- + + @Override + public Expression toCatalyst(Expression input) { + return serializer.transformUp(replace(BoundReference.class, input)); + } + + @Override + public Expression fromCatalyst(Expression input) { + return deserializer.transformUp(replace(GetColumnByOrdinal.class, input)); + } + + // --- AgnosticEncoders.StructEncoder --- + + @Override + public Seq<AgnosticEncoders.EncoderField> fields() { + return encoderFields; + } + + @Override + public boolean isStruct() { + return true; + } + + @Override + public void + org$apache$spark$sql$catalyst$encoders$AgnosticEncoders$StructEncoder$_setter_$isStruct_$eq( + boolean v) { + // no-op: isStruct() is implemented directly above Review Comment: Done in 9c071c5e — added a Javadoc on the method explaining that this is the synthetic setter the Scala compiler emits for trait `val` fields (`<trait>$_setter_<field>_$eq`), why we have to implement it from Java, and what to expect if Spark removes or renames the underlying `isStruct` field. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
