Repository: spark
Updated Branches:
  refs/heads/master 2c04c8a1a -> cc12a86fb


[SPARK-7575] [ML] [DOC] Example code for OneVsRest

Java and Scala examples for OneVsRest. Fixes the base classifier to be Logistic 
Regression and accepts the configuration parameters of the base classifier.

Author: Ram Sriharsha <[email protected]>

Closes #6115 from harsha2010/SPARK-7575 and squashes the following commits:

87ad3c7 [Ram Sriharsha] extra line
f5d9891 [Ram Sriharsha] Merge branch 'master' into SPARK-7575
7076084 [Ram Sriharsha] cleanup
dfd660c [Ram Sriharsha] cleanup
8703e4f [Ram Sriharsha] update doc
cb23995 [Ram Sriharsha] fix commandline options for JavaOneVsRestExample
69e91f8 [Ram Sriharsha] cleanup
7f4e127 [Ram Sriharsha] cleanup
d4c40d0 [Ram Sriharsha] Code Review fixes
461eb38 [Ram Sriharsha] cleanup
e0106d9 [Ram Sriharsha] Fix typo
935cf56 [Ram Sriharsha] Try to match Java and Scala Example Commandline options
5323ff9 [Ram Sriharsha] cleanup
196a59a [Ram Sriharsha] cleanup
6adfa0c [Ram Sriharsha] Style Fix
8cfc5d5 [Ram Sriharsha] [SPARK-7575] Example code for OneVsRest


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cc12a86f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cc12a86f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cc12a86f

Branch: refs/heads/master
Commit: cc12a86fb049f2be1f45baf461d202ec356ccf8f
Parents: 2c04c8a
Author: Ram Sriharsha <[email protected]>
Authored: Fri May 15 19:33:20 2015 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Fri May 15 19:33:20 2015 -0700

----------------------------------------------------------------------
 .../spark/examples/ml/JavaOneVsRestExample.java | 236 +++++++++++++++++++
 .../spark/examples/ml/OneVsRestExample.scala    | 185 +++++++++++++++
 2 files changed, 421 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cc12a86f/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java 
