Author: jmannix
Date: Thu Aug  9 18:27:49 2012
New Revision: 1371361

URL: http://svn.apache.org/viewvc?rev=1371361&view=rev
Log:
Fixes MAHOUT-1051

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java?rev=1371361&r1=1371360&r2=1371361&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
 Thu Aug  9 18:27:49 2012
@@ -45,8 +45,11 @@ import org.apache.mahout.math.DenseVecto
 import org.apache.mahout.math.DistributedRowMatrixWriter;
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.SparseMatrix;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.NamedVector;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -486,18 +489,33 @@ public class InMemoryCollapsedVariationa
         subPaths.add(fileStatus.getPath());
       }
     }
-    List<Vector> vectorList = Lists.newArrayList();
+    List<Pair<Integer, Vector>> rowList = Lists.newArrayList();
+    int numRows = Integer.MIN_VALUE;
+    int numCols = -1;
+    boolean sequentialAccess = false;
     for (Path subPath : subPaths) {
       for (Pair<IntWritable, VectorWritable> record
           : new SequenceFileIterable<IntWritable, VectorWritable>(subPath, 
true, conf)) {
-        vectorList.add(record.getSecond().get());
+        int id = record.getFirst().get();
+        Vector vector = record.getSecond().get();
+        if (vector instanceof NamedVector) {
+          vector = ((NamedVector)vector).getDelegate();
+        }
+        if (numCols < 0) {
+          numCols = vector.size();
+          sequentialAccess = vector.isSequentialAccess();
+        }
+        rowList.add(Pair.of(id, vector));
+        numRows = Math.max(numRows, id);
       }
     }
-    int numRows = vectorList.size();
-    int numCols = vectorList.get(0).size();
-    return new SparseRowMatrix(numRows, numCols,
-        vectorList.toArray(new Vector[vectorList.size()]), true,
-        vectorList.get(0).isSequentialAccess());
+    numRows++;
+    Vector[] rowVectors = new Vector[numRows];
+    for (Pair<Integer, Vector> pair : rowList) {
+      rowVectors[pair.getFirst()] = pair.getSecond();
+    }
+    return new SparseRowMatrix(numRows, numCols, rowVectors, true, 
!sequentialAccess);
+
   }
 
   @Override


Reply via email to