Hi,

I am trying to run the straightforward example of SVm but I am getting low
accuracy (around 50%) when I predict using the same data I used for
training. I am probably doing the prediction in a wrong way. My code is
below. I would appreciate any help.


import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.mllib.classification.SVMModel;
import org.apache.spark.mllib.classification.SVMWithSGD;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;

import scala.Tuple2;
import edu.illinois.biglbjava.readers.LabeledPointReader;

public class SimpleDistSVM {
  public static void main(String[] args) {
    SparkConf conf = new SparkConf().setAppName("SVM Classifier Example");
    SparkContext sc = new SparkContext(conf);
    String inputPath=args[0];

    // Read training data
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc,
inputPath).toJavaRDD();

    // Run training algorithm to build the model.
    int numIterations = 3;
    final SVMModel model = SVMWithSGD.train(data.rdd(), numIterations);

    // Clear the default threshold.
    model.clearThreshold();


    // Predict points in test set and map to an RDD of 0/1 values where 0
is misclassication and 1 is correct classification
    JavaRDD<Integer> classification = data.map(new Function<LabeledPoint,
Integer>() {
         public Integer call(LabeledPoint p) {
           int label = (int) p.label();
           Double score = model.predict(p.features());
           if((score >=0 && label == 1) || (score <0 && label == 0))
           {
           return 1; //correct classiciation
           }
           else
            return 0;

         }
       }
     );
    // sum up all values in the rdd to get the number of correctly
classified examples
     int sum=classification.reduce(new Function2<Integer, Integer,
Integer>()
    {
    public Integer call(Integer arg0, Integer arg1)
    throws Exception {
    return arg0+arg1;
    }});

     //compute accuracy as the percentage of the correctly classified
examples
     double accuracy=((double)sum)/((double)classification.count());
     System.out.println("Accuracy = " + accuracy);

        }
      }
    );
  }
}

Reply via email to