b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
new file mode 100644
index 0000000..75063db
--- /dev/null
+++ 
b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.commons.cli.*;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.classification.OneVsRest;
+import org.apache.spark.ml.classification.OneVsRestModel;
+import org.apache.spark.ml.util.MetadataUtils;
+import org.apache.spark.mllib.evaluation.MulticlassMetrics;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructField;
+
+/**
+ * An example runner for Multiclass to Binary Reduction with One Vs Rest.
+ * The example uses Logistic Regression as the base classifier. All parameters 
that
+ * can be specified on the base classifier can be passed in to the runner 
options.
+ * Run with
+ * <pre>
+ * bin/run-example ml.JavaOneVsRestExample [options]
+ * </pre>
+ */
+public class JavaOneVsRestExample {
+
+  private static class Params {
+    String input;
+    String testInput = null;
+    Integer maxIter = 100;
+    double tol = 1E-6;
+    boolean fitIntercept = true;
+    Double regParam = null;
+    Double elasticNetParam = null;
+    double fracTest = 0.2;
+  }
+
+  public static void main(String[] args) {
+    // parse the arguments
+    Params params = parse(args);
+    SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
+    JavaSparkContext jsc = new JavaSparkContext(conf);
+    SQLContext jsql = new SQLContext(jsc);
+
+    // configure the base classifier
+    LogisticRegression classifier = new LogisticRegression()
+      .setMaxIter(params.maxIter)
+      .setTol(params.tol)
+      .setFitIntercept(params.fitIntercept);
+
+    if (params.regParam != null) {
+      classifier.setRegParam(params.regParam);
+    }
+    if (params.elasticNetParam != null) {
+      classifier.setElasticNetParam(params.elasticNetParam);
+    }
+
+    // instantiate the One Vs Rest Classifier
+    OneVsRest ovr = new OneVsRest().setClassifier(classifier);
+
+    String input = params.input;
+    RDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), input);
+    RDD<LabeledPoint> train;
+    RDD<LabeledPoint> test;
+
+    // compute the train/ test split: if testInput is not provided use part of 
input
+    String testInput = params.testInput;
+    if (testInput != null) {
+      train = inputData;
+      // compute the number of features in the training set.
+      int numFeatures = inputData.first().features().size();
+      test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures);
+    } else {
+      double f = params.fracTest;
+      RDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 
12345);
+      train = tmp[0];
+      test = tmp[1];
+    }
+
+    // train the multiclass model
+    DataFrame trainingDataFrame = jsql.createDataFrame(train, 
LabeledPoint.class);
+    OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());
+
+    // score the model on test data
+    DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
+    DataFrame predictions = ovrModel.transform(testDataFrame.cache())
+      .select("prediction", "label");
+
+    // obtain metrics
+    MulticlassMetrics metrics = new MulticlassMetrics(predictions);
+    StructField predictionColSchema = predictions.schema().apply("prediction");
+    Integer numClasses = (Integer) 
MetadataUtils.getNumClasses(predictionColSchema).get();
+
+    // compute the false positive rate per label
+    StringBuilder results = new StringBuilder();
+    results.append("label\tfpr\n");
+    for (int label = 0; label < numClasses; label++) {
+      results.append(label);
+      results.append("\t");
+      results.append(metrics.falsePositiveRate((double) label));
+      results.append("\n");
+    }
+
+    Matrix confusionMatrix = metrics.confusionMatrix();
+    // output the Confusion Matrix
+    System.out.println("Confusion Matrix");
+    System.out.println(confusionMatrix);
+    System.out.println();
+    System.out.println(results);
+
+    jsc.stop();
+  }
+
+  private static Params parse(String[] args) {
+    Options options = generateCommandlineOptions();
+    CommandLineParser parser = new PosixParser();
+    Params params = new Params();
+
+    try {
+      CommandLine cmd = parser.parse(options, args);
+      String value;
+      if (cmd.hasOption("input")) {
+        params.input = cmd.getOptionValue("input");
+      }
+      if (cmd.hasOption("maxIter")) {
+        value = cmd.getOptionValue("maxIter");
+        params.maxIter = Integer.parseInt(value);
+      }
+      if (cmd.hasOption("tol")) {
+        value = cmd.getOptionValue("tol");
+        params.tol = Double.parseDouble(value);
+      }
+      if (cmd.hasOption("fitIntercept")) {
+        value = cmd.getOptionValue("fitIntercept");
+        params.fitIntercept = Boolean.parseBoolean(value);
+      }
+      if (cmd.hasOption("regParam")) {
+        value = cmd.getOptionValue("regParam");
+        params.regParam = Double.parseDouble(value);
+      }
+      if (cmd.hasOption("elasticNetParam")) {
+        value = cmd.getOptionValue("elasticNetParam");
+        params.elasticNetParam = Double.parseDouble(value);
+      }
+      if (cmd.hasOption("testInput")) {
+        value = cmd.getOptionValue("testInput");
+        params.testInput = value;
+      }
+      if (cmd.hasOption("fracTest")) {
+        value = cmd.getOptionValue("fracTest");
+        params.fracTest = Double.parseDouble(value);
+      }
+
+    } catch (ParseException e) {
+      printHelpAndQuit(options);
+    }
+    return params;
+  }
+
+  private static Options generateCommandlineOptions() {
+    Option input = OptionBuilder.withArgName("input")
+      .hasArg()
+      .isRequired()
+      .withDescription("input path to labeled examples. This path must be 
specified")
+      .create("input");
+    Option testInput = OptionBuilder.withArgName("testInput")
+      .hasArg()
+      .withDescription("input path to test examples")
+      .create("testInput");
+    Option fracTest = OptionBuilder.withArgName("testInput")
+      .hasArg()
+      .withDescription("fraction of data to hold out for testing." +
+        " If given option testInput, this option is ignored. default: 0.2")
+      .create("fracTest");
+    Option maxIter = OptionBuilder.withArgName("maxIter")
+      .hasArg()
+      .withDescription("maximum number of iterations for Logistic Regression. 
default:100")
+      .create("maxIter");
+    Option tol = OptionBuilder.withArgName("tol")
+      .hasArg()
+      .withDescription("the convergence tolerance of iterations " +
+        "for Logistic Regression. default: 1E-6")
+      .create("tol");
+    Option fitIntercept = OptionBuilder.withArgName("fitIntercept")
+      .hasArg()
+      .withDescription("fit intercept for logistic regression. default true")
+      .create("fitIntercept");
+    Option regParam = OptionBuilder.withArgName( "regParam" )
+      .hasArg()
+      .withDescription("the regularization parameter for Logistic Regression.")
+      .create("regParam");
+    Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" )
+      .hasArg()
+      .withDescription("the ElasticNet mixing parameter for Logistic 
Regression.")
+      .create("elasticNetParam");
+
+    Options options = new Options()
+      .addOption(input)
+      .addOption(testInput)
+      .addOption(fracTest)
+      .addOption(maxIter)
+      .addOption(tol)
+      .addOption(fitIntercept)
+      .addOption(regParam)
+      .addOption(elasticNetParam);
+
+    return options;
+  }
+
+  private static void printHelpAndQuit(Options options) {
+    HelpFormatter formatter = new HelpFormatter();
+    formatter.printHelp("JavaOneVsRestExample", options);
+    System.exit(-1);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/cc12a86f/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
----------------------------------------------------------------------
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala 
b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
new file mode 100644
index 0000000..b99d0a1
--- /dev/null
+++ 
b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO}
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkContext, SparkConf}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+
+/**
+ * An example runner for Multiclass to Binary Reduction with One Vs Rest.
+ * The example uses Logistic Regression as the base classifier. All parameters 
that
+ * can be specified on the base classifier can be passed in to the runner 
options.
+ * Run with
+ * {{{
+ * ./bin/run-example ml.OneVsRestExample [options]
+ * }}}
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample 
--driver-memory 1g
+ *   [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use 
`spark-submit` to submit your app.
+ */
+object OneVsRestExample {
+
+  case class Params private[ml] (
+      input: String = null,
+      testInput: Option[String] = None,
+      maxIter: Int = 100,
+      tol: Double = 1E-6,
+      fitIntercept: Boolean = true,
+      regParam: Option[Double] = None,
+      elasticNetParam: Option[Double] = None,
+      fracTest: Double = 0.2) extends AbstractParams[Params]
+
+  def main(args: Array[String]) {
+    val defaultParams = Params()
+
+    val parser = new OptionParser[Params]("OneVsRest Example") {
+      head("OneVsRest Example: multiclass to binary reduction using OneVsRest")
+      opt[String]("input")
+        .text("input path to labeled examples. This path must be specified")
+        .required()
+        .action((x, c) => c.copy(input = x))
+      opt[Double]("fracTest")
+        .text(s"fraction of data to hold out for testing.  If given option 
testInput, " +
+        s"this option is ignored. default: ${defaultParams.fracTest}")
+        .action((x, c) => c.copy(fracTest = x))
+      opt[String]("testInput")
+        .text("input path to test dataset.  If given, option fracTest is 
ignored")
+        .action((x,c) => c.copy(testInput = Some(x)))
+      opt[Int]("maxIter")
+        .text(s"maximum number of iterations for Logistic Regression." +
+          s" default: ${defaultParams.maxIter}")
+        .action((x, c) => c.copy(maxIter = x))
+      opt[Double]("tol")
+        .text(s"the convergence tolerance of iterations for Logistic 
Regression." +
+          s" default: ${defaultParams.tol}")
+        .action((x, c) => c.copy(tol = x))
+      opt[Boolean]("fitIntercept")
+        .text(s"fit intercept for Logistic Regression." +
+        s" default: ${defaultParams.fitIntercept}")
+        .action((x, c) => c.copy(fitIntercept = x))
+      opt[Double]("regParam")
+        .text(s"the regularization parameter for Logistic Regression.")
+        .action((x,c) => c.copy(regParam = Some(x)))
+      opt[Double]("elasticNetParam")
+        .text(s"the ElasticNet mixing parameter for Logistic Regression.")
+        .action((x,c) => c.copy(elasticNetParam = Some(x)))
+      checkConfig { params =>
+        if (params.fracTest < 0 || params.fracTest >= 1) {
+          failure(s"fracTest ${params.fracTest} value incorrect; should be in 
[0,1).")
+        } else {
+          success
+        }
+      }
+    }
+    parser.parse(args, defaultParams).map { params =>
+      run(params)
+    }.getOrElse {
+      sys.exit(1)
+    }
+  }
+
+  private def run(params: Params) {
+    val conf = new SparkConf().setAppName(s"OneVsRestExample with $params")
+    val sc = new SparkContext(conf)
+    val inputData = MLUtils.loadLibSVMFile(sc, params.input)
+    val sqlContext = new SQLContext(sc)
+    import sqlContext.implicits._
+
+    // compute the train/test split: if testInput is not provided use part of 
input.
+    val data = params.testInput match {
+      case Some(t) => {
+        // compute the number of features in the training set.
+        val numFeatures = inputData.first().features.size
+        val testData = MLUtils.loadLibSVMFile(sc, t, numFeatures)
+        Array[RDD[LabeledPoint]](inputData, testData)
+      }
+      case None => {
+        val f = params.fracTest
+        inputData.randomSplit(Array(1 - f, f), seed = 12345)
+      }
+    }
+    val Array(train, test) = data.map(_.toDF().cache())
+
+    // instantiate the base classifier
+    val classifier = new LogisticRegression()
+      .setMaxIter(params.maxIter)
+      .setTol(params.tol)
+      .setFitIntercept(params.fitIntercept)
+
+    // Set regParam, elasticNetParam if specified in params
+    params.regParam.foreach(classifier.setRegParam)
+    params.elasticNetParam.foreach(classifier.setElasticNetParam)
+
+    // instantiate the One Vs Rest Classifier.
+
+    val ovr = new OneVsRest()
+    ovr.setClassifier(classifier)
+
+    // train the multiclass model.
+    val (trainingDuration, ovrModel) = time(ovr.fit(train))
+
+    // score the model on test data.
+    val (predictionDuration, predictions) = time(ovrModel.transform(test))
+
+    // evaluate the model
+    val predictionsAndLabels = predictions.select("prediction", "label")
+      .map(row => (row.getDouble(0), row.getDouble(1)))
+
+    val metrics = new MulticlassMetrics(predictionsAndLabels)
+
+    val confusionMatrix = metrics.confusionMatrix
+
+    // compute the false positive rate per label
+    val predictionColSchema = predictions.schema("prediction")
+    val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get
+    val fprs = Range(0, numClasses).map(p => (p, 
metrics.falsePositiveRate(p.toDouble)))
+
+    println(s" Training Time ${trainingDuration} sec\n")
+
+    println(s" Prediction Time ${predictionDuration} sec\n")
+
+    println(s" Confusion Matrix\n ${confusionMatrix.toString}\n")
+
+    println("label\tfpr")
+
+    println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n"))
+
+    sc.stop()
+  }
+
+  private def time[R](block: => R): (Long, R) = {
+    val t0 = System.nanoTime()
+    val result = block    // call-by-name
+    val t1 = System.nanoTime()
+    (NANO.toSeconds(t1 - t0), result)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to