Move aggregator support classes out of runners namespace, make private
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/adec254d Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/adec254d Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/adec254d Branch: refs/heads/gearpump-runner Commit: adec254d5fdb409e786a1fc2bcee38f8a7a04408 Parents: 9da4bbc Author: Kenneth Knowles <[email protected]> Authored: Fri Jul 1 14:56:20 2016 -0700 Committer: Kenneth Knowles <[email protected]> Committed: Wed Aug 10 11:34:03 2016 -0700 ---------------------------------------------------------------------- .../beam/runners/direct/DirectRunner.java | 7 +- .../beam/runners/flink/FlinkRunnerResult.java | 4 +- .../runners/dataflow/DataflowPipelineJob.java | 4 +- .../beam/runners/dataflow/DataflowRunner.java | 4 +- .../dataflow/DataflowPipelineJobTest.java | 4 +- .../spark/translation/EvaluationContext.java | 4 +- .../spark/translation/SparkRuntimeContext.java | 2 +- .../translation/MultiOutputWordCountTest.java | 2 +- .../beam/sdk/AggregatorPipelineExtractor.java | 93 ++++++++ .../beam/sdk/AggregatorRetrievalException.java | 33 +++ .../org/apache/beam/sdk/AggregatorValues.java | 52 +++++ .../main/java/org/apache/beam/sdk/Pipeline.java | 10 + .../org/apache/beam/sdk/PipelineResult.java | 2 - .../runners/AggregatorPipelineExtractor.java | 93 -------- .../runners/AggregatorRetrievalException.java | 33 --- .../beam/sdk/runners/AggregatorValues.java | 52 ----- .../sdk/AggregatorPipelineExtractorTest.java | 229 +++++++++++++++++++ .../AggregatorPipelineExtractorTest.java | 229 ------------------- .../apache/beam/sdk/transforms/DoFnTest.java | 1 + .../apache/beam/sdk/transforms/OldDoFnTest.java | 3 +- 20 files changed, 434 insertions(+), 427 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index a9c8ecb..f2b781e 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -20,15 +20,14 @@ package org.apache.beam.runners.direct; import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly; import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult; import org.apache.beam.runners.direct.ViewEvaluatorFactory.ViewOverrideFactory; +import org.apache.beam.sdk.AggregatorRetrievalException; +import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.io.Write; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.runners.AggregatorPipelineExtractor; -import org.apache.beam.sdk.runners.AggregatorRetrievalException; -import org.apache.beam.sdk.runners.AggregatorValues; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.AppliedPTransform; @@ -244,7 +243,7 @@ public class DirectRunner executor.start(consumerTrackingVisitor.getRootTransforms()); Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = - new AggregatorPipelineExtractor(pipeline).getAggregatorSteps(); + pipeline.getAggregatorSteps(); DirectPipelineResult result = new DirectPipelineResult(executor, context, aggregatorSteps); if (options.isBlockOnRun()) { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java index cae0b2a..923d54c 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkRunnerResult.java @@ -18,8 +18,8 @@ package org.apache.beam.runners.flink; import org.apache.beam.sdk.PipelineResult; -import org.apache.beam.sdk.runners.AggregatorRetrievalException; -import org.apache.beam.sdk.runners.AggregatorValues; +import org.apache.beam.sdk.AggregatorRetrievalException; +import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.transforms.Aggregator; import org.joda.time.Duration; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java index a6baa4f..e043e23 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineJob.java @@ -23,9 +23,9 @@ import org.apache.beam.runners.dataflow.internal.DataflowAggregatorTransforms; import org.apache.beam.runners.dataflow.internal.DataflowMetricUpdateExtractor; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; import org.apache.beam.runners.dataflow.util.MonitoringUtil; +import org.apache.beam.sdk.AggregatorRetrievalException; +import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.PipelineResult; -import org.apache.beam.sdk.runners.AggregatorRetrievalException; -import org.apache.beam.sdk.runners.AggregatorValues; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.util.AttemptAndTimeBoundedExponentialBackOff; import org.apache.beam.sdk.util.AttemptBoundedExponentialBackOff; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index fadd9c7..3b68e92 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -71,7 +71,6 @@ import org.apache.beam.sdk.io.Write; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsValidator; import org.apache.beam.sdk.options.StreamingOptions; -import org.apache.beam.sdk.runners.AggregatorPipelineExtractor; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.runners.TransformTreeNode; import org.apache.beam.sdk.transforms.Aggregator; @@ -596,9 +595,8 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { // Obtain all of the extractors from the PTransforms used in the pipeline so the // DataflowPipelineJob has access to them. - AggregatorPipelineExtractor aggregatorExtractor = new AggregatorPipelineExtractor(pipeline); Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = - aggregatorExtractor.getAggregatorSteps(); + pipeline.getAggregatorSteps(); DataflowAggregatorTransforms aggregatorTransforms = new DataflowAggregatorTransforms(aggregatorSteps, jobSpecification.getStepNames()); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java index 343d538..e6277d9 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java @@ -35,10 +35,10 @@ import static org.mockito.Mockito.when; import org.apache.beam.runners.dataflow.internal.DataflowAggregatorTransforms; import org.apache.beam.runners.dataflow.testing.TestDataflowPipelineOptions; import org.apache.beam.runners.dataflow.util.MonitoringUtil; +import org.apache.beam.sdk.AggregatorRetrievalException; +import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.PipelineResult.State; import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.runners.AggregatorRetrievalException; -import org.apache.beam.sdk.runners.AggregatorValues; import org.apache.beam.sdk.testing.FastNanoClockAndSleeper; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.AppliedPTransform; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 169c2af..4ccac0e 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -22,10 +22,10 @@ import static com.google.common.base.Preconditions.checkArgument; import org.apache.beam.runners.spark.EvaluationResult; import org.apache.beam.runners.spark.coders.CoderHelpers; +import org.apache.beam.sdk.AggregatorRetrievalException; +import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.runners.AggregatorRetrievalException; -import org.apache.beam.sdk.runners.AggregatorValues; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java index 46f5b33..c2edd02 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkRuntimeContext.java @@ -20,12 +20,12 @@ package org.apache.beam.runners.spark.translation; import org.apache.beam.runners.spark.aggregators.AggAccumParam; import org.apache.beam.runners.spark.aggregators.NamedAggregators; +import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.runners.AggregatorValues; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Max; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java index 291f7b2..0d0c0b4 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java @@ -20,11 +20,11 @@ package org.apache.beam.runners.spark.translation; import org.apache.beam.runners.spark.EvaluationResult; import org.apache.beam.runners.spark.SparkRunner; +import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.runners.AggregatorValues; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.ApproximateUnique; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorPipelineExtractor.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorPipelineExtractor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorPipelineExtractor.java new file mode 100644 index 0000000..ac215c9 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorPipelineExtractor.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk; + +import org.apache.beam.sdk.Pipeline.PipelineVisitor; +import org.apache.beam.sdk.runners.TransformTreeNode; +import org.apache.beam.sdk.transforms.Aggregator; +import org.apache.beam.sdk.transforms.AggregatorRetriever; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PValue; + +import com.google.common.collect.HashMultimap; +import com.google.common.collect.SetMultimap; + +import java.util.Collection; +import java.util.Collections; +import java.util.Map; + +/** + * Retrieves {@link Aggregator Aggregators} at each {@link ParDo} and returns a {@link Map} of + * {@link Aggregator} to the {@link PTransform PTransforms} in which it is present. + */ +class AggregatorPipelineExtractor { + private final Pipeline pipeline; + + /** + * Creates an {@code AggregatorPipelineExtractor} for the given {@link Pipeline}. + */ + public AggregatorPipelineExtractor(Pipeline pipeline) { + this.pipeline = pipeline; + } + + /** + * Returns a {@link Map} between each {@link Aggregator} in the {@link Pipeline} to the {@link + * PTransform PTransforms} in which it is used. + */ + public Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> getAggregatorSteps() { + HashMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps = HashMultimap.create(); + pipeline.traverseTopologically(new AggregatorVisitor(aggregatorSteps)); + return aggregatorSteps.asMap(); + } + + private static class AggregatorVisitor extends PipelineVisitor.Defaults { + private final SetMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps; + + public AggregatorVisitor(SetMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps) { + this.aggregatorSteps = aggregatorSteps; + } + + @Override + public void visitPrimitiveTransform(TransformTreeNode node) { + PTransform<?, ?> transform = node.getTransform(); + addStepToAggregators(transform, getAggregators(transform)); + } + + private Collection<Aggregator<?, ?>> getAggregators(PTransform<?, ?> transform) { + if (transform != null) { + if (transform instanceof ParDo.Bound) { + return AggregatorRetriever.getAggregators(((ParDo.Bound<?, ?>) transform).getFn()); + } else if (transform instanceof ParDo.BoundMulti) { + return AggregatorRetriever.getAggregators(((ParDo.BoundMulti<?, ?>) transform).getFn()); + } + } + return Collections.emptyList(); + } + + private void addStepToAggregators( + PTransform<?, ?> transform, Collection<Aggregator<?, ?>> aggregators) { + for (Aggregator<?, ?> aggregator : aggregators) { + aggregatorSteps.put(aggregator, transform); + } + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) {} + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorRetrievalException.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorRetrievalException.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorRetrievalException.java new file mode 100644 index 0000000..3040815 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorRetrievalException.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk; + +import org.apache.beam.sdk.transforms.Aggregator; + +/** + * Signals that an exception has occurred while retrieving {@link Aggregator}s. + */ +public class AggregatorRetrievalException extends Exception { + /** + * Constructs a new {@code AggregatorRetrievalException} with the specified detail message and + * cause. + */ + public AggregatorRetrievalException(String message, Throwable cause) { + super(message, cause); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorValues.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorValues.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorValues.java new file mode 100644 index 0000000..efaad85 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/AggregatorValues.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk; + +import org.apache.beam.sdk.transforms.Aggregator; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.OldDoFn; + +import java.util.Collection; +import java.util.Map; + +/** + * A collection of values associated with an {@link Aggregator}. Aggregators declared in a + * {@link OldDoFn} are emitted on a per-{@code OldDoFn}-application basis. + * + * @param <T> the output type of the aggregator + */ +public abstract class AggregatorValues<T> { + /** + * Get the values of the {@link Aggregator} at all steps it was used. + */ + public Collection<T> getValues() { + return getValuesAtSteps().values(); + } + + /** + * Get the values of the {@link Aggregator} by the user name at each step it was used. + */ + public abstract Map<String, T> getValuesAtSteps(); + + /** + * Get the total value of this {@link Aggregator} by applying the specified {@link CombineFn}. + */ + public T getTotalValue(CombineFn<T, ?, T> combineFn) { + return combineFn.apply(getValues()); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java index e4f3e4a..1bbc56f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.runners.TransformTreeNode; +import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.PTransform; @@ -47,6 +48,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -518,6 +520,14 @@ public class Pipeline { } /** + * Returns a {@link Map} from each {@link Aggregator} in the {@link Pipeline} to the {@link + * PTransform PTransforms} in which it is used. + */ + public Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> getAggregatorSteps() { + return new AggregatorPipelineExtractor(this).getAggregatorSteps(); + } + + /** * Builds a name from a "/"-delimited prefix and a name. */ private String buildName(String namePrefix, String name) { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java index 993962c..edfc924 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/PipelineResult.java @@ -17,8 +17,6 @@ */ package org.apache.beam.sdk; -import org.apache.beam.sdk.runners.AggregatorRetrievalException; -import org.apache.beam.sdk.runners.AggregatorValues; import org.apache.beam.sdk.transforms.Aggregator; import org.joda.time.Duration; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java deleted file mode 100644 index 146ddfa..0000000 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractor.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.runners; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.Pipeline.PipelineVisitor; -import org.apache.beam.sdk.transforms.Aggregator; -import org.apache.beam.sdk.transforms.AggregatorRetriever; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PValue; - -import com.google.common.collect.HashMultimap; -import com.google.common.collect.SetMultimap; - -import java.util.Collection; -import java.util.Collections; -import java.util.Map; - -/** - * Retrieves {@link Aggregator Aggregators} at each {@link ParDo} and returns a {@link Map} of - * {@link Aggregator} to the {@link PTransform PTransforms} in which it is present. - */ -public class AggregatorPipelineExtractor { - private final Pipeline pipeline; - - /** - * Creates an {@code AggregatorPipelineExtractor} for the given {@link Pipeline}. - */ - public AggregatorPipelineExtractor(Pipeline pipeline) { - this.pipeline = pipeline; - } - - /** - * Returns a {@link Map} between each {@link Aggregator} in the {@link Pipeline} to the {@link - * PTransform PTransforms} in which it is used. - */ - public Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> getAggregatorSteps() { - HashMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps = HashMultimap.create(); - pipeline.traverseTopologically(new AggregatorVisitor(aggregatorSteps)); - return aggregatorSteps.asMap(); - } - - private static class AggregatorVisitor extends PipelineVisitor.Defaults { - private final SetMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps; - - public AggregatorVisitor(SetMultimap<Aggregator<?, ?>, PTransform<?, ?>> aggregatorSteps) { - this.aggregatorSteps = aggregatorSteps; - } - - @Override - public void visitPrimitiveTransform(TransformTreeNode node) { - PTransform<?, ?> transform = node.getTransform(); - addStepToAggregators(transform, getAggregators(transform)); - } - - private Collection<Aggregator<?, ?>> getAggregators(PTransform<?, ?> transform) { - if (transform != null) { - if (transform instanceof ParDo.Bound) { - return AggregatorRetriever.getAggregators(((ParDo.Bound<?, ?>) transform).getFn()); - } else if (transform instanceof ParDo.BoundMulti) { - return AggregatorRetriever.getAggregators(((ParDo.BoundMulti<?, ?>) transform).getFn()); - } - } - return Collections.emptyList(); - } - - private void addStepToAggregators( - PTransform<?, ?> transform, Collection<Aggregator<?, ?>> aggregators) { - for (Aggregator<?, ?> aggregator : aggregators) { - aggregatorSteps.put(aggregator, transform); - } - } - - @Override - public void visitValue(PValue value, TransformTreeNode producer) {} - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorRetrievalException.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorRetrievalException.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorRetrievalException.java deleted file mode 100644 index a0973c3..0000000 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorRetrievalException.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.runners; - -import org.apache.beam.sdk.transforms.Aggregator; - -/** - * Signals that an exception has occurred while retrieving {@link Aggregator}s. - */ -public class AggregatorRetrievalException extends Exception { - /** - * Constructs a new {@code AggregatorRetrievalException} with the specified detail message and - * cause. - */ - public AggregatorRetrievalException(String message, Throwable cause) { - super(message, cause); - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorValues.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorValues.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorValues.java deleted file mode 100644 index 6f6836e..0000000 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/AggregatorValues.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.runners; - -import org.apache.beam.sdk.transforms.Aggregator; -import org.apache.beam.sdk.transforms.Combine.CombineFn; -import org.apache.beam.sdk.transforms.OldDoFn; - -import java.util.Collection; -import java.util.Map; - -/** - * A collection of values associated with an {@link Aggregator}. Aggregators declared in a - * {@link OldDoFn} are emitted on a per-{@code OldDoFn}-application basis. - * - * @param <T> the output type of the aggregator - */ -public abstract class AggregatorValues<T> { - /** - * Get the values of the {@link Aggregator} at all steps it was used. - */ - public Collection<T> getValues() { - return getValuesAtSteps().values(); - } - - /** - * Get the values of the {@link Aggregator} by the user name at each step it was used. - */ - public abstract Map<String, T> getValuesAtSteps(); - - /** - * Get the total value of this {@link Aggregator} by applying the specified {@link CombineFn}. - */ - public T getTotalValue(CombineFn<T, ?, T> combineFn) { - return combineFn.apply(getValues()); - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/test/java/org/apache/beam/sdk/AggregatorPipelineExtractorTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/AggregatorPipelineExtractorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/AggregatorPipelineExtractorTest.java new file mode 100644 index 0000000..930fbe7 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/AggregatorPipelineExtractorTest.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.apache.beam.sdk.Pipeline.PipelineVisitor; +import org.apache.beam.sdk.runners.TransformTreeNode; +import org.apache.beam.sdk.transforms.Aggregator; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.Max; +import org.apache.beam.sdk.transforms.Min; +import org.apache.beam.sdk.transforms.OldDoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Sum; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * Tests for {@link AggregatorPipelineExtractor}. + */ +@RunWith(JUnit4.class) +public class AggregatorPipelineExtractorTest { + @Mock + private Pipeline p; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetAggregatorStepsWithParDoBoundExtractsSteps() { + @SuppressWarnings("rawtypes") + ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); + AggregatorProvidingDoFn<ThreadGroup, StrictMath> fn = new AggregatorProvidingDoFn<>(); + when(bound.getFn()).thenReturn(fn); + + Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); + Aggregator<Integer, Integer> aggregatorTwo = fn.addAggregator(new Min.MinIntegerFn()); + + TransformTreeNode transformNode = mock(TransformTreeNode.class); + when(transformNode.getTransform()).thenReturn(bound); + + doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode))) + .when(p) + .traverseTopologically(Mockito.any(PipelineVisitor.class)); + + AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); + + Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = + extractor.getAggregatorSteps(); + + assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne)); + assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorTwo)); + assertEquals(aggregatorSteps.size(), 2); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetAggregatorStepsWithParDoBoundMultiExtractsSteps() { + @SuppressWarnings("rawtypes") + ParDo.BoundMulti bound = mock(ParDo.BoundMulti.class, "BoundMulti"); + AggregatorProvidingDoFn<Object, Void> fn = new AggregatorProvidingDoFn<>(); + when(bound.getFn()).thenReturn(fn); + + Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Max.MaxLongFn()); + Aggregator<Double, Double> aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn()); + + TransformTreeNode transformNode = mock(TransformTreeNode.class); + when(transformNode.getTransform()).thenReturn(bound); + + doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode))) + .when(p) + .traverseTopologically(Mockito.any(PipelineVisitor.class)); + + AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); + + Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = + extractor.getAggregatorSteps(); + + assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne)); + assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorTwo)); + assertEquals(2, aggregatorSteps.size()); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetAggregatorStepsWithOneAggregatorInMultipleStepsAddsSteps() { + @SuppressWarnings("rawtypes") + ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); + @SuppressWarnings("rawtypes") + ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound"); + AggregatorProvidingDoFn<String, Math> fn = new AggregatorProvidingDoFn<>(); + when(bound.getFn()).thenReturn(fn); + when(otherBound.getFn()).thenReturn(fn); + + Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); + Aggregator<Double, Double> aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn()); + + TransformTreeNode transformNode = mock(TransformTreeNode.class); + when(transformNode.getTransform()).thenReturn(bound); + TransformTreeNode otherTransformNode = mock(TransformTreeNode.class); + when(otherTransformNode.getTransform()).thenReturn(otherBound); + + doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode))) + .when(p) + .traverseTopologically(Mockito.any(PipelineVisitor.class)); + + AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); + + Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = + extractor.getAggregatorSteps(); + + assertEquals( + ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorOne)); + assertEquals( + ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorTwo)); + assertEquals(2, aggregatorSteps.size()); + } + + @SuppressWarnings("unchecked") + @Test + public void testGetAggregatorStepsWithDifferentStepsAddsSteps() { + @SuppressWarnings("rawtypes") + ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); + + AggregatorProvidingDoFn<ThreadGroup, Void> fn = new AggregatorProvidingDoFn<>(); + Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); + + when(bound.getFn()).thenReturn(fn); + + @SuppressWarnings("rawtypes") + ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound"); + + AggregatorProvidingDoFn<Long, Long> otherFn = new AggregatorProvidingDoFn<>(); + Aggregator<Double, Double> aggregatorTwo = otherFn.addAggregator(new Sum.SumDoubleFn()); + + when(otherBound.getFn()).thenReturn(otherFn); + + TransformTreeNode transformNode = mock(TransformTreeNode.class); + when(transformNode.getTransform()).thenReturn(bound); + TransformTreeNode otherTransformNode = mock(TransformTreeNode.class); + when(otherTransformNode.getTransform()).thenReturn(otherBound); + + doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode))) + .when(p) + .traverseTopologically(Mockito.any(PipelineVisitor.class)); + + AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); + + Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = + extractor.getAggregatorSteps(); + + assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne)); + assertEquals(ImmutableSet.<PTransform<?, ?>>of(otherBound), aggregatorSteps.get(aggregatorTwo)); + assertEquals(2, aggregatorSteps.size()); + } + + private static class VisitNodesAnswer implements Answer<Object> { + private final List<TransformTreeNode> nodes; + + public VisitNodesAnswer(List<TransformTreeNode> nodes) { + this.nodes = nodes; + } + + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + PipelineVisitor visitor = (PipelineVisitor) invocation.getArguments()[0]; + for (TransformTreeNode node : nodes) { + visitor.visitPrimitiveTransform(node); + } + return null; + } + } + + private static class AggregatorProvidingDoFn<InT, OuT> extends OldDoFn<InT, OuT> { + public <InputT, OutT> Aggregator<InputT, OutT> addAggregator( + CombineFn<InputT, ?, OutT> combiner) { + return createAggregator(randomName(), combiner); + } + + private String randomName() { + return UUID.randomUUID().toString(); + } + + @Override + public void processElement(OldDoFn<InT, OuT>.ProcessContext c) throws Exception { + fail(); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java deleted file mode 100644 index 13476e2..0000000 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/AggregatorPipelineExtractorTest.java +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.runners; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.Pipeline.PipelineVisitor; -import org.apache.beam.sdk.transforms.Aggregator; -import org.apache.beam.sdk.transforms.Combine.CombineFn; -import org.apache.beam.sdk.transforms.Max; -import org.apache.beam.sdk.transforms.Min; -import org.apache.beam.sdk.transforms.OldDoFn; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.Sum; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; - -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.UUID; - -/** - * Tests for {@link AggregatorPipelineExtractor}. - */ -@RunWith(JUnit4.class) -public class AggregatorPipelineExtractorTest { - @Mock - private Pipeline p; - - @Before - public void setup() { - MockitoAnnotations.initMocks(this); - } - - @SuppressWarnings("unchecked") - @Test - public void testGetAggregatorStepsWithParDoBoundExtractsSteps() { - @SuppressWarnings("rawtypes") - ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); - AggregatorProvidingDoFn<ThreadGroup, StrictMath> fn = new AggregatorProvidingDoFn<>(); - when(bound.getFn()).thenReturn(fn); - - Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); - Aggregator<Integer, Integer> aggregatorTwo = fn.addAggregator(new Min.MinIntegerFn()); - - TransformTreeNode transformNode = mock(TransformTreeNode.class); - when(transformNode.getTransform()).thenReturn(bound); - - doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode))) - .when(p) - .traverseTopologically(Mockito.any(PipelineVisitor.class)); - - AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); - - Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = - extractor.getAggregatorSteps(); - - assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne)); - assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorTwo)); - assertEquals(aggregatorSteps.size(), 2); - } - - @SuppressWarnings("unchecked") - @Test - public void testGetAggregatorStepsWithParDoBoundMultiExtractsSteps() { - @SuppressWarnings("rawtypes") - ParDo.BoundMulti bound = mock(ParDo.BoundMulti.class, "BoundMulti"); - AggregatorProvidingDoFn<Object, Void> fn = new AggregatorProvidingDoFn<>(); - when(bound.getFn()).thenReturn(fn); - - Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Max.MaxLongFn()); - Aggregator<Double, Double> aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn()); - - TransformTreeNode transformNode = mock(TransformTreeNode.class); - when(transformNode.getTransform()).thenReturn(bound); - - doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode))) - .when(p) - .traverseTopologically(Mockito.any(PipelineVisitor.class)); - - AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); - - Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = - extractor.getAggregatorSteps(); - - assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne)); - assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorTwo)); - assertEquals(2, aggregatorSteps.size()); - } - - @SuppressWarnings("unchecked") - @Test - public void testGetAggregatorStepsWithOneAggregatorInMultipleStepsAddsSteps() { - @SuppressWarnings("rawtypes") - ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); - @SuppressWarnings("rawtypes") - ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound"); - AggregatorProvidingDoFn<String, Math> fn = new AggregatorProvidingDoFn<>(); - when(bound.getFn()).thenReturn(fn); - when(otherBound.getFn()).thenReturn(fn); - - Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); - Aggregator<Double, Double> aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn()); - - TransformTreeNode transformNode = mock(TransformTreeNode.class); - when(transformNode.getTransform()).thenReturn(bound); - TransformTreeNode otherTransformNode = mock(TransformTreeNode.class); - when(otherTransformNode.getTransform()).thenReturn(otherBound); - - doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode))) - .when(p) - .traverseTopologically(Mockito.any(PipelineVisitor.class)); - - AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); - - Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = - extractor.getAggregatorSteps(); - - assertEquals( - ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorOne)); - assertEquals( - ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorTwo)); - assertEquals(2, aggregatorSteps.size()); - } - - @SuppressWarnings("unchecked") - @Test - public void testGetAggregatorStepsWithDifferentStepsAddsSteps() { - @SuppressWarnings("rawtypes") - ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); - - AggregatorProvidingDoFn<ThreadGroup, Void> fn = new AggregatorProvidingDoFn<>(); - Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); - - when(bound.getFn()).thenReturn(fn); - - @SuppressWarnings("rawtypes") - ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound"); - - AggregatorProvidingDoFn<Long, Long> otherFn = new AggregatorProvidingDoFn<>(); - Aggregator<Double, Double> aggregatorTwo = otherFn.addAggregator(new Sum.SumDoubleFn()); - - when(otherBound.getFn()).thenReturn(otherFn); - - TransformTreeNode transformNode = mock(TransformTreeNode.class); - when(transformNode.getTransform()).thenReturn(bound); - TransformTreeNode otherTransformNode = mock(TransformTreeNode.class); - when(otherTransformNode.getTransform()).thenReturn(otherBound); - - doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode))) - .when(p) - .traverseTopologically(Mockito.any(PipelineVisitor.class)); - - AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); - - Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = - extractor.getAggregatorSteps(); - - assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne)); - assertEquals(ImmutableSet.<PTransform<?, ?>>of(otherBound), aggregatorSteps.get(aggregatorTwo)); - assertEquals(2, aggregatorSteps.size()); - } - - private static class VisitNodesAnswer implements Answer<Object> { - private final List<TransformTreeNode> nodes; - - public VisitNodesAnswer(List<TransformTreeNode> nodes) { - this.nodes = nodes; - } - - @Override - public Object answer(InvocationOnMock invocation) throws Throwable { - PipelineVisitor visitor = (PipelineVisitor) invocation.getArguments()[0]; - for (TransformTreeNode node : nodes) { - visitor.visitPrimitiveTransform(node); - } - return null; - } - } - - private static class AggregatorProvidingDoFn<InT, OuT> extends OldDoFn<InT, OuT> { - public <InputT, OutT> Aggregator<InputT, OutT> addAggregator( - CombineFn<InputT, ?, OutT> combiner) { - return createAggregator(randomName(), combiner); - } - - private String randomName() { - return UUID.randomUUID().toString(); - } - - @Override - public void processElement(OldDoFn<InT, OuT>.ProcessContext c) throws Exception { - fail(); - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTest.java index 710e4ce..3fb3193 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnTest.java @@ -128,6 +128,7 @@ public class DoFnTest implements Serializable { DoFn<Void, Void> doFn = new NoOpDoFn(); Aggregator<Double, Double> aggregatorOne = + doFn.createAggregator(nameOne, combiner); Aggregator<Double, Double> aggregatorTwo = doFn.createAggregator(nameTwo, combiner); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/adec254d/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/OldDoFnTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/OldDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/OldDoFnTest.java index 9d144b3..5946d9a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/OldDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/OldDoFnTest.java @@ -24,10 +24,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertThat; +import org.apache.beam.sdk.AggregatorValues; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.PipelineResult; -import org.apache.beam.sdk.runners.AggregatorValues; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Combine.CombineFn; @@ -36,6 +36,7 @@ import org.apache.beam.sdk.transforms.Sum.SumIntegerFn; import org.apache.beam.sdk.transforms.display.DisplayData; import com.google.common.collect.ImmutableMap; + import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category;
