Repository: crunch Updated Branches: refs/heads/master e1bfa47e8 -> ebe306126
CRUNCH-437: Fix Crunch-on-Spark duplicate value aggregation Project: http://git-wip-us.apache.org/repos/asf/crunch/repo Commit: http://git-wip-us.apache.org/repos/asf/crunch/commit/ebe30612 Tree: http://git-wip-us.apache.org/repos/asf/crunch/tree/ebe30612 Diff: http://git-wip-us.apache.org/repos/asf/crunch/diff/ebe30612 Branch: refs/heads/master Commit: ebe3061260ee02752536384ff1f7d9cf72d09831 Parents: e1bfa47 Author: Josh Wills <[email protected]> Authored: Sat Jul 5 19:10:41 2014 -0700 Committer: Josh Wills <[email protected]> Committed: Sat Jul 5 19:16:51 2014 -0700 ---------------------------------------------------------------------- .../org/apache/crunch/SparkAggregatorIT.java | 53 +++++++++++++++ .../impl/spark/fn/CombineMapsideFunction.java | 70 +++++++++++++++----- 2 files changed, 106 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/crunch/blob/ebe30612/crunch-spark/src/it/java/org/apache/crunch/SparkAggregatorIT.java ---------------------------------------------------------------------- diff --git a/crunch-spark/src/it/java/org/apache/crunch/SparkAggregatorIT.java b/crunch-spark/src/it/java/org/apache/crunch/SparkAggregatorIT.java new file mode 100644 index 0000000..bc6ebea --- /dev/null +++ b/crunch-spark/src/it/java/org/apache/crunch/SparkAggregatorIT.java @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.crunch; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import org.apache.crunch.impl.spark.SparkPipeline; +import org.apache.crunch.io.From; +import org.apache.crunch.test.TemporaryPath; +import org.apache.crunch.types.avro.Avros; +import org.junit.Rule; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class SparkAggregatorIT { + @Rule + public TemporaryPath tempDir = new TemporaryPath(); + + @Test + public void testCount() throws Exception { + SparkPipeline pipeline = new SparkPipeline("local", "aggregator"); + PCollection<String> set1 = pipeline.read(From.textFile(tempDir.copyResourceFileName("set1.txt"))); + PCollection<String> set2 = pipeline.read(From.textFile(tempDir.copyResourceFileName("set2.txt"))); + Iterable<Pair<Integer, Long>> cnts = set1.union(set2) + .parallelDo(new CntFn(), Avros.ints()) + .count().materialize(); + assertEquals(ImmutableList.of(Pair.of(1, 7L)), Lists.newArrayList(cnts)); + pipeline.done(); + } + + private static class CntFn extends MapFn<String, Integer> { + @Override + public Integer map(String input) { + return 1; + } + } +} http://git-wip-us.apache.org/repos/asf/crunch/blob/ebe30612/crunch-spark/src/main/java/org/apache/crunch/impl/spark/fn/CombineMapsideFunction.java ---------------------------------------------------------------------- diff --git a/crunch-spark/src/main/java/org/apache/crunch/impl/spark/fn/CombineMapsideFunction.java b/crunch-spark/src/main/java/org/apache/crunch/impl/spark/fn/CombineMapsideFunction.java index 3600f16..3cc8e05 100644 --- a/crunch-spark/src/main/java/org/apache/crunch/impl/spark/fn/CombineMapsideFunction.java +++ b/crunch-spark/src/main/java/org/apache/crunch/impl/spark/fn/CombineMapsideFunction.java @@ -17,20 +17,16 @@ */ package org.apache.crunch.impl.spark.fn; -import com.google.common.base.Function; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import com.google.common.collect.Multimap; +import com.google.common.collect.UnmodifiableIterator; import org.apache.crunch.CombineFn; import org.apache.crunch.Pair; import org.apache.crunch.impl.mem.emit.InMemoryEmitter; import org.apache.crunch.impl.spark.SparkRuntimeContext; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.PairFlatMapFunction; import scala.Tuple2; -import javax.annotation.Nullable; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -51,31 +47,36 @@ public class CombineMapsideFunction<K, V> extends PairFlatMapFunction<Iterator<T @Override public Iterable<Tuple2<K, V>> call(Iterator<Tuple2<K, V>> iter) throws Exception { ctxt.initialize(combineFn); - Multimap<K, V> cache = HashMultimap.create(); + Map<K, List<V>> cache = Maps.newHashMap(); int cnt = 0; while (iter.hasNext()) { Tuple2<K, V> t = iter.next(); - cache.put(t._1, t._2); + List<V> values = cache.get(t._1()); + if (values == null) { + values = Lists.newArrayList(); + cache.put(t._1(), values); + } + values.add(t._2()); cnt++; if (cnt % REDUCE_EVERY_N == 0) { cache = reduce(cache); } } - return Iterables.transform(reduce(cache).entries(), new Function<Map.Entry<K, V>, Tuple2<K, V>>() { - @Override - public Tuple2<K, V> apply(Map.Entry<K, V> input) { - return new Tuple2<K, V>(input.getKey(), input.getValue()); - } - }); + return new Flattener<K, V>(cache); } - private Multimap<K, V> reduce(Multimap<K, V> cache) { + private Map<K, List<V>> reduce(Map<K, List<V>> cache) { Set<K> keys = cache.keySet(); - Multimap<K, V> res = HashMultimap.create(keys.size(), keys.size()); + Map<K, List<V>> res = Maps.newHashMap(); for (K key : keys) { for (Pair<K, V> p : reduce(key, cache.get(key))) { - res.put(p.first(), p.second()); + List<V> values = res.get(p.first()); + if (values == null) { + values = Lists.newArrayList(); + res.put(p.first(), values); + } + values.add(p.second()); } } return res; @@ -87,4 +88,39 @@ public class CombineMapsideFunction<K, V> extends PairFlatMapFunction<Iterator<T combineFn.cleanup(emitter); return emitter.getOutput(); } + + private static class Flattener<K, V> implements Iterable<Tuple2<K, V>> { + private final Map<K, List<V>> entries; + + public Flattener(Map<K, List<V>> entries) { + this.entries = entries; + } + + @Override + public Iterator<Tuple2<K, V>> iterator() { + return new UnmodifiableIterator<Tuple2<K, V>>() { + private Iterator<K> keyIter = entries.keySet().iterator(); + private K currentKey; + private Iterator<V> valueIter = null; + + @Override + public boolean hasNext() { + while (valueIter == null || !valueIter.hasNext()) { + if (keyIter.hasNext()) { + currentKey = keyIter.next(); + valueIter = entries.get(currentKey).iterator(); + } else { + return false; + } + } + return true; + } + + @Override + public Tuple2<K, V> next() { + return new Tuple2<K, V>(currentKey, valueIter.next()); + } + }; + } + } }
