Author: dlyubimov
Date: Tue Mar 29 04:36:33 2011
New Revision: 1086473

URL: http://svn.apache.org/viewvc?rev=1086473&view=rev
Log:
MAHOUT-638 first installment: the fix. I will add tests on various types of 
vectors a bit later.

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java?rev=1086473&r1=1086472&r2=1086473&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
 Tue Mar 29 04:36:33 2011
@@ -42,6 +42,8 @@ import org.apache.hadoop.mapreduce.lib.i
 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
 import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.hadoop.stochasticsvd.QJob.QJobKeyWritable;
@@ -148,14 +150,25 @@ public class BtJob {
                                                                         // A 
row
                                                                         // 
labels.
 
-      int n = aRow.size();
       Vector btRow = btValue.get();
-      for (int i = 0; i < n; i++) {
-        double mul = aRow.getQuick(i);
-        for (int j = 0; j < kp; j++)
-          btRow.setQuick(j, mul * qRow.getQuick(j));
-        btKey.set(i);
-        context.write(btKey, btValue);
+      if ( (aRow instanceof SequentialAccessSparseVector) ||
+          (aRow instanceof RandomAccessSparseVector )) {
+        for ( Vector.Element el:aRow ) { 
+          double mul=el.get();
+          for ( int j =0; j < kp; j++ ) 
+            btRow.setQuick(j, mul * qRow.getQuick(j));
+          btKey.set(el.index());
+          context.write(btKey, btValue);
+        }
+      } else { 
+        int n = aRow.size();
+        for (int i = 0; i < n; i++) {
+          double mul = aRow.getQuick(i);
+          for (int j = 0; j < kp; j++)
+            btRow.setQuick(j, mul * qRow.getQuick(j));
+          btKey.set(i);
+          context.write(btKey, btValue);
+        }
       }
 
     }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java?rev=1086473&r1=1086472&r2=1086473&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java
 Tue Mar 29 04:36:33 2011
@@ -20,6 +20,7 @@ package org.apache.mahout.math.hadoop.st
 import java.util.Arrays;
 import java.util.Random;
 
+import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.SequentialAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.Vector.Element;
@@ -59,7 +60,8 @@ public class Omega {
     assert yRow.length == kp;
 
     Arrays.fill(yRow, 0);
-    if (aRow instanceof SequentialAccessSparseVector) {
+    if ((aRow instanceof SequentialAccessSparseVector)||
+        (aRow instanceof RandomAccessSparseVector)){
       int j = 0;
       for (Element el : aRow) {
         accumDots(j, el.get(), yRow);


Reply via email to