Github user MLnick commented on a diff in the pull request:

    https://github.com/apache/spark/pull/12920#discussion_r62800683
  
    --- Diff: 
examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java 
---
    @@ -17,222 +17,68 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.commons.cli.*;
    -
     // $example on$
     import org.apache.spark.ml.classification.LogisticRegression;
     import org.apache.spark.ml.classification.OneVsRest;
     import org.apache.spark.ml.classification.OneVsRestModel;
    -import org.apache.spark.ml.util.MetadataUtils;
    -import org.apache.spark.mllib.evaluation.MulticlassMetrics;
    -import org.apache.spark.mllib.linalg.Matrix;
    -import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
     import org.apache.spark.sql.Dataset;
     import org.apache.spark.sql.Row;
    -import org.apache.spark.sql.SparkSession;
    -import org.apache.spark.sql.types.StructField;
     // $example off$
    +import org.apache.spark.sql.SparkSession;
    +
     
     /**
    - * An example runner for Multiclass to Binary Reduction with One Vs Rest.
    - * The example uses Logistic Regression as the base classifier. All 
parameters that
    - * can be specified on the base classifier can be passed in to the runner 
options.
    + * An example of Multiclass to Binary Reduction with One Vs Rest,
    + * using Logistic Regression as the base classifier.
      * Run with
      * <pre>
    - * bin/run-example ml.JavaOneVsRestExample [options]
    + * bin/run-example ml.JavaOneVsRestExample
      * </pre>
      */
     public class JavaOneVsRestExample {
    -
    -  private static class Params {
    -    String input;
    -    String testInput = null;
    -    Integer maxIter = 100;
    -    double tol = 1E-6;
    -    boolean fitIntercept = true;
    -    Double regParam = null;
    -    Double elasticNetParam = null;
    -    double fracTest = 0.2;
    -  }
    -
       public static void main(String[] args) {
    -    // parse the arguments
    -    Params params = parse(args);
         SparkSession spark = SparkSession
           .builder()
           .appName("JavaOneVsRestExample")
           .getOrCreate();
     
         // $example on$
    -    // configure the base classifier
    -    LogisticRegression classifier = new LogisticRegression()
    -      .setMaxIter(params.maxIter)
    -      .setTol(params.tol)
    -      .setFitIntercept(params.fitIntercept);
    +    // load data file.
    +    Dataset<Row> inputData = spark.read().format("libsvm")
    +      .load("data/mllib/sample_multiclass_classification_data.txt");
     
    -    if (params.regParam != null) {
    -      classifier.setRegParam(params.regParam);
    -    }
    -    if (params.elasticNetParam != null) {
    -      classifier.setElasticNetParam(params.elasticNetParam);
    -    }
    +    // generate the train/test split.
    +    Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2});
    +    Dataset<Row> train = tmp[0];
    +    Dataset<Row> test = tmp[1];
     
    -    // instantiate the One Vs Rest Classifier
    -    OneVsRest ovr = new OneVsRest().setClassifier(classifier);
    -
    -    String input = params.input;
    -    Dataset<Row> inputData = spark.read().format("libsvm").load(input);
    -    Dataset<Row> train;
    -    Dataset<Row> test;
    +    // configure the base classifier.
    +    LogisticRegression classifier = new LogisticRegression()
    +      .setMaxIter(10)
    +      .setTol(1E-6)
    +      .setFitIntercept(true);
     
    -    // compute the train/ test split: if testInput is not provided use 
part of input
    -    String testInput = params.testInput;
    -    if (testInput != null) {
    -      train = inputData;
    -      // compute the number of features in the training set.
    -      int numFeatures = inputData.first().<Vector>getAs(1).size();
    -      test = spark.read().format("libsvm").option("numFeatures",
    -        String.valueOf(numFeatures)).load(testInput);
    -    } else {
    -      double f = params.fracTest;
    -      Dataset<Row>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 
12345);
    -      train = tmp[0];
    -      test = tmp[1];
    -    }
    +    // instantiate the One Vs Rest Classifier.
    +    OneVsRest ovr = new OneVsRest().setClassifier(classifier);
     
    -    // train the multiclass model
    +    // train the multiclass model.
    --- End diff --
    
    Sorry one last thing - here we use `train.cache()` but we don't do that in 
the other examples. Actually in general we don't seem to do that in any other 
examples from a quick look. So perhaps remove that and just do `ovr.fit(train);`


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to