Repository: spark
Updated Branches:
  refs/heads/branch-1.0 f9eeddccb -> bc9a96e2e


[SPARK-1741][MLLIB] add predict(JavaRDD) to RegressionModel, 
ClassificationModel, and KMeans

`model.predict` returns a RDD of Scala primitive type (Int/Double), which is 
recognized as Object in Java. Adding predict(JavaRDD) could make life easier 
for Java users.

Added tests for KMeans, LinearRegression, and NaiveBayes.

Will update examples after https://github.com/apache/spark/pull/653 gets merged.

cc: @srowen

Author: Xiangrui Meng <[email protected]>

Closes #670 from mengxr/predict-javardd and squashes the following commits:

b77ccd8 [Xiangrui Meng] Merge branch 'master' into predict-javardd
43caac9 [Xiangrui Meng] add predict(JavaRDD) to RegressionModel, 
ClassificationModel, and KMeans
(cherry picked from commit d52761d67f42ad4d2ff02d96f0675fb3ab709f38)

Signed-off-by: Patrick Wendell <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc9a96e2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc9a96e2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc9a96e2

Branch: refs/heads/branch-1.0
Commit: bc9a96e2e97d4a9b4a2075fb026be320b96bd08b
Parents: f9eeddc
Author: Xiangrui Meng <[email protected]>
Authored: Thu May 15 11:59:59 2014 -0700
Committer: Patrick Wendell <[email protected]>
Committed: Thu May 15 12:00:26 2014 -0700

----------------------------------------------------------------------
 .../classification/ClassificationModel.scala    | 11 +++++++++-
 .../spark/mllib/clustering/KMeansModel.scala    |  5 +++++
 .../mllib/regression/RegressionModel.scala      | 11 +++++++++-
 .../classification/JavaNaiveBayesSuite.java     | 16 +++++++++++++++
 .../spark/mllib/clustering/JavaKMeansSuite.java | 14 +++++++++++++
 .../regression/JavaLinearRegressionSuite.java   | 21 ++++++++++++++++++++
 6 files changed, 76 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bc9a96e2/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
index 6332301..b7a1d90 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.mllib.classification
 
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.rdd.RDD
-import org.apache.spark.annotation.Experimental
 
 /**
  * :: Experimental ::
@@ -43,4 +44,12 @@ trait ClassificationModel extends Serializable {
    * @return predicted category from the trained model
    */
   def predict(testData: Vector): Double
+
+  /**
+   * Predict values for examples stored in a JavaRDD.
+   * @param testData JavaRDD representing data points to be predicted
+   * @return a JavaRDD[java.lang.Double] where each entry contains the 
corresponding prediction
+   */
+  def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
+    predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc9a96e2/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index ce14b06..fba21ae 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.mllib.clustering
 
+import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext._
 import org.apache.spark.mllib.linalg.Vector
@@ -40,6 +41,10 @@ class KMeansModel private[mllib] (val clusterCenters: 
Array[Vector]) extends Ser
     points.map(p => KMeans.findClosest(centersWithNorm, new 
BreezeVectorWithNorm(p))._1)
   }
 
+  /** Maps given points to their cluster indices. */
+  def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
+    predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
+
   /**
    * Return the K-means cost (sum of squared distances of points to their 
nearest center) for this
    * model on the given data.

http://git-wip-us.apache.org/repos/asf/spark/blob/bc9a96e2/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
index b27e158..64b02f7 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.mllib.regression
 
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.annotation.Experimental
 
 @Experimental
 trait RegressionModel extends Serializable {
@@ -38,4 +39,12 @@ trait RegressionModel extends Serializable {
    * @return Double prediction from the trained model
    */
   def predict(testData: Vector): Double
+
+  /**
+   * Predict values for examples stored in a JavaRDD.
+   * @param testData JavaRDD representing data points to be predicted
+   * @return a JavaRDD[java.lang.Double] where each entry contains the 
corresponding prediction
+   */
+  def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
+    predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc9a96e2/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index c80b113..743a43a 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.classification;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.junit.After;
@@ -87,4 +89,18 @@ public class JavaNaiveBayesSuite implements Serializable {
     int numAccurate2 = validatePrediction(POINTS, model2);
     Assert.assertEquals(POINTS.size(), numAccurate2);
   }
+
+  @Test
+  public void testPredictJavaRDD() {
+    JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache();
+    NaiveBayesModel model = NaiveBayes.train(examples.rdd());
+    JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, 
Vector>() {
+      @Override
+      public Vector call(LabeledPoint v) throws Exception {
+        return v.features();
+      }});
+    JavaRDD<Double> predictions = model.predict(vectors);
+    // Should be able to get the first prediction.
+    predictions.first();
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc9a96e2/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
index 49a614b..0c916ca 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
@@ -88,4 +88,18 @@ public class JavaKMeansSuite implements Serializable {
       .run(data.rdd());
     assertEquals(expectedCenter, model.clusterCenters()[0]);
   }
+
+  @Test
+  public void testPredictJavaRDD() {
+    List<Vector> points = Lists.newArrayList(
+      Vectors.dense(1.0, 2.0, 6.0),
+      Vectors.dense(1.0, 3.0, 0.0),
+      Vectors.dense(1.0, 4.0, 6.0)
+    );
+    JavaRDD<Vector> data = sc.parallelize(points, 2);
+    KMeansModel model = new 
KMeans().setK(1).setMaxIterations(5).run(data.rdd());
+    JavaRDD<Integer> predictions = model.predict(data);
+    // Should be able to get the first prediction.
+    predictions.first();
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc9a96e2/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
index 7151e55..6dc6877 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -25,8 +25,10 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.util.LinearDataGenerator;
 
 public class JavaLinearRegressionSuite implements Serializable {
@@ -92,4 +94,23 @@ public class JavaLinearRegressionSuite implements 
Serializable {
     Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
   }
 
+  @Test
+  public void testPredictJavaRDD() {
+    int nPoints = 100;
+    double A = 0.0;
+    double[] weights = {10, 10};
+    JavaRDD<LabeledPoint> testRDD = sc.parallelize(
+      LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 
0.1), 2).cache();
+    LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
+    LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
+    JavaRDD<Vector> vectors = testRDD.map(new Function<LabeledPoint, Vector>() 
{
+      @Override
+      public Vector call(LabeledPoint v) throws Exception {
+        return v.features();
+      }
+    });
+    JavaRDD<Double> predictions = model.predict(vectors);
+    // Should be able to get the first prediction.
+    predictions.first();
+  }
 }

Reply via email to