Hi, I just tried running the SGD example with the following command line (adapted from the corresponding JIRA issue):
./bin/mahout org.apache.mahout.classifier.sgd.TrainLogistic --passes 100 --rate 50 --lambda 0.001 --input examples/src/main/resources/donut.csv --features 21 -- output donut.model --target color --categories 2 --predictors x y xx xy yy a b c --types n n When running the code above I ran into a few NullPointerExceptions - I was able to fix them with a few tiny changes. If not stripped they should be attached to this mail to highlight the lines of code that caused the trouble. However I was wondering whether I simply used the wrong command line. Isabel
diff --git a/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java b/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
index 5cbdef2..bde3021 100644
--- a/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
+++ b/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
@@ -243,8 +243,9 @@ public class CsvRecordFactory implements RecordFactory {
if (predictor >= 0) {
value = values.get(predictor);
} else {
- value = null;
+ value = "null";
}
+System.out.println(value);
predictorEncoders.get(predictor).addToVector(value, featureVector);
}
return targetValue;
diff --git a/core/src/main/java/org/apache/mahout/vectors/ConstantValueEncoder.java b/core/src/main/java/org/apache/mahout/vectors/ConstantValueEncoder.java
index d76fd81..3112681 100644
--- a/core/src/main/java/org/apache/mahout/vectors/ConstantValueEncoder.java
+++ b/core/src/main/java/org/apache/mahout/vectors/ConstantValueEncoder.java
@@ -34,7 +34,7 @@ public class ConstantValueEncoder extends FeatureVectorEncoder {
for (int i = 0; i < probes; i++) {
int n = hashForProbe(originalForm, data.size(), name, i);
if(isTraceEnabled()){
- trace((byte[]) null, n);
+ trace(new byte[]{}, n);
}
data.set(n, data.get(n) + getWeight(originalForm,weight));
}
diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
index 30cd353..3f7d1d5 100644
--- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
+++ b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
@@ -132,6 +132,8 @@ public final class TrainLogistic {
private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
double weight = 0;
+ if (csv.getTraceDictionary().get(predictor) == null)
+ return 0;
for (Integer column : csv.getTraceDictionary().get(predictor)) {
weight += lr.getBeta().get(row, column);
}
signature.asc
Description: This is a digitally signed message part.
