Hi, I am trying to run the straightforward example of SVm but I am getting low accuracy (around 50%) when I predict using the same data I used for training. I am probably doing the prediction in a wrong way. My code is below. I would appreciate any help.
import java.util.List; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.mllib.classification.SVMModel; import org.apache.spark.mllib.classification.SVMWithSGD; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import scala.Tuple2; import edu.illinois.biglbjava.readers.LabeledPointReader; public class SimpleDistSVM { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("SVM Classifier Example"); SparkContext sc = new SparkContext(conf); String inputPath=args[0]; // Read training data JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, inputPath).toJavaRDD(); // Run training algorithm to build the model. int numIterations = 3; final SVMModel model = SVMWithSGD.train(data.rdd(), numIterations); // Clear the default threshold. model.clearThreshold(); // Predict points in test set and map to an RDD of 0/1 values where 0 is misclassication and 1 is correct classification JavaRDD<Integer> classification = data.map(new Function<LabeledPoint, Integer>() { public Integer call(LabeledPoint p) { int label = (int) p.label(); Double score = model.predict(p.features()); if((score >=0 && label == 1) || (score <0 && label == 0)) { return 1; //correct classiciation } else return 0; } } ); // sum up all values in the rdd to get the number of correctly classified examples int sum=classification.reduce(new Function2<Integer, Integer, Integer>() { public Integer call(Integer arg0, Integer arg1) throws Exception { return arg0+arg1; }}); //compute accuracy as the percentage of the correctly classified examples double accuracy=((double)sum)/((double)classification.count()); System.out.println("Accuracy = " + accuracy); } } ); } }