Repository: mahout Updated Branches: refs/heads/flink-binding 05f1b8ff0 -> d31b07048
MAHOUT-1640:Better collections would significantly improve vector-operation speed, closes apache/mahout#81 Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/d31b0704 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/d31b0704 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/d31b0704 Branch: refs/heads/flink-binding Commit: d31b07048cd66feff95c432d0796b774bfd62b07 Parents: 05f1b8f Author: smarthi <[email protected]> Authored: Tue Mar 8 15:32:57 2016 -0500 Committer: smarthi <[email protected]> Committed: Wed Mar 9 18:09:14 2016 -0500 ---------------------------------------------------------------------- LICENSE.txt | 2 +- math/pom.xml | 6 + .../mahout/math/RandomAccessSparseVector.java | 154 +++++++++---------- .../math/TestRandomAccessSparseVector.java | 2 +- .../java/org/apache/mahout/math/VectorTest.java | 13 +- 5 files changed, 91 insertions(+), 86 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/d31b0704/LICENSE.txt ---------------------------------------------------------------------- diff --git a/LICENSE.txt b/LICENSE.txt index 336bce7..5a59b14 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -2,7 +2,7 @@ The following license applies to software from the Apache Software Foundation. It also applies to software from the Uncommons Watchmaker and Math -projects, Google Guava software, and MongoDB.org driver software +projects, Google Guava software, MongoDB.org driver software and fastutil. -------------------------------------------------------------------------- Apache License http://git-wip-us.apache.org/repos/asf/mahout/blob/d31b0704/math/pom.xml ---------------------------------------------------------------------- diff --git a/math/pom.xml b/math/pom.xml index 4badfb6..af96ea0 100644 --- a/math/pom.xml +++ b/math/pom.xml @@ -146,6 +146,12 @@ </dependency> <dependency> + <groupId>it.unimi.dsi</groupId> + <artifactId>fastutil</artifactId> + <version>7.0.11</version> + </dependency> + + <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-api</artifactId> </dependency> http://git-wip-us.apache.org/repos/asf/mahout/blob/d31b0704/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java b/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java index 3efac7e..9316915 100644 --- a/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java +++ b/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java @@ -17,21 +17,23 @@ package org.apache.mahout.math; +import it.unimi.dsi.fastutil.doubles.DoubleIterator; +import it.unimi.dsi.fastutil.ints.Int2DoubleMap; +import it.unimi.dsi.fastutil.ints.Int2DoubleMap.Entry; +import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; + import java.util.Iterator; import java.util.NoSuchElementException; -import org.apache.mahout.math.list.DoubleArrayList; -import org.apache.mahout.math.map.OpenIntDoubleHashMap; -import org.apache.mahout.math.map.OpenIntDoubleHashMap.MapElement; import org.apache.mahout.math.set.AbstractSet; - /** Implements vector that only stores non-zero doubles */ public class RandomAccessSparseVector extends AbstractVector { private static final int INITIAL_CAPACITY = 11; - private OpenIntDoubleHashMap values; + private Int2DoubleOpenHashMap values; /** For serialization purposes only. */ public RandomAccessSparseVector() { @@ -44,7 +46,7 @@ public class RandomAccessSparseVector extends AbstractVector { public RandomAccessSparseVector(int cardinality, int initialCapacity) { super(cardinality); - values = new OpenIntDoubleHashMap(initialCapacity); + values = new Int2DoubleOpenHashMap(initialCapacity, .5f); } public RandomAccessSparseVector(Vector other) { @@ -54,14 +56,14 @@ public class RandomAccessSparseVector extends AbstractVector { } } - private RandomAccessSparseVector(int cardinality, OpenIntDoubleHashMap values) { + private RandomAccessSparseVector(int cardinality, Int2DoubleOpenHashMap values) { super(cardinality); this.values = values; } public RandomAccessSparseVector(RandomAccessSparseVector other, boolean shallowCopy) { super(other.size()); - values = shallowCopy ? other.values : (OpenIntDoubleHashMap)other.values.clone(); + values = shallowCopy ? other.values : other.values.clone(); } @Override @@ -71,7 +73,7 @@ public class RandomAccessSparseVector extends AbstractVector { @Override public RandomAccessSparseVector clone() { - return new RandomAccessSparseVector(size(), (OpenIntDoubleHashMap) values.clone()); + return new RandomAccessSparseVector(size(), values.clone()); } @Override @@ -123,7 +125,7 @@ public class RandomAccessSparseVector extends AbstractVector { public void setQuick(int index, double value) { invalidateCachedLength(); if (value == 0.0) { - values.removeKey(index); + values.remove(index); } else { values.put(index, value); } @@ -132,7 +134,7 @@ public class RandomAccessSparseVector extends AbstractVector { @Override public void incrementQuick(int index, double increment) { invalidateCachedLength(); - values.adjustOrPutValue(index, increment, increment); + values.addTo( index, increment); } @@ -153,14 +155,9 @@ public class RandomAccessSparseVector extends AbstractVector { @Override public int getNumNonZeroElements() { - DoubleArrayList elementValues = values.values(); - int numMappedElements = elementValues.size(); + final DoubleIterator iterator = values.values().iterator(); int numNonZeros = 0; - for (int index = 0; index < numMappedElements; index++) { - if (elementValues.getQuick(index) != 0) { - numNonZeros++; - } - } + for( int i = values.size(); i-- != 0; ) if ( iterator.nextDouble() != 0 ) numNonZeros++; return numNonZeros; } @@ -190,6 +187,49 @@ public class RandomAccessSparseVector extends AbstractVector { } */ + private final class NonZeroIterator implements Iterator<Element> { + final ObjectIterator<Int2DoubleMap.Entry> fastIterator = values.int2DoubleEntrySet().fastIterator(); + final RandomAccessElement element = new RandomAccessElement( fastIterator ); + + @Override + public boolean hasNext() { + return fastIterator.hasNext(); + } + + @Override + public Element next() { + if ( ! hasNext() ) throw new NoSuchElementException(); + element.entry = fastIterator.next(); + return element; + } +} + + final class RandomAccessElement implements Element { + Int2DoubleMap.Entry entry; + final ObjectIterator<Int2DoubleMap.Entry> fastIterator; + + public RandomAccessElement( ObjectIterator<Entry> fastIterator ) { + super(); + this.fastIterator = fastIterator; + } + + @Override + public double get() { + return entry.getDoubleValue(); + } + + @Override + public int index() { + return entry.getIntKey(); + } + + @Override + public void set( double value ) { + invalidateCachedLength(); + if (value == 0.0) fastIterator.remove(); + else entry.setValue( value ); + } + } /** * NOTE: this implementation reuses the Vector.Element instance for each call of next(). If you need to preserve the * instance, you need to make a copy of it @@ -199,7 +239,7 @@ public class RandomAccessSparseVector extends AbstractVector { */ @Override public Iterator<Element> iterateNonZero() { - return new NonDefaultIterator(); + return new NonZeroIterator(); } @Override @@ -207,54 +247,30 @@ public class RandomAccessSparseVector extends AbstractVector { return new AllIterator(); } - private final class NonDefaultIterator implements Iterator<Element> { - private final class NonDefaultElement implements Element { - @Override - public double get() { - return mapElement.get(); - } - - @Override - public int index() { - return mapElement.index(); - } - - @Override - public void set(double value) { - invalidateCachedLength(); - mapElement.set(value); - } - } - - - private MapElement mapElement; - private final NonDefaultElement element = new NonDefaultElement(); - - private final Iterator<MapElement> iterator; - - private NonDefaultIterator() { - this.iterator = values.iterator(); - } + final class GeneralElement implements Element { + int index; + double value; @Override - public boolean hasNext() { - return iterator.hasNext(); + public double get() { + return value; } @Override - public Element next() { - mapElement = iterator.next(); // This will throw an exception at the end of enumeration. - return element; + public int index() { + return index; } @Override - public void remove() { - throw new UnsupportedOperationException(); + public void set( double value ) { + invalidateCachedLength(); + if (value == 0.0) values.remove( index ); + else values.put( index, value ); } - } +} private final class AllIterator implements Iterator<Element> { - private final RandomAccessElement element = new RandomAccessElement(); + private final GeneralElement element = new GeneralElement(); private AllIterator() { element.index = -1; @@ -270,7 +286,7 @@ public class RandomAccessSparseVector extends AbstractVector { if (!hasNext()) { throw new NoSuchElementException(); } - element.index++; + element.value = values.get( ++element.index ); return element; } @@ -279,28 +295,4 @@ public class RandomAccessSparseVector extends AbstractVector { throw new UnsupportedOperationException(); } } - - private final class RandomAccessElement implements Element { - int index; - - @Override - public double get() { - return values.get(index); - } - - @Override - public int index() { - return index; - } - - @Override - public void set(double value) { - invalidateCachedLength(); - if (value == 0.0) { - values.removeKey(index); - } else { - values.put(index, value); - } - } - } } http://git-wip-us.apache.org/repos/asf/mahout/blob/d31b0704/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java ---------------------------------------------------------------------- diff --git a/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java b/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java index 088bba0..ecc005d 100644 --- a/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java +++ b/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java @@ -50,7 +50,7 @@ public final class TestRandomAccessSparseVector extends AbstractVectorTest<Rando w.set(13, 100500.); w.set(19, 3.141592); - for (String token : Splitter.on(',').split(w.toString().substring(1, w.toString().length() - 2))) { + for (String token : Splitter.on(',').split(w.toString().substring(1, w.toString().length() - 1))) { String[] tokens = token.split(":"); assertEquals(Double.parseDouble(tokens[1]), w.get(Integer.parseInt(tokens[0])), 0.0); } http://git-wip-us.apache.org/repos/asf/mahout/blob/d31b0704/math/src/test/java/org/apache/mahout/math/VectorTest.java ---------------------------------------------------------------------- diff --git a/math/src/test/java/org/apache/mahout/math/VectorTest.java b/math/src/test/java/org/apache/mahout/math/VectorTest.java index 67dc1e9..d355499 100644 --- a/math/src/test/java/org/apache/mahout/math/VectorTest.java +++ b/math/src/test/java/org/apache/mahout/math/VectorTest.java @@ -18,6 +18,7 @@ package org.apache.mahout.math; import java.util.Collection; +import java.util.HashSet; import java.util.Iterator; import java.util.NoSuchElementException; import java.util.Set; @@ -919,17 +920,23 @@ public final class VectorTest extends MahoutTestCase { Iterator<Element> it = vector.nonZeroes().iterator(); Element element = null; int i = 0; + HashSet<Integer> indexes = new HashSet<Integer>(); while (it.hasNext()) { // hasNext is called more often than next if (i % 2 == 0) { element = it.next(); + indexes.add(element.index()); } //noinspection ConstantConditions - assertEquals(element.index(), 2* (i/2)); - assertEquals(element.get(), vector.get(2* (i/2)), 0); + assertEquals(element.get(), vector.get(element.index()), 0); ++i; } assertEquals(7, i); // Last element is print only once. - + assertEquals(4, indexes.size()); + assertTrue(indexes.contains(0)); + assertTrue(indexes.contains(2)); + assertTrue(indexes.contains(4)); + assertTrue(indexes.contains(6)); + // Test all iterator. it = vector.all().iterator(); element = null;
