Redistributed some responsibilities in order to remove getAggregatorValues() form EvaluationContext.
Inferred excepted exception handling according to existing codebase and tests. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/158378f0 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/158378f0 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/158378f0 Branch: refs/heads/master Commit: 158378f0f682b80462b917002b895ddbf782d06d Parents: b1a6793 Author: Stas Levin <[email protected]> Authored: Sat Dec 3 00:47:39 2016 +0200 Committer: Sela <[email protected]> Committed: Mon Dec 5 12:56:41 2016 +0200 ---------------------------------------------------------------------- .../beam/runners/spark/SparkPipelineResult.java | 76 ++++++++------- .../apache/beam/runners/spark/SparkRunner.java | 35 +++++-- .../beam/runners/spark/TestSparkRunner.java | 1 + .../spark/aggregators/AccumulatorSingleton.java | 6 +- .../spark/aggregators/SparkAggregators.java | 97 ++++++++++++++++++++ .../spark/translation/EvaluationContext.java | 20 +--- .../spark/translation/SparkRuntimeContext.java | 62 +------------ .../spark/translation/TransformTranslator.java | 10 +- .../streaming/StreamingTransformTranslator.java | 10 +- .../runners/spark/SparkPipelineStateTest.java | 36 ++++---- .../spark/aggregators/ClearAggregatorsRule.java | 37 ++++++++ .../metrics/sink/ClearAggregatorsRule.java | 33 ------- .../metrics/sink/NamedAggregatorsTest.java | 1 + .../streaming/EmptyStreamAssertionTest.java | 2 +- .../ResumeFromCheckpointStreamingTest.java | 9 +- 15 files changed, 247 insertions(+), 188 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java index ec0610c..b1027a6 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineResult.java @@ -23,7 +23,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import org.apache.beam.runners.spark.translation.EvaluationContext; +import org.apache.beam.runners.spark.aggregators.SparkAggregators; import org.apache.beam.runners.spark.translation.SparkContextFactory; import org.apache.beam.sdk.AggregatorRetrievalException; import org.apache.beam.sdk.AggregatorValues; @@ -31,7 +31,10 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.transforms.Aggregator; +import org.apache.beam.sdk.util.UserCodeException; import org.apache.spark.SparkException; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.joda.time.Duration; /** @@ -40,29 +43,37 @@ import org.joda.time.Duration; public abstract class SparkPipelineResult implements PipelineResult { protected final Future pipelineExecution; - protected final EvaluationContext context; + protected JavaSparkContext javaSparkContext; protected PipelineResult.State state; SparkPipelineResult(final Future<?> pipelineExecution, - final EvaluationContext evaluationContext) { + final JavaSparkContext javaSparkContext) { this.pipelineExecution = pipelineExecution; - this.context = evaluationContext; + this.javaSparkContext = javaSparkContext; // pipelineExecution is expected to have started executing eagerly. state = State.RUNNING; } - private RuntimeException runtimeExceptionFrom(Throwable e) { + private RuntimeException runtimeExceptionFrom(final Throwable e) { return (e instanceof RuntimeException) ? (RuntimeException) e : new RuntimeException(e); } - private RuntimeException beamExceptionFrom(Throwable e) { + private RuntimeException beamExceptionFrom(final Throwable e) { // Scala doesn't declare checked exceptions in the bytecode, and the Java compiler // won't let you catch something that is not declared, so we can't catch // SparkException directly, instead we do an instanceof check. - return (e instanceof SparkException) - ? new Pipeline.PipelineExecutionException(e.getCause() != null ? e.getCause() : e) - : runtimeExceptionFrom(e); + + if (e instanceof SparkException) { + if (e.getCause() != null && e.getCause() instanceof UserCodeException) { + UserCodeException userException = (UserCodeException) e.getCause(); + return new Pipeline.PipelineExecutionException(userException.getCause()); + } else if (e.getCause() != null) { + return new Pipeline.PipelineExecutionException(e.getCause()); + } + } + + return runtimeExceptionFrom(e); } protected abstract void stop(); @@ -70,8 +81,14 @@ public abstract class SparkPipelineResult implements PipelineResult { protected abstract State awaitTermination(Duration duration) throws TimeoutException, ExecutionException, InterruptedException; - public <T> T getAggregatorValue(String named, Class<T> resultType) { - return context.getAggregatorValue(named, resultType); + public <T> T getAggregatorValue(final String name, final Class<T> resultType) { + return SparkAggregators.valueOf(name, resultType, javaSparkContext); + } + + @Override + public <T> AggregatorValues<T> getAggregatorValues(final Aggregator<?, T> aggregator) + throws AggregatorRetrievalException { + return SparkAggregators.valueOf(aggregator, javaSparkContext); } @Override @@ -85,15 +102,15 @@ public abstract class SparkPipelineResult implements PipelineResult { } @Override - public State waitUntilFinish(Duration duration) { + public State waitUntilFinish(final Duration duration) { try { state = awaitTermination(duration); - } catch (TimeoutException e) { + } catch (final TimeoutException e) { state = null; - } catch (ExecutionException e) { + } catch (final ExecutionException e) { state = PipelineResult.State.FAILED; throw beamExceptionFrom(e.getCause()); - } catch (Exception e) { + } catch (final Exception e) { state = PipelineResult.State.FAILED; throw beamExceptionFrom(e); } finally { @@ -104,12 +121,6 @@ public abstract class SparkPipelineResult implements PipelineResult { } @Override - public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) - throws AggregatorRetrievalException { - return context.getAggregatorValues(aggregator); - } - - @Override public MetricResults metrics() { throw new UnsupportedOperationException("The SparkRunner does not currently support metrics."); } @@ -130,17 +141,17 @@ public abstract class SparkPipelineResult implements PipelineResult { static class BatchMode extends SparkPipelineResult { BatchMode(final Future<?> pipelineExecution, - final EvaluationContext evaluationContext) { - super(pipelineExecution, evaluationContext); + final JavaSparkContext javaSparkContext) { + super(pipelineExecution, javaSparkContext); } @Override protected void stop() { - SparkContextFactory.stopSparkContext(context.getSparkContext()); + SparkContextFactory.stopSparkContext(javaSparkContext); } @Override - protected State awaitTermination(Duration duration) + protected State awaitTermination(final Duration duration) throws TimeoutException, ExecutionException, InterruptedException { pipelineExecution.get(duration.getMillis(), TimeUnit.MILLISECONDS); return PipelineResult.State.DONE; @@ -152,22 +163,25 @@ public abstract class SparkPipelineResult implements PipelineResult { */ static class StreamingMode extends SparkPipelineResult { + private final JavaStreamingContext javaStreamingContext; + StreamingMode(final Future<?> pipelineExecution, - final EvaluationContext evaluationContext) { - super(pipelineExecution, evaluationContext); + final JavaStreamingContext javaStreamingContext) { + super(pipelineExecution, javaStreamingContext.sparkContext()); + this.javaStreamingContext = javaStreamingContext; } @Override protected void stop() { - context.getStreamingContext().stop(false, true); - SparkContextFactory.stopSparkContext(context.getSparkContext()); + javaStreamingContext.stop(false, true); + SparkContextFactory.stopSparkContext(javaSparkContext); } @Override - protected State awaitTermination(Duration duration) throws TimeoutException, + protected State awaitTermination(final Duration duration) throws TimeoutException, ExecutionException, InterruptedException { pipelineExecution.get(duration.getMillis(), TimeUnit.MILLISECONDS); - if (context.getStreamingContext().awaitTerminationOrTimeout(duration.getMillis())) { + if (javaStreamingContext.awaitTerminationOrTimeout(duration.getMillis())) { return State.DONE; } else { return null; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java index a8c600e..d51ee7d 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java @@ -23,6 +23,9 @@ import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import org.apache.beam.runners.spark.aggregators.NamedAggregators; +import org.apache.beam.runners.spark.aggregators.SparkAggregators; +import org.apache.beam.runners.spark.aggregators.metrics.AggregatorMetricSource; import org.apache.beam.runners.spark.translation.EvaluationContext; import org.apache.beam.runners.spark.translation.SparkContextFactory; import org.apache.beam.runners.spark.translation.SparkPipelineTranslator; @@ -45,7 +48,10 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; +import org.apache.spark.Accumulator; +import org.apache.spark.SparkEnv$; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.metrics.MetricsSystem; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -122,12 +128,25 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { mOptions = options; } + private void registerMetrics(final SparkPipelineOptions opts, final JavaSparkContext jsc) { + final Accumulator<NamedAggregators> accum = SparkAggregators.getNamedAggregators(jsc); + final NamedAggregators initialValue = accum.value(); + + if (opts.getEnableSparkMetricSinks()) { + final MetricsSystem metricsSystem = SparkEnv$.MODULE$.get().metricsSystem(); + final AggregatorMetricSource aggregatorMetricSource = + new AggregatorMetricSource(opts.getAppName(), initialValue); + // re-register the metrics in case of context re-use + metricsSystem.removeSource(aggregatorMetricSource); + metricsSystem.registerSource(aggregatorMetricSource); + } + } + @Override public SparkPipelineResult run(final Pipeline pipeline) { LOG.info("Executing pipeline using the SparkRunner."); final SparkPipelineResult result; - final EvaluationContext evaluationContext; final Future<?> startPipeline; final ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -139,30 +158,26 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { final JavaStreamingContext jssc = JavaStreamingContext.getOrCreate(mOptions.getCheckpointDir(), contextFactory); - // if recovering from checkpoint, we have to reconstruct the Evaluation instance. - evaluationContext = - contextFactory.getCtxt() == null - ? new EvaluationContext(jssc.sparkContext(), pipeline, jssc) - : contextFactory.getCtxt(); - startPipeline = executorService.submit(new Runnable() { @Override public void run() { + registerMetrics(mOptions, jssc.sparkContext()); LOG.info("Starting streaming pipeline execution."); jssc.start(); } }); - result = new SparkPipelineResult.StreamingMode(startPipeline, evaluationContext); + result = new SparkPipelineResult.StreamingMode(startPipeline, jssc); } else { final JavaSparkContext jsc = SparkContextFactory.getSparkContext(mOptions); - evaluationContext = new EvaluationContext(jsc, pipeline); + final EvaluationContext evaluationContext = new EvaluationContext(jsc, pipeline); startPipeline = executorService.submit(new Runnable() { @Override public void run() { + registerMetrics(mOptions, jsc); pipeline.traverseTopologically(new Evaluator(new TransformTranslator.Translator(), evaluationContext)); evaluationContext.computeOutputs(); @@ -170,7 +185,7 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { } }); - result = new SparkPipelineResult.BatchMode(startPipeline, evaluationContext); + result = new SparkPipelineResult.BatchMode(startPipeline, jsc); } return result; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java index 9a67f9c..2c26d84 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java @@ -75,6 +75,7 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { public SparkPipelineResult run(Pipeline pipeline) { TestPipelineOptions testPipelineOptions = pipeline.getOptions().as(TestPipelineOptions.class); SparkPipelineResult result = delegate.run(pipeline); + result.waitUntilFinish(); assertThat(result, testPipelineOptions.getOnCreateMatcher()); assertThat(result, testPipelineOptions.getOnSuccessMatcher()); return result; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AccumulatorSingleton.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AccumulatorSingleton.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AccumulatorSingleton.java index bc7105f..883830e 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AccumulatorSingleton.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/AccumulatorSingleton.java @@ -26,11 +26,11 @@ import org.apache.spark.api.java.JavaSparkContext; * For resilience, {@link Accumulator}s are required to be wrapped in a Singleton. * @see <a href="https://spark.apache.org/docs/1.6.3/streaming-programming-guide.html#accumulators-and-broadcast-variables">accumulators</a> */ -public class AccumulatorSingleton { +class AccumulatorSingleton { private static volatile Accumulator<NamedAggregators> instance = null; - public static Accumulator<NamedAggregators> getInstance(JavaSparkContext jsc) { + static Accumulator<NamedAggregators> getInstance(JavaSparkContext jsc) { if (instance == null) { synchronized (AccumulatorSingleton.class) { if (instance == null) { @@ -45,7 +45,7 @@ public class AccumulatorSingleton { } @VisibleForTesting - public static void clear() { + static void clear() { synchronized (AccumulatorSingleton.class) { instance = null; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java new file mode 100644 index 0000000..1b06691 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/aggregators/SparkAggregators.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.aggregators; + +import com.google.common.collect.ImmutableList; +import java.util.Collection; +import java.util.Map; +import org.apache.beam.sdk.AggregatorValues; +import org.apache.beam.sdk.transforms.Aggregator; +import org.apache.spark.Accumulator; +import org.apache.spark.api.java.JavaSparkContext; + +/** + * A utility class for retrieving aggregator values. + */ +public class SparkAggregators { + + private static <T> AggregatorValues<T> valueOf(final Accumulator<NamedAggregators> accum, + final Aggregator<?, T> aggregator) { + @SuppressWarnings("unchecked") + Class<T> valueType = (Class<T>) aggregator.getCombineFn().getOutputType().getRawType(); + final T value = valueOf(accum, aggregator.getName(), valueType); + + return new AggregatorValues<T>() { + + @Override + public Collection<T> getValues() { + return ImmutableList.of(value); + } + + @Override + public Map<String, T> getValuesAtSteps() { + throw new UnsupportedOperationException("getValuesAtSteps is not supported."); + } + }; + } + + private static <T> T valueOf(final Accumulator<NamedAggregators> accum, + final String aggregatorName, + final Class<T> typeClass) { + return accum.value().getValue(aggregatorName, typeClass); + } + + /** + * Retrieves the {@link NamedAggregators} instance using the provided Spark context. + * + * @param jsc a Spark context to be used in order to retrieve the name + * {@link NamedAggregators} instance + * @return a {@link NamedAggregators} instance + */ + public static Accumulator<NamedAggregators> getNamedAggregators(JavaSparkContext jsc) { + return AccumulatorSingleton.getInstance(jsc); + } + + /** + * Retrieves the value of an aggregator from a SparkContext instance. + * + * @param aggregator The aggregator whose value to retrieve + * @param javaSparkContext The SparkContext instance + * @param <T> The type of the aggregator's output + * @return The value of the aggregator + */ + public static <T> AggregatorValues<T> valueOf(final Aggregator<?, T> aggregator, + final JavaSparkContext javaSparkContext) { + return valueOf(getNamedAggregators(javaSparkContext), aggregator); + } + + /** + * Retrieves the value of an aggregator from a SparkContext instance. + * + * @param name Name of the aggregator to retrieve the value of. + * @param typeClass Type class of value to be retrieved. + * @param <T> Type of object to be returned. + * @return The value of the aggregator. + */ + public static <T> T valueOf(final String name, + final Class<T> typeClass, + final JavaSparkContext javaSparkContext) { + return valueOf(getNamedAggregators(javaSparkContext), name, typeClass); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 425f114..a412e31 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -26,13 +26,9 @@ import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; import org.apache.beam.runners.spark.SparkPipelineOptions; -import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; import org.apache.beam.runners.spark.translation.streaming.UnboundedDataset; -import org.apache.beam.sdk.AggregatorRetrievalException; -import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.WindowedValue; @@ -65,11 +61,10 @@ public class EvaluationContext { public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) { this.jsc = jsc; this.pipeline = pipeline; - this.runtime = new SparkRuntimeContext(pipeline, jsc); + this.runtime = new SparkRuntimeContext(pipeline); } - public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline, - JavaStreamingContext jssc) { + public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline, JavaStreamingContext jssc) { this(jsc, pipeline); this.jssc = jssc; } @@ -192,17 +187,6 @@ public class EvaluationContext { throw new IllegalStateException("Cannot resolve un-known PObject: " + value); } - public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) - throws AggregatorRetrievalException { - return runtime.getAggregatorValues(AccumulatorSingleton.getInstance(jsc), aggregator); - } - - public <T> T getAggregatorValue(String named, Class<T> resultType) { - return runtime.getAggregatorValue(AccumulatorSingleton.getInstance(jsc), - named, - resultType); - } - /** * Retrieves an iterable of results associated with the PCollection passed in. * http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java index 564db39..01b6b54 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java @@ -20,17 +20,11 @@ package org.apache.beam.runners.spark.translation; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableList; import java.io.IOException; import java.io.Serializable; -import java.util.Collection; import java.util.HashMap; import java.util.Map; -import org.apache.beam.runners.spark.SparkPipelineOptions; -import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; import org.apache.beam.runners.spark.aggregators.NamedAggregators; -import org.apache.beam.runners.spark.aggregators.metrics.AggregatorMetricSource; -import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; @@ -43,10 +37,6 @@ import org.apache.beam.sdk.transforms.Min; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.spark.Accumulator; -import org.apache.spark.SparkEnv$; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.metrics.MetricsSystem; - /** * The SparkRuntimeContext allows us to define useful features on the client side before our @@ -61,12 +51,11 @@ public class SparkRuntimeContext implements Serializable { private final Map<String, Aggregator<?, ?>> aggregators = new HashMap<>(); private transient CoderRegistry coderRegistry; - SparkRuntimeContext(Pipeline pipeline, JavaSparkContext jsc) { + SparkRuntimeContext(Pipeline pipeline) { this.serializedPipelineOptions = serializePipelineOptions(pipeline.getOptions()); - registerMetrics(pipeline.getOptions().as(SparkPipelineOptions.class), jsc); } - private static String serializePipelineOptions(PipelineOptions pipelineOptions) { + private String serializePipelineOptions(PipelineOptions pipelineOptions) { try { return new ObjectMapper().writeValueAsString(pipelineOptions); } catch (JsonProcessingException e) { @@ -82,53 +71,6 @@ public class SparkRuntimeContext implements Serializable { } } - private void registerMetrics(final SparkPipelineOptions opts, final JavaSparkContext jsc) { - final Accumulator<NamedAggregators> accum = AccumulatorSingleton.getInstance(jsc); - final NamedAggregators initialValue = accum.value(); - - if (opts.getEnableSparkMetricSinks()) { - final MetricsSystem metricsSystem = SparkEnv$.MODULE$.get().metricsSystem(); - final AggregatorMetricSource aggregatorMetricSource = - new AggregatorMetricSource(opts.getAppName(), initialValue); - // re-register the metrics in case of context re-use - metricsSystem.removeSource(aggregatorMetricSource); - metricsSystem.registerSource(aggregatorMetricSource); - } - } - - /** - * Retrieves corresponding value of an aggregator. - * - * @param accum The Spark Accumulator holding all Aggregators. - * @param aggregatorName Name of the aggregator to retrieve the value of. - * @param typeClass Type class of value to be retrieved. - * @param <T> Type of object to be returned. - * @return The value of the aggregator. - */ - public <T> T getAggregatorValue(Accumulator<NamedAggregators> accum, - String aggregatorName, - Class<T> typeClass) { - return accum.value().getValue(aggregatorName, typeClass); - } - - public <T> AggregatorValues<T> getAggregatorValues(Accumulator<NamedAggregators> accum, - Aggregator<?, T> aggregator) { - @SuppressWarnings("unchecked") - Class<T> aggValueClass = (Class<T>) aggregator.getCombineFn().getOutputType().getRawType(); - final T aggregatorValue = getAggregatorValue(accum, aggregator.getName(), aggValueClass); - return new AggregatorValues<T>() { - @Override - public Collection<T> getValues() { - return ImmutableList.of(aggregatorValue); - } - - @Override - public Map<String, T> getValuesAtSteps() { - throw new UnsupportedOperationException("getValuesAtSteps is not supported."); - } - }; - } - public synchronized PipelineOptions getPipelineOptions() { return deserializePipelineOptions(serializedPipelineOptions); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 66da181..e033ab1 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -33,8 +33,8 @@ import org.apache.avro.mapreduce.AvroJob; import org.apache.avro.mapreduce.AvroKeyInputFormat; import org.apache.beam.runners.core.AssignWindowsDoFn; import org.apache.beam.runners.spark.SparkRunner; -import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; import org.apache.beam.runners.spark.aggregators.NamedAggregators; +import org.apache.beam.runners.spark.aggregators.SparkAggregators; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.io.SourceRDD; import org.apache.beam.runners.spark.io.hadoop.HadoopIO; @@ -126,7 +126,7 @@ public final class TransformTranslator { final KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder(); final Accumulator<NamedAggregators> accum = - AccumulatorSingleton.getInstance(context.getSparkContext()); + SparkAggregators.getNamedAggregators(context.getSparkContext()); context.putDataset(transform, new BoundedDataset<>(GroupCombineFunctions.groupByKey(inRDD, accum, coder, @@ -249,7 +249,7 @@ public final class TransformTranslator { final WindowFn<Object, ?> windowFn = (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn(); Accumulator<NamedAggregators> accum = - AccumulatorSingleton.getInstance(context.getSparkContext()); + SparkAggregators.getNamedAggregators(context.getSparkContext()); Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context); context.putDataset(transform, @@ -281,7 +281,7 @@ public final class TransformTranslator { final WindowFn<Object, ?> windowFn = (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn(); Accumulator<NamedAggregators> accum = - AccumulatorSingleton.getInstance(context.getSparkContext()); + SparkAggregators.getNamedAggregators(context.getSparkContext()); JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD .mapPartitionsToPair( new MultiDoFnFunction<>(accum, transform.getFn(), context.getRuntimeContext(), @@ -530,7 +530,7 @@ public final class TransformTranslator { WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn(); OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); Accumulator<NamedAggregators> accum = - AccumulatorSingleton.getInstance(context.getSparkContext()); + SparkAggregators.getNamedAggregators(context.getSparkContext()); context.putDataset(transform, new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, addWindowsDoFn, context.getRuntimeContext(), null, null)))); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 6ed5b55..85d796a 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -24,8 +24,8 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; import org.apache.beam.runners.core.AssignWindowsDoFn; -import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; import org.apache.beam.runners.spark.aggregators.NamedAggregators; +import org.apache.beam.runners.spark.aggregators.SparkAggregators; import org.apache.beam.runners.spark.io.ConsoleIO; import org.apache.beam.runners.spark.io.CreateStream; import org.apache.beam.runners.spark.io.SparkUnboundedSource; @@ -194,7 +194,7 @@ final class StreamingTransformTranslator { @Override public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> rdd) throws Exception { final Accumulator<NamedAggregators> accum = - AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context())); return rdd.mapPartitions( new DoFnFunction<>(accum, addWindowsDoFn, runtimeContext, null, null)); } @@ -227,7 +227,7 @@ final class StreamingTransformTranslator { public JavaRDD<WindowedValue<KV<K, Iterable<V>>>> call( JavaRDD<WindowedValue<KV<K, V>>> rdd) throws Exception { final Accumulator<NamedAggregators> accum = - AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context())); return GroupCombineFunctions.groupByKey(rdd, accum, coder, runtimeContext, windowingStrategy); } @@ -363,7 +363,7 @@ final class StreamingTransformTranslator { public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) throws Exception { final Accumulator<NamedAggregators> accum = - AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context())); return rdd.mapPartitions( new DoFnFunction<>(accum, transform.getFn(), runtimeContext, sideInputs, windowFn)); } @@ -396,7 +396,7 @@ final class StreamingTransformTranslator { public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call( JavaRDD<WindowedValue<InputT>> rdd) throws Exception { final Accumulator<NamedAggregators> accum = - AccumulatorSingleton.getInstance(new JavaSparkContext(rdd.context())); + SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context())); return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, transform.getFn(), runtimeContext, transform.getMainOutputTag(), sideInputs, windowFn)); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java index 69cf1c4..54e210d 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPipelineStateTest.java @@ -39,7 +39,6 @@ import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SimpleFunction; -import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Duration; @@ -52,9 +51,9 @@ import org.junit.rules.TestName; */ public class SparkPipelineStateTest implements Serializable { - private static class UserException extends RuntimeException { + private static class MyCustomException extends RuntimeException { - UserException(String message) { + MyCustomException(final String message) { super(message); } } @@ -76,13 +75,13 @@ public class SparkPipelineStateTest implements Serializable { return ParDo.of(new DoFn<String, String>() { @ProcessElement - public void processElement(ProcessContext c) { + public void processElement(final ProcessContext c) { System.out.println(prefix + " " + c.element()); } }); } - private PTransform<PBegin, PCollection<String>> getValues(SparkPipelineOptions options) { + private PTransform<PBegin, PCollection<String>> getValues(final SparkPipelineOptions options) { return options.isStreaming() ? CreateStream.fromQueue(STREAMING_WORDS) : Create.of(BATCH_WORDS); @@ -98,7 +97,7 @@ public class SparkPipelineStateTest implements Serializable { return commonOptions.getOptions(); } - private Pipeline getPipeline(SparkPipelineOptions options) { + private Pipeline getPipeline(final SparkPipelineOptions options) { final Pipeline pipeline = Pipeline.create(options); final String name = testName.getMethodName() + "(isStreaming=" + options.isStreaming() + ")"; @@ -110,7 +109,7 @@ public class SparkPipelineStateTest implements Serializable { return pipeline; } - private void testFailedPipeline(SparkPipelineOptions options) throws Exception { + private void testFailedPipeline(final SparkPipelineOptions options) throws Exception { SparkPipelineResult result = null; @@ -121,18 +120,17 @@ public class SparkPipelineStateTest implements Serializable { .apply(MapElements.via(new SimpleFunction<String, String>() { @Override - public String apply(String input) { - throw new UserException(FAILED_THE_BATCH_INTENTIONALLY); + public String apply(final String input) { + throw new MyCustomException(FAILED_THE_BATCH_INTENTIONALLY); } })); result = (SparkPipelineResult) pipeline.run(); result.waitUntilFinish(); - } catch (Exception e) { + } catch (final Exception e) { assertThat(e, instanceOf(Pipeline.PipelineExecutionException.class)); - assertThat(e.getCause(), instanceOf(UserCodeException.class)); - assertThat(e.getCause().getCause(), instanceOf(UserException.class)); - assertThat(e.getCause().getCause().getMessage(), is(FAILED_THE_BATCH_INTENTIONALLY)); + assertThat(e.getCause(), instanceOf(MyCustomException.class)); + assertThat(e.getCause().getMessage(), is(FAILED_THE_BATCH_INTENTIONALLY)); assertThat(result.getState(), is(PipelineResult.State.FAILED)); result.cancel(); return; @@ -141,11 +139,11 @@ public class SparkPipelineStateTest implements Serializable { fail("An injected failure did not affect the pipeline as expected."); } - private void testTimeoutPipeline(SparkPipelineOptions options) throws Exception { + private void testTimeoutPipeline(final SparkPipelineOptions options) throws Exception { final Pipeline pipeline = getPipeline(options); - SparkPipelineResult result = (SparkPipelineResult) pipeline.run(); + final SparkPipelineResult result = (SparkPipelineResult) pipeline.run(); result.waitUntilFinish(Duration.millis(1)); @@ -154,22 +152,22 @@ public class SparkPipelineStateTest implements Serializable { result.cancel(); } - private void testCanceledPipeline(SparkPipelineOptions options) throws Exception { + private void testCanceledPipeline(final SparkPipelineOptions options) throws Exception { final Pipeline pipeline = getPipeline(options); - SparkPipelineResult result = (SparkPipelineResult) pipeline.run(); + final SparkPipelineResult result = (SparkPipelineResult) pipeline.run(); result.cancel(); assertThat(result.getState(), is(PipelineResult.State.CANCELLED)); } - private void testRunningPipeline(SparkPipelineOptions options) throws Exception { + private void testRunningPipeline(final SparkPipelineOptions options) throws Exception { final Pipeline pipeline = getPipeline(options); - SparkPipelineResult result = (SparkPipelineResult) pipeline.run(); + final SparkPipelineResult result = (SparkPipelineResult) pipeline.run(); assertThat(result.getState(), is(PipelineResult.State.RUNNING)); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/ClearAggregatorsRule.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/ClearAggregatorsRule.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/ClearAggregatorsRule.java new file mode 100644 index 0000000..4e91d15 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/ClearAggregatorsRule.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.aggregators; + +import org.junit.rules.ExternalResource; + +/** + * A rule that clears the {@link org.apache.beam.runners.spark.aggregators.AccumulatorSingleton} + * which represents the Beam {@link org.apache.beam.sdk.transforms.Aggregator}s. + */ +public class ClearAggregatorsRule extends ExternalResource { + + @Override + protected void before() throws Throwable { + clearNamedAggregators(); + } + + public void clearNamedAggregators() { + AccumulatorSingleton.clear(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/ClearAggregatorsRule.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/ClearAggregatorsRule.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/ClearAggregatorsRule.java deleted file mode 100644 index 52ae019..0000000 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/ClearAggregatorsRule.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.spark.aggregators.metrics.sink; - -import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; -import org.junit.rules.ExternalResource; - -/** - * A rule that clears the {@link org.apache.beam.runners.spark.aggregators.AccumulatorSingleton} - * which represents the Beam {@link org.apache.beam.sdk.transforms.Aggregator}s. - */ -public class ClearAggregatorsRule extends ExternalResource { - @Override - protected void before() throws Throwable { - AccumulatorSingleton.clear(); - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java index 6b36bcc..3b5dd21 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java @@ -27,6 +27,7 @@ import java.util.Arrays; import java.util.List; import java.util.Set; import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.aggregators.ClearAggregatorsRule; import org.apache.beam.runners.spark.examples.WordCount; import org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions; import org.apache.beam.sdk.Pipeline; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/EmptyStreamAssertionTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/EmptyStreamAssertionTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/EmptyStreamAssertionTest.java index e3561d6..e482945 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/EmptyStreamAssertionTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/EmptyStreamAssertionTest.java @@ -23,7 +23,7 @@ import static org.junit.Assert.fail; import java.io.Serializable; import java.util.Collections; import org.apache.beam.runners.spark.SparkPipelineOptions; -import org.apache.beam.runners.spark.aggregators.metrics.sink.ClearAggregatorsRule; +import org.apache.beam.runners.spark.aggregators.ClearAggregatorsRule; import org.apache.beam.runners.spark.io.CreateStream; import org.apache.beam.runners.spark.translation.streaming.utils.PAssertStreaming; import org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptionsForStreaming; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/158378f0/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java index e0d71d4..945ee76 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java @@ -29,7 +29,7 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.SparkPipelineResult; -import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; +import org.apache.beam.runners.spark.aggregators.ClearAggregatorsRule; import org.apache.beam.runners.spark.translation.streaming.utils.EmbeddedKafkaCluster; import org.apache.beam.runners.spark.translation.streaming.utils.PAssertStreaming; import org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptionsForStreaming; @@ -83,6 +83,9 @@ public class ResumeFromCheckpointStreamingTest { public SparkTestPipelineOptionsForStreaming commonOptions = new SparkTestPipelineOptionsForStreaming(); + @Rule + public ClearAggregatorsRule clearAggregatorsRule = new ClearAggregatorsRule(); + @BeforeClass public static void init() throws IOException { EMBEDDED_ZOOKEEPER.startup(); @@ -132,8 +135,8 @@ public class ResumeFromCheckpointStreamingTest { equalTo(EXPECTED_AGG_FIRST)); } - private static SparkPipelineResult runAgain(SparkPipelineOptions options) { - AccumulatorSingleton.clear(); + private SparkPipelineResult runAgain(SparkPipelineOptions options) { + clearAggregatorsRule.clearNamedAggregators(); // sleep before next run. Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); return run(options);
