[BEAM-85, BEAM-298] Make TestPipeline a JUnit Rule checking proper usage

Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/443b25a4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/443b25a4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/443b25a4

Branch: refs/heads/gearpump-runner
Commit: 443b25a4d11201fb88f40da437ec7aab4b3e273f
Parents: 33b7ca7
Author: Stas Levin <stasle...@gmail.com>
Authored: Tue Dec 13 19:27:41 2016 +0200
Committer: Kenneth Knowles <k...@google.com>
Committed: Sat Dec 17 14:11:39 2016 -0800

----------------------------------------------------------------------
 .../apache/beam/sdk/testing/TestPipeline.java   | 207 ++++++++++++++++---
 .../beam/sdk/testing/TestPipelineTest.java      | 183 ++++++++++++++--
 2 files changed, 344 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/443b25a4/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java
index 493d4cc..49ac3af 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestPipeline.java
@@ -23,12 +23,17 @@ import com.fasterxml.jackson.databind.JsonNode;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.fasterxml.jackson.databind.node.ObjectNode;
 import com.google.common.base.Optional;
+import com.google.common.base.Predicate;
+import com.google.common.base.Predicates;
 import com.google.common.base.Strings;
+import com.google.common.collect.FluentIterable;
 import com.google.common.collect.Iterators;
 import java.io.IOException;
 import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
 import java.util.Map.Entry;
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.Pipeline;
@@ -39,34 +44,39 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptions.CheckEnabled;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.runners.PipelineRunner;
+import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.util.IOChannelUtils;
 import org.apache.beam.sdk.util.TestCredential;
 import org.junit.experimental.categories.Category;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runners.model.Statement;
 
 /**
- * A creator of test pipelines that can be used inside of tests that can be
- * configured to run locally or against a remote pipeline runner.
+ * A creator of test pipelines that can be used inside of tests that can be 
configured to run
+ * locally or against a remote pipeline runner.
  *
- * <p>It is recommended to tag hand-selected tests for this purpose using the
- * {@link RunnableOnService} {@link Category} annotation, as each test run 
against a pipeline runner
- * will utilize resources of that pipeline runner.
+ * <p>It is recommended to tag hand-selected tests for this purpose using the 
{@link
+ * RunnableOnService} {@link Category} annotation, as each test run against a 
pipeline runner will
+ * utilize resources of that pipeline runner.
  *
  * <p>In order to run tests on a pipeline runner, the following conditions 
must be met:
+ *
  * <ul>
- *   <li>System property "beamTestPipelineOptions" must contain a JSON 
delimited list of pipeline
- *   options. For example:
- *   <pre>{@code [
+ * <li>System property "beamTestPipelineOptions" must contain a JSON delimited 
list of pipeline
+ *     options. For example:
+ *     <pre>{@code [
  *     "--runner=org.apache.beam.runners.dataflow.testing.TestDataflowRunner",
  *     "--project=mygcpproject",
  *     "--stagingLocation=gs://mygcsbucket/path"
  *     ]}</pre>
  *     Note that the set of pipeline options required is pipeline runner 
specific.
- *   </li>
- *   <li>Jars containing the SDK and test classes must be available on the 
classpath.</li>
+ * <li>Jars containing the SDK and test classes must be available on the 
classpath.
  * </ul>
  *
  * <p>Use {@link PAssert} for tests, as it integrates with this test harness 
in both direct and
  * remote execution modes. For example:
+ *
  * <pre>{@code
  * Pipeline p = TestPipeline.create();
  * PCollection<Integer> output = ...
@@ -76,19 +86,136 @@ import org.junit.experimental.categories.Category;
  * p.run();
  * }</pre>
  *
- * <p>For pipeline runners, it is required that they must throw an {@link 
AssertionError}
- * containing the message from the {@link PAssert} that failed.
+ * <p>For pipeline runners, it is required that they must throw an {@link 
AssertionError} containing
+ * the message from the {@link PAssert} that failed.
  */
