Repository: incubator-beam
Updated Branches:
  refs/heads/master 0442a2416 -> b2b5f429f


Implement getAggregatorValues.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/89e2bb52
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/89e2bb52
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/89e2bb52

Branch: refs/heads/master
Commit: 89e2bb521d8ed8480a2af102614248f29942cbe2
Parents: 13edbec
Author: Tom White <t...@cloudera.com>
Authored: Mon Jun 29 22:59:42 2015 +0100
Committer: Tom White <t...@cloudera.com>
Committed: Thu Mar 10 11:15:14 2016 +0000

----------------------------------------------------------------------
 .../dataflow/spark/EvaluationContext.java        |  3 +--
 .../dataflow/spark/SparkRuntimeContext.java      | 19 +++++++++++++++++++
 .../dataflow/spark/MultiOutputWordCountTest.java | 17 +++++++++++++++--
 3 files changed, 35 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/89e2bb52/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
index c7aa7c6..df3f7f7 100644
--- 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
+++ 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/EvaluationContext.java
@@ -168,8 +168,7 @@ public class EvaluationContext implements EvaluationResult {
   @Override
   public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> 
aggregator)
       throws AggregatorRetrievalException {
-    //TODO: Support this.
-    throw new UnsupportedOperationException("getAggregatorValues is not yet 
supported.");
+    return runtime.getAggregatorValues(aggregator);
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/89e2bb52/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java
 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java
index fbc16d6..51db39b 100644
--- 
a/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java
+++ 
b/runners/spark/src/main/java/com/cloudera/dataflow/spark/SparkRuntimeContext.java
@@ -17,6 +17,7 @@ package com.cloudera.dataflow.spark;
 
 import java.io.IOException;
 import java.io.Serializable;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -27,12 +28,14 @@ import 
com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException;
 import com.google.cloud.dataflow.sdk.coders.Coder;
 import com.google.cloud.dataflow.sdk.coders.CoderRegistry;
 import com.google.cloud.dataflow.sdk.options.PipelineOptions;
+import com.google.cloud.dataflow.sdk.runners.AggregatorValues;
 import com.google.cloud.dataflow.sdk.transforms.Aggregator;
 import com.google.cloud.dataflow.sdk.transforms.Combine;
 import com.google.cloud.dataflow.sdk.transforms.Max;
 import com.google.cloud.dataflow.sdk.transforms.Min;
 import com.google.cloud.dataflow.sdk.transforms.Sum;
 import com.google.cloud.dataflow.sdk.values.TypeDescriptor;
+import com.google.common.collect.ImmutableList;
 import org.apache.spark.Accumulator;
 import org.apache.spark.api.java.JavaSparkContext;
 
@@ -90,6 +93,22 @@ public class SparkRuntimeContext implements Serializable {
     return accum.value().getValue(aggregatorName, typeClass);
   }
 
+  public <T> AggregatorValues<T> getAggregatorValues(Aggregator<?, T> 
aggregator) {
+    final T aggregatorValue = (T) getAggregatorValue(aggregator.getName(),
+        aggregator.getCombineFn().getOutputType().getRawType());
+    return new AggregatorValues<T>() {
+      @Override
+      public Collection<T> getValues() {
+        return ImmutableList.of(aggregatorValue);
+      }
+
+      @Override
+      public Map<String, T> getValuesAtSteps() {
+        throw new UnsupportedOperationException("getValuesAtSteps is not 
supported.");
+      }
+    };
+  }
+
   public synchronized PipelineOptions getPipelineOptions() {
     return deserializePipelineOptions(serializedPipelineOptions);
   }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/89e2bb52/runners/spark/src/test/java/com/cloudera/dataflow/spark/MultiOutputWordCountTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/com/cloudera/dataflow/spark/MultiOutputWordCountTest.java
 
b/runners/spark/src/test/java/com/cloudera/dataflow/spark/MultiOutputWordCountTest.java
index b16320d..bf2ecdc 100644
--- 
a/runners/spark/src/test/java/com/cloudera/dataflow/spark/MultiOutputWordCountTest.java
+++ 
b/runners/spark/src/test/java/com/cloudera/dataflow/spark/MultiOutputWordCountTest.java
@@ -18,6 +18,7 @@ package com.cloudera.dataflow.spark;
 import com.google.cloud.dataflow.sdk.Pipeline;
 import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
 import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
+import com.google.cloud.dataflow.sdk.runners.AggregatorValues;
 import com.google.cloud.dataflow.sdk.transforms.Aggregator;
 import com.google.cloud.dataflow.sdk.transforms.ApproximateUnique;
 import com.google.cloud.dataflow.sdk.transforms.Count;
@@ -36,6 +37,7 @@ import com.google.cloud.dataflow.sdk.values.PCollectionTuple;
 import com.google.cloud.dataflow.sdk.values.PCollectionView;
 import com.google.cloud.dataflow.sdk.values.TupleTag;
 import com.google.cloud.dataflow.sdk.values.TupleTagList;
+import com.google.common.collect.Iterables;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -56,7 +58,8 @@ public class MultiOutputWordCountTest {
 
     PCollection<String> union = list.apply(Flatten.<String>pCollections());
     PCollectionView<String> regexView = 
regex.apply(View.<String>asSingleton());
-    PCollectionTuple luc = union.apply(new CountWords(regexView));
+    CountWords countWords = new CountWords(regexView);
+    PCollectionTuple luc = union.apply(countWords);
     PCollection<Long> unique = luc.get(lowerCnts).apply(
         ApproximateUnique.<KV<String, Long>>globally(16));
 
@@ -70,6 +73,10 @@ public class MultiOutputWordCountTest {
     Assert.assertEquals(18, actualTotalWords);
     int actualMaxWordLength = res.getAggregatorValue("maxWordLength", 
Integer.class);
     Assert.assertEquals(6, actualMaxWordLength);
+    AggregatorValues<Integer> aggregatorValues = 
res.getAggregatorValues(countWords
+        .getTotalWordsAggregator());
+    Assert.assertEquals(18, 
Iterables.getOnlyElement(aggregatorValues.getValues()).intValue());
+
     res.close();
   }
 
@@ -108,16 +115,18 @@ public class MultiOutputWordCountTest {
   public static class CountWords extends PTransform<PCollection<String>, 
PCollectionTuple> {
 
     private final PCollectionView<String> regex;
+    private final ExtractWordsFn extractWordsFn;
 
     public CountWords(PCollectionView<String> regex) {
       this.regex = regex;
+      this.extractWordsFn = new ExtractWordsFn(regex);
     }
 
     @Override
     public PCollectionTuple apply(PCollection<String> lines) {
       // Convert lines of text into individual words.
       PCollectionTuple lowerUpper = lines
-          .apply(ParDo.of(new ExtractWordsFn(regex))
+          .apply(ParDo.of(extractWordsFn)
               .withSideInputs(regex)
               .withOutputTags(lower, TupleTagList.of(upper)));
       lowerUpper.get(lower).setCoder(StringUtf8Coder.of());
@@ -130,5 +139,9 @@ public class MultiOutputWordCountTest {
           .of(lowerCnts, lowerCounts)
           .and(upperCnts, upperCounts);
     }
+
+    Aggregator<Integer, Integer> getTotalWordsAggregator() {
+      return extractWordsFn.totalWords;
+    }
   }
 }

Reply via email to