Author: jeastman
Date: Fri Sep  3 15:36:46 2010
New Revision: 992332

URL: http://svn.apache.org/viewvc?rev=992332&view=rev
Log:
MAHOUT-479: Added unit tests using DistributedLanczosSolver to project testdata 
onto SVD basis. Added comment to DistributedRowMatrix.times() indicating it is 
really transposeTimes()

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
    
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java?rev=992332&r1=992331&r2=992332&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
 Fri Sep  3 15:36:46 2010
@@ -135,6 +135,12 @@ public class DistributedRowMatrix implem
     return numCols;
   }
 
+  /**
+   * This implements matrix this.transpose().times(other)
+   * @param other   a DistributedRowMatrix
+   * @return    a DistributedRowMatrix containing the product
+   * @throws IOException
+   */
   public DistributedRowMatrix times(DistributedRowMatrix other) throws 
IOException {
     if (numRows != other.numRows()) {
       throw new CardinalityException(numRows, other.numRows());

Modified: 
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java?rev=992332&r1=992331&r2=992332&view=diff
==============================================================================
--- 
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java
 (original)
+++ 
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java
 Fri Sep  3 15:36:46 2010
@@ -25,6 +25,10 @@ import java.util.List;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapred.JobConf;
 import org.apache.lucene.analysis.standard.StandardAnalyzer;
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
@@ -45,9 +49,13 @@ import org.apache.mahout.clustering.mean
 import org.apache.mahout.common.distance.CosineDistanceMeasure;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.NamedVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.hadoop.decomposer.DistributedLanczosSolver;
 import org.apache.mahout.utils.MahoutTestCase;
 import org.apache.mahout.utils.clustering.ClusterDumper;
 import org.apache.mahout.utils.vectors.TFIDF;
@@ -83,7 +91,7 @@ public final class TestClusterDumper ext
     FileSystem fs = FileSystem.get(conf);
     // Create test data
     getSampleData(DOCS);
-    ClusteringTestUtils.writePointsToFile(sampleData, 
getTestTempFilePath("testdata/file1"), fs, conf);
+    ClusteringTestUtils.writePointsToFile(sampleData, true, 
getTestTempFilePath("testdata/file1"), fs, conf);
   }
 
   private void getSampleData(String[] docs2) throws IOException {
@@ -230,4 +238,100 @@ public final class TestClusterDumper ext
     ClusterDumper clusterDumper = new ClusterDumper(new Path(output, 
"clusters-10"), new Path(output, "clusteredPoints"));
     clusterDumper.printClusters(termDictionary);
   }
+
+  @Test
+  public void testKmeansSVD() throws Exception {
+    DistanceMeasure measure = new EuclideanDistanceMeasure();
+    Path output = getTestTempDirPath("output");
+    Path tmp = getTestTempDirPath("tmp");
+    Path eigenvectors = new Path(output, "eigenvectors");
+    int desiredRank = 15;
+    DistributedLanczosSolver solver = new DistributedLanczosSolver();
+    Configuration conf = new Configuration();
+    solver.setConf(conf);
+    Path testData = getTestTempDirPath("testdata");
+    int sampleDimension = sampleData.get(0).get().size();
+    solver.run(testData, tmp, eigenvectors, sampleData.size(), 
sampleDimension, false, desiredRank);
+    // build in-memory data matrix A
+    Matrix a = new DenseMatrix(sampleData.size(), sampleDimension);
+    int i = 0;
+    for (VectorWritable vw : sampleData) {
+      a.assignRow(i++, vw.get());
+    }
+    // extract the eigenvectors into P
+    Matrix p = new DenseMatrix(39, desiredRank - 1);
+    FileSystem fs = FileSystem.get(eigenvectors.toUri(), conf);
+    SequenceFile.Reader reader = new SequenceFile.Reader(fs, eigenvectors, 
conf);
+    try {
+      Writable key = (Writable) reader.getKeyClass().newInstance();
+      Writable value = (Writable) reader.getValueClass().newInstance();
+      i = 0;
+      while (reader.next(key, value)) {
+        Vector v = ((VectorWritable) value).get();
+        p.assignColumn(i, v);
+        System.out.println("k=" + key.toString() + " V=" + 
AbstractCluster.formatVector(v, termDictionary));
+        value = (Writable) reader.getValueClass().newInstance();
+        i++;
+      }
+    } finally {
+      reader.close();
+    }
+    // sData = A P
+    Matrix sData = a.times(p);
+  
+    // now write sData back to file system so clustering can run against it
+    Path svdData = new Path(output, "svddata");
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, svdData, 
IntWritable.class, VectorWritable.class);
+    try {
+      IntWritable key = new IntWritable();
+      VectorWritable value = new VectorWritable();
+  
+      for (int row = 0; row < sData.numRows(); row++) {
+        key.set(row);
+        value.set(sData.getRow(row));
+        writer.append(key, value);
+      }
+    } finally {
+      writer.close();
+    }
+    // now run the Canopy job to prime kMeans canopies
+    CanopyDriver.runJob(svdData, output, measure, 8, 4, false, false);
+    // now run the KMeans job
+    KMeansDriver.runJob(svdData, new Path(output, "clusters-0"), output, 
measure, 0.001, 10, 1, true, false);
+    // run ClusterDumper
+    ClusterDumper clusterDumper = new ClusterDumper(new Path(output, 
"clusters-2"), new Path(output, "clusteredPoints"));
+    clusterDumper.printClusters(termDictionary);
+  }
+
+  @Test
+  public void testKmeansDSVD() throws Exception {
+    DistanceMeasure measure = new EuclideanDistanceMeasure();
+    Path output = getTestTempDirPath("output");
+    Path tmp = getTestTempDirPath("tmp");
+    Path eigenvectors = new Path(output, "eigenvectors");
+    int desiredRank = 13;
+    DistributedLanczosSolver solver = new DistributedLanczosSolver();
+    Configuration config = new Configuration();
+    solver.setConf(config);
+    Path testData = getTestTempDirPath("testdata");
+    int sampleDimension = sampleData.get(0).get().size();
+    solver.run(testData, tmp, eigenvectors, sampleData.size(), 
sampleDimension, false, desiredRank);
+    
+    // now multiply the testdata matrix and the eigenvector matrix
+    DistributedRowMatrix svdT = new DistributedRowMatrix(eigenvectors, tmp, 
desiredRank - 1, sampleDimension);
+    JobConf conf = new JobConf(config);
+    svdT.configure(conf);
+    DistributedRowMatrix a = new DistributedRowMatrix(testData, tmp, 
sampleData.size(), sampleDimension);
+    a.configure(conf);
+    DistributedRowMatrix sData = a.transpose().times(svdT.transpose());
+    sData.configure(conf);
+  
+    // now run the Canopy job to prime kMeans canopies
+    CanopyDriver.runJob(sData.getRowPath(), output, measure, 8, 4, false, 
false);
+    // now run the KMeans job
+    KMeansDriver.runJob(sData.getRowPath(), new Path(output, "clusters-0"), 
output, measure, 0.001, 10, 1, true, false);
+    // run ClusterDumper
+    ClusterDumper clusterDumper = new ClusterDumper(new Path(output, 
"clusters-2"), new Path(output, "clusteredPoints"));
+    clusterDumper.printClusters(termDictionary);
+  }
 }


Reply via email to