Repository: spark Updated Branches: refs/heads/master 3fabbc576 -> 6b94420f6
[SPARK-24231][PYSPARK][ML] Provide Python API for evaluateEachIteration for spark.ml GBTs ## What changes were proposed in this pull request? Add evaluateEachIteration for GBTClassification and GBTRegressionModel ## How was this patch tested? doctest Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG <[email protected]> Closes #21335 from ludatabricks/SPARK-14682. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6b94420f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6b94420f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6b94420f Branch: refs/heads/master Commit: 6b94420f6c672683678a54404e6341a0b9ab3c24 Parents: 3fabbc5 Author: Lu WANG <[email protected]> Authored: Tue May 15 14:16:31 2018 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Tue May 15 14:16:31 2018 -0700 ---------------------------------------------------------------------- python/pyspark/ml/classification.py | 15 +++++++++++++++ python/pyspark/ml/regression.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6b94420f/python/pyspark/ml/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ec17653..424ecfd 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1222,6 +1222,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)], + ... ["indexed", "features"]) + >>> model.evaluateEachIteration(validation) + [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] .. versionadded:: 1.4.0 """ @@ -1319,6 +1323,17 @@ class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWrita """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + return self._call_java("evaluateEachIteration", dataset) + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, http://git-wip-us.apache.org/repos/asf/spark/blob/6b94420f/python/pyspark/ml/regression.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9a66d87..dd0b62f 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1056,6 +1056,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))], + ... ["label", "features"]) + >>> model.evaluateEachIteration(validation, "squared") + [0.0, 0.0, 0.0, 0.0, 0.0] .. versionadded:: 1.4.0 """ @@ -1156,6 +1160,20 @@ class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset, loss): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + :param loss: + The loss function used to compute error. + Supported options: squared, absolute + """ + return self._call_java("evaluateEachIteration", dataset, loss) + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
