http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e13cacb8/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java new file mode 100644 index 0000000..491363a --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.when; + +import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; +import org.apache.beam.runners.direct.InProcessPipelineRunner.UncommittedBundle; +import org.apache.beam.sdk.coders.BigEndianLongCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.io.CountingSource; +import org.apache.beam.sdk.io.Read; +import org.apache.beam.sdk.io.Read.Bounded; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; + +import com.google.common.collect.ImmutableList; + +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * Tests for {@link BoundedReadEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class BoundedReadEvaluatorFactoryTest { + private BoundedSource<Long> source; + private PCollection<Long> longs; + private TransformEvaluatorFactory factory; + @Mock private InProcessEvaluationContext context; + private BundleFactory bundleFactory; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + source = CountingSource.upTo(10L); + TestPipeline p = TestPipeline.create(); + longs = p.apply(Read.from(source)); + + factory = new BoundedReadEvaluatorFactory(); + bundleFactory = InProcessBundleFactory.create(); + } + + @Test + public void boundedSourceInMemoryTransformEvaluatorProducesElements() throws Exception { + UncommittedBundle<Long> output = bundleFactory.createRootBundle(longs); + when(context.createRootBundle(longs)).thenReturn(output); + + TransformEvaluator<?> evaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + InProcessTransformResult result = evaluator.finishBundle(); + assertThat(result.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); + assertThat( + output.commit(BoundedWindow.TIMESTAMP_MAX_VALUE).getElements(), + containsInAnyOrder( + gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); + } + + /** + * Demonstrate that acquiring multiple {@link TransformEvaluator TransformEvaluators} for the same + * {@link Bounded Read.Bounded} application with the same evaluation context only produces the + * elements once. + */ + @Test + public void boundedSourceInMemoryTransformEvaluatorAfterFinishIsEmpty() throws Exception { + UncommittedBundle<Long> output = bundleFactory.createRootBundle(longs); + when(context.createRootBundle(longs)).thenReturn(output); + + TransformEvaluator<?> evaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + InProcessTransformResult result = evaluator.finishBundle(); + assertThat(result.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); + Iterable<? extends WindowedValue<Long>> outputElements = + output.commit(BoundedWindow.TIMESTAMP_MAX_VALUE).getElements(); + assertThat( + outputElements, + containsInAnyOrder( + gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); + + UncommittedBundle<Long> secondOutput = bundleFactory.createRootBundle(longs); + when(context.createRootBundle(longs)).thenReturn(secondOutput); + TransformEvaluator<?> secondEvaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + InProcessTransformResult secondResult = secondEvaluator.finishBundle(); + assertThat(secondResult.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + assertThat(secondResult.getOutputBundles(), emptyIterable()); + assertThat( + secondOutput.commit(BoundedWindow.TIMESTAMP_MAX_VALUE).getElements(), emptyIterable()); + assertThat( + outputElements, + containsInAnyOrder( + gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); + } + + /** + * Demonstrates that acquiring multiple evaluators from the factory are independent, but + * the elements in the source are only produced once. + */ + @Test + public void boundedSourceEvaluatorSimultaneousEvaluations() throws Exception { + UncommittedBundle<Long> output = bundleFactory.createRootBundle(longs); + UncommittedBundle<Long> secondOutput = bundleFactory.createRootBundle(longs); + when(context.createRootBundle(longs)).thenReturn(output).thenReturn(secondOutput); + + // create both evaluators before finishing either. + TransformEvaluator<?> evaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + TransformEvaluator<?> secondEvaluator = + factory.forApplication(longs.getProducingTransformInternal(), null, context); + + InProcessTransformResult secondResult = secondEvaluator.finishBundle(); + + InProcessTransformResult result = evaluator.finishBundle(); + assertThat(result.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE)); + Iterable<? extends WindowedValue<Long>> outputElements = + output.commit(BoundedWindow.TIMESTAMP_MAX_VALUE).getElements(); + + assertThat( + outputElements, + containsInAnyOrder( + gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); + assertThat(secondResult.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + assertThat(secondResult.getOutputBundles(), emptyIterable()); + assertThat( + secondOutput.commit(BoundedWindow.TIMESTAMP_MAX_VALUE).getElements(), emptyIterable()); + assertThat( + outputElements, + containsInAnyOrder( + gw(1L), gw(2L), gw(4L), gw(8L), gw(9L), gw(7L), gw(6L), gw(5L), gw(3L), gw(0L))); + } + + @Test + public void boundedSourceEvaluatorClosesReader() throws Exception { + TestSource<Long> source = new TestSource<>(BigEndianLongCoder.of(), 1L, 2L, 3L); + + TestPipeline p = TestPipeline.create(); + PCollection<Long> pcollection = p.apply(Read.from(source)); + AppliedPTransform<?, ?, ?> sourceTransform = pcollection.getProducingTransformInternal(); + + UncommittedBundle<Long> output = bundleFactory.createRootBundle(pcollection); + when(context.createRootBundle(pcollection)).thenReturn(output); + + TransformEvaluator<?> evaluator = factory.forApplication(sourceTransform, null, context); + evaluator.finishBundle(); + CommittedBundle<Long> committed = output.commit(Instant.now()); + assertThat(committed.getElements(), containsInAnyOrder(gw(2L), gw(3L), gw(1L))); + assertThat(TestSource.readerClosed, is(true)); + } + + @Test + public void boundedSourceEvaluatorNoElementsClosesReader() throws Exception { + TestSource<Long> source = new TestSource<>(BigEndianLongCoder.of()); + + TestPipeline p = TestPipeline.create(); + PCollection<Long> pcollection = p.apply(Read.from(source)); + AppliedPTransform<?, ?, ?> sourceTransform = pcollection.getProducingTransformInternal(); + + UncommittedBundle<Long> output = bundleFactory.createRootBundle(pcollection); + when(context.createRootBundle(pcollection)).thenReturn(output); + + TransformEvaluator<?> evaluator = factory.forApplication(sourceTransform, null, context); + evaluator.finishBundle(); + CommittedBundle<Long> committed = output.commit(Instant.now()); + assertThat(committed.getElements(), emptyIterable()); + assertThat(TestSource.readerClosed, is(true)); + } + + private static class TestSource<T> extends BoundedSource<T> { + private static boolean readerClosed; + private final Coder<T> coder; + private final T[] elems; + + public TestSource(Coder<T> coder, T... elems) { + this.elems = elems; + this.coder = coder; + readerClosed = false; + } + + @Override + public List<? extends BoundedSource<T>> splitIntoBundles( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + return ImmutableList.of(this); + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return 0; + } + + @Override + public boolean producesSortedKeys(PipelineOptions options) throws Exception { + return false; + } + + @Override + public BoundedSource.BoundedReader<T> createReader(PipelineOptions options) throws IOException { + return new TestReader<>(this, elems); + } + + @Override + public void validate() { + } + + @Override + public Coder<T> getDefaultOutputCoder() { + return coder; + } + } + + private static class TestReader<T> extends BoundedReader<T> { + private final BoundedSource<T> source; + private final List<T> elems; + private int index; + + public TestReader(BoundedSource<T> source, T... elems) { + this.source = source; + this.elems = Arrays.asList(elems); + this.index = -1; + } + + @Override + public BoundedSource<T> getCurrentSource() { + return source; + } + + @Override + public boolean start() throws IOException { + return advance(); + } + + @Override + public boolean advance() throws IOException { + if (elems.size() > index + 1) { + index++; + return true; + } + return false; + } + + @Override + public T getCurrent() throws NoSuchElementException { + return elems.get(index); + } + + @Override + public void close() throws IOException { + TestSource.readerClosed = true; + } + } + + private static WindowedValue<Long> gw(Long elem) { + return WindowedValue.valueInGlobalWindow(elem); + } +}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e13cacb8/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java new file mode 100644 index 0000000..b30e005 --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.direct; + +import static org.junit.Assert.assertThat; + +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; + +import com.google.common.collect.ImmutableList; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.Collections; +import java.util.List; + +/** + * Tests for {@link CommittedResult}. + */ +@RunWith(JUnit4.class) +public class CommittedResultTest implements Serializable { + private transient TestPipeline p = TestPipeline.create(); + private transient AppliedPTransform<?, ?, ?> transform = + AppliedPTransform.of("foo", p.begin(), PDone.in(p), new PTransform<PBegin, PDone>() { + }); + private transient BundleFactory bundleFactory = InProcessBundleFactory.create(); + + @Test + public void getTransformExtractsFromResult() { + CommittedResult result = + CommittedResult.create(StepTransformResult.withoutHold(transform).build(), + Collections.<InProcessPipelineRunner.CommittedBundle<?>>emptyList()); + + assertThat(result.getTransform(), Matchers.<AppliedPTransform<?, ?, ?>>equalTo(transform)); + } + + @Test + public void getOutputsEqualInput() { + List<? extends InProcessPipelineRunner.CommittedBundle<?>> outputs = + ImmutableList.of(bundleFactory.createRootBundle(PCollection.createPrimitiveOutputInternal(p, + WindowingStrategy.globalDefault(), + PCollection.IsBounded.BOUNDED)).commit(Instant.now()), + bundleFactory.createRootBundle(PCollection.createPrimitiveOutputInternal(p, + WindowingStrategy.globalDefault(), + PCollection.IsBounded.UNBOUNDED)).commit(Instant.now())); + CommittedResult result = + CommittedResult.create(StepTransformResult.withoutHold(transform).build(), outputs); + + assertThat(result.getOutputs(), Matchers.containsInAnyOrder(outputs.toArray())); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e13cacb8/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitorTest.java new file mode 100644 index 0000000..353eef6 --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ConsumerTrackingPipelineVisitorTest.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import static org.hamcrest.Matchers.emptyIterable; +import static org.junit.Assert.assertThat; + +import org.apache.beam.sdk.io.CountingInput; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.PValue; + +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.List; + +/** + * Tests for {@link ConsumerTrackingPipelineVisitor}. + */ +@RunWith(JUnit4.class) +public class ConsumerTrackingPipelineVisitorTest implements Serializable { + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + private transient TestPipeline p = TestPipeline.create(); + private transient ConsumerTrackingPipelineVisitor visitor = new ConsumerTrackingPipelineVisitor(); + + @Test + public void getViewsReturnsViews() { + PCollectionView<List<String>> listView = + p.apply("listCreate", Create.of("foo", "bar")) + .apply( + ParDo.of( + new DoFn<String, String>() { + @Override + public void processElement(DoFn<String, String>.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })) + .apply(View.<String>asList()); + PCollectionView<Object> singletonView = + p.apply("singletonCreate", Create.<Object>of(1, 2, 3)).apply(View.<Object>asSingleton()); + p.traverseTopologically(visitor); + assertThat( + visitor.getViews(), + Matchers.<PCollectionView<?>>containsInAnyOrder(listView, singletonView)); + } + + @Test + public void getRootTransformsContainsPBegins() { + PCollection<String> created = p.apply(Create.of("foo", "bar")); + PCollection<Long> counted = p.apply(CountingInput.upTo(1234L)); + PCollection<Long> unCounted = p.apply(CountingInput.unbounded()); + p.traverseTopologically(visitor); + assertThat( + visitor.getRootTransforms(), + Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( + created.getProducingTransformInternal(), + counted.getProducingTransformInternal(), + unCounted.getProducingTransformInternal())); + } + + @Test + public void getRootTransformsContainsEmptyFlatten() { + PCollection<String> empty = + PCollectionList.<String>empty(p).apply(Flatten.<String>pCollections()); + p.traverseTopologically(visitor); + assertThat( + visitor.getRootTransforms(), + Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( + empty.getProducingTransformInternal())); + } + + @Test + public void getValueToConsumersSucceeds() { + PCollection<String> created = p.apply(Create.of("1", "2", "3")); + PCollection<String> transformed = + created.apply( + ParDo.of( + new DoFn<String, String>() { + @Override + public void processElement(DoFn<String, String>.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })); + + PCollection<String> flattened = + PCollectionList.of(created).and(transformed).apply(Flatten.<String>pCollections()); + + p.traverseTopologically(visitor); + + assertThat( + visitor.getValueToConsumers().get(created), + Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( + transformed.getProducingTransformInternal(), + flattened.getProducingTransformInternal())); + assertThat( + visitor.getValueToConsumers().get(transformed), + Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( + flattened.getProducingTransformInternal())); + assertThat(visitor.getValueToConsumers().get(flattened), emptyIterable()); + } + + @Test + public void getUnfinalizedPValuesContainsDanglingOutputs() { + PCollection<String> created = p.apply(Create.of("1", "2", "3")); + PCollection<String> transformed = + created.apply( + ParDo.of( + new DoFn<String, String>() { + @Override + public void processElement(DoFn<String, String>.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })); + + p.traverseTopologically(visitor); + assertThat(visitor.getUnfinalizedPValues(), Matchers.<PValue>contains(transformed)); + } + + @Test + public void getUnfinalizedPValuesEmpty() { + p.apply(Create.of("1", "2", "3")) + .apply( + ParDo.of( + new DoFn<String, String>() { + @Override + public void processElement(DoFn<String, String>.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })) + .apply( + new PTransform<PInput, PDone>() { + @Override + public PDone apply(PInput input) { + return PDone.in(input.getPipeline()); + } + }); + + p.traverseTopologically(visitor); + assertThat(visitor.getUnfinalizedPValues(), emptyIterable()); + } + + @Test + public void getStepNamesContainsAllTransforms() { + PCollection<String> created = p.apply(Create.of("1", "2", "3")); + PCollection<String> transformed = + created.apply( + ParDo.of( + new DoFn<String, String>() { + @Override + public void processElement(DoFn<String, String>.ProcessContext c) + throws Exception { + c.output(Integer.toString(c.element().length())); + } + })); + PDone finished = + transformed.apply( + new PTransform<PInput, PDone>() { + @Override + public PDone apply(PInput input) { + return PDone.in(input.getPipeline()); + } + }); + + p.traverseTopologically(visitor); + assertThat( + visitor.getStepNames(), + Matchers.<AppliedPTransform<?, ?, ?>, String>hasEntry( + created.getProducingTransformInternal(), "s0")); + assertThat( + visitor.getStepNames(), + Matchers.<AppliedPTransform<?, ?, ?>, String>hasEntry( + transformed.getProducingTransformInternal(), "s1")); + assertThat( + visitor.getStepNames(), + Matchers.<AppliedPTransform<?, ?, ?>, String>hasEntry( + finished.getProducingTransformInternal(), "s2")); + } + + @Test + public void traverseMultipleTimesThrows() { + p.apply(Create.of(1, 2, 3)); + + p.traverseTopologically(visitor); + thrown.expect(IllegalStateException.class); + thrown.expectMessage(ConsumerTrackingPipelineVisitor.class.getSimpleName()); + thrown.expectMessage("is finalized"); + p.traverseTopologically(visitor); + } + + @Test + public void traverseIndependentPathsSucceeds() { + p.apply("left", Create.of(1, 2, 3)); + p.apply("right", Create.of("foo", "bar", "baz")); + + p.traverseTopologically(visitor); + } + + @Test + public void getRootTransformsWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getRootTransforms"); + visitor.getRootTransforms(); + } + @Test + public void getStepNamesWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getStepNames"); + visitor.getStepNames(); + } + @Test + public void getUnfinalizedPValuesWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getUnfinalizedPValues"); + visitor.getUnfinalizedPValues(); + } + + @Test + public void getValueToConsumersWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getValueToConsumers"); + visitor.getValueToConsumers(); + } + + @Test + public void getViewsWithoutVisitingThrows() { + thrown.expect(IllegalStateException.class); + thrown.expectMessage("completely traversed"); + thrown.expectMessage("getViews"); + visitor.getViews(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e13cacb8/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java new file mode 100644 index 0000000..9a358dd --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import static org.hamcrest.Matchers.isA; + +import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; +import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.UserCodeException; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; + +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Collections; + +/** + * Tests for {@link EncodabilityEnforcementFactory}. + */ +@RunWith(JUnit4.class) +public class EncodabilityEnforcementFactoryTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + private EncodabilityEnforcementFactory factory = EncodabilityEnforcementFactory.create(); + private BundleFactory bundleFactory = InProcessBundleFactory.create(); + + @Test + public void encodeFailsThrows() { + WindowedValue<Record> record = WindowedValue.valueInGlobalWindow(new Record()); + + ModelEnforcement<Record> enforcement = createEnforcement(new RecordNoEncodeCoder(), record); + + thrown.expect(UserCodeException.class); + thrown.expectCause(isA(CoderException.class)); + thrown.expectMessage("Encode not allowed"); + enforcement.beforeElement(record); + } + + @Test + public void decodeFailsThrows() { + WindowedValue<Record> record = WindowedValue.valueInGlobalWindow(new Record()); + + ModelEnforcement<Record> enforcement = createEnforcement(new RecordNoDecodeCoder(), record); + + thrown.expect(UserCodeException.class); + thrown.expectCause(isA(CoderException.class)); + thrown.expectMessage("Decode not allowed"); + enforcement.beforeElement(record); + } + + @Test + public void consistentWithEqualsStructuralValueNotEqualThrows() { + WindowedValue<Record> record = + WindowedValue.<Record>valueInGlobalWindow( + new Record() { + @Override + public String toString() { + return "OriginalRecord"; + } + }); + + ModelEnforcement<Record> enforcement = + createEnforcement(new RecordStructuralValueCoder(), record); + + thrown.expect(UserCodeException.class); + thrown.expectCause(isA(IllegalArgumentException.class)); + thrown.expectMessage("does not maintain structural value equality"); + thrown.expectMessage(RecordStructuralValueCoder.class.getSimpleName()); + thrown.expectMessage("OriginalRecord"); + enforcement.beforeElement(record); + } + + @Test + public void notConsistentWithEqualsStructuralValueNotEqualSucceeds() { + TestPipeline p = TestPipeline.create(); + PCollection<Record> unencodable = + p.apply( + Create.of(new Record()) + .withCoder(new RecordNotConsistentWithEqualsStructuralValueCoder())); + AppliedPTransform<?, ?, ?> consumer = + unencodable.apply(Count.<Record>globally()).getProducingTransformInternal(); + + WindowedValue<Record> record = WindowedValue.<Record>valueInGlobalWindow(new Record()); + + CommittedBundle<Record> input = + bundleFactory.createRootBundle(unencodable).add(record).commit(Instant.now()); + ModelEnforcement<Record> enforcement = factory.forBundle(input, consumer); + + enforcement.beforeElement(record); + enforcement.afterElement(record); + enforcement.afterFinish( + input, + StepTransformResult.withoutHold(consumer).build(), + Collections.<CommittedBundle<?>>emptyList()); + } + + private <T> ModelEnforcement<T> createEnforcement(Coder<T> coder, WindowedValue<T> record) { + TestPipeline p = TestPipeline.create(); + PCollection<T> unencodable = p.apply(Create.<T>of().withCoder(coder)); + AppliedPTransform<?, ?, ?> consumer = + unencodable.apply(Count.<T>globally()).getProducingTransformInternal(); + CommittedBundle<T> input = + bundleFactory.createRootBundle(unencodable).add(record).commit(Instant.now()); + ModelEnforcement<T> enforcement = factory.forBundle(input, consumer); + return enforcement; + } + + @Test + public void structurallyEqualResultsSucceeds() { + TestPipeline p = TestPipeline.create(); + PCollection<Integer> unencodable = p.apply(Create.of(1).withCoder(VarIntCoder.of())); + AppliedPTransform<?, ?, ?> consumer = + unencodable.apply(Count.<Integer>globally()).getProducingTransformInternal(); + + WindowedValue<Integer> value = WindowedValue.valueInGlobalWindow(1); + + CommittedBundle<Integer> input = + bundleFactory.createRootBundle(unencodable).add(value).commit(Instant.now()); + ModelEnforcement<Integer> enforcement = factory.forBundle(input, consumer); + + enforcement.beforeElement(value); + enforcement.afterElement(value); + enforcement.afterFinish( + input, + StepTransformResult.withoutHold(consumer).build(), + Collections.<CommittedBundle<?>>emptyList()); + } + + private static class Record {} + private static class RecordNoEncodeCoder extends AtomicCoder<Record> { + + @Override + public void encode( + Record value, + OutputStream outStream, + org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + throw new CoderException("Encode not allowed"); + } + + @Override + public Record decode( + InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + return null; + } + } + + private static class RecordNoDecodeCoder extends AtomicCoder<Record> { + @Override + public void encode( + Record value, + OutputStream outStream, + org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException {} + + @Override + public Record decode( + InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + throw new CoderException("Decode not allowed"); + } + } + + private static class RecordStructuralValueCoder extends AtomicCoder<Record> { + @Override + public void encode( + Record value, + OutputStream outStream, + org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException {} + + @Override + public Record decode( + InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + return new Record() { + @Override + public String toString() { + return "DecodedRecord"; + } + }; + } + + @Override + public boolean consistentWithEquals() { + return true; + } + + @Override + public Object structuralValue(Record value) { + return value; + } + } + + private static class RecordNotConsistentWithEqualsStructuralValueCoder + extends AtomicCoder<Record> { + @Override + public void encode( + Record value, + OutputStream outStream, + org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException {} + + @Override + public Record decode( + InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + return new Record() { + @Override + public String toString() { + return "DecodedRecord"; + } + }; + } + + @Override + public boolean consistentWithEquals() { + return false; + } + + @Override + public Object structuralValue(Record value) { + return value; + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e13cacb8/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java new file mode 100644 index 0000000..66a5106 --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.emptyIterable; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; +import org.apache.beam.runners.direct.InProcessPipelineRunner.UncommittedBundle; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionList; + +import org.hamcrest.Matchers; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link FlattenEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class FlattenEvaluatorFactoryTest { + private BundleFactory bundleFactory = InProcessBundleFactory.create(); + @Test + public void testFlattenInMemoryEvaluator() throws Exception { + TestPipeline p = TestPipeline.create(); + PCollection<Integer> left = p.apply("left", Create.of(1, 2, 4)); + PCollection<Integer> right = p.apply("right", Create.of(-1, 2, -4)); + PCollectionList<Integer> list = PCollectionList.of(left).and(right); + + PCollection<Integer> flattened = list.apply(Flatten.<Integer>pCollections()); + + CommittedBundle<Integer> leftBundle = + bundleFactory.createRootBundle(left).commit(Instant.now()); + CommittedBundle<Integer> rightBundle = + bundleFactory.createRootBundle(right).commit(Instant.now()); + + InProcessEvaluationContext context = mock(InProcessEvaluationContext.class); + + UncommittedBundle<Integer> flattenedLeftBundle = bundleFactory.createRootBundle(flattened); + UncommittedBundle<Integer> flattenedRightBundle = bundleFactory.createRootBundle(flattened); + + when(context.createBundle(leftBundle, flattened)).thenReturn(flattenedLeftBundle); + when(context.createBundle(rightBundle, flattened)).thenReturn(flattenedRightBundle); + + FlattenEvaluatorFactory factory = new FlattenEvaluatorFactory(); + TransformEvaluator<Integer> leftSideEvaluator = + factory.forApplication(flattened.getProducingTransformInternal(), leftBundle, context); + TransformEvaluator<Integer> rightSideEvaluator = + factory.forApplication( + flattened.getProducingTransformInternal(), + rightBundle, + context); + + leftSideEvaluator.processElement(WindowedValue.valueInGlobalWindow(1)); + rightSideEvaluator.processElement(WindowedValue.valueInGlobalWindow(-1)); + leftSideEvaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow(2, new Instant(1024))); + leftSideEvaluator.processElement(WindowedValue.valueInEmptyWindows(4, PaneInfo.NO_FIRING)); + rightSideEvaluator.processElement( + WindowedValue.valueInEmptyWindows(2, PaneInfo.ON_TIME_AND_ONLY_FIRING)); + rightSideEvaluator.processElement( + WindowedValue.timestampedValueInGlobalWindow(-4, new Instant(-4096))); + + InProcessTransformResult rightSideResult = rightSideEvaluator.finishBundle(); + InProcessTransformResult leftSideResult = leftSideEvaluator.finishBundle(); + + assertThat( + rightSideResult.getOutputBundles(), + Matchers.<UncommittedBundle<?>>contains(flattenedRightBundle)); + assertThat( + rightSideResult.getTransform(), + Matchers.<AppliedPTransform<?, ?, ?>>equalTo(flattened.getProducingTransformInternal())); + assertThat( + leftSideResult.getOutputBundles(), + Matchers.<UncommittedBundle<?>>contains(flattenedLeftBundle)); + assertThat( + leftSideResult.getTransform(), + Matchers.<AppliedPTransform<?, ?, ?>>equalTo(flattened.getProducingTransformInternal())); + + assertThat( + flattenedLeftBundle.commit(Instant.now()).getElements(), + containsInAnyOrder( + WindowedValue.timestampedValueInGlobalWindow(2, new Instant(1024)), + WindowedValue.valueInEmptyWindows(4, PaneInfo.NO_FIRING), + WindowedValue.valueInGlobalWindow(1))); + assertThat( + flattenedRightBundle.commit(Instant.now()).getElements(), + containsInAnyOrder( + WindowedValue.valueInEmptyWindows(2, PaneInfo.ON_TIME_AND_ONLY_FIRING), + WindowedValue.timestampedValueInGlobalWindow(-4, new Instant(-4096)), + WindowedValue.valueInGlobalWindow(-1))); + } + + @Test + public void testFlattenInMemoryEvaluatorWithEmptyPCollectionList() throws Exception { + TestPipeline p = TestPipeline.create(); + PCollectionList<Integer> list = PCollectionList.empty(p); + + PCollection<Integer> flattened = list.apply(Flatten.<Integer>pCollections()); + + InProcessEvaluationContext context = mock(InProcessEvaluationContext.class); + + FlattenEvaluatorFactory factory = new FlattenEvaluatorFactory(); + TransformEvaluator<Integer> emptyEvaluator = + factory.forApplication(flattened.getProducingTransformInternal(), null, context); + + InProcessTransformResult leftSideResult = emptyEvaluator.finishBundle(); + + assertThat(leftSideResult.getOutputBundles(), emptyIterable()); + assertThat( + leftSideResult.getTransform(), + Matchers.<AppliedPTransform<?, ?, ?>>equalTo(flattened.getProducingTransformInternal())); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e13cacb8/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ForwardingPTransformTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ForwardingPTransformTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ForwardingPTransformTest.java new file mode 100644 index 0000000..9ea71d7 --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ForwardingPTransformTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.PCollection; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Tests for {@link ForwardingPTransform}. + */ +@RunWith(JUnit4.class) +public class ForwardingPTransformTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Mock private PTransform<PCollection<Integer>, PCollection<String>> delegate; + + private ForwardingPTransform<PCollection<Integer>, PCollection<String>> forwarding; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + forwarding = + new ForwardingPTransform<PCollection<Integer>, PCollection<String>>() { + @Override + protected PTransform<PCollection<Integer>, PCollection<String>> delegate() { + return delegate; + } + }; + } + + @Test + public void applyDelegates() { + @SuppressWarnings("unchecked") + PCollection<Integer> collection = mock(PCollection.class); + @SuppressWarnings("unchecked") + PCollection<String> output = mock(PCollection.class); + when(delegate.apply(collection)).thenReturn(output); + PCollection<String> result = forwarding.apply(collection); + assertThat(result, equalTo(output)); + } + + @Test + public void getNameDelegates() { + String name = "My_forwardingptransform-name;for!thisTest"; + when(delegate.getName()).thenReturn(name); + assertThat(forwarding.getName(), equalTo(name)); + } + + @Test + public void validateDelegates() { + @SuppressWarnings("unchecked") + PCollection<Integer> input = mock(PCollection.class); + doThrow(RuntimeException.class).when(delegate).validate(input); + + thrown.expect(RuntimeException.class); + forwarding.validate(input); + } + + @Test + public void getDefaultOutputCoderDelegates() throws Exception { + @SuppressWarnings("unchecked") + PCollection<Integer> input = mock(PCollection.class); + @SuppressWarnings("unchecked") + PCollection<String> output = mock(PCollection.class); + @SuppressWarnings("unchecked") + Coder<String> outputCoder = mock(Coder.class); + + when(delegate.getDefaultOutputCoder(input, output)).thenReturn(outputCoder); + assertThat(forwarding.getDefaultOutputCoder(input, output), equalTo(outputCoder)); + } + + @Test + public void populateDisplayDataDelegates() { + DisplayData.Builder builder = mock(DisplayData.Builder.class); + doThrow(RuntimeException.class).when(delegate).populateDisplayData(builder); + + thrown.expect(RuntimeException.class); + forwarding.populateDisplayData(builder); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e13cacb8/runners/direct-java/src/test/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactoryTest.java new file mode 100644 index 0000000..267266d --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/GroupByKeyEvaluatorFactoryTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; +import org.apache.beam.runners.direct.InProcessPipelineRunner.UncommittedBundle; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.GroupByKeyViaGroupByKeyOnly.ReifyTimestampsAndWindows; +import org.apache.beam.sdk.util.KeyedWorkItem; +import org.apache.beam.sdk.util.KeyedWorkItems; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multiset; + +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link GroupByKeyEvaluatorFactory}. + */ +@RunWith(JUnit4.class) +public class GroupByKeyEvaluatorFactoryTest { + private BundleFactory bundleFactory = InProcessBundleFactory.create(); + + @Test + public void testInMemoryEvaluator() throws Exception { + TestPipeline p = TestPipeline.create(); + KV<String, Integer> firstFoo = KV.of("foo", -1); + KV<String, Integer> secondFoo = KV.of("foo", 1); + KV<String, Integer> thirdFoo = KV.of("foo", 3); + KV<String, Integer> firstBar = KV.of("bar", 22); + KV<String, Integer> secondBar = KV.of("bar", 12); + KV<String, Integer> firstBaz = KV.of("baz", Integer.MAX_VALUE); + PCollection<KV<String, Integer>> values = + p.apply(Create.of(firstFoo, firstBar, secondFoo, firstBaz, secondBar, thirdFoo)); + PCollection<KV<String, WindowedValue<Integer>>> kvs = + values.apply(new ReifyTimestampsAndWindows<String, Integer>()); + PCollection<KeyedWorkItem<String, Integer>> groupedKvs = + kvs.apply(new GroupByKeyEvaluatorFactory.InProcessGroupByKeyOnly<String, Integer>()); + + CommittedBundle<KV<String, WindowedValue<Integer>>> inputBundle = + bundleFactory.createRootBundle(kvs).commit(Instant.now()); + InProcessEvaluationContext evaluationContext = mock(InProcessEvaluationContext.class); + + UncommittedBundle<KeyedWorkItem<String, Integer>> fooBundle = + bundleFactory.createKeyedBundle(null, "foo", groupedKvs); + UncommittedBundle<KeyedWorkItem<String, Integer>> barBundle = + bundleFactory.createKeyedBundle(null, "bar", groupedKvs); + UncommittedBundle<KeyedWorkItem<String, Integer>> bazBundle = + bundleFactory.createKeyedBundle(null, "baz", groupedKvs); + + when(evaluationContext.createKeyedBundle(inputBundle, "foo", groupedKvs)).thenReturn(fooBundle); + when(evaluationContext.createKeyedBundle(inputBundle, "bar", groupedKvs)).thenReturn(barBundle); + when(evaluationContext.createKeyedBundle(inputBundle, "baz", groupedKvs)).thenReturn(bazBundle); + + // The input to a GroupByKey is assumed to be a KvCoder + @SuppressWarnings("unchecked") + Coder<String> keyCoder = + ((KvCoder<String, WindowedValue<Integer>>) kvs.getCoder()).getKeyCoder(); + TransformEvaluator<KV<String, WindowedValue<Integer>>> evaluator = + new GroupByKeyEvaluatorFactory() + .forApplication( + groupedKvs.getProducingTransformInternal(), inputBundle, evaluationContext); + + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(firstFoo))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(secondFoo))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(thirdFoo))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(firstBar))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(secondBar))); + evaluator.processElement(WindowedValue.valueInEmptyWindows(gwValue(firstBaz))); + + evaluator.finishBundle(); + + assertThat( + fooBundle.commit(Instant.now()).getElements(), + contains( + new KeyedWorkItemMatcher<String, Integer>( + KeyedWorkItems.elementsWorkItem( + "foo", + ImmutableSet.of( + WindowedValue.valueInGlobalWindow(-1), + WindowedValue.valueInGlobalWindow(1), + WindowedValue.valueInGlobalWindow(3))), + keyCoder))); + assertThat( + barBundle.commit(Instant.now()).getElements(), + contains( + new KeyedWorkItemMatcher<String, Integer>( + KeyedWorkItems.elementsWorkItem( + "bar", + ImmutableSet.of( + WindowedValue.valueInGlobalWindow(12), + WindowedValue.valueInGlobalWindow(22))), + keyCoder))); + assertThat( + bazBundle.commit(Instant.now()).getElements(), + contains( + new KeyedWorkItemMatcher<String, Integer>( + KeyedWorkItems.elementsWorkItem( + "baz", + ImmutableSet.of(WindowedValue.valueInGlobalWindow(Integer.MAX_VALUE))), + keyCoder))); + } + + private <K, V> KV<K, WindowedValue<V>> gwValue(KV<K, V> kv) { + return KV.of(kv.getKey(), WindowedValue.valueInGlobalWindow(kv.getValue())); + } + + private static class KeyedWorkItemMatcher<K, V> + extends BaseMatcher<WindowedValue<KeyedWorkItem<K, V>>> { + private final KeyedWorkItem<K, V> myWorkItem; + private final Coder<K> keyCoder; + + public KeyedWorkItemMatcher(KeyedWorkItem<K, V> myWorkItem, Coder<K> keyCoder) { + this.myWorkItem = myWorkItem; + this.keyCoder = keyCoder; + } + + @Override + public boolean matches(Object item) { + if (item == null || !(item instanceof WindowedValue)) { + return false; + } + WindowedValue<KeyedWorkItem<K, V>> that = (WindowedValue<KeyedWorkItem<K, V>>) item; + Multiset<WindowedValue<V>> myValues = HashMultiset.create(); + Multiset<WindowedValue<V>> thatValues = HashMultiset.create(); + for (WindowedValue<V> value : myWorkItem.elementsIterable()) { + myValues.add(value); + } + for (WindowedValue<V> value : that.getValue().elementsIterable()) { + thatValues.add(value); + } + try { + return myValues.equals(thatValues) + && keyCoder + .structuralValue(myWorkItem.key()) + .equals(keyCoder.structuralValue(that.getValue().key())); + } catch (Exception e) { + return false; + } + } + + @Override + public void describeTo(Description description) { + description + .appendText("KeyedWorkItem<K, V> containing key ") + .appendValue(myWorkItem.key()) + .appendText(" and values ") + .appendValueList("[", ", ", "]", myWorkItem.elementsIterable()); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e13cacb8/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactoryTest.java new file mode 100644 index 0000000..557ebff --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactoryTest.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.isA; +import static org.junit.Assert.assertThat; + +import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; +import org.apache.beam.runners.direct.InProcessPipelineRunner.UncommittedBundle; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.IllegalMutationException; +import org.apache.beam.sdk.util.UserCodeException; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; + +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Tests for {@link ImmutabilityCheckingBundleFactory}. + */ +@RunWith(JUnit4.class) +public class ImmutabilityCheckingBundleFactoryTest { + @Rule public ExpectedException thrown = ExpectedException.none(); + private ImmutabilityCheckingBundleFactory factory; + private PCollection<byte[]> created; + private PCollection<byte[]> transformed; + + @Before + public void setup() { + TestPipeline p = TestPipeline.create(); + created = p.apply(Create.<byte[]>of().withCoder(ByteArrayCoder.of())); + transformed = created.apply(ParDo.of(new IdentityDoFn<byte[]>())); + factory = ImmutabilityCheckingBundleFactory.create(InProcessBundleFactory.create()); + } + + @Test + public void noMutationRootBundleSucceeds() { + UncommittedBundle<byte[]> root = factory.createRootBundle(created); + byte[] array = new byte[] {0, 1, 2}; + root.add(WindowedValue.valueInGlobalWindow(array)); + CommittedBundle<byte[]> committed = root.commit(Instant.now()); + + assertThat( + committed.getElements(), containsInAnyOrder(WindowedValue.valueInGlobalWindow(array))); + } + + @Test + public void noMutationKeyedBundleSucceeds() { + CommittedBundle<byte[]> root = factory.createRootBundle(created).commit(Instant.now()); + UncommittedBundle<byte[]> keyed = factory.createKeyedBundle(root, "mykey", transformed); + + WindowedValue<byte[]> windowedArray = + WindowedValue.of( + new byte[] {4, 8, 12}, + new Instant(891L), + new IntervalWindow(new Instant(0), new Instant(1000)), + PaneInfo.ON_TIME_AND_ONLY_FIRING); + keyed.add(windowedArray); + + CommittedBundle<byte[]> committed = keyed.commit(Instant.now()); + assertThat(committed.getElements(), containsInAnyOrder(windowedArray)); + } + + @Test + public void noMutationCreateBundleSucceeds() { + CommittedBundle<byte[]> root = factory.createRootBundle(created).commit(Instant.now()); + UncommittedBundle<byte[]> intermediate = factory.createBundle(root, transformed); + + WindowedValue<byte[]> windowedArray = + WindowedValue.of( + new byte[] {4, 8, 12}, + new Instant(891L), + new IntervalWindow(new Instant(0), new Instant(1000)), + PaneInfo.ON_TIME_AND_ONLY_FIRING); + intermediate.add(windowedArray); + + CommittedBundle<byte[]> committed = intermediate.commit(Instant.now()); + assertThat(committed.getElements(), containsInAnyOrder(windowedArray)); + } + + @Test + public void mutationBeforeAddRootBundleSucceeds() { + UncommittedBundle<byte[]> root = factory.createRootBundle(created); + byte[] array = new byte[] {0, 1, 2}; + array[1] = 2; + root.add(WindowedValue.valueInGlobalWindow(array)); + CommittedBundle<byte[]> committed = root.commit(Instant.now()); + + assertThat( + committed.getElements(), containsInAnyOrder(WindowedValue.valueInGlobalWindow(array))); + } + + @Test + public void mutationBeforeAddKeyedBundleSucceeds() { + CommittedBundle<byte[]> root = factory.createRootBundle(created).commit(Instant.now()); + UncommittedBundle<byte[]> keyed = factory.createKeyedBundle(root, "mykey", transformed); + + byte[] array = new byte[] {4, 8, 12}; + array[0] = Byte.MAX_VALUE; + WindowedValue<byte[]> windowedArray = + WindowedValue.of( + array, + new Instant(891L), + new IntervalWindow(new Instant(0), new Instant(1000)), + PaneInfo.ON_TIME_AND_ONLY_FIRING); + keyed.add(windowedArray); + + CommittedBundle<byte[]> committed = keyed.commit(Instant.now()); + assertThat(committed.getElements(), containsInAnyOrder(windowedArray)); + } + + @Test + public void mutationBeforeAddCreateBundleSucceeds() { + CommittedBundle<byte[]> root = factory.createRootBundle(created).commit(Instant.now()); + UncommittedBundle<byte[]> intermediate = factory.createBundle(root, transformed); + + byte[] array = new byte[] {4, 8, 12}; + WindowedValue<byte[]> windowedArray = + WindowedValue.of( + array, + new Instant(891L), + new IntervalWindow(new Instant(0), new Instant(1000)), + PaneInfo.ON_TIME_AND_ONLY_FIRING); + array[2] = -3; + intermediate.add(windowedArray); + + CommittedBundle<byte[]> committed = intermediate.commit(Instant.now()); + assertThat(committed.getElements(), containsInAnyOrder(windowedArray)); + } + + @Test + public void mutationAfterAddRootBundleThrows() { + UncommittedBundle<byte[]> root = factory.createRootBundle(created); + byte[] array = new byte[] {0, 1, 2}; + root.add(WindowedValue.valueInGlobalWindow(array)); + + array[1] = 2; + thrown.expect(UserCodeException.class); + thrown.expectCause(isA(IllegalMutationException.class)); + thrown.expectMessage("Values must not be mutated in any way after being output"); + CommittedBundle<byte[]> committed = root.commit(Instant.now()); + } + + @Test + public void mutationAfterAddKeyedBundleThrows() { + CommittedBundle<byte[]> root = factory.createRootBundle(created).commit(Instant.now()); + UncommittedBundle<byte[]> keyed = factory.createKeyedBundle(root, "mykey", transformed); + + byte[] array = new byte[] {4, 8, 12}; + WindowedValue<byte[]> windowedArray = + WindowedValue.of( + array, + new Instant(891L), + new IntervalWindow(new Instant(0), new Instant(1000)), + PaneInfo.ON_TIME_AND_ONLY_FIRING); + keyed.add(windowedArray); + + array[0] = Byte.MAX_VALUE; + thrown.expect(UserCodeException.class); + thrown.expectCause(isA(IllegalMutationException.class)); + thrown.expectMessage("Values must not be mutated in any way after being output"); + CommittedBundle<byte[]> committed = keyed.commit(Instant.now()); + } + + @Test + public void mutationAfterAddCreateBundleThrows() { + CommittedBundle<byte[]> root = factory.createRootBundle(created).commit(Instant.now()); + UncommittedBundle<byte[]> intermediate = factory.createBundle(root, transformed); + + byte[] array = new byte[] {4, 8, 12}; + WindowedValue<byte[]> windowedArray = + WindowedValue.of( + array, + new Instant(891L), + new IntervalWindow(new Instant(0), new Instant(1000)), + PaneInfo.ON_TIME_AND_ONLY_FIRING); + intermediate.add(windowedArray); + + array[2] = -3; + thrown.expect(UserCodeException.class); + thrown.expectCause(isA(IllegalMutationException.class)); + thrown.expectMessage("Values must not be mutated in any way after being output"); + CommittedBundle<byte[]> committed = intermediate.commit(Instant.now()); + } + + private static class IdentityDoFn<T> extends DoFn<T, T> { + @Override + public void processElement(DoFn<T, T>.ProcessContext c) throws Exception { + c.output(c.element()); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/e13cacb8/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactoryTest.java new file mode 100644 index 0000000..6cef60d --- /dev/null +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactoryTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.direct; + +import org.apache.beam.runners.direct.InProcessPipelineRunner.CommittedBundle; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.AppliedPTransform; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.util.IllegalMutationException; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; + +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.Serializable; +import java.util.Collections; + +/** + * Tests for {@link ImmutabilityEnforcementFactory}. + */ +@RunWith(JUnit4.class) +public class ImmutabilityEnforcementFactoryTest implements Serializable { + @Rule public transient ExpectedException thrown = ExpectedException.none(); + private transient ImmutabilityEnforcementFactory factory; + private transient BundleFactory bundleFactory; + private transient PCollection<byte[]> pcollection; + private transient AppliedPTransform<?, ?, ?> consumer; + + @Before + public void setup() { + factory = new ImmutabilityEnforcementFactory(); + bundleFactory = InProcessBundleFactory.create(); + TestPipeline p = TestPipeline.create(); + pcollection = + p.apply(Create.of("foo".getBytes(), "spamhameggs".getBytes())) + .apply( + ParDo.of( + new DoFn<byte[], byte[]>() { + @Override + public void processElement(DoFn<byte[], byte[]>.ProcessContext c) + throws Exception { + c.element()[0] = 'b'; + } + })); + consumer = pcollection.apply(Count.<byte[]>globally()).getProducingTransformInternal(); + } + + @Test + public void unchangedSucceeds() { + WindowedValue<byte[]> element = WindowedValue.valueInGlobalWindow("bar".getBytes()); + CommittedBundle<byte[]> elements = + bundleFactory.createRootBundle(pcollection).add(element).commit(Instant.now()); + + ModelEnforcement<byte[]> enforcement = factory.forBundle(elements, consumer); + enforcement.beforeElement(element); + enforcement.afterElement(element); + enforcement.afterFinish( + elements, + StepTransformResult.withoutHold(consumer).build(), + Collections.<CommittedBundle<?>>emptyList()); + } + + @Test + public void mutatedDuringProcessElementThrows() { + WindowedValue<byte[]> element = WindowedValue.valueInGlobalWindow("bar".getBytes()); + CommittedBundle<byte[]> elements = + bundleFactory.createRootBundle(pcollection).add(element).commit(Instant.now()); + + ModelEnforcement<byte[]> enforcement = factory.forBundle(elements, consumer); + enforcement.beforeElement(element); + element.getValue()[0] = 'f'; + thrown.expect(IllegalMutationException.class); + thrown.expectMessage(consumer.getFullName()); + thrown.expectMessage("illegaly mutated"); + thrown.expectMessage("Input values must not be mutated"); + enforcement.afterElement(element); + enforcement.afterFinish( + elements, + StepTransformResult.withoutHold(consumer).build(), + Collections.<CommittedBundle<?>>emptyList()); + } + + @Test + public void mutatedAfterProcessElementFails() { + + WindowedValue<byte[]> element = WindowedValue.valueInGlobalWindow("bar".getBytes()); + CommittedBundle<byte[]> elements = + bundleFactory.createRootBundle(pcollection).add(element).commit(Instant.now()); + + ModelEnforcement<byte[]> enforcement = factory.forBundle(elements, consumer); + enforcement.beforeElement(element); + enforcement.afterElement(element); + + element.getValue()[0] = 'f'; + thrown.expect(IllegalMutationException.class); + thrown.expectMessage(consumer.getFullName()); + thrown.expectMessage("illegaly mutated"); + thrown.expectMessage("Input values must not be mutated"); + enforcement.afterFinish( + elements, + StepTransformResult.withoutHold(consumer).build(), + Collections.<CommittedBundle<?>>emptyList()); + } +}