Repository: spark Updated Branches: refs/heads/branch-1.2 127c19b44 -> 16da988c5
[SPARK-4369] [MLLib] fix TreeModel.predict() with RDD Fix TreeModel.predict() with RDD, added tests for it. (Also checked that other models don't have this issue) Author: Davies Liu <dav...@databricks.com> Closes #3230 from davies/predict and squashes the following commits: 81172aa [Davies Liu] fix predict (cherry picked from commit bd86118c4e980f94916f892c76fb808fd4c8bd85) Signed-off-by: Xiangrui Meng <m...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/16da988c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/16da988c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/16da988c Branch: refs/heads/branch-1.2 Commit: 16da988c5cdae935151e307a66a5385bac5167c3 Parents: 127c19b Author: Davies Liu <dav...@databricks.com> Authored: Wed Nov 12 13:56:41 2014 -0800 Committer: Xiangrui Meng <m...@databricks.com> Committed: Wed Nov 12 13:56:50 2014 -0800 ---------------------------------------------------------------------- .../mllib/tree/model/DecisionTreeModel.scala | 12 +++++++++ python/pyspark/mllib/tree.py | 26 +++++++++++--------- 2 files changed, 26 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/16da988c/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index ec1d99a..ac4d02e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.tree.model +import org.apache.spark.api.java.JavaRDD import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD @@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable features.map(x => predict(x)) } + + /** + * Predict values for the given data set using the model trained. + * + * @param features JavaRDD representing data points to be predicted + * @return JavaRDD of predictions for each of the given data points + */ + def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { + predict(features.rdd) + } + /** * Get number of nodes in tree, including leaf nodes. */ http://git-wip-us.apache.org/repos/asf/spark/blob/16da988c/python/pyspark/mllib/tree.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 5d1a3c0..ef0d556 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -124,10 +124,13 @@ class DecisionTree(object): Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 - True - >>> model.predict(array([0.0])) == 0 - True + >>> model.predict(array([1.0])) + 1.0 + >>> model.predict(array([0.0])) + 0.0 + >>> rdd = sc.parallelize([[1.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @@ -170,14 +173,13 @@ class DecisionTree(object): ... ] >>> >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {}) - >>> model.predict(array([0.0, 1.0])) == 1 - True - >>> model.predict(array([0.0, 0.0])) == 0 - True - >>> model.predict(SparseVector(2, {1: 1.0})) == 1 - True - >>> model.predict(SparseVector(2, {1: 0.0})) == 0 - True + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {1: 0.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org