http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachineTest.java ---------------------------------------------------------------------- diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachineTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachineTest.java index 45a5cfb..e4d10a0 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachineTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/AfterWatermarkStateMachineTest.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnMergeContext; +import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; import org.apache.beam.runners.core.triggers.TriggerStateMachineTester.SimpleTriggerStateMachineTester; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; @@ -45,8 +46,8 @@ import org.mockito.MockitoAnnotations; @RunWith(JUnit4.class) public class AfterWatermarkStateMachineTest { - @Mock private TriggerStateMachine mockEarly; - @Mock private TriggerStateMachine mockLate; + @Mock private OnceTriggerStateMachine mockEarly; + @Mock private OnceTriggerStateMachine mockLate; private SimpleTriggerStateMachineTester<IntervalWindow> tester; private static TriggerStateMachine.TriggerContext anyTriggerContext() { @@ -69,7 +70,7 @@ public class AfterWatermarkStateMachineTest { MockitoAnnotations.initMocks(this); } - public void testRunningAsTrigger(TriggerStateMachine mockTrigger, IntervalWindow window) + public void testRunningAsTrigger(OnceTriggerStateMachine mockTrigger, IntervalWindow window) throws Exception { // Don't fire due to mock saying no
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/StubTriggerStateMachine.java ---------------------------------------------------------------------- diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/StubTriggerStateMachine.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/StubTriggerStateMachine.java index 1bc757e..4512848 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/StubTriggerStateMachine.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/triggers/StubTriggerStateMachine.java @@ -18,11 +18,12 @@ package org.apache.beam.runners.core.triggers; import com.google.common.collect.Lists; +import org.apache.beam.runners.core.triggers.TriggerStateMachine.OnceTriggerStateMachine; /** - * No-op {@link TriggerStateMachine} implementation for testing. + * No-op {@link OnceTriggerStateMachine} implementation for testing. */ -abstract class StubTriggerStateMachine extends TriggerStateMachine { +abstract class StubTriggerStateMachine extends OnceTriggerStateMachine { /** * Create a stub {@link TriggerStateMachine} instance which returns the specified name on {@link * #toString()}. @@ -41,7 +42,7 @@ abstract class StubTriggerStateMachine extends TriggerStateMachine { } @Override - public void onFire(TriggerContext context) throws Exception { + protected void onOnlyFiring(TriggerContext context) throws Exception { } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/pom.xml ---------------------------------------------------------------------- diff --git a/runners/direct-java/pom.xml b/runners/direct-java/pom.xml index e14e813..bec2113 100644 --- a/runners/direct-java/pom.xml +++ b/runners/direct-java/pom.xml @@ -22,7 +22,7 @@ <parent> <groupId>org.apache.beam</groupId> <artifactId>beam-runners-parent</artifactId> - <version>2.2.0-SNAPSHOT</version> + <version>2.1.0-SNAPSHOT</version> <relativePath>../pom.xml</relativePath> </parent> @@ -117,7 +117,7 @@ </relocation> </relocations> <transformers> - <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" /> + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/> </transformers> </configuration> </execution> @@ -155,8 +155,7 @@ <systemPropertyVariables> <beamTestPipelineOptions> [ - "--runner=DirectRunner", - "--runnerDeterminedSharding=false" + "--runner=DirectRunner" ] </beamTestPipelineOptions> </systemPropertyVariables> http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CommittedResult.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CommittedResult.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CommittedResult.java index 70e3ac3..8c45449 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CommittedResult.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CommittedResult.java @@ -19,8 +19,8 @@ package org.apache.beam.runners.direct; import com.google.auto.value.AutoValue; -import com.google.common.base.Optional; import java.util.Set; +import javax.annotation.Nullable; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.View.CreatePCollectionView; @@ -36,10 +36,12 @@ abstract class CommittedResult { /** * Returns the {@link CommittedBundle} that contains the input elements that could not be - * processed by the evaluation. The returned optional is present if there were any unprocessed - * input elements, and absent otherwise. + * processed by the evaluation. + * + * <p>{@code null} if the input bundle was null. */ - public abstract Optional<? extends CommittedBundle<?>> getUnprocessedInputs(); + @Nullable + public abstract CommittedBundle<?> getUnprocessedInputs(); /** * Returns the outputs produced by the transform. @@ -57,7 +59,7 @@ abstract class CommittedResult { public static CommittedResult create( TransformResult<?> original, - Optional<? extends CommittedBundle<?>> unprocessedElements, + CommittedBundle<?> unprocessedElements, Iterable<? extends CommittedBundle<?>> outputs, Set<OutputType> producedOutputs) { return new AutoValue_CommittedResult(original.getTransform(), http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java index ad17b2b..c2c0afa 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java @@ -17,8 +17,6 @@ */ package org.apache.beam.runners.direct; -import static com.google.common.base.Preconditions.checkArgument; - import com.google.common.collect.ListMultimap; import java.util.Collection; import java.util.List; @@ -38,8 +36,7 @@ import org.apache.beam.sdk.values.PValue; class DirectGraph { private final Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers; private final Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters; - private final ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers; - private final ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers; + private final ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers; private final Set<AppliedPTransform<?, ?, ?>> rootTransforms; private final Map<AppliedPTransform<?, ?, ?>, String> stepNames; @@ -47,36 +44,23 @@ class DirectGraph { public static DirectGraph create( Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers, Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters, - ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers, - ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers, + ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers, Set<AppliedPTransform<?, ?, ?>> rootTransforms, Map<AppliedPTransform<?, ?, ?>, String> stepNames) { - return new DirectGraph( - producers, viewWriters, perElementConsumers, allConsumers, rootTransforms, stepNames); + return new DirectGraph(producers, viewWriters, primitiveConsumers, rootTransforms, stepNames); } private DirectGraph( Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers, Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters, - ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers, - ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers, + ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers, Set<AppliedPTransform<?, ?, ?>> rootTransforms, Map<AppliedPTransform<?, ?, ?>, String> stepNames) { this.producers = producers; this.viewWriters = viewWriters; - this.perElementConsumers = perElementConsumers; - this.allConsumers = allConsumers; + this.primitiveConsumers = primitiveConsumers; this.rootTransforms = rootTransforms; this.stepNames = stepNames; - for (AppliedPTransform<?, ?, ?> step : stepNames.keySet()) { - for (PValue input : step.getInputs().values()) { - checkArgument( - allConsumers.get(input).contains(step), - "Step %s lists value %s as input, but it is not in the graph of consumers", - step.getFullName(), - input); - } - } } AppliedPTransform<?, ?, ?> getProducer(PCollection<?> produced) { @@ -87,22 +71,14 @@ class DirectGraph { return viewWriters.get(view); } - List<AppliedPTransform<?, ?, ?>> getPerElementConsumers(PValue consumed) { - return perElementConsumers.get(consumed); - } - - List<AppliedPTransform<?, ?, ?>> getAllConsumers(PValue consumed) { - return allConsumers.get(consumed); + List<AppliedPTransform<?, ?, ?>> getPrimitiveConsumers(PValue consumed) { + return primitiveConsumers.get(consumed); } Set<AppliedPTransform<?, ?, ?>> getRootTransforms() { return rootTransforms; } - Set<PCollection<?>> getPCollections() { - return producers.keySet(); - } - Set<PCollectionView<?>> getViews() { return viewWriters.keySet(); } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index 675de2c..d54de5d 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -21,26 +21,19 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ListMultimap; -import com.google.common.collect.Sets; -import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; -import org.apache.beam.runners.core.construction.TransformInputs; -import org.apache.beam.runners.direct.ViewOverrideFactory.WriteView; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.PValue; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the @@ -48,15 +41,11 @@ import org.slf4j.LoggerFactory; * input after the upstream transform has produced and committed output. */ class DirectGraphVisitor extends PipelineVisitor.Defaults { - private static final Logger LOG = LoggerFactory.getLogger(DirectGraphVisitor.class); private Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers = new HashMap<>(); private Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters = new HashMap<>(); - private Set<PCollectionView<?>> consumedViews = new HashSet<>(); - private ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers = - ArrayListMultimap.create(); - private ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers = + private ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers = ArrayListMultimap.create(); private Set<AppliedPTransform<?, ?, ?>> rootTransforms = new HashSet<>(); @@ -84,13 +73,6 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults { getClass().getSimpleName()); if (node.isRootNode()) { finalized = true; - checkState( - viewWriters.keySet().containsAll(consumedViews), - "All %ss that are consumed must be written by some %s %s: Missing %s", - PCollectionView.class.getSimpleName(), - WriteView.class.getSimpleName(), - PTransform.class.getSimpleName(), - Sets.difference(consumedViews, viewWriters.keySet())); } } @@ -101,30 +83,18 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults { if (node.getInputs().isEmpty()) { rootTransforms.add(appliedTransform); } else { - Collection<PValue> mainInputs = - TransformInputs.nonAdditionalInputs(node.toAppliedPTransform(getPipeline())); - if (!mainInputs.containsAll(node.getInputs().values())) { - LOG.debug( - "Inputs reduced to {} from {} by removing additional inputs", - mainInputs, - node.getInputs().values()); - } - for (PValue value : mainInputs) { - perElementConsumers.put(value, appliedTransform); - } for (PValue value : node.getInputs().values()) { - allConsumers.put(value, appliedTransform); + primitiveConsumers.put(value, appliedTransform); + } + if (node.getTransform() instanceof ViewOverrideFactory.WriteView) { + viewWriters.put( + ((ViewOverrideFactory.WriteView<?, ?>) node.getTransform()).getView(), + node.toAppliedPTransform(getPipeline())); } - } - if (node.getTransform() instanceof ParDo.MultiOutput) { - consumedViews.addAll(((ParDo.MultiOutput<?, ?>) node.getTransform()).getSideInputs()); - } else if (node.getTransform() instanceof ViewOverrideFactory.WriteView) { - viewWriters.put( - ((WriteView) node.getTransform()).getView(), node.toAppliedPTransform(getPipeline())); } } - @Override + @Override public void visitValue(PValue value, TransformHierarchy.Node producer) { AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(producer); if (value instanceof PCollection && !producers.containsKey(value)) { @@ -149,6 +119,6 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults { public DirectGraph getGraph() { checkState(finalized, "Can't get a graph before the Pipeline has been completely traversed"); return DirectGraph.create( - producers, viewWriters, perElementConsumers, allConsumers, rootTransforms, stepNames); + producers, viewWriters, primitiveConsumers, rootTransforms, stepNames); } } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java index 06b8e29..2fc0dd4 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKey.java @@ -36,17 +36,13 @@ import org.apache.beam.sdk.values.WindowingStrategy; class DirectGroupByKey<K, V> extends ForwardingPTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> { - private final PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> original; + private final GroupByKey<K, V> original; static final String DIRECT_GBKO_URN = "urn:beam:directrunner:transforms:gbko:v1"; static final String DIRECT_GABW_URN = "urn:beam:directrunner:transforms:gabw:v1"; - private final WindowingStrategy<?, ?> outputWindowingStrategy; - DirectGroupByKey( - PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> original, - WindowingStrategy<?, ?> outputWindowingStrategy) { - this.original = original; - this.outputWindowingStrategy = outputWindowingStrategy; + DirectGroupByKey(GroupByKey<K, V> from) { + this.original = from; } @Override @@ -61,6 +57,9 @@ class DirectGroupByKey<K, V> // key/value input elements and the window merge operation of the // window function associated with the input PCollection. WindowingStrategy<?, ?> inputWindowingStrategy = input.getWindowingStrategy(); + // Update the windowing strategy as appropriate. + WindowingStrategy<?, ?> outputWindowingStrategy = + original.updateWindowingStrategy(inputWindowingStrategy); // By default, implement GroupByKey via a series of lower-level operations. return input http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java index 9c2de3d..c2eb5e7 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java @@ -17,34 +17,26 @@ */ package org.apache.beam.runners.direct; -import com.google.common.collect.Iterables; import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.transforms.GroupByKey; -import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; /** A {@link PTransformOverrideFactory} for {@link GroupByKey} PTransforms. */ final class DirectGroupByKeyOverrideFactory<K, V> extends SingleInputOutputOverrideFactory< - PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, - PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>> { + PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>> { @Override public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> getReplacementTransform( AppliedPTransform< - PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, - PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>> + PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>> transform) { - - PCollection<KV<K, Iterable<V>>> output = - (PCollection<KV<K, Iterable<V>>>) Iterables.getOnlyElement(transform.getOutputs().values()); - return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), - new DirectGroupByKey<>(transform.getTransform(), output.getWindowingStrategy())); + new DirectGroupByKey<>(transform.getTransform())); } } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java index 53fb2f2..0e6fbab 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRegistrar.java @@ -50,7 +50,7 @@ public class DirectRegistrar { @Override public Iterable<Class<? extends PipelineOptions>> getPipelineOptions() { return ImmutableList.<Class<? extends PipelineOptions>>of( - DirectOptions.class, DirectTestOptions.class); + DirectOptions.class); } } } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 4621224..dbd1ec4 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -30,7 +30,6 @@ import java.util.Map; import java.util.Set; import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems; import org.apache.beam.runners.core.construction.PTransformMatchers; -import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.SplittableParDo; import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory; @@ -38,11 +37,17 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.PipelineRunner; +import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.PTransformOverride; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.ParDo.MultiOutput; +import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Duration; @@ -69,17 +74,16 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { IMMUTABILITY { @Override public boolean appliesTo(PCollection<?> collection, DirectGraph graph) { - return CONTAINS_UDF.contains( - PTransformTranslation.urnForTransform(graph.getProducer(collection).getTransform())); + return CONTAINS_UDF.contains(graph.getProducer(collection).getTransform().getClass()); } }; /** * The set of {@link PTransform PTransforms} that execute a UDF. Useful for some enforcements. */ - private static final Set<String> CONTAINS_UDF = + private static final Set<Class<? extends PTransform>> CONTAINS_UDF = ImmutableSet.of( - PTransformTranslation.READ_TRANSFORM_URN, PTransformTranslation.PAR_DO_TRANSFORM_URN); + Read.Bounded.class, Read.Unbounded.class, ParDo.SingleOutput.class, MultiOutput.class); public abstract boolean appliesTo(PCollection<?> collection, DirectGraph graph); @@ -108,19 +112,22 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { return bundleFactory; } - private static Map<String, Collection<ModelEnforcementFactory>> + @SuppressWarnings("rawtypes") + private static Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> defaultModelEnforcements(Set<Enforcement> enabledEnforcements) { - ImmutableMap.Builder<String, Collection<ModelEnforcementFactory>> enforcements = - ImmutableMap.builder(); + ImmutableMap.Builder<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> + enforcements = ImmutableMap.builder(); ImmutableList.Builder<ModelEnforcementFactory> enabledParDoEnforcements = ImmutableList.builder(); if (enabledEnforcements.contains(Enforcement.IMMUTABILITY)) { enabledParDoEnforcements.add(ImmutabilityEnforcementFactory.create()); } Collection<ModelEnforcementFactory> parDoEnforcements = enabledParDoEnforcements.build(); - enforcements.put(PTransformTranslation.PAR_DO_TRANSFORM_URN, parDoEnforcements); + enforcements.put(ParDo.SingleOutput.class, parDoEnforcements); + enforcements.put(MultiOutput.class, parDoEnforcements); return enforcements.build(); } + } //////////////////////////////////////////////////////////////////////////////////////////////// @@ -216,45 +223,42 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { @SuppressWarnings("rawtypes") @VisibleForTesting List<PTransformOverride> defaultTransformOverrides() { - DirectTestOptions testOptions = options.as(DirectTestOptions.class); - ImmutableList.Builder<PTransformOverride> builder = ImmutableList.builder(); - if (testOptions.isRunnerDeterminedSharding()) { - builder.add( - PTransformOverride.of( - PTransformMatchers.writeWithRunnerDeterminedSharding(), - new WriteWithShardingFactory())); /* Uses a view internally. */ - } - builder = builder.add( - PTransformOverride.of( - PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN), - new ViewOverrideFactory())) /* Uses pardos and GBKs */ + return ImmutableList.<PTransformOverride>builder() + .add( + PTransformOverride.of( + PTransformMatchers.writeWithRunnerDeterminedSharding(), + new WriteWithShardingFactory())) /* Uses a view internally. */ + .add( + PTransformOverride.of( + PTransformMatchers.classEqualTo(CreatePCollectionView.class), + new ViewOverrideFactory())) /* Uses pardos and GBKs */ .add( PTransformOverride.of( - PTransformMatchers.urnEqualTo(PTransformTranslation.TEST_STREAM_TRANSFORM_URN), + PTransformMatchers.classEqualTo(TestStream.class), new DirectTestStreamFactory(this))) /* primitive */ // SplittableParMultiDo is implemented in terms of nonsplittable simple ParDos and extra // primitives .add( PTransformOverride.of( - PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory())) + PTransformMatchers.splittableParDoMulti(), new ParDoMultiOverrideFactory())) // state and timer pardos are implemented in terms of simple ParDos and extra primitives .add( PTransformOverride.of( - PTransformMatchers.stateOrTimerParDo(), new ParDoMultiOverrideFactory())) + PTransformMatchers.stateOrTimerParDoMulti(), new ParDoMultiOverrideFactory())) .add( PTransformOverride.of( - PTransformMatchers.urnEqualTo( - SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN), + PTransformMatchers.classEqualTo(SplittableParDo.ProcessKeyedElements.class), new SplittableParDoViaKeyedWorkItems.OverrideFactory())) .add( PTransformOverride.of( - PTransformMatchers.urnEqualTo(SplittableParDo.SPLITTABLE_GBKIKWI_URN), + PTransformMatchers.classEqualTo( + SplittableParDoViaKeyedWorkItems.GBKIntoKeyedWorkItems.class), new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */ .add( PTransformOverride.of( - PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN), - new DirectGroupByKeyOverrideFactory())); /* returns two chained primitives. */ - return builder.build(); + PTransformMatchers.classEqualTo(GroupByKey.class), + new DirectGroupByKeyOverrideFactory())) /* returns two chained primitives. */ + .build(); } /** http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTestOptions.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTestOptions.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTestOptions.java deleted file mode 100644 index a426443..0000000 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectTestOptions.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.direct; - -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.options.ApplicationNameOptions; -import org.apache.beam.sdk.options.Default; -import org.apache.beam.sdk.options.Description; -import org.apache.beam.sdk.options.Hidden; -import org.apache.beam.sdk.options.PipelineOptions; - -/** - * Internal-only options for tweaking the behavior of the {@link DirectRunner} in ways that users - * should never do. - * - * <p>Currently, the only use is to disable user-friendly overrides that prevent fully testing - * certain composite transforms. - */ -@Internal -@Hidden -public interface DirectTestOptions extends PipelineOptions, ApplicationNameOptions { - @Default.Boolean(true) - @Description( - "Indicates whether this is an automatically-run unit test.") - boolean isRunnerDeterminedSharding(); - void setRunnerDeterminedSharding(boolean goAheadAndDetermineSharding); -} http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java index d192785..e215070 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java @@ -20,7 +20,6 @@ package org.apache.beam.runners.direct; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.MoreExecutors; @@ -159,9 +158,12 @@ class EvaluationContext { } else { outputTypes.add(OutputType.BUNDLE); } - CommittedResult committedResult = - CommittedResult.create( - result, getUnprocessedInput(completedBundle, result), committedBundles, outputTypes); + CommittedResult committedResult = CommittedResult.create(result, + completedBundle == null + ? null + : completedBundle.withElements((Iterable) result.getUnprocessedElements()), + committedBundles, + outputTypes); // Update state internals CopyOnAccessInMemoryStateInternals theirState = result.getState(); if (theirState != null) { @@ -185,22 +187,6 @@ class EvaluationContext { return committedResult; } - /** - * Returns an {@link Optional} containing a bundle which contains all of the unprocessed elements - * that were not processed from the {@code completedBundle}. If all of the elements of the {@code - * completedBundle} were processed, or if {@code completedBundle} is null, returns an absent - * {@link Optional}. - */ - private Optional<? extends CommittedBundle<?>> getUnprocessedInput( - @Nullable CommittedBundle<?> completedBundle, TransformResult<?> result) { - if (completedBundle == null || Iterables.isEmpty(result.getUnprocessedElements())) { - return Optional.absent(); - } - CommittedBundle<?> residual = - completedBundle.withElements((Iterable) result.getUnprocessedElements()); - return Optional.of(residual); - } - private Iterable<? extends CommittedBundle<?>> commitBundles( Iterable<? extends UncommittedBundle<?>> bundles) { ImmutableList.Builder<CommittedBundle<?>> completed = ImmutableList.builder(); http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java index 75e2562..71ab4cc 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java @@ -49,11 +49,11 @@ import javax.annotation.Nullable; import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.KeyedWorkItems; import org.apache.beam.runners.core.TimerInternals.TimerData; -import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.direct.WatermarkManager.FiredTimers; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineResult.State; import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; @@ -77,7 +77,9 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { private final DirectGraph graph; private final RootProviderRegistry rootProviderRegistry; private final TransformEvaluatorRegistry registry; - private final Map<String, Collection<ModelEnforcementFactory>> transformEnforcements; + @SuppressWarnings("rawtypes") + private final Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> + transformEnforcements; private final EvaluationContext evaluationContext; @@ -110,7 +112,9 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { DirectGraph graph, RootProviderRegistry rootProviderRegistry, TransformEvaluatorRegistry registry, - Map<String, Collection<ModelEnforcementFactory>> transformEnforcements, + @SuppressWarnings("rawtypes") + Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> + transformEnforcements, EvaluationContext context) { return new ExecutorServiceParallelExecutor( targetParallelism, @@ -126,7 +130,8 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { DirectGraph graph, RootProviderRegistry rootProviderRegistry, TransformEvaluatorRegistry registry, - Map<String, Collection<ModelEnforcementFactory>> transformEnforcements, + @SuppressWarnings("rawtypes") + Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> transformEnforcements, EvaluationContext context) { this.targetParallelism = targetParallelism; // Don't use Daemon threads for workers. The Pipeline should continue to execute even if there @@ -232,8 +237,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { Collection<ModelEnforcementFactory> enforcements = MoreObjects.firstNonNull( - transformEnforcements.get( - PTransformTranslation.urnForTransform(transform.getTransform())), + transformEnforcements.get(transform.getTransform().getClass()), Collections.<ModelEnforcementFactory>emptyList()); TransformExecutor<T> callable = @@ -351,18 +355,17 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { for (CommittedBundle<?> outputBundle : committedResult.getOutputs()) { allUpdates.offer( ExecutorUpdate.fromBundle( - outputBundle, graph.getPerElementConsumers(outputBundle.getPCollection()))); + outputBundle, graph.getPrimitiveConsumers(outputBundle.getPCollection()))); } - Optional<? extends CommittedBundle<?>> unprocessedInputs = - committedResult.getUnprocessedInputs(); - if (unprocessedInputs.isPresent()) { + CommittedBundle<?> unprocessedInputs = committedResult.getUnprocessedInputs(); + if (unprocessedInputs != null && !Iterables.isEmpty(unprocessedInputs.getElements())) { if (inputBundle.getPCollection() == null) { // TODO: Split this logic out of an if statement - pendingRootBundles.get(result.getTransform()).offer(unprocessedInputs.get()); + pendingRootBundles.get(result.getTransform()).offer(unprocessedInputs); } else { allUpdates.offer( ExecutorUpdate.fromBundle( - unprocessedInputs.get(), + unprocessedInputs, Collections.<AppliedPTransform<?, ?, ?>>singleton( committedResult.getTransform()))); } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java index 516f798..8aa75cf 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java @@ -20,6 +20,7 @@ package org.apache.beam.runners.direct; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; +import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -78,7 +79,6 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator (TransformEvaluator<T>) createEvaluator( (AppliedPTransform) application, - (PCollection<InputT>) inputBundle.getPCollection(), inputBundle.getKey(), doFn, transform.getSideInputs(), @@ -102,7 +102,6 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator @SuppressWarnings({"unchecked", "rawtypes"}) DoFnLifecycleManagerRemovingTransformEvaluator<InputT> createEvaluator( AppliedPTransform<PCollection<InputT>, PCollectionTuple, ?> application, - PCollection<InputT> mainInput, StructuralKey<?> inputBundleKey, DoFn<InputT, OutputT> doFn, List<PCollectionView<?>> sideInputs, @@ -121,7 +120,6 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator createParDoEvaluator( application, inputBundleKey, - mainInput, sideInputs, mainOutputTag, additionalOutputTags, @@ -134,7 +132,6 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator ParDoEvaluator<InputT> createParDoEvaluator( AppliedPTransform<PCollection<InputT>, PCollectionTuple, ?> application, StructuralKey<?> key, - PCollection<InputT> mainInput, List<PCollectionView<?>> sideInputs, TupleTag<OutputT> mainOutputTag, List<TupleTag<?>> additionalOutputTags, @@ -147,7 +144,8 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator evaluationContext, stepContext, application, - mainInput.getWindowingStrategy(), + ((PCollection<InputT>) Iterables.getOnlyElement(application.getInputs().values())) + .getWindowingStrategy(), fn, key, sideInputs, @@ -175,4 +173,5 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator } return pcs; } + } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index 891d102..858ea34 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -19,17 +19,15 @@ package org.apache.beam.runners.direct; import static com.google.common.base.Preconditions.checkState; -import java.io.IOException; -import java.util.List; import java.util.Map; import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.KeyedWorkItemCoder; import org.apache.beam.runners.core.KeyedWorkItems; import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.PTransformTranslation; -import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.core.construction.ReplacementOutputs; import org.apache.beam.runners.core.construction.SplittableParDo; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.runners.AppliedPTransform; @@ -38,6 +36,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.ParDo.MultiOutput; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.AfterPane; @@ -49,8 +48,6 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.PCollectionViews; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -63,48 +60,36 @@ import org.apache.beam.sdk.values.WindowingStrategy; */ class ParDoMultiOverrideFactory<InputT, OutputT> implements PTransformOverrideFactory< - PCollection<? extends InputT>, PCollectionTuple, - PTransform<PCollection<? extends InputT>, PCollectionTuple>> { + PCollection<? extends InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>> { @Override public PTransformReplacement<PCollection<? extends InputT>, PCollectionTuple> getReplacementTransform( AppliedPTransform< - PCollection<? extends InputT>, PCollectionTuple, - PTransform<PCollection<? extends InputT>, PCollectionTuple>> - application) { - - try { - return PTransformReplacement.of( - PTransformReplacements.getSingletonMainInput(application), - getReplacementForApplication(application)); - } catch (IOException exc) { - throw new RuntimeException(exc); - } + PCollection<? extends InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>> + transform) { + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + getReplacementTransform(transform.getTransform())); } @SuppressWarnings("unchecked") - private PTransform<PCollection<? extends InputT>, PCollectionTuple> getReplacementForApplication( - AppliedPTransform< - PCollection<? extends InputT>, PCollectionTuple, - PTransform<PCollection<? extends InputT>, PCollectionTuple>> - application) - throws IOException { - - DoFn<InputT, OutputT> fn = (DoFn<InputT, OutputT>) ParDoTranslation.getDoFn(application); + private PTransform<PCollection<? extends InputT>, PCollectionTuple> getReplacementTransform( + MultiOutput<InputT, OutputT> transform) { + DoFn<InputT, OutputT> fn = transform.getFn(); DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); - if (signature.processElement().isSplittable()) { - return (PTransform) SplittableParDo.forAppliedParDo(application); + return new SplittableParDo(transform); } else if (signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0) { - return new GbkThenStatefulParDo( - fn, - ParDoTranslation.getMainOutputTag(application), - ParDoTranslation.getAdditionalOutputTags(application), - ParDoTranslation.getSideInputs(application)); + // Based on the fact that the signature is stateful, DoFnSignatures ensures + // that it is also keyed + MultiOutput<KV<?, ?>, OutputT> keyedTransform = + (MultiOutput<KV<?, ?>, OutputT>) transform; + + return new GbkThenStatefulParDo(keyedTransform); } else { - return application.getTransform(); + return transform; } } @@ -116,25 +101,10 @@ class ParDoMultiOverrideFactory<InputT, OutputT> static class GbkThenStatefulParDo<K, InputT, OutputT> extends PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> { - private final transient DoFn<KV<K, InputT>, OutputT> doFn; - private final TupleTagList additionalOutputTags; - private final TupleTag<OutputT> mainOutputTag; - private final List<PCollectionView<?>> sideInputs; - - public GbkThenStatefulParDo( - DoFn<KV<K, InputT>, OutputT> doFn, - TupleTag<OutputT> mainOutputTag, - TupleTagList additionalOutputTags, - List<PCollectionView<?>> sideInputs) { - this.doFn = doFn; - this.additionalOutputTags = additionalOutputTags; - this.mainOutputTag = mainOutputTag; - this.sideInputs = sideInputs; - } + private final MultiOutput<KV<K, InputT>, OutputT> underlyingParDo; - @Override - public Map<TupleTag<?>, PValue> getAdditionalInputs() { - return PCollectionViews.toAdditionalInputs(sideInputs); + public GbkThenStatefulParDo(MultiOutput<KV<K, InputT>, OutputT> underlyingParDo) { + this.underlyingParDo = underlyingParDo; } @Override @@ -190,9 +160,7 @@ class ParDoMultiOverrideFactory<InputT, OutputT> adjustedInput // Explode the resulting iterable into elements that are exactly the ones from // the input - .apply( - "Stateful ParDo", - new StatefulParDo<>(doFn, mainOutputTag, additionalOutputTags, sideInputs)); + .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, input)); return outputs; } @@ -204,41 +172,25 @@ class ParDoMultiOverrideFactory<InputT, OutputT> static class StatefulParDo<K, InputT, OutputT> extends PTransformTranslation.RawPTransform< PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>>, PCollectionTuple> { - private final transient DoFn<KV<K, InputT>, OutputT> doFn; - private final TupleTagList additionalOutputTags; - private final TupleTag<OutputT> mainOutputTag; - private final List<PCollectionView<?>> sideInputs; + private final transient MultiOutput<KV<K, InputT>, OutputT> underlyingParDo; + private final transient PCollection<KV<K, InputT>> originalInput; public StatefulParDo( - DoFn<KV<K, InputT>, OutputT> doFn, - TupleTag<OutputT> mainOutputTag, - TupleTagList additionalOutputTags, - List<PCollectionView<?>> sideInputs) { - this.doFn = doFn; - this.mainOutputTag = mainOutputTag; - this.additionalOutputTags = additionalOutputTags; - this.sideInputs = sideInputs; - } - - public DoFn<KV<K, InputT>, OutputT> getDoFn() { - return doFn; - } - - public TupleTag<OutputT> getMainOutputTag() { - return mainOutputTag; - } - - public List<PCollectionView<?>> getSideInputs() { - return sideInputs; + MultiOutput<KV<K, InputT>, OutputT> underlyingParDo, + PCollection<KV<K, InputT>> originalInput) { + this.underlyingParDo = underlyingParDo; + this.originalInput = originalInput; } - public TupleTagList getAdditionalOutputTags() { - return additionalOutputTags; + public MultiOutput<KV<K, InputT>, OutputT> getUnderlyingParDo() { + return underlyingParDo; } @Override - public Map<TupleTag<?>, PValue> getAdditionalInputs() { - return PCollectionViews.toAdditionalInputs(sideInputs); + public <T> Coder<T> getDefaultOutputCoder( + PCollection<? extends KeyedWorkItem<K, KV<K, InputT>>> input, PCollection<T> output) + throws CannotProvideCoderException { + return underlyingParDo.getDefaultOutputCoder(originalInput, output); } @Override @@ -247,7 +199,8 @@ class ParDoMultiOverrideFactory<InputT, OutputT> PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal( input.getPipeline(), - TupleTagList.of(getMainOutputTag()).and(getAdditionalOutputTags().getAll()), + TupleTagList.of(underlyingParDo.getMainOutputTag()) + .and(underlyingParDo.getAdditionalOutputTags().getAll()), input.getWindowingStrategy(), input.isBounded()); http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java index e6b51b7..b85f481c 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java @@ -35,6 +35,7 @@ import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateInternalsFactory; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.TimerInternalsFactory; +import org.apache.beam.runners.core.construction.ElementAndRestriction; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.DoFn; @@ -42,7 +43,6 @@ import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; @@ -54,7 +54,8 @@ import org.joda.time.Instant; class SplittableProcessElementsEvaluatorFactory< InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT>> implements TransformEvaluatorFactory { - private final ParDoEvaluatorFactory<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> + private final ParDoEvaluatorFactory< + KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>, OutputT> delegateFactory; private final EvaluationContext evaluationContext; @@ -83,13 +84,14 @@ class SplittableProcessElementsEvaluatorFactory< } @SuppressWarnings({"unchecked", "rawtypes"}) - private TransformEvaluator<KeyedWorkItem<String, KV<InputT, RestrictionT>>> createEvaluator( - AppliedPTransform< - PCollection<KeyedWorkItem<String, KV<InputT, RestrictionT>>>, PCollectionTuple, - ProcessElements<InputT, OutputT, RestrictionT, TrackerT>> - application, - CommittedBundle<InputT> inputBundle) - throws Exception { + private TransformEvaluator<KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>> + createEvaluator( + AppliedPTransform< + PCollection<KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>>, + PCollectionTuple, ProcessElements<InputT, OutputT, RestrictionT, TrackerT>> + application, + CommittedBundle<InputT> inputBundle) + throws Exception { final ProcessElements<InputT, OutputT, RestrictionT, TrackerT> transform = application.getTransform(); @@ -99,7 +101,9 @@ class SplittableProcessElementsEvaluatorFactory< DoFnLifecycleManager fnManager = DoFnLifecycleManager.of(processFn); processFn = ((ProcessFn<InputT, OutputT, RestrictionT, TrackerT>) - fnManager.<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT>get()); + fnManager + .<KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>, OutputT> + get()); String stepName = evaluationContext.getStepName(application); final DirectExecutionContext.DirectStepContext stepContext = @@ -107,13 +111,11 @@ class SplittableProcessElementsEvaluatorFactory< .getExecutionContext(application, inputBundle.getKey()) .getStepContext(stepName); - final ParDoEvaluator<KeyedWorkItem<String, KV<InputT, RestrictionT>>> + final ParDoEvaluator<KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>> parDoEvaluator = delegateFactory.createParDoEvaluator( application, inputBundle.getKey(), - (PCollection<KeyedWorkItem<String, KV<InputT, RestrictionT>>>) - inputBundle.getPCollection(), transform.getSideInputs(), transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), @@ -185,16 +187,17 @@ class SplittableProcessElementsEvaluatorFactory< } private static <InputT, OutputT, RestrictionT> - ParDoEvaluator.DoFnRunnerFactory<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> + ParDoEvaluator.DoFnRunnerFactory< + KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>, OutputT> processFnRunnerFactory() { return new ParDoEvaluator.DoFnRunnerFactory< - KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT>() { + KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>, OutputT>() { @Override public PushbackSideInputDoFnRunner< - KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> + KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>, OutputT> createRunner( PipelineOptions options, - DoFn<KeyedWorkItem<String, KV<InputT, RestrictionT>>, OutputT> fn, + DoFn<KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>, OutputT> fn, List<PCollectionView<?>> sideInputs, ReadyCheckingSideInputReader sideInputReader, OutputManager outputManager, http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java index bdec9c8..506c84c 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java @@ -98,7 +98,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo throws Exception { final DoFn<KV<K, InputT>, OutputT> doFn = - application.getTransform().getDoFn(); + application.getTransform().getUnderlyingParDo().getFn(); final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); // If the DoFn is stateful, schedule state clearing. @@ -117,12 +117,11 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo DoFnLifecycleManagerRemovingTransformEvaluator<KV<K, InputT>> delegateEvaluator = delegateFactory.createEvaluator( (AppliedPTransform) application, - (PCollection) inputBundle.getPCollection(), inputBundle.getKey(), doFn, - application.getTransform().getSideInputs(), - application.getTransform().getMainOutputTag(), - application.getTransform().getAdditionalOutputTags().getAll()); + application.getTransform().getUnderlyingParDo().getSideInputs(), + application.getTransform().getUnderlyingParDo().getMainOutputTag(), + application.getTransform().getUnderlyingParDo().getAdditionalOutputTags().getAll()); return new StatefulParDoEvaluator<>(delegateEvaluator); } @@ -152,11 +151,12 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo transformOutputWindow .getTransform() .getTransform() + .getUnderlyingParDo() .getMainOutputTag()); WindowingStrategy<?, ?> windowingStrategy = pc.getWindowingStrategy(); BoundedWindow window = transformOutputWindow.getWindow(); final DoFn<?, ?> doFn = - transformOutputWindow.getTransform().getTransform().getDoFn(); + transformOutputWindow.getTransform().getTransform().getUnderlyingParDo().getFn(); final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); final DirectStepContext stepContext = http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java index 16c8589..2da7a71 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java @@ -22,7 +22,6 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; import com.google.common.collect.Iterables; -import java.io.IOException; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -31,7 +30,6 @@ import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.ReplacementOutputs; -import org.apache.beam.runners.core.construction.TestStreamTranslation; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.testing.TestStream; @@ -162,8 +160,7 @@ class TestStreamEvaluatorFactory implements TransformEvaluatorFactory { } static class DirectTestStreamFactory<T> - implements PTransformOverrideFactory< - PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>> { + implements PTransformOverrideFactory<PBegin, PCollection<T>, TestStream<T>> { private final DirectRunner runner; DirectTestStreamFactory(DirectRunner runner) { @@ -172,17 +169,10 @@ class TestStreamEvaluatorFactory implements TransformEvaluatorFactory { @Override public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform( - AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>> transform) { - try { - return PTransformReplacement.of( - transform.getPipeline().begin(), - new DirectTestStream<T>(runner, TestStreamTranslation.getTestStream(transform))); - } catch (IOException exc) { - throw new RuntimeException( - String.format( - "Transform could not be converted to %s", TestStream.class.getSimpleName()), - exc); - } + AppliedPTransform<PBegin, PCollection<T>, TestStream<T>> transform) { + return PTransformReplacement.of( + transform.getPipeline().begin(), + new DirectTestStream<T>(runner, transform.getTransform())); } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java index 8a281a7..057f4a1 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewEvaluatorFactory.java @@ -28,6 +28,7 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; /** * The {@link DirectRunner} {@link TransformEvaluatorFactory} for the {@link CreatePCollectionView} @@ -59,13 +60,12 @@ class ViewEvaluatorFactory implements TransformEvaluatorFactory { public void cleanup() throws Exception {} private <InT, OuT> TransformEvaluator<Iterable<InT>> createEvaluator( - final AppliedPTransform< - PCollection<Iterable<InT>>, PCollection<Iterable<InT>>, WriteView<InT, OuT>> + final AppliedPTransform<PCollection<Iterable<InT>>, PCollectionView<OuT>, WriteView<InT, OuT>> application) { PCollection<Iterable<InT>> input = (PCollection<Iterable<InT>>) Iterables.getOnlyElement(application.getInputs().values()); - final PCollectionViewWriter<InT, OuT> writer = - context.createPCollectionViewWriter(input, application.getTransform().getView()); + final PCollectionViewWriter<InT, OuT> writer = context.createPCollectionViewWriter(input, + (PCollectionView<OuT>) Iterables.getOnlyElement(application.getOutputs().values())); return new TransformEvaluator<Iterable<InT>>() { private final List<WindowedValue<InT>> elements = new ArrayList<>(); http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java index 5dcf016..fdff63d 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java @@ -18,12 +18,11 @@ package org.apache.beam.runners.direct; -import java.io.IOException; +import java.util.Collections; import java.util.Map; -import org.apache.beam.runners.core.construction.CreatePCollectionViewTranslation; +import org.apache.beam.runners.core.construction.ForwardingPTransform; import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform; -import org.apache.beam.runners.core.construction.ReplacementOutputs; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.runners.AppliedPTransform; @@ -44,56 +43,46 @@ import org.apache.beam.sdk.values.TupleTag; */ class ViewOverrideFactory<ElemT, ViewT> implements PTransformOverrideFactory< - PCollection<ElemT>, PCollection<ElemT>, - PTransform<PCollection<ElemT>, PCollection<ElemT>>> { + PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT, ViewT>> { @Override - public PTransformReplacement<PCollection<ElemT>, PCollection<ElemT>> getReplacementTransform( + public PTransformReplacement<PCollection<ElemT>, PCollectionView<ViewT>> getReplacementTransform( AppliedPTransform< - PCollection<ElemT>, PCollection<ElemT>, - PTransform<PCollection<ElemT>, PCollection<ElemT>>> + PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT, ViewT>> transform) { - - PCollectionView<ViewT> view; - try { - view = CreatePCollectionViewTranslation.getView(transform); - } catch (IOException exc) { - throw new RuntimeException( - String.format( - "Could not extract %s from transform %s", - PCollectionView.class.getSimpleName(), transform), - exc); - } - - return PTransformReplacement.of( + return PTransformReplacement.of( PTransformReplacements.getSingletonMainInput(transform), - new GroupAndWriteView<ElemT, ViewT>(view)); + new GroupAndWriteView<>(transform.getTransform())); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - Map<TupleTag<?>, PValue> outputs, PCollection<ElemT> newOutput) { - return ReplacementOutputs.singleton(outputs, newOutput); + Map<TupleTag<?>, PValue> outputs, PCollectionView<ViewT> newOutput) { + return Collections.emptyMap(); } /** The {@link DirectRunner} composite override for {@link CreatePCollectionView}. */ static class GroupAndWriteView<ElemT, ViewT> - extends PTransform<PCollection<ElemT>, PCollection<ElemT>> { - private final PCollectionView<ViewT> view; + extends ForwardingPTransform<PCollection<ElemT>, PCollectionView<ViewT>> { + private final CreatePCollectionView<ElemT, ViewT> og; - private GroupAndWriteView(PCollectionView<ViewT> view) { - this.view = view; + private GroupAndWriteView(CreatePCollectionView<ElemT, ViewT> og) { + this.og = og; } @Override - public PCollection<ElemT> expand(final PCollection<ElemT> input) { - input + public PCollectionView<ViewT> expand(PCollection<ElemT> input) { + return input .apply(WithKeys.<Void, ElemT>of((Void) null)) .setCoder(KvCoder.of(VoidCoder.of(), input.getCoder())) .apply(GroupByKey.<Void, ElemT>create()) .apply(Values.<Iterable<ElemT>>create()) - .apply(new WriteView<ElemT, ViewT>(view)); - return input; + .apply(new WriteView<ElemT, ViewT>(og)); + } + + @Override + protected PTransform<PCollection<ElemT>, PCollectionView<ViewT>> delegate() { + return og; } } @@ -105,24 +94,22 @@ class ViewOverrideFactory<ElemT, ViewT> * to {@link ViewT}. */ static final class WriteView<ElemT, ViewT> - extends RawPTransform<PCollection<Iterable<ElemT>>, PCollection<Iterable<ElemT>>> { - private final PCollectionView<ViewT> view; + extends RawPTransform<PCollection<Iterable<ElemT>>, PCollectionView<ViewT>> { + private final CreatePCollectionView<ElemT, ViewT> og; - WriteView(PCollectionView<ViewT> view) { - this.view = view; + WriteView(CreatePCollectionView<ElemT, ViewT> og) { + this.og = og; } @Override @SuppressWarnings("deprecation") - public PCollection<Iterable<ElemT>> expand(PCollection<Iterable<ElemT>> input) { - return PCollection.<Iterable<ElemT>>createPrimitiveOutputInternal( - input.getPipeline(), input.getWindowingStrategy(), input.isBounded()) - .setCoder(input.getCoder()); + public PCollectionView<ViewT> expand(PCollection<Iterable<ElemT>> input) { + return og.getView(); } @SuppressWarnings("deprecation") public PCollectionView<ViewT> getView() { - return view; + return og.getView(); } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java index 599b74f..40ce163 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java @@ -54,7 +54,6 @@ import javax.annotation.concurrent.GuardedBy; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.TimerInternals.TimerData; -import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.state.TimeDomain; @@ -63,6 +62,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Instant; /** @@ -831,11 +831,11 @@ class WatermarkManager { private Collection<Watermark> getInputProcessingWatermarks(AppliedPTransform<?, ?, ?> transform) { ImmutableList.Builder<Watermark> inputWmsBuilder = ImmutableList.builder(); - Collection<PValue> inputs = TransformInputs.nonAdditionalInputs(transform); + Map<TupleTag<?>, PValue> inputs = transform.getInputs(); if (inputs.isEmpty()) { inputWmsBuilder.add(THE_END_OF_TIME); } - for (PValue pvalue : inputs) { + for (PValue pvalue : inputs.values()) { Watermark producerOutputWatermark = getValueWatermark(pvalue).synchronizedProcessingOutputWatermark; inputWmsBuilder.add(producerOutputWatermark); @@ -845,11 +845,11 @@ class WatermarkManager { private List<Watermark> getInputWatermarks(AppliedPTransform<?, ?, ?> transform) { ImmutableList.Builder<Watermark> inputWatermarksBuilder = ImmutableList.builder(); - Collection< PValue> inputs = TransformInputs.nonAdditionalInputs(transform); + Map<TupleTag<?>, PValue> inputs = transform.getInputs(); if (inputs.isEmpty()) { inputWatermarksBuilder.add(THE_END_OF_TIME); } - for (PValue pvalue : inputs) { + for (PValue pvalue : inputs.values()) { Watermark producerOutputWatermark = getValueWatermark(pvalue).outputWatermark; inputWatermarksBuilder.add(producerOutputWatermark); } @@ -987,16 +987,16 @@ class WatermarkManager { // refresh. for (CommittedBundle<?> bundle : result.getOutputs()) { for (AppliedPTransform<?, ?, ?> consumer : - graph.getPerElementConsumers(bundle.getPCollection())) { + graph.getPrimitiveConsumers(bundle.getPCollection())) { TransformWatermarks watermarks = transformToWatermarks.get(consumer); watermarks.addPending(bundle); } } TransformWatermarks completedTransform = transformToWatermarks.get(result.getTransform()); - if (result.getUnprocessedInputs().isPresent()) { + if (input != null) { // Add the unprocessed inputs - completedTransform.addPending(result.getUnprocessedInputs().get()); + completedTransform.addPending(result.getUnprocessedInputs()); } completedTransform.updateTimers(timerUpdate); if (input != null) { @@ -1035,7 +1035,7 @@ class WatermarkManager { if (updateResult.isAdvanced()) { Set<AppliedPTransform<?, ?, ?>> additionalRefreshes = new HashSet<>(); for (PValue outputPValue : toRefresh.getOutputs().values()) { - additionalRefreshes.addAll(graph.getPerElementConsumers(outputPValue)); + additionalRefreshes.addAll(graph.getPrimitiveConsumers(outputPValue)); } return additionalRefreshes; } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java index ba796ae..65a5a19 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java @@ -21,13 +21,11 @@ package org.apache.beam.runners.direct; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; import com.google.common.base.Suppliers; -import java.io.IOException; import java.io.Serializable; import java.util.Collections; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; import org.apache.beam.runners.core.construction.PTransformReplacements; -import org.apache.beam.runners.core.construction.WriteFilesTranslation; import org.apache.beam.sdk.io.WriteFiles; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverrideFactory; @@ -45,35 +43,23 @@ import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; /** - * A {@link PTransformOverrideFactory} that overrides {@link WriteFiles} {@link PTransform - * PTransforms} with an unspecified number of shards with a write with a specified number of shards. - * The number of shards is the log base 10 of the number of input records, with up to 2 additional - * shards. + * A {@link PTransformOverrideFactory} that overrides {@link WriteFiles} + * {@link PTransform PTransforms} with an unspecified number of shards with a write with a + * specified number of shards. The number of shards is the log base 10 of the number of input + * records, with up to 2 additional shards. */ class WriteWithShardingFactory<InputT> - implements PTransformOverrideFactory< - PCollection<InputT>, PDone, PTransform<PCollection<InputT>, PDone>> { + implements PTransformOverrideFactory<PCollection<InputT>, PDone, WriteFiles<InputT>> { static final int MAX_RANDOM_EXTRA_SHARDS = 3; @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3; @Override public PTransformReplacement<PCollection<InputT>, PDone> getReplacementTransform( - AppliedPTransform<PCollection<InputT>, PDone, PTransform<PCollection<InputT>, PDone>> - transform) { - try { - WriteFiles<InputT, ?, ?> replacement = - WriteFiles.to( - WriteFilesTranslation.getSink(transform), - WriteFilesTranslation.getFormatFunction(transform)); - if (WriteFilesTranslation.isWindowedWrites(transform)) { - replacement = replacement.withWindowedWrites(); - } - return PTransformReplacement.of( - PTransformReplacements.getSingletonMainInput(transform), - replacement.withSharding(new LogElementShardsWithDrift<InputT>())); - } catch (IOException e) { - throw new RuntimeException(e); - } + AppliedPTransform<PCollection<InputT>, PDone, WriteFiles<InputT>> transform) { + + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + transform.getTransform().withSharding(new LogElementShardsWithDrift<InputT>())); } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java index 8b95b34..cf19dc2 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CommittedResultTest.java @@ -18,9 +18,9 @@ package org.apache.beam.runners.direct; +import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; -import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; import java.io.Serializable; import java.util.Collections; @@ -72,7 +72,7 @@ public class CommittedResultTest implements Serializable { CommittedResult result = CommittedResult.create( StepTransformResult.withoutHold(transform).build(), - Optional.<CommittedBundle<?>>absent(), + bundleFactory.createBundle(created).commit(Instant.now()), Collections.<CommittedBundle<?>>emptyList(), EnumSet.noneOf(OutputType.class)); @@ -88,11 +88,11 @@ public class CommittedResultTest implements Serializable { CommittedResult result = CommittedResult.create( StepTransformResult.withoutHold(transform).build(), - Optional.of(bundle), + bundle, Collections.<CommittedBundle<?>>emptyList(), EnumSet.noneOf(OutputType.class)); - assertThat(result.getUnprocessedInputs().get(), + assertThat(result.getUnprocessedInputs(), Matchers.<CommittedBundle<?>>equalTo(bundle)); } @@ -101,14 +101,11 @@ public class CommittedResultTest implements Serializable { CommittedResult result = CommittedResult.create( StepTransformResult.withoutHold(transform).build(), - Optional.<CommittedBundle<?>>absent(), + null, Collections.<CommittedBundle<?>>emptyList(), EnumSet.noneOf(OutputType.class)); - assertThat( - result.getUnprocessedInputs(), - Matchers.<Optional<? extends CommittedBundle<?>>>equalTo( - Optional.<CommittedBundle<?>>absent())); + assertThat(result.getUnprocessedInputs(), nullValue()); } @Test @@ -123,7 +120,7 @@ public class CommittedResultTest implements Serializable { CommittedResult result = CommittedResult.create( StepTransformResult.withoutHold(transform).build(), - Optional.<CommittedBundle<?>>absent(), + bundleFactory.createBundle(created).commit(Instant.now()), outputs, EnumSet.of(OutputType.BUNDLE, OutputType.PCOLLECTION_VIEW)); http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java index bf3e83e..576edf3 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java @@ -151,13 +151,13 @@ public class DirectGraphVisitorTest implements Serializable { graph.getProducer(flattened); assertThat( - graph.getPerElementConsumers(created), + graph.getPrimitiveConsumers(created), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( transformedProducer, flattenedProducer)); assertThat( - graph.getPerElementConsumers(transformed), + graph.getPrimitiveConsumers(transformed), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder(flattenedProducer)); - assertThat(graph.getPerElementConsumers(flattened), emptyIterable()); + assertThat(graph.getPrimitiveConsumers(flattened), emptyIterable()); } @Test @@ -173,10 +173,10 @@ public class DirectGraphVisitorTest implements Serializable { AppliedPTransform<?, ?, ?> flattenedProducer = graph.getProducer(flattened); assertThat( - graph.getPerElementConsumers(created), + graph.getPrimitiveConsumers(created), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder(flattenedProducer, flattenedProducer)); - assertThat(graph.getPerElementConsumers(flattened), emptyIterable()); + assertThat(graph.getPrimitiveConsumers(flattened), emptyIterable()); } @Test http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphs.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphs.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphs.java index 7707f7f..43de091 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphs.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphs.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.direct; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -26,12 +25,6 @@ import org.apache.beam.sdk.values.PValue; /** Test utilities for the {@link DirectRunner}. */ final class DirectGraphs { - public static void performDirectOverrides(Pipeline p) { - p.replaceAll( - DirectRunner.fromOptions(PipelineOptionsFactory.create().as(DirectOptions.class)) - .defaultTransformOverrides()); - } - public static DirectGraph getGraph(Pipeline p) { DirectGraphVisitor visitor = new DirectGraphVisitor(); p.traverseTopologically(visitor); http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRegistrarTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRegistrarTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRegistrarTest.java index 4b909bc..603e43e 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRegistrarTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRegistrarTest.java @@ -37,7 +37,7 @@ public class DirectRegistrarTest { @Test public void testCorrectOptionsAreReturned() { assertEquals( - ImmutableList.of(DirectOptions.class, DirectTestOptions.class), + ImmutableList.of(DirectOptions.class), new Options().getPipelineOptions()); } http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java index 699a318..c0e43d6 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java @@ -101,13 +101,14 @@ public class EvaluationContextTest { view = created.apply(View.<Integer>asIterable()); unbounded = p.apply(GenerateSequence.from(0)); - p.replaceAll(runner.defaultTransformOverrides()); + p.replaceAll( + DirectRunner.fromOptions(TestPipeline.testingPipelineOptions()) + .defaultTransformOverrides()); KeyedPValueTrackingVisitor keyedPValueTrackingVisitor = KeyedPValueTrackingVisitor.create(); p.traverseTopologically(keyedPValueTrackingVisitor); BundleFactory bundleFactory = ImmutableListBundleFactory.create(); - DirectGraphs.performDirectOverrides(p); graph = DirectGraphs.getGraph(p); context = EvaluationContext.create( @@ -414,7 +415,7 @@ public class EvaluationContextTest { StepTransformResult.withoutHold(unboundedProducer).build()); assertThat(context.isDone(), is(false)); - for (AppliedPTransform<?, ?, ?> consumers : graph.getPerElementConsumers(created)) { + for (AppliedPTransform<?, ?, ?> consumers : graph.getPrimitiveConsumers(created)) { context.handleResult( committedBundle, ImmutableList.<TimerData>of(),
