Author: jeastman
Date: Wed May 13 20:42:22 2009
New Revision: 774521
URL: http://svn.apache.org/viewvc?rev=774521&view=rev
Log:
- committing MAHOUT-109, CosineDistanceMeasure with one change:
- removed abstract from test class definition
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/CosineDistanceMeasure.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/CosineDistanceMeasure.java
URL:
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/CosineDistanceMeasure.java?rev=774521&view=auto
==============================================================================
---
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/CosineDistanceMeasure.java
(added)
+++
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/CosineDistanceMeasure.java
Wed May 13 20:42:22 2009
@@ -0,0 +1,87 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils;
+
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.matrix.CardinalityException;
+import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.utils.parameters.Parameter;
+
+import java.util.Collection;
+import java.util.Collections;
+
+/**
+ * This class implements a cosine distance metric by dividing the dot product
+ * of two vectors by the product of their lengths
+ */
+public class CosineDistanceMeasure implements DistanceMeasure {
+
+ @Override
+ public void configure(JobConf job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, JobConf jobConf) {
+ // nothing to do
+ }
+
+ public static double distance(double[] p1, double[] p2) {
+ double dotProduct = 0.0;
+ double lengthSquaredp1 = 0.0;
+ double lengthSquaredp2 = 0.0;
+ for (int i = 0; i < p1.length; i++) {
+ lengthSquaredp1 += p1[i] * p1[i];
+ lengthSquaredp2 += p2[i] * p2[i];
+ dotProduct += p1[i] * p2[i];
+ }
+ double denominator = Math.sqrt(lengthSquaredp1) *
Math.sqrt(lengthSquaredp2);
+
+ // correct for floating-point rounding errors
+ if(denominator < dotProduct)
+ denominator = dotProduct;
+
+ return 1.0 - (dotProduct / denominator);
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ if (v1.cardinality() != v2.cardinality())
+ throw new CardinalityException();
+ double lengthSquaredv1 = 0.0;
+ double lengthSquaredv2 = 0.0;
+ for (int i = 0; i < v1.cardinality(); i++) {
+ lengthSquaredv1 += v1.getQuick(i) * v1.getQuick(i);
+ lengthSquaredv2 += v2.getQuick(i) * v2.getQuick(i);
+ }
+ double dotProduct = v1.dot(v2);
+ double denominator = Math.sqrt(lengthSquaredv1) *
Math.sqrt(lengthSquaredv2);
+
+ // correct for floating-point rounding errors
+ if(denominator < dotProduct)
+ denominator = dotProduct;
+
+ return 1.0 - (dotProduct / denominator);
+ }
+
+}
Added:
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java
URL:
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java?rev=774521&view=auto
==============================================================================
---
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java
(added)
+++
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/utils/CosineDistanceMeasureTest.java
Wed May 13 20:42:22 2009
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils;
+
+import junit.framework.TestCase;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+
+
+public class CosineDistanceMeasureTest extends TestCase {
+
+ public void testMeasure() {
+
+ DistanceMeasure distanceMeasure = new CosineDistanceMeasure();
+
+ Vector[] vectors = {
+ new DenseVector(new double[]{1, 0, 0, 0, 0, 0}),
+ new DenseVector(new double[]{1, 1, 1, 0, 0, 0}),
+ new DenseVector(new double[]{1, 1, 1, 1, 1, 1})
+ };
+
+ double[][] distanceMatrix = new double[3][3];
+
+ for (int a = 0; a < 3; a++) {
+ for (int b = 0; b < 3; b++) {
+ distanceMatrix[a][b] = distanceMeasure.distance(vectors[a],
vectors[b]);
+ }
+ }
+
+ assertEquals(0.0, distanceMatrix[0][0]);
+ assertTrue(distanceMatrix[0][0] < distanceMatrix[0][1]);
+ assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]);
+
+ assertEquals(0.0, distanceMatrix[1][1]);
+ assertTrue(distanceMatrix[1][0] > distanceMatrix[1][1]);
+ assertTrue(distanceMatrix[1][2] < distanceMatrix[1][0]);
+
+ assertEquals(0.0, distanceMatrix[2][2]);
+ assertTrue(distanceMatrix[2][0] > distanceMatrix[2][1]);
+ assertTrue(distanceMatrix[2][1] > distanceMatrix[2][2]);
+
+
+ }
+
+}