This is an automated email from the ASF dual-hosted git repository. goenka pushed a commit to branch spark-cache-stage in repository https://gitbox.apache.org/repos/asf/beam.git
commit 4cb97fb3cf08ba9e2a57f8ca5f1a32496d90afbe Author: Robert Bradshaw <rober...@gmail.com> AuthorDate: Fri May 24 10:49:00 2019 +0200 Merge pull request #8558 [BEAM-7131] Spark: cache output to prevent re-computation --- .../SparkBatchPortablePipelineTranslator.java | 94 +++++++++++---- .../spark/translation/SparkTranslationContext.java | 24 +++- .../runners/spark/SparkPortableExecutionTest.java | 126 +++++++++++++++++++-- 3 files changed, 211 insertions(+), 33 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java index 8e7796f..0180496 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java @@ -64,6 +64,7 @@ import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.storage.StorageLevel; import scala.Tuple2; /** Translates a bounded portable pipeline into a Spark job. */ @@ -112,6 +113,24 @@ public class SparkBatchPortablePipelineTranslator { QueryablePipeline.forTransforms( pipeline.getRootTransformIdsList(), pipeline.getComponents()); for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) { + // Pre-scan pipeline to count which pCollections are consumed as inputs more than once so + // their corresponding RDDs can later be cached. + for (String inputId : transformNode.getTransform().getInputsMap().values()) { + context.incrementConsumptionCountBy(inputId, 1); + } + // Executable stage consists of two parts: computation and extraction. This means the result + // of computation is an intermediate RDD, which we might also need to cache. + if (transformNode.getTransform().getSpec().getUrn().equals(ExecutableStage.URN)) { + context.incrementConsumptionCountBy( + getExecutableStageIntermediateId(transformNode), + transformNode.getTransform().getOutputsMap().size()); + } + for (String outputId : transformNode.getTransform().getOutputsMap().values()) { + WindowedValueCoder outputCoder = getWindowedValueCoder(outputId, pipeline.getComponents()); + context.putCoder(outputId, outputCoder); + } + } + for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) { urnToTransformTranslator .getOrDefault( transformNode.getTransform().getSpec().getUrn(), @@ -141,18 +160,9 @@ public class SparkBatchPortablePipelineTranslator { RunnerApi.Components components = pipeline.getComponents(); String inputId = getInputId(transformNode); - PCollection inputPCollection = components.getPcollectionsOrThrow(inputId); Dataset inputDataset = context.popDataset(inputId); JavaRDD<WindowedValue<KV<K, V>>> inputRdd = ((BoundedDataset<KV<K, V>>) inputDataset).getRDD(); - PCollectionNode inputPCollectionNode = PipelineNode.pCollection(inputId, inputPCollection); - WindowedValueCoder<KV<K, V>> inputCoder; - try { - inputCoder = - (WindowedValueCoder) - WireCoders.instantiateRunnerWireCoder(inputPCollectionNode, components); - } catch (IOException e) { - throw new RuntimeException(e); - } + WindowedValueCoder<KV<K, V>> inputCoder = getWindowedValueCoder(inputId, components); KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) inputCoder.getValueCoder(); Coder<K> inputKeyCoder = inputKvCoder.getKeyCoder(); Coder<V> inputValueCoder = inputKvCoder.getValueCoder(); @@ -200,18 +210,18 @@ public class SparkBatchPortablePipelineTranslator { Dataset inputDataset = context.popDataset(inputPCollectionId); JavaRDD<WindowedValue<InputT>> inputRdd = ((BoundedDataset<InputT>) inputDataset).getRDD(); Map<String, String> outputs = transformNode.getTransform().getOutputsMap(); - BiMap<String, Integer> outputMap = createOutputMap(outputs.values()); + BiMap<String, Integer> outputExtractionMap = createOutputMap(outputs.values()); ImmutableMap.Builder<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> broadcastVariablesBuilder = ImmutableMap.builder(); for (SideInputId sideInputId : stagePayload.getSideInputsList()) { - RunnerApi.Components components = stagePayload.getComponents(); + RunnerApi.Components stagePayloadComponents = stagePayload.getComponents(); String collectionId = - components + stagePayloadComponents .getTransformsOrThrow(sideInputId.getTransformId()) .getInputsOrThrow(sideInputId.getLocalName()); Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 = - broadcastSideInput(collectionId, components, context); + broadcastSideInput(collectionId, stagePayloadComponents, context); broadcastVariablesBuilder.put(collectionId, tuple2); } @@ -219,14 +229,38 @@ public class SparkBatchPortablePipelineTranslator { new SparkExecutableStageFunction<>( stagePayload, context.jobInfo, - outputMap, + outputExtractionMap, broadcastVariablesBuilder.build(), MetricsAccumulator.getInstance()); JavaRDD<RawUnionValue> staged = inputRdd.mapPartitions(function); + String intermediateId = getExecutableStageIntermediateId(transformNode); + context.pushDataset( + intermediateId, + new Dataset() { + @Override + public void cache(String storageLevel, Coder<?> coder) { + StorageLevel level = StorageLevel.fromString(storageLevel); + staged.persist(level); + } + + @Override + public void action() { + // Empty function to force computation of RDD. + staged.foreach(TranslationUtils.emptyVoidFunction()); + } + + @Override + public void setName(String name) { + staged.setName(name); + } + }); + // pop dataset to mark RDD as used + context.popDataset(intermediateId); for (String outputId : outputs.values()) { JavaRDD<WindowedValue<OutputT>> outputRdd = - staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputMap.get(outputId))); + staged.flatMap( + new SparkExecutableStageExtractionFunction<>(outputExtractionMap.get(outputId))); context.pushDataset(outputId, new BoundedDataset<>(outputRdd)); } if (outputs.isEmpty()) { @@ -249,17 +283,9 @@ public class SparkBatchPortablePipelineTranslator { */ private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<T>> broadcastSideInput( String collectionId, RunnerApi.Components components, SparkTranslationContext context) { - PCollection collection = components.getPcollectionsOrThrow(collectionId); @SuppressWarnings("unchecked") BoundedDataset<T> dataset = (BoundedDataset<T>) context.popDataset(collectionId); - PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, collection); - WindowedValueCoder<T> coder; - try { - coder = - (WindowedValueCoder<T>) WireCoders.instantiateRunnerWireCoder(collectionNode, components); - } catch (IOException e) { - throw new RuntimeException(e); - } + WindowedValueCoder<T> coder = getWindowedValueCoder(collectionId, components); List<byte[]> bytes = dataset.getBytes(coder); Broadcast<List<byte[]>> broadcast = context.getSparkContext().broadcast(bytes); return new Tuple2<>(broadcast, coder); @@ -324,4 +350,22 @@ public class SparkBatchPortablePipelineTranslator { private static String getOutputId(PTransformNode transformNode) { return Iterables.getOnlyElement(transformNode.getTransform().getOutputsMap().values()); } + + private static <T> WindowedValueCoder<T> getWindowedValueCoder( + String pCollectionId, RunnerApi.Components components) { + PCollection pCollection = components.getPcollectionsOrThrow(pCollectionId); + PCollectionNode pCollectionNode = PipelineNode.pCollection(pCollectionId, pCollection); + WindowedValueCoder<T> coder; + try { + coder = + (WindowedValueCoder) WireCoders.instantiateRunnerWireCoder(pCollectionNode, components); + } catch (IOException e) { + throw new RuntimeException(e); + } + return coder; + } + + private static String getExecutableStageIntermediateId(PTransformNode transformNode) { + return transformNode.getId(); + } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java index 772e0d2..8c2cee8 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java @@ -17,12 +17,16 @@ */ package org.apache.beam.runners.spark.translation; +import com.sun.istack.Nullable; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.fnexecution.provisioning.JobInfo; +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.spark.api.java.JavaSparkContext; @@ -33,6 +37,9 @@ import org.apache.spark.api.java.JavaSparkContext; public class SparkTranslationContext { private final JavaSparkContext jsc; final JobInfo jobInfo; + // Map pCollection IDs to the number of times they are consumed as inputs. + private final Map<String, Integer> consumptionCount = new HashMap<>(); + private final Map<String, Coder> coderMap = new HashMap<>(); private final Map<String, Dataset> datasets = new LinkedHashMap<>(); private final Set<Dataset> leaves = new LinkedHashSet<>(); final SerializablePipelineOptions serializablePipelineOptions; @@ -51,7 +58,13 @@ public class SparkTranslationContext { /** Add output of transform to context. */ public void pushDataset(String pCollectionId, Dataset dataset) { dataset.setName(pCollectionId); - // TODO cache + SparkPipelineOptions sparkOptions = + serializablePipelineOptions.get().as(SparkPipelineOptions.class); + if (!sparkOptions.isCacheDisabled() && consumptionCount.getOrDefault(pCollectionId, 0) > 1) { + String storageLevel = sparkOptions.getStorageLevel(); + @Nullable Coder coder = coderMap.get(pCollectionId); + dataset.cache(storageLevel, coder); + } datasets.put(pCollectionId, dataset); leaves.add(dataset); } @@ -70,6 +83,15 @@ public class SparkTranslationContext { } } + void incrementConsumptionCountBy(String pCollectionId, int addend) { + int count = consumptionCount.getOrDefault(pCollectionId, 0); + consumptionCount.put(pCollectionId, count + addend); + } + + void putCoder(String pCollectionId, Coder coder) { + coderMap.put(pCollectionId, coder); + } + /** Generate a unique pCollection id number to identify runner-generated sinks. */ public int nextSinkId() { return sinkId++; diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java index d7d3428..eb9dce0 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java @@ -17,9 +17,11 @@ */ package org.apache.beam.runners.spark; +import java.io.File; import java.io.Serializable; +import java.nio.file.FileSystems; import java.util.Collections; -import java.util.concurrent.Executors; +import java.util.UUID; import java.util.concurrent.TimeUnit; import org.apache.beam.model.jobmanagement.v1.JobApi.JobState.Enum; import org.apache.beam.model.pipeline.v1.RunnerApi; @@ -46,8 +48,11 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableLis import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.ListeningExecutorService; import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.MoreExecutors; import org.junit.AfterClass; +import org.junit.Assert; import org.junit.BeforeClass; +import org.junit.ClassRule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,15 +61,13 @@ import org.slf4j.LoggerFactory; */ public class SparkPortableExecutionTest implements Serializable { + @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder(); private static final Logger LOG = LoggerFactory.getLogger(SparkPortableExecutionTest.class); - private static ListeningExecutorService sparkJobExecutor; @BeforeClass public static void setup() { - // Restrict this to only one thread to avoid multiple Spark clusters up at the same time - // which is not suitable for memory-constraint environments, i.e. Jenkins. - sparkJobExecutor = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); + sparkJobExecutor = MoreExecutors.newDirectExecutorService(); } @AfterClass @@ -159,8 +162,117 @@ public class SparkPortableExecutionTest implements Serializable { pipelineProto, options.as(SparkPipelineOptions.class)); jobInvocation.start(); - while (jobInvocation.getState() != Enum.DONE) { - Thread.sleep(1000); + Assert.assertEquals(Enum.DONE, jobInvocation.getState()); + } + + /** + * Verifies that each executable stage runs exactly once, even if that executable stage has + * multiple immediate outputs. While re-computation may be necessary in the event of failure, + * re-computation of a whole executable stage is expensive and can cause unexpected behavior when + * the executable stage has side effects (BEAM-7131). + * + * <pre> + * |-> B -> GBK + * A -| + * |-> C -> GBK + * </pre> + */ + @Test(timeout = 120_000) + public void testExecStageWithMultipleOutputs() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + options.setRunner(CrashingRunner.class); + options + .as(PortablePipelineOptions.class) + .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED); + Pipeline pipeline = Pipeline.create(options); + PCollection<KV<String, String>> a = + pipeline + .apply("impulse", Impulse.create()) + .apply("A", ParDo.of(new DoFnWithSideEffect<>("A"))); + PCollection<KV<String, String>> b = a.apply("B", ParDo.of(new DoFnWithSideEffect<>("B"))); + PCollection<KV<String, String>> c = a.apply("C", ParDo.of(new DoFnWithSideEffect<>("C"))); + // Use GBKs to force re-computation of executable stage unless cached. + b.apply(GroupByKey.create()); + c.apply(GroupByKey.create()); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline); + JobInvocation jobInvocation = + SparkJobInvoker.createJobInvocation( + "testExecStageWithMultipleOutputs", + "testExecStageWithMultipleOutputsRetrievalToken", + sparkJobExecutor, + pipelineProto, + options.as(SparkPipelineOptions.class)); + jobInvocation.start(); + Assert.assertEquals(Enum.DONE, jobInvocation.getState()); + } + + /** + * Verifies that each executable stage runs exactly once, even if that executable stage has + * multiple downstream consumers. While re-computation may be necessary in the event of failure, + * re-computation of a whole executable stage is expensive and can cause unexpected behavior when + * the executable stage has side effects (BEAM-7131). + * + * <pre> + * |-> G + * F -> GBK -| + * |-> H + * </pre> + */ + @Test(timeout = 120_000) + public void testExecStageWithMultipleConsumers() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + options.setRunner(CrashingRunner.class); + options + .as(PortablePipelineOptions.class) + .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED); + Pipeline pipeline = Pipeline.create(options); + PCollection<KV<String, Iterable<String>>> f = + pipeline + .apply("impulse", Impulse.create()) + .apply("F", ParDo.of(new DoFnWithSideEffect<>("F"))) + // use GBK to prevent fusion of F, G, and H + .apply(GroupByKey.create()); + f.apply("G", ParDo.of(new DoFnWithSideEffect<>("G"))); + f.apply("H", ParDo.of(new DoFnWithSideEffect<>("H"))); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline); + JobInvocation jobInvocation = + SparkJobInvoker.createJobInvocation( + "testExecStageWithMultipleConsumers", + "testExecStageWithMultipleConsumersRetrievalToken", + sparkJobExecutor, + pipelineProto, + options.as(SparkPipelineOptions.class)); + jobInvocation.start(); + Assert.assertEquals(Enum.DONE, jobInvocation.getState()); + } + + /** A non-idempotent DoFn that cannot be run more than once without error. */ + private class DoFnWithSideEffect<InputT> extends DoFn<InputT, KV<String, String>> { + + private final String name; + private final File file; + + DoFnWithSideEffect(String name) { + this.name = name; + String path = + FileSystems.getDefault() + .getPath( + temporaryFolder.getRoot().getAbsolutePath(), + String.format("%s-%s", this.name, UUID.randomUUID().toString())) + .toString(); + file = new File(path); + } + + @ProcessElement + public void process(ProcessContext context) throws Exception { + context.output(KV.of(name, name)); + // Verify this DoFn has not run more than once by enacting a side effect via the local file + // system. + Assert.assertTrue( + String.format( + "Create file %s failed (DoFn %s should only have been run once).", + file.getAbsolutePath(), name), + file.createNewFile()); } } }