Author: jmannix
Date: Thu Aug 9 18:27:49 2012
New Revision: 1371361
URL: http://svn.apache.org/viewvc?rev=1371361&view=rev
Log:
Fixes MAHOUT-1051
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java?rev=1371361&r1=1371360&r2=1371361&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
Thu Aug 9 18:27:49 2012
@@ -45,8 +45,11 @@ import org.apache.mahout.math.DenseVecto
import org.apache.mahout.math.DistributedRowMatrixWriter;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.NamedVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -486,18 +489,33 @@ public class InMemoryCollapsedVariationa
subPaths.add(fileStatus.getPath());
}
}
- List<Vector> vectorList = Lists.newArrayList();
+ List<Pair<Integer, Vector>> rowList = Lists.newArrayList();
+ int numRows = Integer.MIN_VALUE;
+ int numCols = -1;
+ boolean sequentialAccess = false;
for (Path subPath : subPaths) {
for (Pair<IntWritable, VectorWritable> record
: new SequenceFileIterable<IntWritable, VectorWritable>(subPath,
true, conf)) {
- vectorList.add(record.getSecond().get());
+ int id = record.getFirst().get();
+ Vector vector = record.getSecond().get();
+ if (vector instanceof NamedVector) {
+ vector = ((NamedVector)vector).getDelegate();
+ }
+ if (numCols < 0) {
+ numCols = vector.size();
+ sequentialAccess = vector.isSequentialAccess();
+ }
+ rowList.add(Pair.of(id, vector));
+ numRows = Math.max(numRows, id);
}
}
- int numRows = vectorList.size();
- int numCols = vectorList.get(0).size();
- return new SparseRowMatrix(numRows, numCols,
- vectorList.toArray(new Vector[vectorList.size()]), true,
- vectorList.get(0).isSequentialAccess());
+ numRows++;
+ Vector[] rowVectors = new Vector[numRows];
+ for (Pair<Integer, Vector> pair : rowList) {
+ rowVectors[pair.getFirst()] = pair.getSecond();
+ }
+ return new SparseRowMatrix(numRows, numCols, rowVectors, true,
!sequentialAccess);
+
}
@Override