echauchot commented on code in PR #22446: URL: https://github.com/apache/beam/pull/22446#discussion_r975280687
########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java: ########## @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupingTranslator.value; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static scala.collection.Iterator.single; + +import java.util.Collection; +import java.util.Map; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; +import scala.collection.Iterator; + +/** + * Translator for {@link Combine.Globally} using a Spark {@link Aggregator}. + * + * <p>To minimize the amount of data shuffled, this first reduces the data per partition using + * {@link Aggregator#reduce}, gathers the partial results (using {@code coalesce(1)}) and finally + * merges these using {@link Aggregator#merge}. + * + * <p>TODOs: + * <li>any missing features? + */ +class CombineGloballyTranslatorBatch<InT, AccT, OutT> Review Comment: this class LGTM. Nice to have used Spark aggregators directly ########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java: ########## @@ -0,0 +1,591 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mapEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mutablePairEncoder; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators.peekingIterator; + +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.BiFunction; +import java.util.function.BinaryOperator; +import java.util.function.Function; +import javax.annotation.Nullable; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.PeekingIterator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.util.MutablePair; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.PolyNull; +import org.joda.time.Instant; + +public class Aggregators { Review Comment: This is the main big change. LGTM, code logic seems correct as as the VR pass, it is correct. And the code is very readable. thanks ########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java: ########## @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.io; + +import static java.util.stream.Collectors.toList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static scala.collection.JavaConverters.asScalaIterator; + +import java.io.Closeable; +import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntSupplier; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.spark.InterruptibleIterator; +import org.apache.spark.Partition; +import org.apache.spark.SparkContext; +import org.apache.spark.TaskContext; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.Iterator; +import scala.reflect.ClassTag; + +public class BoundedDatasetFactory { + private BoundedDatasetFactory() {} + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}. + * + * <p>Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization. + * This makes this approach at the time being significantly less performant than creating a + * dataset from an RDD. + */ + public static <T> Dataset<WindowedValue<T>> createDatasetFromRows( + SparkSession session, + BoundedSource<T> source, + SerializablePipelineOptions options, + Encoder<WindowedValue<T>> encoder) { + Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + BeamTable<T> table = new BeamTable<>(source, params); + LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty()); + return Dataset.ofRows(session, logicalPlan).as(encoder); + } + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link RDD}. + * + * <p>This is currently the most efficient approach as it avoid any serialization overhead. + */ + public static <T> Dataset<WindowedValue<T>> createDatasetFromRDD( + SparkSession session, + BoundedSource<T> source, + SerializablePipelineOptions options, + Encoder<WindowedValue<T>> encoder) { + Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + RDD<WindowedValue<T>> rdd = new BoundedRDD<>(session.sparkContext(), source, params); + return session.createDataset(rdd, encoder); + } + + /** An {@link RDD} for a bounded Beam source. */ + private static class BoundedRDD<T> extends RDD<WindowedValue<T>> { Review Comment: Also the calls to Beam BoundedSource API are now clearer than when implemented via DataSource V2. Have you compared the perf of a simple source pipeline to check that there is no perf degradation imputable to this RDD based impl ? I guess Nexmark query0 (pardo with encode/decode roundtrip + metrics) that gives a good clue on that. ########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java: ########## @@ -17,74 +17,264 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; +import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.collect_list; +import static org.apache.spark.sql.functions.explode; +import static org.apache.spark.sql.functions.max; +import static org.apache.spark.sql.functions.min; +import static org.apache.spark.sql.functions.struct; + import java.io.Serializable; import org.apache.beam.runners.core.InMemoryStateInternals; -import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.ReduceFnRunner; import org.apache.beam.runners.core.StateInternalsFactory; import org.apache.beam.runners.core.SystemReduceFn; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; -import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.catalyst.expressions.CreateArray; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.Seq; +import scala.collection.immutable.List; +/** + * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the build-in aggregation + * function {@code collect_list} when applicable. Review Comment: good idea, avoiding materialization like with ReduceFnRunner and using a spark native instead is better because it allows spark to spill to disk instead of throwing OOM. ########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java: ########## @@ -17,98 +17,129 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; -import java.util.ArrayList; -import java.util.List; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; -import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; + +import java.util.Collection; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; import scala.Tuple2; +import scala.collection.TraversableOnce; -@SuppressWarnings({ - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) -class CombinePerKeyTranslatorBatch<K, InputT, AccumT, OutputT> - implements TransformTranslator< - PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> { +/** + * Translator for {@link Combine.PerKey} using {@link Dataset#groupByKey} with a Spark {@link + * Aggregator}. + * + * <ul> + * <li>When using the default global window, window information is dropped and restored after the + * aggregation. + * <li>For non-merging windows, windows are exploded and moved into a composite key for better + * distribution. After the aggregation, windowed values are restored from the composite key. + * <li>All other cases use an aggregator on windowed values that is optimized for the current + * windowing strategy. + * </ul> + * + * TODOs: + * <li>combine with context (CombineFnWithContext)? + * <li>combine with sideInputs? + * <li>other there other missing features? + */ +class CombinePerKeyTranslatorBatch<K, InT, AccT, OutT> Review Comment: this class LGTM modulo a wrong comment.Here aslo, clear code, thanks ########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java: ########## @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.io; + +import static java.util.stream.Collectors.toList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static scala.collection.JavaConverters.asScalaIterator; + +import java.io.Closeable; +import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntSupplier; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.spark.InterruptibleIterator; +import org.apache.spark.Partition; +import org.apache.spark.SparkContext; +import org.apache.spark.TaskContext; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.Iterator; +import scala.reflect.ClassTag; + +public class BoundedDatasetFactory { + private BoundedDatasetFactory() {} + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}. + * + * <p>Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization. + * This makes this approach at the time being significantly less performant than creating a + * dataset from an RDD. + */ + public static <T> Dataset<WindowedValue<T>> createDatasetFromRows( + SparkSession session, + BoundedSource<T> source, + SerializablePipelineOptions options, + Encoder<WindowedValue<T>> encoder) { + Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + BeamTable<T> table = new BeamTable<>(source, params); + LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty()); + return Dataset.ofRows(session, logicalPlan).as(encoder); + } + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link RDD}. + * + * <p>This is currently the most efficient approach as it avoid any serialization overhead. + */ + public static <T> Dataset<WindowedValue<T>> createDatasetFromRDD( + SparkSession session, + BoundedSource<T> source, + SerializablePipelineOptions options, + Encoder<WindowedValue<T>> encoder) { + Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + RDD<WindowedValue<T>> rdd = new BoundedRDD<>(session.sparkContext(), source, params); + return session.createDataset(rdd, encoder); + } + + /** An {@link RDD} for a bounded Beam source. */ + private static class BoundedRDD<T> extends RDD<WindowedValue<T>> { Review Comment: yes, good idea to extend RDD, totally goes around internal row and all the native spark source API ! Good not to use anymore the spark source api because it has changed a lot over time: first DatasourceV1 then DataSourceV2, then breaking changes in DataSourceV2 between spark 2 and spark 3 ... ########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java: ########## @@ -17,74 +17,264 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; +import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.collect_list; +import static org.apache.spark.sql.functions.explode; +import static org.apache.spark.sql.functions.max; +import static org.apache.spark.sql.functions.min; +import static org.apache.spark.sql.functions.struct; + import java.io.Serializable; import org.apache.beam.runners.core.InMemoryStateInternals; -import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.ReduceFnRunner; import org.apache.beam.runners.core.StateInternalsFactory; import org.apache.beam.runners.core.SystemReduceFn; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; -import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.catalyst.expressions.CreateArray; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.Seq; +import scala.collection.immutable.List; +/** + * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the build-in aggregation + * function {@code collect_list} when applicable. + * + * <p>Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the + * latter case the entire group (iterator) has to be loaded into memory as well, risking OOM errors + * in both cases. When disabling {@link #useCollectList}, a more memory sensitive iterable is used Review Comment: remove "in both cases" ########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java: ########## @@ -17,98 +17,129 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; -import java.util.ArrayList; -import java.util.List; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; -import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; + +import java.util.Collection; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; import scala.Tuple2; +import scala.collection.TraversableOnce; -@SuppressWarnings({ - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) -class CombinePerKeyTranslatorBatch<K, InputT, AccumT, OutputT> - implements TransformTranslator< - PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> { +/** + * Translator for {@link Combine.PerKey} using {@link Dataset#groupByKey} with a Spark {@link + * Aggregator}. + * + * <ul> + * <li>When using the default global window, window information is dropped and restored after the + * aggregation. + * <li>For non-merging windows, windows are exploded and moved into a composite key for better + * distribution. After the aggregation, windowed values are restored from the composite key. + * <li>All other cases use an aggregator on windowed values that is optimized for the current + * windowing strategy. + * </ul> + * + * TODOs: + * <li>combine with context (CombineFnWithContext)? + * <li>combine with sideInputs? + * <li>other there other missing features? + */ +class CombinePerKeyTranslatorBatch<K, InT, AccT, OutT> + extends GroupingTranslator<K, InT, OutT, Combine.PerKey<K, InT, OutT>> { @Override - public void translateTransform( - PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> transform, - AbstractTranslationContext context) { + public void translate(Combine.PerKey<K, InT, OutT> transform, Context cxt) { + WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy(); + CombineFn<InT, AccT, OutT> combineFn = (CombineFn<InT, AccT, OutT>) transform.getFn(); + + KvCoder<K, InT> inputCoder = (KvCoder<K, InT>) cxt.getInput().getCoder(); + KvCoder<K, OutT> outputCoder = (KvCoder<K, OutT>) cxt.getOutput().getCoder(); + + Encoder<K> keyEnc = cxt.keyEncoderOf(inputCoder); + Encoder<KV<K, InT>> inputEnc = cxt.encoderOf(inputCoder); + Encoder<WindowedValue<KV<K, OutT>>> wvOutputEnc = cxt.windowedEncoder(outputCoder); + Encoder<AccT> accumEnc = accumEncoder(combineFn, inputCoder.getValueCoder(), cxt); + + final Dataset<WindowedValue<KV<K, OutT>>> result; + + boolean globalGroupBy = eligibleForGlobalGroupBy(windowing, true); + boolean groupByWindow = eligibleForGroupByWindow(windowing, true); - Combine.PerKey combineTransform = (Combine.PerKey) transform; - @SuppressWarnings("unchecked") - final PCollection<KV<K, InputT>> input = (PCollection<KV<K, InputT>>) context.getInput(); - @SuppressWarnings("unchecked") - final PCollection<KV<K, OutputT>> output = (PCollection<KV<K, OutputT>>) context.getOutput(); - @SuppressWarnings("unchecked") - final Combine.CombineFn<InputT, AccumT, OutputT> combineFn = - (Combine.CombineFn<InputT, AccumT, OutputT>) combineTransform.getFn(); - WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); + if (globalGroupBy || groupByWindow) { + Aggregator<KV<K, InT>, ?, OutT> valueAgg = + Aggregators.value(combineFn, KV::getValue, accumEnc, cxt.valueEncoderOf(outputCoder)); - Dataset<WindowedValue<KV<K, InputT>>> inputDataset = context.getDataset(input); + if (globalGroupBy) { + // Drop window and group by key globally to run the aggregation (combineFn), afterwards the + // global window is restored + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .mapValues(value(), inputEnc) + .agg(valueAgg.toColumn()) + .map(globalKV(), wvOutputEnc); + } else { + Encoder<Tuple2<BoundedWindow, K>> windowedKeyEnc = windowedKeyEnc(keyEnc, cxt); - KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder(); - Coder<K> keyCoder = inputCoder.getKeyCoder(); - KvCoder<K, OutputT> outputKVCoder = (KvCoder<K, OutputT>) output.getCoder(); - Coder<OutputT> outputCoder = outputKVCoder.getValueCoder(); + // Group by window and key to run the aggregation (combineFn) + result = + cxt.getDataset(cxt.getInput()) + .flatMap(explodeWindowedKey(value()), cxt.tupleEncoder(windowedKeyEnc, inputEnc)) + .groupByKey(fun1(Tuple2::_1), windowedKeyEnc) + .mapValues(fun1(Tuple2::_2), inputEnc) + .agg(valueAgg.toColumn()) + .map(windowedKV(), wvOutputEnc); + } + } else { + // Use an optimized aggregator for session window fns Review Comment: not only sessions but merging and non-merging windows as well. ########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java: ########## @@ -17,74 +17,264 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; +import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.collect_list; +import static org.apache.spark.sql.functions.explode; +import static org.apache.spark.sql.functions.max; +import static org.apache.spark.sql.functions.min; +import static org.apache.spark.sql.functions.struct; + import java.io.Serializable; import org.apache.beam.runners.core.InMemoryStateInternals; -import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.ReduceFnRunner; import org.apache.beam.runners.core.StateInternalsFactory; import org.apache.beam.runners.core.SystemReduceFn; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; -import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.catalyst.expressions.CreateArray; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.Seq; +import scala.collection.immutable.List; +/** + * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the build-in aggregation + * function {@code collect_list} when applicable. + * + * <p>Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the + * latter case the entire group (iterator) has to be loaded into memory as well, risking OOM errors + * in both cases. When disabling {@link #useCollectList}, a more memory sensitive iterable is used + * that can be traversed just once. Attempting to traverse the iterable again will throw. + * + * <ul> + * <li>When using the default global window, window information is dropped and restored after the + * aggregation. + * <li>For non-merging windows, windows are exploded and moved into a composite key for better + * distribution. Though, to keep the amount of shuffled data low, this is only done if values + * are assigned to a single window or if there are only few keys and distributing data is + * important. After the aggregation, windowed values are restored from the composite key. + * <li>All other cases are implemented using the SDK {@link ReduceFnRunner}. + * </ul> + */ class GroupByKeyTranslatorBatch<K, V> - implements TransformTranslator< - PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>> { + extends GroupingTranslator<K, V, Iterable<V>, GroupByKey<K, V>> { + + /** Literal of binary encoded Pane info. */ + private static final Expression PANE_NO_FIRING = lit(toByteArray(NO_FIRING, PaneInfoCoder.of())); + + /** Defaults for value in single global window. */ + private static final List<Expression> GLOBAL_WINDOW_DETAILS = + windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY})); + + private boolean useCollectList = true; + + public GroupByKeyTranslatorBatch() {} + + public GroupByKeyTranslatorBatch(boolean useCollectList) { + this.useCollectList = useCollectList; + } @Override - public void translateTransform( - PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> transform, - AbstractTranslationContext context) { - - @SuppressWarnings("unchecked") - final PCollection<KV<K, V>> inputPCollection = (PCollection<KV<K, V>>) context.getInput(); - Dataset<WindowedValue<KV<K, V>>> input = context.getDataset(inputPCollection); - WindowingStrategy<?, ?> windowingStrategy = inputPCollection.getWindowingStrategy(); - KvCoder<K, V> kvCoder = (KvCoder<K, V>) inputPCollection.getCoder(); - Coder<V> valueCoder = kvCoder.getValueCoder(); - - // group by key only - Coder<K> keyCoder = kvCoder.getKeyCoder(); - KeyValueGroupedDataset<K, WindowedValue<KV<K, V>>> groupByKeyOnly = - input.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder)); - - // group also by windows - WindowedValue.FullWindowedValueCoder<KV<K, Iterable<V>>> outputCoder = - WindowedValue.FullWindowedValueCoder.of( - KvCoder.of(keyCoder, IterableCoder.of(valueCoder)), - windowingStrategy.getWindowFn().windowCoder()); - Dataset<WindowedValue<KV<K, Iterable<V>>>> output = - groupByKeyOnly.flatMapGroups( - new GroupAlsoByWindowViaOutputBufferFn<>( - windowingStrategy, - new InMemoryStateInternalsFactory<>(), - SystemReduceFn.buffering(valueCoder), - context.getSerializableOptions()), - EncoderHelpers.fromBeamCoder(outputCoder)); - - context.putDataset(context.getOutput(), output); + public void translate(GroupByKey<K, V> transform, Context cxt) { + WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy(); + TimestampCombiner tsCombiner = windowing.getTimestampCombiner(); + + Dataset<WindowedValue<KV<K, V>>> input = cxt.getDataset(cxt.getInput()); + + KvCoder<K, V> inputCoder = (KvCoder<K, V>) cxt.getInput().getCoder(); + KvCoder<K, Iterable<V>> outputCoder = (KvCoder<K, Iterable<V>>) cxt.getOutput().getCoder(); + + Encoder<V> valueEnc = cxt.valueEncoderOf(inputCoder); + Encoder<K> keyEnc = cxt.keyEncoderOf(inputCoder); + + // In batch we can ignore triggering and allowed lateness parameters Review Comment: in batch we can ignore allowed lateness but not triggers as triggers can be set in processing time in global window (See nexmark query 12 that sets a trigger repeatedly after processingTime pass the first element in pane + a given period) ########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java: ########## @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.io; + +import static java.util.stream.Collectors.toList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static scala.collection.JavaConverters.asScalaIterator; + +import java.io.Closeable; +import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntSupplier; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.spark.InterruptibleIterator; +import org.apache.spark.Partition; +import org.apache.spark.SparkContext; +import org.apache.spark.TaskContext; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.Iterator; +import scala.reflect.ClassTag; + +public class BoundedDatasetFactory { + private BoundedDatasetFactory() {} + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}. + * + * <p>Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization. + * This makes this approach at the time being significantly less performant than creating a + * dataset from an RDD. + */ + public static <T> Dataset<WindowedValue<T>> createDatasetFromRows( + SparkSession session, + BoundedSource<T> source, + SerializablePipelineOptions options, + Encoder<WindowedValue<T>> encoder) { + Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + BeamTable<T> table = new BeamTable<>(source, params); + LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty()); + return Dataset.ofRows(session, logicalPlan).as(encoder); + } + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link RDD}. + * + * <p>This is currently the most efficient approach as it avoid any serialization overhead. + */ + public static <T> Dataset<WindowedValue<T>> createDatasetFromRDD( + SparkSession session, + BoundedSource<T> source, + SerializablePipelineOptions options, + Encoder<WindowedValue<T>> encoder) { + Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + RDD<WindowedValue<T>> rdd = new BoundedRDD<>(session.sparkContext(), source, params); + return session.createDataset(rdd, encoder); Review Comment: It is way better than using the user source API as before ! Indeed, in addition to InternalRow serialization, before it required to pass the Beam source in String (with the ugly base 64 serialization :vomiting_face: ) as there was not way of passing anything else than Strings to the user source API. ########## runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java: ########## @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.io; + +import static java.util.stream.Collectors.toList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static scala.collection.JavaConverters.asScalaIterator; + +import java.io.Closeable; +import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntSupplier; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.spark.InterruptibleIterator; +import org.apache.spark.Partition; +import org.apache.spark.SparkContext; +import org.apache.spark.TaskContext; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.Iterator; +import scala.reflect.ClassTag; + +public class BoundedDatasetFactory { + private BoundedDatasetFactory() {} + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}. + * + * <p>Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization. Review Comment: Yeah, that was the same if using the user source api, we ended up having an InternalRow that required serialization -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
