Repository: spark Updated Branches: refs/heads/branch-2.0 238b7b416 -> eb0db9090
[SPARK-14814][MLLIB] API: Java compatibility, docs ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-14814 fix a java compatibility function in mllib DecisionTreeModel. As synced in jira, other compatibility issues don't need fixes. ## How was this patch tested? existing ut Author: Yuhao Yang <[email protected]> Closes #12971 from hhbyyh/javacompatibility. (cherry picked from commit 68abc1b4e9afbb6c2a87689221a46b835dded102) Signed-off-by: Sean Owen <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eb0db909 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eb0db909 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eb0db909 Branch: refs/heads/branch-2.0 Commit: eb0db909009afd9289d24fd5a59eb060b8aafc5f Parents: 238b7b4 Author: Yuhao Yang <[email protected]> Authored: Mon May 9 09:08:54 2016 +0100 Committer: Sean Owen <[email protected]> Committed: Mon May 9 09:09:07 2016 +0100 ---------------------------------------------------------------------- .../apache/spark/mllib/tree/model/DecisionTreeModel.scala | 4 ++-- .../apache/spark/mllib/tree/JavaDecisionTreeSuite.java | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/eb0db909/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a87f8a6..c13b9a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -75,8 +75,8 @@ class DecisionTreeModel @Since("1.0.0") ( * @return JavaRDD of predictions for each of the given data points */ @Since("1.2.0") - def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { - predict(features.rdd) + def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = { + predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } /** http://git-wip-us.apache.org/repos/asf/spark/blob/eb0db909/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 8dd2906..60585d2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -28,6 +28,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.configuration.Algo; import org.apache.spark.mllib.tree.configuration.Strategy; @@ -95,6 +97,14 @@ public class JavaDecisionTreeSuite implements Serializable { DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); + // java compatibility test + JavaRDD<Double> predictions = model.predict(rdd.map(new Function<LabeledPoint, Vector>() { + @Override + public Vector call(LabeledPoint v1) { + return v1.features(); + } + })); + int numCorrect = validatePrediction(arr, model); Assert.assertTrue(numCorrect == rdd.count()); } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
