Author: robinanil Date: Mon Feb 22 14:38:48 2010 New Revision: 912585 URL: http://svn.apache.org/viewvc?rev=912585&view=rev Log: Distance Measure improvements
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java?rev=912585&r1=912584&r2=912585&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java Mon Feb 22 14:38:48 2010 @@ -19,7 +19,6 @@ import java.util.Collection; import java.util.Collections; -import java.util.Iterator; import org.apache.hadoop.mapred.JobConf; import org.apache.mahout.common.parameters.Parameter; @@ -71,18 +70,8 @@ if (v1.size() != v2.size()) { throw new CardinalityException(); } - double lengthSquaredv1 = 0.0; - Iterator<Vector.Element> iter = v1.iterateNonZero(); - while (iter.hasNext()) { - Vector.Element elt = iter.next(); - lengthSquaredv1 += elt.get() * elt.get(); - } - iter = v2.iterateNonZero(); - double lengthSquaredv2 = 0.0; - while (iter.hasNext()) { - Vector.Element elt = iter.next(); - lengthSquaredv2 += elt.get() * elt.get(); - } + double lengthSquaredv1 = v1.getLengthSquared(); + double lengthSquaredv2 = v2.getLengthSquared(); double dotProduct = v1.dot(v2); double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2); @@ -97,12 +86,8 @@ @Override public double distance(double centroidLengthSquare, Vector centroid, Vector v) { - Iterator<Vector.Element> iter = v.iterateNonZero(); - double lengthSquaredv = 0.0; - while (iter.hasNext()) { - Vector.Element elt = iter.next(); - lengthSquaredv += elt.get() * elt.get(); - } + + double lengthSquaredv = v.getLengthSquared(); double dotProduct = centroid.dot(v); double denominator = Math.sqrt(centroidLengthSquare) * Math.sqrt(lengthSquaredv); Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java?rev=912585&r1=912584&r2=912585&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java Mon Feb 22 14:38:48 2010 @@ -62,9 +62,8 @@ } double result = 0; Vector vector = v1.minus(v2); - Iterator<Vector.Element> iter = vector.iterateNonZero(); // this contains all non zero elements between - // the - // two + Iterator<Vector.Element> iter = vector.iterateNonZero(); + // this contains all non zero elements between the two while (iter.hasNext()) { Vector.Element e = iter.next(); result += Math.abs(e.get()); Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java?rev=912585&r1=912584&r2=912585&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java Mon Feb 22 14:38:48 2010 @@ -49,7 +49,7 @@ @Override public double distance(Vector v1, Vector v2) { - return v1.getDistanceSquared(v2); + return v2.getDistanceSquared(v1); } @Override Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java?rev=912585&r1=912584&r2=912585&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java Mon Feb 22 14:38:48 2010 @@ -39,8 +39,16 @@ */ @Override public double distance(Vector a, Vector b) { - double ab = dot(a, b); - double denominator = dot(a, a) + dot(b, b) - ab; + double ab = 0; + double denominator = 0; + if (getWeights() != null) { + ab = dot(b, a); // b is SequentialAccess + denominator = dot(a, a) + dot(b, b) - ab; + } else { + ab = b.dot(a); // b is SequentialAccess + denominator = a.getLengthSquared() + b.getLengthSquared() - ab; + } + if (denominator < ab) { // correct for fp round-off: distance >= 0 denominator = ab; } @@ -61,9 +69,7 @@ while (it.hasNext() && (el = it.next()) != null) { double elementValue = el.get(); double value = elementValue * (sameVector ? elementValue : b.getQuick(el.index())); - if (weights != null) { - value *= weights.getQuick(el.index()); - } + value *= weights.getQuick(el.index()); dot += value; } return dot;