Repository: beam Updated Branches: refs/heads/master 37e532188 -> 7a2fe68fd
[BEAM-1148] Port PAssert away from aggregators Separates evaluation of the assertion into a transform that outputs `SuccessOrFailure` from the reporting of failures. The latter happens in a separate composite transform making it possible to override the implementation. Introduces a default implementation that uses Metrics to count the number of successfully executed assertions as well as the number of failing assertions. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/e8f0922f Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/e8f0922f Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/e8f0922f Branch: refs/heads/master Commit: e8f0922f6f15dd3ab96f54ba4a5c1083269d70bd Parents: 37e5321 Author: Pablo <[email protected]> Authored: Wed Mar 29 14:49:53 2017 -0700 Committer: bchambers <[email protected]> Committed: Mon Apr 24 11:56:58 2017 -0700 ---------------------------------------------------------------------- .../beam/runners/spark/TestSparkRunner.java | 35 ++-- .../ResumeFromCheckpointStreamingTest.java | 72 ++++---- .../org/apache/beam/sdk/testing/PAssert.java | 164 ++++++++++++------- .../beam/sdk/testing/SuccessOrFailure.java | 82 ++++++++++ .../apache/beam/sdk/testing/PAssertTest.java | 55 +++++++ 5 files changed, 307 insertions(+), 101 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/e8f0922f/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java index 61fcaa9..10e98b8 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java @@ -40,6 +40,9 @@ import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource; +import org.apache.beam.sdk.metrics.MetricNameFilter; +import org.apache.beam.sdk.metrics.MetricResult; +import org.apache.beam.sdk.metrics.MetricsFilter; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsValidator; import org.apache.beam.sdk.runners.PTransformOverride; @@ -136,11 +139,15 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { isOneOf(PipelineResult.State.STOPPED, PipelineResult.State.DONE)); // validate assertion succeeded (at least once). - int successAssertions = 0; - try { - successAssertions = result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class); - } catch (NullPointerException e) { - // No assertions registered will cause an NPE here. + long successAssertions = 0; + Iterable<MetricResult<Long>> counterResults = result.metrics().queryMetrics( + MetricsFilter.builder() + .addNameFilter(MetricNameFilter.named(PAssert.class, PAssert.SUCCESS_COUNTER)) + .build()).counters(); + for (MetricResult<Long> counter : counterResults) { + if (counter.attempted().longValue() > 0) { + successAssertions++; + } } Integer expectedAssertions = testSparkPipelineOptions.getExpectedAssertions() != null ? testSparkPipelineOptions.getExpectedAssertions() : expectedNumberOfAssertions; @@ -149,18 +156,22 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { "Expected %d successful assertions, but found %d.", expectedAssertions, successAssertions), successAssertions, - is(expectedAssertions)); + is(expectedAssertions.longValue())); // validate assertion didn't fail. - int failedAssertions = 0; - try { - failedAssertions = result.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class); - } catch (NullPointerException e) { - // No assertions registered will cause an NPE here. + long failedAssertions = 0; + Iterable<MetricResult<Long>> failCounterResults = result.metrics().queryMetrics( + MetricsFilter.builder() + .addNameFilter(MetricNameFilter.named(PAssert.class, PAssert.FAILURE_COUNTER)) + .build()).counters(); + for (MetricResult<Long> counter : failCounterResults) { + if (counter.attempted().longValue() > 0) { + failedAssertions++; + } } assertThat( String.format("Found %d failed assertions.", failedAssertions), failedAssertions, - is(0)); + is(0L)); LOG.info( String.format( http://git-wip-us.apache.org/repos/asf/beam/blob/e8f0922f/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java index 6cbf83a..1aa76a3 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java @@ -19,7 +19,6 @@ package org.apache.beam.runners.spark.translation.streaming; import static org.apache.beam.sdk.metrics.MetricMatchers.attemptedMetricsResult; import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; @@ -51,10 +50,10 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.kafka.KafkaIO; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.MetricNameFilter; +import org.apache.beam.sdk.metrics.MetricResult; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsFilter; import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; @@ -62,7 +61,6 @@ import org.apache.beam.sdk.transforms.Keys; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.WithKeys; @@ -94,8 +92,8 @@ import org.junit.experimental.categories.Category; * <p>Runs the pipeline reading from a Kafka backlog with a WM function that will move to infinity * on a EOF signal. * After resuming from checkpoint, a single output (guaranteed by the WM) is asserted, along with - * {@link Aggregator}s and {@link Metrics} values that are expected to resume from previous count - * and a side-input that is expected to recover as well. + * {@link Metrics} values that are expected to resume from previous count and a side-input that is + * expected to recover as well. */ public class ResumeFromCheckpointStreamingTest { private static final EmbeddedKafkaCluster.EmbeddedZookeeper EMBEDDED_ZOOKEEPER = @@ -161,16 +159,13 @@ public class ResumeFromCheckpointStreamingTest { // first run should expect EOT matching the last injected element. SparkPipelineResult res = run(pipelineRule, Optional.of(new Instant(400)), 0); - // assertions 1: - long processedMessages1 = res.getAggregatorValue("processedMessages", Long.class); - assertThat( - String.format( - "Expected %d processed messages count but found %d", 4, processedMessages1), - processedMessages1, - equalTo(4L)); + assertThat(res.metrics().queryMetrics(metricsFilter).counters(), hasItem(attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(), "allMessages", "EOFShallNotPassFn", 4L))); + assertThat(res.metrics().queryMetrics(metricsFilter).counters(), + hasItem(attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(), + "processedMessages", "EOFShallNotPassFn", 4L))); //--- between executions: @@ -186,27 +181,42 @@ public class ResumeFromCheckpointStreamingTest { // recovery should resume from last read offset, and read the second batch of input. res = runAgain(pipelineRule, 1); // assertions 2: - long processedMessages2 = res.getAggregatorValue("processedMessages", Long.class); - assertThat( - String.format("Expected %d processed messages count but found %d", 5, processedMessages2), - processedMessages2, - equalTo(5L)); + assertThat(res.metrics().queryMetrics(metricsFilter).counters(), + hasItem(attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(), + "processedMessages", "EOFShallNotPassFn", 5L))); assertThat(res.metrics().queryMetrics(metricsFilter).counters(), hasItem(attemptedMetricsResult(ResumeFromCheckpointStreamingTest.class.getName(), "allMessages", "EOFShallNotPassFn", 6L))); - int successAssertions = res.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class); - res.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class); + long successAssertions = 0; + Iterable<MetricResult<Long>> counterResults = res.metrics().queryMetrics( + MetricsFilter.builder() + .addNameFilter(MetricNameFilter.named(PAssert.class, PAssert.SUCCESS_COUNTER)) + .build()).counters(); + for (MetricResult<Long> counter : counterResults) { + if (counter.attempted().longValue() > 0) { + successAssertions++; + } + } assertThat( String.format( - "Expected %d successful assertions, but found %d.", 1, successAssertions), + "Expected %d successful assertions, but found %d.", 1L, successAssertions), successAssertions, - is(1)); + is(1L)); // validate assertion didn't fail. - int failedAssertions = res.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class); + long failedAssertions = 0; + Iterable<MetricResult<Long>> failCounterResults = res.metrics().queryMetrics( + MetricsFilter.builder() + .addNameFilter(MetricNameFilter.named(PAssert.class, PAssert.FAILURE_COUNTER)) + .build()).counters(); + for (MetricResult<Long> counter : failCounterResults) { + if (counter.attempted().longValue() > 0) { + failedAssertions++; + } + } assertThat( String.format("Found %d failed assertions.", failedAssertions), failedAssertions, - is(0)); + is(0L)); } @@ -289,8 +299,8 @@ public class ResumeFromCheckpointStreamingTest { /** A pass-through fn that prevents EOF event from passing. */ private static class EOFShallNotPassFn extends DoFn<String, String> { final PCollectionView<List<String>> view; - private final Aggregator<Long, Long> aggregator = - createAggregator("processedMessages", Sum.ofLongs()); + private final Counter aggregator = Metrics.counter( + ResumeFromCheckpointStreamingTest.class, "processedMessages"); Counter counter = Metrics.counter(ResumeFromCheckpointStreamingTest.class, "allMessages"); @@ -305,7 +315,7 @@ public class ResumeFromCheckpointStreamingTest { assertThat(c.sideInput(view), containsInAnyOrder("side1", "side2")); counter.inc(); if (!element.equals("EOF")) { - aggregator.addValue(1L); + aggregator.inc(); c.output(c.element()); } } @@ -330,10 +340,8 @@ public class ResumeFromCheckpointStreamingTest { } private static class AssertDoFn<T> extends DoFn<Iterable<T>, Void> { - private final Aggregator<Integer, Integer> success = - createAggregator(PAssert.SUCCESS_COUNTER, Sum.ofIntegers()); - private final Aggregator<Integer, Integer> failure = - createAggregator(PAssert.FAILURE_COUNTER, Sum.ofIntegers()); + private final Counter success = Metrics.counter(PAssert.class, PAssert.SUCCESS_COUNTER); + private final Counter failure = Metrics.counter(PAssert.class, PAssert.FAILURE_COUNTER); private final T[] expected; AssertDoFn(T[] expected) { @@ -344,9 +352,9 @@ public class ResumeFromCheckpointStreamingTest { public void processElement(ProcessContext c) throws Exception { try { assertThat(c.element(), containsInAnyOrder(expected)); - success.addValue(1); + success.inc(); } catch (Throwable t) { - failure.addValue(1); + failure.inc(); throw t; } } http://git-wip-us.apache.org/repos/asf/beam/blob/e8f0922f/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java index 92dca53..85b8c5f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/PAssert.java @@ -40,9 +40,10 @@ import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.MapCoder; import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.runners.TransformHierarchy.Node; -import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; @@ -52,7 +53,6 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.SimpleFunction; -import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.WithKeys; @@ -107,9 +107,12 @@ import org.slf4j.LoggerFactory; public class PAssert { private static final Logger LOG = LoggerFactory.getLogger(PAssert.class); - public static final String SUCCESS_COUNTER = "PAssertSuccess"; public static final String FAILURE_COUNTER = "PAssertFailure"; + private static final Counter successCounter = Metrics.counter( + PAssert.class, PAssert.SUCCESS_COUNTER); + private static final Counter failureCounter = Metrics.counter( + PAssert.class, PAssert.FAILURE_COUNTER); private static int assertCount = 0; @@ -121,6 +124,79 @@ public class PAssert { private PAssert() {} /** + * A {@link DoFn} that counts the number of successful {@link SuccessOrFailure} in the + * input {@link PCollection} and counts them. If a failed {@link SuccessOrFailure} is + * encountered, it is counted and immediately raised. + */ + private static final class DefaultConcludeFn extends DoFn<SuccessOrFailure, Void> { + + @ProcessElement + public void processElement(ProcessContext c) { + SuccessOrFailure e = c.element(); + if (e.isSuccess()) { + PAssert.successCounter.inc(); + } else { + PAssert.failureCounter.inc(); + throw e.assertionError(); + } + } + } + + /** + * Default transform to check that a PAssert was successful. This transform + * relies on two {@link Counter} objects from the Metrics API to count the number of + * successful and failed asserts. + * Runners that do not support the Metrics API should replace this transform with + * their own implementation. + */ + public static class DefaultConcludeTransform + extends PTransform<PCollection<SuccessOrFailure>, PCollection<Void>> { + public PCollection<Void> expand(PCollection<SuccessOrFailure> input) { + return input.apply(ParDo.of(new DefaultConcludeFn())); + } + } + + /** + * Track the place where an assertion is defined. + * This is necessary because the stack trace of a Throwable is a transient attribute, and can't + * be serialized. {@link PAssertionSite} helps track the stack trace + * of the place where an assertion is issued. + */ + public static class PAssertionSite implements Serializable { + private final String message; + private final StackTraceElement[] creationStackTrace; + + static PAssertionSite capture(String message) { + return new PAssertionSite(message, new Throwable().getStackTrace()); + } + + PAssertionSite() { + this(null, new StackTraceElement[0]); + } + + PAssertionSite(String message, StackTraceElement[] creationStackTrace) { + this.message = message; + this.creationStackTrace = creationStackTrace; + } + + public AssertionError wrap(Throwable t) { + AssertionError res = + new AssertionError( + message.isEmpty() ? t.getMessage() : (message + ": " + t.getMessage()), t); + res.setStackTrace(creationStackTrace); + return res; + } + + public AssertionError wrap(String message) { + String outputMessage = (this.message == null || this.message.isEmpty()) + ? message : (this.message + ": " + message); + AssertionError res = new AssertionError(outputMessage); + res.setStackTrace(creationStackTrace); + return res; + } + } + + /** * Builder interface for assertions applicable to iterables and PCollection contents. */ public interface IterableAssert<T> { @@ -400,33 +476,11 @@ public class PAssert { //////////////////////////////////////////////////////////// - private static class PAssertionSite implements Serializable { - private final String message; - private final StackTraceElement[] creationStackTrace; - - static PAssertionSite capture(String message) { - return new PAssertionSite(message, new Throwable().getStackTrace()); - } - - PAssertionSite(String message, StackTraceElement[] creationStackTrace) { - this.message = message; - this.creationStackTrace = creationStackTrace; - } - - public AssertionError wrap(Throwable t) { - AssertionError res = - new AssertionError( - message.isEmpty() ? t.getMessage() : (message + ": " + t.getMessage()), t); - res.setStackTrace(creationStackTrace); - return res; - } - } - /** * An {@link IterableAssert} about the contents of a {@link PCollection}. This does not require * the runner to support side inputs. */ - private static class PCollectionContentsAssert<T> implements IterableAssert<T> { + protected static class PCollectionContentsAssert<T> implements IterableAssert<T> { private final PCollection<T> actual; private final AssertionWindows rewindowingStrategy; private final SimpleFunction<Iterable<ValueInSingleWindow<T>>, Iterable<T>> paneExtractor; @@ -560,7 +614,8 @@ public class PAssert { return this; } - private static class MatcherCheckerFn<T> implements SerializableFunction<T, Void> { + /** Check that the passed-in matchers match the existing data. */ + protected static class MatcherCheckerFn<T> implements SerializableFunction<T, Void> { private SerializableMatcher<T> matcher; public MatcherCheckerFn(SerializableMatcher<T> matcher) { @@ -690,7 +745,8 @@ public class PAssert { SerializableFunction<Iterable<T>, Void> checkerFn) { actual.apply( "PAssert$" + (assertCount++), - new GroupThenAssertForSingleton<>(checkerFn, rewindowingStrategy, paneExtractor, site)); + new GroupThenAssertForSingleton<>( + checkerFn, rewindowingStrategy, paneExtractor, site)); return this; } @@ -1033,7 +1089,8 @@ public class PAssert { .apply("GroupGlobally", new GroupGlobally<T>(rewindowingStrategy)) .apply("GetPane", MapElements.via(paneExtractor)) .setCoder(IterableCoder.of(input.getCoder())) - .apply("RunChecks", ParDo.of(new GroupedValuesCheckerDoFn<>(checkerFn, site))); + .apply("RunChecks", ParDo.of(new GroupedValuesCheckerDoFn<>(checkerFn, site))) + .apply("VerifyAssertions", new DefaultConcludeTransform()); return PDone.in(input.getPipeline()); } @@ -1069,7 +1126,8 @@ public class PAssert { .apply("GroupGlobally", new GroupGlobally<Iterable<T>>(rewindowingStrategy)) .apply("GetPane", MapElements.via(paneExtractor)) .setCoder(IterableCoder.of(input.getCoder())) - .apply("RunChecks", ParDo.of(new SingletonCheckerDoFn<>(checkerFn, site))); + .apply("RunChecks", ParDo.of(new SingletonCheckerDoFn<>(checkerFn, site))) + .apply("VerifyAssertions", new DefaultConcludeTransform()); return PDone.in(input.getPipeline()); } @@ -1112,8 +1170,8 @@ public class PAssert { .apply("WindowToken", windowToken) .apply( "RunChecks", - ParDo.of(new SideInputCheckerDoFn<>(checkerFn, actual, site)).withSideInputs(actual)); - + ParDo.of(new SideInputCheckerDoFn<>(checkerFn, actual, site)).withSideInputs(actual)) + .apply("VerifyAssertions", new DefaultConcludeTransform()); return PDone.in(input.getPipeline()); } } @@ -1125,12 +1183,8 @@ public class PAssert { * <p>The input is ignored, but is {@link Integer} to be usable on runners that do not support * null values. */ - private static class SideInputCheckerDoFn<ActualT> extends DoFn<Integer, Void> { + private static class SideInputCheckerDoFn<ActualT> extends DoFn<Integer, SuccessOrFailure> { private final SerializableFunction<ActualT, Void> checkerFn; - private final Aggregator<Integer, Integer> success = - createAggregator(SUCCESS_COUNTER, Sum.ofIntegers()); - private final Aggregator<Integer, Integer> failure = - createAggregator(FAILURE_COUNTER, Sum.ofIntegers()); private final PCollectionView<ActualT> actual; private final PAssertionSite site; @@ -1146,7 +1200,7 @@ public class PAssert { @ProcessElement public void processElement(ProcessContext c) { ActualT actualContents = c.sideInput(actual); - doChecks(site, actualContents, checkerFn, success, failure); + c.output(doChecks(site, actualContents, checkerFn)); } } @@ -1157,12 +1211,8 @@ public class PAssert { * * <p>The singleton property is presumed, not enforced. */ - private static class GroupedValuesCheckerDoFn<ActualT> extends DoFn<ActualT, Void> { + private static class GroupedValuesCheckerDoFn<ActualT> extends DoFn<ActualT, SuccessOrFailure> { private final SerializableFunction<ActualT, Void> checkerFn; - private final Aggregator<Integer, Integer> success = - createAggregator(SUCCESS_COUNTER, Sum.ofIntegers()); - private final Aggregator<Integer, Integer> failure = - createAggregator(FAILURE_COUNTER, Sum.ofIntegers()); private final PAssertionSite site; private GroupedValuesCheckerDoFn( @@ -1173,7 +1223,11 @@ public class PAssert { @ProcessElement public void processElement(ProcessContext c) { - doChecks(site, c.element(), checkerFn, success, failure); + try { + c.output(doChecks(site, c.element(), checkerFn)); + } catch (Throwable t) { + throw t; + } } } @@ -1185,12 +1239,9 @@ public class PAssert { * <p>The singleton property of the input {@link PCollection} is presumed, not enforced. However, * each input element must be a singleton iterable, or this will fail. */ - private static class SingletonCheckerDoFn<ActualT> extends DoFn<Iterable<ActualT>, Void> { + private static class SingletonCheckerDoFn<ActualT> + extends DoFn<Iterable<ActualT>, SuccessOrFailure> { private final SerializableFunction<ActualT, Void> checkerFn; - private final Aggregator<Integer, Integer> success = - createAggregator(SUCCESS_COUNTER, Sum.ofIntegers()); - private final Aggregator<Integer, Integer> failure = - createAggregator(FAILURE_COUNTER, Sum.ofIntegers()); private final PAssertionSite site; private SingletonCheckerDoFn( @@ -1202,22 +1253,21 @@ public class PAssert { @ProcessElement public void processElement(ProcessContext c) { ActualT actualContents = Iterables.getOnlyElement(c.element()); - doChecks(site, actualContents, checkerFn, success, failure); + c.output(doChecks(site, actualContents, checkerFn)); } } - private static <ActualT> void doChecks( + protected static <ActualT> SuccessOrFailure doChecks( PAssertionSite site, ActualT actualContents, - SerializableFunction<ActualT, Void> checkerFn, - Aggregator<Integer, Integer> successAggregator, - Aggregator<Integer, Integer> failureAggregator) { + SerializableFunction<ActualT, Void> checkerFn) { + SuccessOrFailure result = SuccessOrFailure.success(); try { checkerFn.apply(actualContents); - successAggregator.addValue(1); } catch (Throwable t) { - failureAggregator.addValue(1); - throw site.wrap(t); + result = SuccessOrFailure.failure(site, t.getMessage()); + } finally { + return result; } } http://git-wip-us.apache.org/repos/asf/beam/blob/e8f0922f/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SuccessOrFailure.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SuccessOrFailure.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SuccessOrFailure.java new file mode 100644 index 0000000..04e3c35 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/SuccessOrFailure.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.testing; + +import com.google.common.base.MoreObjects; +import java.io.Serializable; +import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.DefaultCoder; +import org.apache.beam.sdk.coders.SerializableCoder; + +/** + * Output of {@link PAssert}. Passed to a conclude function to act upon. + */ +@DefaultCoder(SerializableCoder.class) +public final class SuccessOrFailure implements Serializable { + // TODO Add a SerializableThrowable. instead of relying on PAssertionSite.(BEAM-1898) + + private final boolean isSuccess; + @Nullable + private final PAssert.PAssertionSite site; + @Nullable + private final String message; + + private SuccessOrFailure() { + this(true, null, null); + } + + private SuccessOrFailure( + boolean isSuccess, + @Nullable PAssert.PAssertionSite site, + @Nullable String message) { + this.isSuccess = isSuccess; + this.site = site; + this.message = message; + } + + public boolean isSuccess() { + return isSuccess; + } + + @Nullable + public AssertionError assertionError() { + return site == null ? null : site.wrap(message); + } + + public static SuccessOrFailure success() { + return new SuccessOrFailure(true, null, null); + } + + public static SuccessOrFailure failure(@Nullable PAssert.PAssertionSite site, + @Nullable String message) { + return new SuccessOrFailure(false, site, message); + } + + public static SuccessOrFailure failure(@Nullable PAssert.PAssertionSite site) { + return new SuccessOrFailure(false, site, null); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("isSuccess", isSuccess()) + .addValue(message) + .omitNullValues() + .toString(); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/e8f0922f/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java index 9d580e4..2ef892c 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/PAssertTest.java @@ -20,6 +20,7 @@ package org.apache.beam.sdk.testing; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -36,8 +37,10 @@ import java.util.regex.Pattern; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.testing.PAssert.PCollectionContentsAssert.MatcherCheckerFn; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.Sum; @@ -46,6 +49,7 @@ import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.SlidingWindows; import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.common.ElementByteSizeObserver; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; @@ -117,6 +121,44 @@ public class PAssertTest implements Serializable { } } + @Test + public void testFailureEncodedDecoded() throws IOException { + AssertionError error = null; + try { + assertEquals(0, 1); + } catch (AssertionError e) { + error = e; + } + SuccessOrFailure failure = SuccessOrFailure.failure( + new PAssert.PAssertionSite(error.getMessage(), error.getStackTrace())); + SerializableCoder<SuccessOrFailure> coder = SerializableCoder.of(SuccessOrFailure.class); + + byte[] encoded = CoderUtils.encodeToByteArray(coder, failure); + SuccessOrFailure res = CoderUtils.decodeFromByteArray(coder, encoded); + + // Should compare strings, because throwables are not directly comparable. + assertEquals("Encode-decode failed SuccessOrFailure", + failure.assertionError().toString(), res.assertionError().toString()); + String resultStacktrace = Throwables.getStackTraceAsString(res.assertionError()); + String failureStacktrace = Throwables.getStackTraceAsString(failure.assertionError()); + assertThat(resultStacktrace, is(failureStacktrace)); + } + + @Test + public void testSuccessEncodedDecoded() throws IOException { + SuccessOrFailure success = SuccessOrFailure.success(); + SerializableCoder<SuccessOrFailure> coder = SerializableCoder.of(SuccessOrFailure.class); + + byte[] encoded = CoderUtils.encodeToByteArray(coder, success); + SuccessOrFailure res = CoderUtils.decodeFromByteArray(coder, encoded); + + assertEquals("Encode-decode successful SuccessOrFailure", + success.isSuccess(), res.isSuccess()); + assertEquals("Encode-decode successful SuccessOrFailure", + success.assertionError(), + res.assertionError()); + } + /** * A {@link PAssert} about the contents of a {@link PCollection} * must not require the contents of the {@link PCollection} to be @@ -452,6 +494,19 @@ public class PAssertTest implements Serializable { } @Test + public void testAssertionSiteIsCaptured() { + // This check should return a failure. + SuccessOrFailure res = PAssert.doChecks( + PAssert.PAssertionSite.capture("Captured assertion message."), + new Integer(10), + new MatcherCheckerFn(SerializableMatchers.contains(new Integer(11)))); + + String stacktrace = Throwables.getStackTraceAsString(res.assertionError()); + assertEquals(res.isSuccess(), false); + assertThat(stacktrace, containsString("PAssertionSite.capture")); + } + + @Test @Category(ValidatesRunner.class) public void testAssertionSiteIsCapturedWithMessage() throws Exception { PCollection<Long> vals = pipeline.apply(GenerateSequence.from(0).to(5));
