Repository: incubator-beam Updated Branches: refs/heads/master 0442a2416 -> b2b5f429f
Implement getAggregatorValues. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/89e2bb52 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/89e2bb52 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/89e2bb52 Branch: refs/heads/master Commit: 89e2bb521d8ed8480a2af102614248f29942cbe2 Parents: 13edbec Author: Tom White <t...@cloudera.com> Authored: Mon Jun 29 22:59:42 2015 +0100 Committer: Tom White <t...@cloudera.com> Committed: Thu Mar 10 11:15:14 2016 +0000 ---------------------------------------------------------------------- .../dataflow/spark/EvaluationContext.java | 3 +-- .../dataflow/spark/SparkRuntimeContext.java | 19 +++++++++++++++++++ .../dataflow/spark/MultiOutputWordCountTest.java | 17 +++++++++++++++-- 3 files changed, 35 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/89e2bb52/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java index c7aa7c6..df3f7f7 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java @@ -168,8 +168,7 @@ public class EvaluationContext implements EvaluationResult { @Override public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) throws AggregatorRetrievalException { - //TODO: Support this. - throw new UnsupportedOperationException("getAggregatorValues is not yet supported."); + return runtime.getAggregatorValues(aggregator); } @Override http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/89e2bb52/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java index fbc16d6..51db39b 100644 --- a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java +++ b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java @@ -17,6 +17,7 @@ package com.cloudera.dataflow.spark; import java.io.IOException; import java.io.Serializable; +import java.util.Collection; import java.util.HashMap; import java.util.Map; @@ -27,12 +28,14 @@ import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.CoderRegistry; import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.Combine; import com.google.cloud.dataflow.sdk.transforms.Max; import com.google.cloud.dataflow.sdk.transforms.Min; import com.google.cloud.dataflow.sdk.transforms.Sum; import com.google.cloud.dataflow.sdk.values.TypeDescriptor; +import com.google.common.collect.ImmutableList; import org.apache.spark.Accumulator; import org.apache.spark.api.java.JavaSparkContext; @@ -90,6 +93,22 @@ public class SparkRuntimeContext implements Serializable { return accum.value().getValue(aggregatorName, typeClass); } + public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> aggregator) { + final T aggregatorValue = (T) getAggregatorValue(aggregator.getName(), + aggregator.getCombineFn().getOutputType().getRawType()); + return new AggregatorValues<T>() { + @Override + public Collection<T> getValues() { + return ImmutableList.of(aggregatorValue); + } + + @Override + public Map<String, T> getValuesAtSteps() { + throw new UnsupportedOperationException("getValuesAtSteps is not supported."); + } + }; + } + public synchronized PipelineOptions getPipelineOptions() { return deserializePipelineOptions(serializedPipelineOptions); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/89e2bb52/runners/spark/src/test/java/com/cloudera/dataflow/spark/MultiOutputWordCountTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/com/cloudera/dataflow/spark/MultiOutputWordCountTest.java b/runners/spark/src/test/java/com/cloudera/dataflow/spark/MultiOutputWordCountTest.java index b16320d..bf2ecdc 100644 --- a/runners/spark/src/test/java/com/cloudera/dataflow/spark/MultiOutputWordCountTest.java +++ b/runners/spark/src/test/java/com/cloudera/dataflow/spark/MultiOutputWordCountTest.java @@ -18,6 +18,7 @@ package com.cloudera.dataflow.spark; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; import com.google.cloud.dataflow.sdk.transforms.Aggregator; import com.google.cloud.dataflow.sdk.transforms.ApproximateUnique; import com.google.cloud.dataflow.sdk.transforms.Count; @@ -36,6 +37,7 @@ import com.google.cloud.dataflow.sdk.values.PCollectionTuple; import com.google.cloud.dataflow.sdk.values.PCollectionView; import com.google.cloud.dataflow.sdk.values.TupleTag; import com.google.cloud.dataflow.sdk.values.TupleTagList; +import com.google.common.collect.Iterables; import org.junit.Assert; import org.junit.Test; @@ -56,7 +58,8 @@ public class MultiOutputWordCountTest { PCollection<String> union = list.apply(Flatten.<String>pCollections()); PCollectionView<String> regexView = regex.apply(View.<String>asSingleton()); - PCollectionTuple luc = union.apply(new CountWords(regexView)); + CountWords countWords = new CountWords(regexView); + PCollectionTuple luc = union.apply(countWords); PCollection<Long> unique = luc.get(lowerCnts).apply( ApproximateUnique.<KV<String, Long>>globally(16)); @@ -70,6 +73,10 @@ public class MultiOutputWordCountTest { Assert.assertEquals(18, actualTotalWords); int actualMaxWordLength = res.getAggregatorValue("maxWordLength", Integer.class); Assert.assertEquals(6, actualMaxWordLength); + AggregatorValues<Integer> aggregatorValues = res.getAggregatorValues(countWords + .getTotalWordsAggregator()); + Assert.assertEquals(18, Iterables.getOnlyElement(aggregatorValues.getValues()).intValue()); + res.close(); } @@ -108,16 +115,18 @@ public class MultiOutputWordCountTest { public static class CountWords extends PTransform<PCollection<String>, PCollectionTuple> { private final PCollectionView<String> regex; + private final ExtractWordsFn extractWordsFn; public CountWords(PCollectionView<String> regex) { this.regex = regex; + this.extractWordsFn = new ExtractWordsFn(regex); } @Override public PCollectionTuple apply(PCollection<String> lines) { // Convert lines of text into individual words. PCollectionTuple lowerUpper = lines - .apply(ParDo.of(new ExtractWordsFn(regex)) + .apply(ParDo.of(extractWordsFn) .withSideInputs(regex) .withOutputTags(lower, TupleTagList.of(upper))); lowerUpper.get(lower).setCoder(StringUtf8Coder.of()); @@ -130,5 +139,9 @@ public class MultiOutputWordCountTest { .of(lowerCnts, lowerCounts) .and(upperCnts, upperCounts); } + + Aggregator<Integer, Integer> getTotalWordsAggregator() { + return extractWordsFn.totalWords; + } } }