This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f49b450fc15 Evaluate removal of RDD caching for MEMORY_ONLY in the 
Spark Dataset runner (#25327)
f49b450fc15 is described below

commit f49b450fc1535c5d9bb191ecf200f04e69fa0de1
Author: Moritz Mack <[email protected]>
AuthorDate: Fri Mar 3 15:27:50 2023 +0100

    Evaluate removal of RDD caching for MEMORY_ONLY in the Spark Dataset runner 
(#25327)
---
 .../translation/PipelineTranslator.java            | 18 +-----
 .../translation/batch/ParDoTranslatorBatch.java    | 65 ++++++----------------
 2 files changed, 19 insertions(+), 64 deletions(-)

diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
index ea8441b0bf1..75fd6353123 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
@@ -21,7 +21,6 @@ import static 
org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_
 import static 
org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
 import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
 import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
-import static org.apache.spark.storage.StorageLevel.MEMORY_ONLY;
 
 import java.io.IOException;
 import java.io.Serializable;
@@ -231,7 +230,6 @@ public abstract class PipelineTranslator {
     private final PipelineOptions options;
     private final Supplier<PipelineOptions> optionsSupplier;
     private final StorageLevel storageLevel;
-    private final boolean isMemoryOnly;
 
     private final Set<TranslationResult<?>> leaves;
 
@@ -244,7 +242,6 @@ public abstract class PipelineTranslator {
       this.options = options;
       this.optionsSupplier = new BroadcastOptions(sparkSession, options);
       this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
-      this.isMemoryOnly = storageLevel.equals(MEMORY_ONLY());
       this.encoders = new HashMap<>();
       this.leaves = new HashSet<>();
     }
@@ -294,18 +291,9 @@ public abstract class PipelineTranslator {
       TranslationResult<T> result = getResult(pCollection);
       result.dataset = dataset;
 
-      if (!cache && isMemoryOnly) {
-        result.resetPlanComplexity(); // cached as RDD in memory which breaks 
linage
-      } else if (cache && result.usages() > 1) {
-        if (isMemoryOnly) {
-          // Cache as RDD in-memory only, this helps to also break linage of 
complex query plans.
-          LOG.info("Dataset {} will be cached in-memory as RDD for reuse.", 
result.name);
-          result.dataset = sparkSession.createDataset(dataset.rdd().persist(), 
dataset.encoder());
-          result.resetPlanComplexity();
-        } else {
-          LOG.info("Dataset {} will be cached for reuse.", result.name);
-          dataset.persist(storageLevel); // use NONE to disable
-        }
+      if (cache && result.usages() > 1) {
+        LOG.info("Dataset {} will be cached for reuse.", result.name);
+        dataset.persist(storageLevel); // use NONE to disable
       }
 
       if (result.estimatePlanComplexity() > PLAN_COMPLEXITY_THRESHOLD) {
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
index 40d26be8a8b..da40e4c9c50 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
@@ -23,7 +23,6 @@ import static 
org.apache.beam.runners.spark.structuredstreaming.translation.util
 import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
 import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 import static org.apache.spark.sql.functions.col;
-import static org.apache.spark.storage.StorageLevel.MEMORY_ONLY;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -51,14 +50,12 @@ import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
 import org.apache.spark.broadcast.Broadcast;
-import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.TypedColumn;
 import org.apache.spark.storage.StorageLevel;
 import scala.Tuple2;
 import scala.collection.TraversableOnce;
-import scala.reflect.ClassTag;
 
 /**
  * Translator for {@link ParDo.MultiOutput} based on {@link 
DoFnRunners#simpleRunner}.
@@ -73,12 +70,6 @@ class ParDoTranslatorBatch<InputT, OutputT>
     extends TransformTranslator<
         PCollection<? extends InputT>, PCollectionTuple, 
ParDo.MultiOutput<InputT, OutputT>> {
 
-  private static final ClassTag<WindowedValue<Object>> WINDOWED_VALUE_CTAG =
-      ClassTag.apply(WindowedValue.class);
-
-  private static final ClassTag<Tuple2<Integer, WindowedValue<Object>>> 
TUPLE2_CTAG =
-      ClassTag.apply(Tuple2.class);
-
   ParDoTranslatorBatch() {
     super(0.2f);
   }
@@ -148,46 +139,22 @@ class ParDoTranslatorBatch<InputT, OutputT>
 
       SparkCommonPipelineOptions opts = 
cxt.getOptions().as(SparkCommonPipelineOptions.class);
       StorageLevel storageLevel = 
StorageLevel.fromString(opts.getStorageLevel());
-      // If using storage level MEMORY_ONLY, it's best to persist the dataset 
as RDD to avoid any
-      // serialization / use of encoders. Persisting a Dataset, even if using 
a "deserialized"
-      // storage level, involves converting the data to the internal 
representation (InternalRow)
-      // by use of an encoder.
-      // For any other storage level, persist as Dataset, so we can select 
columns by TupleTag
-      // individually without restoring the entire row.
-      // In both cases caching of the outputs in the translation context is 
disabled to avoid
-      // caching the same data twice.
-      if (MEMORY_ONLY().equals(storageLevel)) {
-
-        RDD<Tuple2<Integer, WindowedValue<Object>>> allTagsRDD =
-            inputDs.rdd().mapPartitions(doFnMapper, false, TUPLE2_CTAG);
-        allTagsRDD.persist();
-
-        // divide into separate output datasets per tag
-        for (TupleTag<?> tag : outputs.keySet()) {
-          int colIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown 
tag");
-          RDD<WindowedValue<Object>> rddByTag =
-              allTagsRDD.flatMap(selectByColumnIdx(colIdx), 
WINDOWED_VALUE_CTAG);
-          cxt.putDataset(
-              cxt.getOutput((TupleTag) tag),
-              cxt.getSparkSession().createDataset(rddByTag, 
encoders.get(colIdx)),
-              false);
-        }
-      } else {
-        // Persist as wide rows with one column per TupleTag to support 
different schemas
-        Dataset<Tuple2<Integer, WindowedValue<Object>>> allTagsDS =
-            inputDs.mapPartitions(doFnMapper, oneOfEncoder(encoders));
-        allTagsDS.persist(storageLevel);
-
-        // divide into separate output datasets per tag
-        for (TupleTag<?> tag : outputs.keySet()) {
-          int colIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown 
tag");
-          // Resolve specific column matching the tuple tag (by id)
-          TypedColumn<Tuple2<Integer, WindowedValue<Object>>, 
WindowedValue<Object>> col =
-              (TypedColumn) 
col(Integer.toString(colIdx)).as(encoders.get(colIdx));
-
-          cxt.putDataset(
-              cxt.getOutput((TupleTag) tag), 
allTagsDS.filter(col.isNotNull()).select(col), false);
-        }
+
+      // Persist as wide rows with one column per TupleTag to support 
different schemas
+      Dataset<Tuple2<Integer, WindowedValue<Object>>> allTagsDS =
+          inputDs.mapPartitions(doFnMapper, oneOfEncoder(encoders));
+      allTagsDS.persist(storageLevel);
+
+      // divide into separate output datasets per tag
+      for (TupleTag<?> tag : outputs.keySet()) {
+        int colIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown 
tag");
+        // Resolve specific column matching the tuple tag (by id)
+        TypedColumn<Tuple2<Integer, WindowedValue<Object>>, 
WindowedValue<Object>> col =
+            (TypedColumn) 
col(Integer.toString(colIdx)).as(encoders.get(colIdx));
+
+        // Caching of the returned outputs is disabled to avoid caching the 
same data twice.
+        cxt.putDataset(
+            cxt.getOutput((TupleTag) tag), 
allTagsDS.filter(col.isNotNull()).select(col), false);
       }
     } else {
       PCollection<OutputT> output = cxt.getOutput(mainOut);

Reply via email to