-public class TestPipeline extends Pipeline {
+public class TestPipeline extends Pipeline implements TestRule {
+
+  private static class PipelineRunEnforcement {
+
+    protected boolean enableAutoRunIfMissing;
+    protected final Pipeline pipeline;
+    private boolean runInvoked;
+
+    private PipelineRunEnforcement(final Pipeline pipeline) {
+      this.pipeline = pipeline;
+    }
+
+    private void enableAutoRunIfMissing(final boolean enable) {
+      enableAutoRunIfMissing = enable;
+    }
+
+    protected void beforePipelineExecution() {
+      runInvoked = true;
+    }
+
+    protected void afterTestCompletion() {
+      if (!runInvoked && enableAutoRunIfMissing) {
+        pipeline.run().waitUntilFinish();
+      }
+    }
+  }
+
+  private static class PipelineAbandonedNodeEnforcement extends 
PipelineRunEnforcement {
+
+    private List<TransformHierarchy.Node> runVisitedNodes;
+
+    private final Predicate<TransformHierarchy.Node> isPAssertNode =
+        new Predicate<TransformHierarchy.Node>() {
+
+          @Override
+          public boolean apply(final TransformHierarchy.Node node) {
+            return node.getTransform() instanceof PAssert.GroupThenAssert
+                || node.getTransform() instanceof 
PAssert.GroupThenAssertForSingleton
+                || node.getTransform() instanceof PAssert.OneSideInputAssert;
+          }
+        };
+
+    private static class NodeRecorder extends PipelineVisitor.Defaults {
+
+      private final List<TransformHierarchy.Node> visited = new LinkedList<>();
+
+      @Override
+      public void leaveCompositeTransform(final TransformHierarchy.Node node) {
+        visited.add(node);
+      }
+
+      @Override
+      public void visitPrimitiveTransform(final TransformHierarchy.Node node) {
+        visited.add(node);
+      }
+    }
+
+    private PipelineAbandonedNodeEnforcement(final TestPipeline pipeline) {
+      super(pipeline);
+    }
+
+    private List<TransformHierarchy.Node> recordPipelineNodes(final Pipeline 
pipeline) {
+      final NodeRecorder nodeRecorder = new NodeRecorder();
+      pipeline.traverseTopologically(nodeRecorder);
+      return nodeRecorder.visited;
+    }
+
+    private void verifyPipelineExecution() {
+      final List<TransformHierarchy.Node> pipelineNodes = 
recordPipelineNodes(pipeline);
+      if (runVisitedNodes != null && !runVisitedNodes.equals(pipelineNodes)) {
+        final boolean hasDanglingPAssert =
+            FluentIterable.from(pipelineNodes)
+                .filter(Predicates.not(Predicates.in(runVisitedNodes)))
+                .anyMatch(isPAssertNode);
+        if (hasDanglingPAssert) {
+          throw new AbandonedNodeException("The pipeline contains abandoned 
PAssert(s).");
+        } else {
+          throw new AbandonedNodeException("The pipeline contains abandoned 
PTransform(s).");
+        }
+      } else if (runVisitedNodes == null && !enableAutoRunIfMissing) {
+        throw new PipelineRunMissingException("The pipeline has not been 
run.");
+      }
+    }
+
+    @Override
+    protected void beforePipelineExecution() {
+      super.beforePipelineExecution();
+      runVisitedNodes = recordPipelineNodes(pipeline);
+    }
+
+    @Override
+    protected void afterTestCompletion() {
+      super.afterTestCompletion();
+      verifyPipelineExecution();
+    }
+  }
+
+  /**
+   * An exception thrown in case an abandoned {@link 
org.apache.beam.sdk.transforms.PTransform} is
+   * detected, that is, a {@link org.apache.beam.sdk.transforms.PTransform} 
that has not been run.
+   */
+  public static class AbandonedNodeException extends RuntimeException {
+
+    AbandonedNodeException(final String msg) {
+      super(msg);
+    }
+  }
+
+  /** An exception thrown in case a test finishes without invoking {@link 
Pipeline#run()}. */
+  public static class PipelineRunMissingException extends RuntimeException {
+
+    PipelineRunMissingException(final String msg) {
+      super(msg);
+    }
+  }
+
   static final String PROPERTY_BEAM_TEST_PIPELINE_OPTIONS = 
"beamTestPipelineOptions";
   static final String PROPERTY_USE_DEFAULT_DUMMY_RUNNER = "beamUseDummyRunner";
   private static final ObjectMapper MAPPER = new ObjectMapper();
 
+  private PipelineRunEnforcement enforcement = new 
PipelineAbandonedNodeEnforcement(this);
+
   /**
    * Creates and returns a new test pipeline.
    *
-   * <p>Use {@link PAssert} to add tests, then call
-   * {@link Pipeline#run} to execute the pipeline and check the tests.
+   * <p>Use {@link PAssert} to add tests, then call {@link Pipeline#run} to 
execute the pipeline and
+   * check the tests.
    */
   public static TestPipeline create() {
     return fromOptions(testingPipelineOptions());
@@ -98,16 +225,30 @@ public class TestPipeline extends Pipeline {
     return new TestPipeline(PipelineRunner.fromOptions(options), options);
   }
 
-  private TestPipeline(PipelineRunner<? extends PipelineResult> runner, 
PipelineOptions options) {
+  private TestPipeline(
+      final PipelineRunner<? extends PipelineResult> runner, final 
PipelineOptions options) {
     super(runner, options);
   }
 
+  @Override
+  public Statement apply(final Statement statement, final Description 
description) {
+    return new Statement() {
+
+      @Override
+      public void evaluate() throws Throwable {
+        statement.evaluate();
+        enforcement.afterTestCompletion();
+      }
+    };
+  }
+
   /**
-   * Runs this {@link TestPipeline}, unwrapping any {@code AssertionError}
-   * that is raised during testing.
+   * Runs this {@link TestPipeline}, unwrapping any {@code AssertionError} 
that is raised during
+   * testing.
    */
   @Override
   public PipelineResult run() {
+    enforcement.beforePipelineExecution();
     try {
       return super.run();
     } catch (RuntimeException exc) {
@@ -120,18 +261,28 @@ public class TestPipeline extends Pipeline {
     }
   }
 
+  public TestPipeline enableAbandonedNodeEnforcement(final boolean enable) {
+    enforcement =
+        enable ? new PipelineAbandonedNodeEnforcement(this) : new 
PipelineRunEnforcement(this);
+
+    return this;
+  }
+
+  public TestPipeline enableAutoRunIfMissing(final boolean enable) {
+    enforcement.enableAutoRunIfMissing(enable);
+    return this;
+  }
+
   @Override
   public String toString() {
     return "TestPipeline#" + 
getOptions().as(ApplicationNameOptions.class).getAppName();
   }
 
-  /**
-   * Creates {@link PipelineOptions} for testing.
-   */
+  /** Creates {@link PipelineOptions} for testing. */
   public static PipelineOptions testingPipelineOptions() {
     try {
-      @Nullable String beamTestPipelineOptions =
-          System.getProperty(PROPERTY_BEAM_TEST_PIPELINE_OPTIONS);
+      @Nullable
+      String beamTestPipelineOptions = 
System.getProperty(PROPERTY_BEAM_TEST_PIPELINE_OPTIONS);
 
       PipelineOptions options =
           Strings.isNullOrEmpty(beamTestPipelineOptions)
@@ -155,13 +306,15 @@ public class TestPipeline extends Pipeline {
       IOChannelUtils.registerIOFactoriesAllowOverride(options);
       return options;
     } catch (IOException e) {
-      throw new RuntimeException("Unable to instantiate test options from 
system property "
-          + PROPERTY_BEAM_TEST_PIPELINE_OPTIONS + ":"
-          + System.getProperty(PROPERTY_BEAM_TEST_PIPELINE_OPTIONS), e);
+      throw new RuntimeException(
+          "Unable to instantiate test options from system property "
+              + PROPERTY_BEAM_TEST_PIPELINE_OPTIONS
+              + ":"
+              + System.getProperty(PROPERTY_BEAM_TEST_PIPELINE_OPTIONS),
+          e);
     }
   }
 
-
   public static String[] convertToArgs(PipelineOptions options) {
     try {
       byte[] opts = MAPPER.writeValueAsBytes(options);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/443b25a4/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestPipelineTest.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestPipelineTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestPipelineTest.java
index 03563f3..d1797e7 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestPipelineTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestPipelineTest.java
@@ -24,30 +24,54 @@ import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertThat;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
+import java.io.IOException;
+import java.io.Serializable;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.Date;
 import java.util.List;
 import java.util.UUID;
+import org.apache.beam.sdk.AggregatorRetrievalException;
+import org.apache.beam.sdk.AggregatorValues;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.metrics.MetricResults;
 import org.apache.beam.sdk.options.ApplicationNameOptions;
 import org.apache.beam.sdk.options.GcpOptions;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.runners.PipelineRunner;
+import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.values.PCollection;
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
+import org.joda.time.Duration;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
+import org.junit.rules.RuleChain;
 import org.junit.rules.TestRule;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
 /** Tests for {@link TestPipeline}. */
 @RunWith(JUnit4.class)
-public class TestPipelineTest {
-  @Rule public TestRule restoreSystemProperties = new 
RestoreSystemProperties();
-  @Rule public ExpectedException thrown = ExpectedException.none();
+public class TestPipelineTest implements Serializable {
+  private static final List<String> WORDS = Collections.singletonList("hi 
there");
+  private static final String DUMMY = "expected";
+
+  private final transient TestPipeline pipeline =
+      
TestPipeline.fromOptions(pipelineOptions()).enableAbandonedNodeEnforcement(true);
+
+  private final transient ExpectedException exception = 
ExpectedException.none();
+
+  @Rule public transient TestRule restoreSystemProperties = new 
RestoreSystemProperties();
+  @Rule public transient ExpectedException thrown = ExpectedException.none();
+  @Rule public transient RuleChain ruleOrder = 
RuleChain.outerRule(exception).around(pipeline);
 
   @Test
   public void testCreationUsingDefaults() {
@@ -57,13 +81,13 @@ public class TestPipelineTest {
   @Test
   public void testCreationOfPipelineOptions() throws Exception {
     ObjectMapper mapper = new ObjectMapper();
-    String stringOptions = mapper.writeValueAsString(new String[]{
-      "--runner=org.apache.beam.sdk.testing.CrashingRunner",
-      "--project=testProject"
-    });
+    String stringOptions =
+        mapper.writeValueAsString(
+            new String[] {
+              "--runner=org.apache.beam.sdk.testing.CrashingRunner", 
"--project=testProject"
+            });
     System.getProperties().put("beamTestPipelineOptions", stringOptions);
-    GcpOptions options =
-        TestPipeline.testingPipelineOptions().as(GcpOptions.class);
+    GcpOptions options = 
TestPipeline.testingPipelineOptions().as(GcpOptions.class);
     assertEquals(CrashingRunner.class, options.getRunner());
     assertEquals(options.getProject(), "testProject");
   }
@@ -71,8 +95,10 @@ public class TestPipelineTest {
   @Test
   public void testCreationOfPipelineOptionsFromReallyVerboselyNamedTestCase() 
throws Exception {
     PipelineOptions options = TestPipeline.testingPipelineOptions();
-    assertThat(options.as(ApplicationNameOptions.class).getAppName(), 
startsWith(
-        
"TestPipelineTest-testCreationOfPipelineOptionsFromReallyVerboselyNamedTestCase"));
+    assertThat(
+        options.as(ApplicationNameOptions.class).getAppName(),
+        startsWith(
+            
"TestPipelineTest-testCreationOfPipelineOptionsFromReallyVerboselyNamedTestCase"));
   }
 
   @Test
@@ -96,13 +122,13 @@ public class TestPipelineTest {
 
   @Test
   public void testConvertToArgs() {
-    String[] args = new String[]{"--tempLocation=Test_Location"};
+    String[] args = new String[] {"--tempLocation=Test_Location"};
     PipelineOptions options = 
PipelineOptionsFactory.fromArgs(args).as(PipelineOptions.class);
     String[] arr = TestPipeline.convertToArgs(options);
     List<String> lst = Arrays.asList(arr);
     assertEquals(lst.size(), 2);
-    assertThat(lst, containsInAnyOrder("--tempLocation=Test_Location",
-          "--appName=TestPipelineTest"));
+    assertThat(
+        lst, containsInAnyOrder("--tempLocation=Test_Location", 
"--appName=TestPipelineTest"));
   }
 
   @Test
@@ -131,8 +157,8 @@ public class TestPipelineTest {
     opts.setOnSuccessMatcher(m2);
 
     String[] arr = TestPipeline.convertToArgs(opts);
-    TestPipelineOptions newOpts = PipelineOptionsFactory.fromArgs(arr)
-        .as(TestPipelineOptions.class);
+    TestPipelineOptions newOpts =
+        PipelineOptionsFactory.fromArgs(arr).as(TestPipelineOptions.class);
 
     assertEquals(m1, newOpts.getOnCreateMatcher());
     assertEquals(m2, newOpts.getOnSuccessMatcher());
@@ -150,12 +176,11 @@ public class TestPipelineTest {
     pipeline.run();
   }
 
-  /**
-   * TestMatcher is a matcher designed for testing matcher 
serialization/deserialization.
-   */
+  /** TestMatcher is a matcher designed for testing matcher 
serialization/deserialization. */
   public static class TestMatcher extends BaseMatcher<PipelineResult>
       implements SerializableMatcher<PipelineResult> {
     private final UUID uuid = UUID.randomUUID();
+
     @Override
     public boolean matches(Object o) {
       return true;
@@ -180,4 +205,124 @@ public class TestPipelineTest {
       return uuid.hashCode();
     }
   }
+
+  private static class DummyRunner extends PipelineRunner<PipelineResult> {
+
+    @SuppressWarnings("unused") // used by reflection
+    public static DummyRunner fromOptions(final PipelineOptions opts) {
+      return new DummyRunner();
+    }
+
+    @Override
+    public PipelineResult run(final Pipeline pipeline) {
+      return new PipelineResult() {
+
+        @Override
+        public State getState() {
+          return null;
+        }
+
+        @Override
+        public State cancel() throws IOException {
+          return null;
+        }
+
+        @Override
+        public State waitUntilFinish(final Duration duration) {
+          return null;
+        }
+
+        @Override
+        public State waitUntilFinish() {
+          return null;
+        }
+
+        @Override
+        public <T> AggregatorValues<T> getAggregatorValues(final Aggregator<?, 
T> aggregator)
+            throws AggregatorRetrievalException {
+          return null;
+        }
+
+        @Override
+        public MetricResults metrics() {
+          return null;
+        }
+      };
+    }
+  }
+
+  private static PipelineOptions pipelineOptions() {
+    final PipelineOptions pipelineOptions = PipelineOptionsFactory.create();
+    pipelineOptions.setRunner(DummyRunner.class);
+    return pipelineOptions;
+  }
+
+  private PCollection<String> pCollection() {
+    return 
addTransform(pipeline.apply(Create.of(WORDS).withCoder(StringUtf8Coder.of())));
+  }
+
+  private PCollection<String> addTransform(final PCollection<String> 
pCollection) {
+    return pCollection.apply(
+        MapElements.via(
+            new SimpleFunction<String, String>() {
+
+              @Override
+              public String apply(final String input) {
+                return DUMMY;
+              }
+            }));
+  }
+
+  @Test
+  public void testPipelineRunMissing() throws Throwable {
+    exception.expect(TestPipeline.PipelineRunMissingException.class);
+    PAssert.that(pCollection()).containsInAnyOrder(DUMMY);
+    // missing pipeline#run
+  }
+
+  @Test
+  public void testPipelineHasAbandonedPAssertNode() throws Throwable {
+    exception.expect(TestPipeline.AbandonedNodeException.class);
+    exception.expectMessage("PAssert");
+
+    final PCollection<String> pCollection = pCollection();
+    PAssert.that(pCollection).containsInAnyOrder(DUMMY);
+    pipeline.run().waitUntilFinish();
+
+    // dangling PAssert
+    PAssert.that(pCollection).containsInAnyOrder(DUMMY);
+  }
+
+  @Test
+  public void testPipelineHasAbandonedPTransformNode() throws Throwable {
+    exception.expect(TestPipeline.AbandonedNodeException.class);
+    exception.expectMessage("PTransform");
+
+    final PCollection<String> pCollection = pCollection();
+    PAssert.that(pCollection).containsInAnyOrder(DUMMY);
+    pipeline.run().waitUntilFinish();
+
+    // dangling PTransform
+    addTransform(pCollection);
+  }
+
+  @Test
+  public void testNormalFlowWithPAssert() throws Throwable {
+    PAssert.that(pCollection()).containsInAnyOrder(DUMMY);
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testAutoAddMissingRunFlow() throws Throwable {
+    PAssert.that(pCollection()).containsInAnyOrder(DUMMY);
+    // missing pipeline#run, but have it auto-added.
+    pipeline.enableAutoRunIfMissing(true);
+  }
+
+  @Test
+  public void testDisableStrictPAssertFlow() throws Throwable {
+    pCollection();
+    // dangling PTransform, but ignore it
+    pipeline.enableAbandonedNodeEnforcement(false);
+  }
 }

Reply via email to