Repository: beam Updated Branches: refs/heads/master c045b0ec2 -> 82b7b8613
[BEAM-649] Analyse DAG to determine if RDD/DStream has to be cached or not Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/daa10ddb Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/daa10ddb Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/daa10ddb Branch: refs/heads/master Commit: daa10ddbf7c1cc07962ab1bced0f14485f3739cb Parents: 9ac1ffc Author: Jean-Baptiste Onofré <[email protected]> Authored: Thu Mar 2 17:28:50 2017 +0100 Committer: Jean-Baptiste Onofré <[email protected]> Committed: Thu Mar 23 17:26:11 2017 +0100 ---------------------------------------------------------------------- .../apache/beam/runners/spark/SparkRunner.java | 64 ++++++++++++++++++-- .../spark/translation/BoundedDataset.java | 3 +- .../spark/translation/EvaluationContext.java | 61 ++++++++++++------- .../SparkRunnerStreamingContextFactory.java | 4 ++ .../apache/beam/runners/spark/CacheTest.java | 61 +++++++++++++++++++ .../spark/translation/StorageLevelTest.java | 6 +- 6 files changed, 168 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/daa10ddb/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java index de648fc..fc5d4af 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java @@ -38,6 +38,7 @@ import org.apache.beam.runners.spark.translation.TransformEvaluator; import org.apache.beam.runners.spark.translation.TransformTranslator; import org.apache.beam.runners.spark.translation.streaming.Checkpoint.CheckpointDir; import org.apache.beam.runners.spark.translation.streaming.SparkRunnerStreamingContextFactory; +import org.apache.beam.runners.spark.translation.streaming.StreamingTransformTranslator; import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.WatermarksListener; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.Read; @@ -90,6 +91,7 @@ import org.slf4j.LoggerFactory; public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { private static final Logger LOG = LoggerFactory.getLogger(SparkRunner.class); + /** * Options used in this pipeline runner. */ @@ -143,10 +145,14 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { final SparkPipelineResult result; final Future<?> startPipeline; + + final SparkPipelineTranslator translator; + final ExecutorService executorService = Executors.newSingleThreadExecutor(); MetricsEnvironment.setMetricsSupported(true); + // visit the pipeline to determine the translation mode detectTranslationMode(pipeline); if (mOptions.isStreaming()) { @@ -157,6 +163,11 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { JavaStreamingContext.getOrCreate(checkpointDir.getSparkCheckpointDir().toString(), contextFactory); + // update cache candidates + translator = new StreamingTransformTranslator.Translator( + new TransformTranslator.Translator()); + updateCacheCandidates(pipeline, translator, contextFactory.getEvaluationContext()); + // Checkpoint aggregator/metrics values jssc.addStreamingListener( new JavaStreamingListenerWrapper( @@ -191,8 +202,13 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { result = new SparkPipelineResult.StreamingMode(startPipeline, jssc); } else { + // create the evaluation context final JavaSparkContext jsc = SparkContextFactory.getSparkContext(mOptions); final EvaluationContext evaluationContext = new EvaluationContext(jsc, pipeline); + translator = new TransformTranslator.Translator(); + + // update the cache candidates + updateCacheCandidates(pipeline, translator, evaluationContext); initAccumulators(mOptions, jsc); @@ -200,8 +216,7 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { @Override public void run() { - pipeline.traverseTopologically(new Evaluator(new TransformTranslator.Translator(), - evaluationContext)); + pipeline.traverseTopologically(new Evaluator(translator, evaluationContext)); evaluationContext.computeOutputs(); LOG.info("Batch pipeline execution complete."); } @@ -240,9 +255,7 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { } /** - * Detect the translation mode for the pipeline and change options in case streaming - * translation is needed. - * @param pipeline + * Visit the pipeline to determine the translation mode (batch/streaming). */ private void detectTranslationMode(Pipeline pipeline) { TranslationModeDetector detector = new TranslationModeDetector(); @@ -254,6 +267,17 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { } /** + * Evaluator that update/populate the cache candidates. + */ + private void updateCacheCandidates( + Pipeline pipeline, + SparkPipelineTranslator translator, + EvaluationContext evaluationContext) { + CacheVisitor cacheVisitor = new CacheVisitor(translator, evaluationContext); + pipeline.traverseTopologically(cacheVisitor); + } + + /** * The translation mode of the Beam Pipeline. */ enum TranslationMode { @@ -298,6 +322,36 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { } /** + * Traverses the pipeline to populate the candidates for caching. + */ + static class CacheVisitor extends Evaluator { + + protected CacheVisitor( + SparkPipelineTranslator translator, + EvaluationContext evaluationContext) { + super(translator, evaluationContext); + } + + @Override + public void doVisitTransform(TransformHierarchy.Node node) { + // we populate cache candidates by updating the map with inputs of each node. + // The goal is to detect the PCollections accessed more than one time, and so enable cache + // on the underlying RDDs or DStreams. + + for (TaggedPValue input : node.getInputs()) { + PValue value = input.getValue(); + if (value instanceof PCollection) { + long count = 1L; + if (ctxt.getCacheCandidates().get(value) != null) { + count = ctxt.getCacheCandidates().get(value) + 1; + } + ctxt.getCacheCandidates().put((PCollection) value, count); + } + } + } + } + + /** * Evaluator on the pipeline. */ @SuppressWarnings("WeakerAccess") http://git-wip-us.apache.org/repos/asf/beam/blob/daa10ddb/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java index 6e4ffc7..652c753 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java @@ -99,7 +99,8 @@ public class BoundedDataset<T> implements Dataset { @Override public void cache(String storageLevel) { - rdd.persist(StorageLevel.fromString(storageLevel)); + // populate the rdd if needed + getRDD().persist(StorageLevel.fromString(storageLevel)); } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/daa10ddb/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 329e047..643749d 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -21,12 +21,14 @@ package org.apache.beam.runners.spark.translation; import static com.google.common.base.Preconditions.checkArgument; import com.google.common.collect.Iterables; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.AppliedPTransform; @@ -52,10 +54,10 @@ public class EvaluationContext { private final Map<PValue, Dataset> datasets = new LinkedHashMap<>(); private final Map<PValue, Dataset> pcollections = new LinkedHashMap<>(); private final Set<Dataset> leaves = new LinkedHashSet<>(); - private final Set<PValue> multiReads = new LinkedHashSet<>(); private final Map<PValue, Object> pobjects = new LinkedHashMap<>(); private AppliedPTransform<?, ?, ?> currentTransform; private final SparkPCollectionView pviews = new SparkPCollectionView(); + private final Map<PCollection, Long> cacheCandidates = new HashMap<>(); public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) { this.jsc = jsc; @@ -116,6 +118,15 @@ public class EvaluationContext { return currentTransform.getOutputs(); } + private boolean shouldCache(PValue pvalue) { + if ((pvalue instanceof PCollection) + && cacheCandidates.containsKey(pvalue) + && cacheCandidates.get(pvalue) > 1) { + return true; + } + return false; + } + public void putDataset(PTransform<?, ? extends PValue> transform, Dataset dataset) { putDataset(getOutput(transform), dataset); } @@ -126,13 +137,30 @@ public class EvaluationContext { } catch (IllegalStateException e) { // name not set, ignore } + if (shouldCache(pvalue)) { + dataset.cache(storageLevel()); + } datasets.put(pvalue, dataset); leaves.add(dataset); } <T> void putBoundedDatasetFromValues( PTransform<?, ? extends PValue> transform, Iterable<T> values, Coder<T> coder) { - datasets.put(getOutput(transform), new BoundedDataset<>(values, jsc, coder)); + PValue output = getOutput(transform); + if (shouldCache(output)) { + // eagerly create the RDD, as it will be reused. + Iterable<WindowedValue<T>> elems = Iterables.transform(values, + WindowingHelpers.<T>windowValueFunction()); + WindowedValue.ValueOnlyWindowedValueCoder<T> windowCoder = + WindowedValue.getValueOnlyCoder(coder); + JavaRDD<WindowedValue<T>> rdd = + getSparkContext().parallelize(CoderHelpers.toByteArrays(elems, windowCoder)) + .map(CoderHelpers.fromByteFunction(windowCoder)); + putDataset(transform, new BoundedDataset<>(rdd)); + } else { + // create a BoundedDataset that would create a RDD on demand + datasets.put(getOutput(transform), new BoundedDataset<>(values, jsc, coder)); + } } public Dataset borrowDataset(PTransform<? extends PValue, ?> transform) { @@ -142,12 +170,6 @@ public class EvaluationContext { public Dataset borrowDataset(PValue pvalue) { Dataset dataset = datasets.get(pvalue); leaves.remove(dataset); - if (multiReads.contains(pvalue)) { - // Ensure the RDD is marked as cached - dataset.cache(storageLevel()); - } else { - multiReads.add(pvalue); - } return dataset; } @@ -157,8 +179,6 @@ public class EvaluationContext { */ public void computeOutputs() { for (Dataset dataset : leaves) { - // cache so that any subsequent get() is cheap. - dataset.cache(storageLevel()); dataset.action(); // force computation. } } @@ -186,18 +206,6 @@ public class EvaluationContext { } /** - * Retrieves an iterable of results associated with the PCollection passed in. - * - * @param pcollection Collection we wish to translate. - * @param <T> Type of elements contained in collection. - * @return Natively types result associated with collection. - */ - <T> Iterable<T> get(PCollection<T> pcollection) { - Iterable<WindowedValue<T>> windowedValues = getWindowedValues(pcollection); - return Iterables.transform(windowedValues, WindowingHelpers.<T>unwindowValueFunction()); - } - - /** * Retrun the current views creates in the pipepline. * * @return SparkPCollectionView @@ -220,6 +228,15 @@ public class EvaluationContext { pviews.putPView(view, value, coder); } + /** + * Get the map of cache candidates hold by the evaluation context. + * + * @return The current {@link Map} of cache candidates. + */ + public Map<PCollection, Long> getCacheCandidates() { + return this.cacheCandidates; + } + <T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) { @SuppressWarnings("unchecked") BoundedDataset<T> boundedDataset = (BoundedDataset<T>) datasets.get(pcollection); http://git-wip-us.apache.org/repos/asf/beam/blob/daa10ddb/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java index 7048be6..c298886 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/SparkRunnerStreamingContextFactory.java @@ -91,6 +91,10 @@ public class SparkRunnerStreamingContextFactory implements JavaStreamingContextF return jssc; } + public EvaluationContext getEvaluationContext() { + return this.ctxt; + } + private void checkpoint(JavaStreamingContext jssc) { Path rootCheckpointPath = checkpointDir.getRootCheckpointDir(); Path sparkCheckpointPath = checkpointDir.getSparkCheckpointDir(); http://git-wip-us.apache.org/repos/asf/beam/blob/daa10ddb/runners/spark/src/test/java/org/apache/beam/runners/spark/CacheTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/CacheTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/CacheTest.java new file mode 100644 index 0000000..c3b48d8 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/CacheTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark; + +import static org.junit.Assert.assertEquals; + +import org.apache.beam.runners.spark.translation.EvaluationContext; +import org.apache.beam.runners.spark.translation.SparkContextFactory; +import org.apache.beam.runners.spark.translation.TransformTranslator; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.apache.spark.api.java.JavaSparkContext; +import org.junit.Rule; +import org.junit.Test; + +/** + * This test checks how the cache candidates map is populated by the runner when evaluating the + * pipeline. + */ +public class CacheTest { + + @Rule + public final transient PipelineRule pipelineRule = PipelineRule.batch(); + + @Test + public void cacheCandidatesUpdaterTest() throws Exception { + Pipeline pipeline = pipelineRule.createPipeline(); + PCollection<String> pCollection = pipeline.apply(Create.of("foo", "bar")); + // first read + pCollection.apply(Count.<String>globally()); + // second read + // as we access the same PCollection two times, the Spark runner does optimization and so + // will cache the RDD representing this PCollection + pCollection.apply(Count.<String>globally()); + + JavaSparkContext jsc = SparkContextFactory.getSparkContext(pipelineRule.getOptions()); + EvaluationContext ctxt = new EvaluationContext(jsc, pipeline); + SparkRunner.CacheVisitor cacheVisitor = + new SparkRunner.CacheVisitor(new TransformTranslator.Translator(), ctxt); + pipeline.traverseTopologically(cacheVisitor); + assertEquals(2L, (long) ctxt.getCacheCandidates().get(pCollection)); + } + +} http://git-wip-us.apache.org/repos/asf/beam/blob/daa10ddb/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java index 4dc5dee..2b7b87b 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java @@ -37,9 +37,9 @@ public class StorageLevelTest { @Test public void test() throws Exception { pipelineRule.getOptions().setStorageLevel("DISK_ONLY"); - Pipeline p = pipelineRule.createPipeline(); + Pipeline pipeline = pipelineRule.createPipeline(); - PCollection<String> pCollection = p.apply(Create.of("foo")); + PCollection<String> pCollection = pipeline.apply(Create.of("foo")); // by default, the Spark runner doesn't cache the RDD if it accessed only one time. // So, to "force" the caching of the RDD, we have to call the RDD at least two time. @@ -50,7 +50,7 @@ public class StorageLevelTest { PAssert.thatSingleton(output).isEqualTo("Disk Serialized 1x Replicated"); - p.run(); + pipeline.run(); } }
