Hi,

I just tried running the SGD example with the following command line (adapted 
from the corresponding JIRA issue):

./bin/mahout org.apache.mahout.classifier.sgd.TrainLogistic --passes 100 --rate 
50 --lambda 0.001 --input examples/src/main/resources/donut.csv --features 21 --
output donut.model --target color --categories 2 --predictors x y xx xy yy a b 
c 
--types n n

When running the code above I ran into a few NullPointerExceptions - I was able 
to fix them with a few tiny changes. If not stripped they should be attached to 
this mail to highlight the lines of code that caused the trouble. However I was 
wondering whether I simply used the wrong command line.

Isabel
diff --git a/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java b/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
index 5cbdef2..bde3021 100644
--- a/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
+++ b/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
@@ -243,8 +243,9 @@ public class CsvRecordFactory implements RecordFactory {
       if (predictor >= 0) {
         value = values.get(predictor);
       } else {
-        value = null;
+        value = "null";
       }
+System.out.println(value);
       predictorEncoders.get(predictor).addToVector(value, featureVector);
     }
     return targetValue;
diff --git a/core/src/main/java/org/apache/mahout/vectors/ConstantValueEncoder.java b/core/src/main/java/org/apache/mahout/vectors/ConstantValueEncoder.java
index d76fd81..3112681 100644
--- a/core/src/main/java/org/apache/mahout/vectors/ConstantValueEncoder.java
+++ b/core/src/main/java/org/apache/mahout/vectors/ConstantValueEncoder.java
@@ -34,7 +34,7 @@ public class ConstantValueEncoder extends FeatureVectorEncoder {
     for (int i = 0; i < probes; i++) {
       int n = hashForProbe(originalForm, data.size(), name, i);
       if(isTraceEnabled()){
-        trace((byte[]) null, n);                
+        trace(new byte[]{}, n);
       }
       data.set(n, data.get(n) + getWeight(originalForm,weight));
     }
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
index 30cd353..3f7d1d5 100644
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
+++ b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
@@ -132,6 +132,8 @@ public final class TrainLogistic {
 
   private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
     double weight = 0;
+    if (csv.getTraceDictionary().get(predictor) == null)
+      return 0;
     for (Integer column : csv.getTraceDictionary().get(predictor)) {
       weight += lr.getBeta().get(row, column);
     }

Attachment: signature.asc
Description: This is a digitally signed message part.

Reply via email to