Repository: spark
Updated Branches:
  refs/heads/master 3fabbc576 -> 6b94420f6


[SPARK-24231][PYSPARK][ML] Provide Python API for evaluateEachIteration for 
spark.ml GBTs

## What changes were proposed in this pull request?

Add evaluateEachIteration for GBTClassification and GBTRegressionModel

## How was this patch tested?

doctest

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: Lu WANG <[email protected]>

Closes #21335 from ludatabricks/SPARK-14682.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6b94420f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6b94420f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6b94420f

Branch: refs/heads/master
Commit: 6b94420f6c672683678a54404e6341a0b9ab3c24
Parents: 3fabbc5
Author: Lu WANG <[email protected]>
Authored: Tue May 15 14:16:31 2018 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Tue May 15 14:16:31 2018 -0700

----------------------------------------------------------------------
 python/pyspark/ml/classification.py | 15 +++++++++++++++
 python/pyspark/ml/regression.py     | 18 ++++++++++++++++++
 2 files changed, 33 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6b94420f/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index ec17653..424ecfd 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -1222,6 +1222,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredictionCol
     True
     >>> model.trees
     [DecisionTreeRegressionModel (uid=...) of depth..., 
DecisionTreeRegressionModel...]
+    >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
+    ...              ["indexed", "features"])
+    >>> model.evaluateEachIteration(validation)
+    [0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
 
     .. versionadded:: 1.4.0
     """
@@ -1319,6 +1323,17 @@ class GBTClassificationModel(TreeEnsembleModel, 
JavaPredictionModel, JavaMLWrita
         """Trees in this ensemble. Warning: These have null parent 
Estimators."""
         return [DecisionTreeRegressionModel(m) for m in 
list(self._call_java("trees"))]
 
+    @since("2.4.0")
+    def evaluateEachIteration(self, dataset):
+        """
+        Method to compute error or loss for every iteration of gradient 
boosting.
+
+        :param dataset:
+            Test dataset to evaluate model on, where dataset is an
+            instance of :py:class:`pyspark.sql.DataFrame`
+        """
+        return self._call_java("evaluateEachIteration", dataset)
+
 
 @inherit_doc
 class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, 
HasProbabilityCol,

http://git-wip-us.apache.org/repos/asf/spark/blob/6b94420f/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 9a66d87..dd0b62f 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1056,6 +1056,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, 
HasLabelCol, HasPredictionCol,
     True
     >>> model.trees
     [DecisionTreeRegressionModel (uid=...) of depth..., 
DecisionTreeRegressionModel...]
+    >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
+    ...              ["label", "features"])
+    >>> model.evaluateEachIteration(validation, "squared")
+    [0.0, 0.0, 0.0, 0.0, 0.0]
 
     .. versionadded:: 1.4.0
     """
@@ -1156,6 +1160,20 @@ class GBTRegressionModel(TreeEnsembleModel, 
JavaPredictionModel, JavaMLWritable,
         """Trees in this ensemble. Warning: These have null parent 
Estimators."""
         return [DecisionTreeRegressionModel(m) for m in 
list(self._call_java("trees"))]
 
+    @since("2.4.0")
+    def evaluateEachIteration(self, dataset, loss):
+        """
+        Method to compute error or loss for every iteration of gradient 
boosting.
+
+        :param dataset:
+            Test dataset to evaluate model on, where dataset is an
+            instance of :py:class:`pyspark.sql.DataFrame`
+        :param loss:
+            The loss function used to compute error.
+            Supported options: squared, absolute
+        """
+        return self._call_java("evaluateEachIteration", dataset, loss)
+
 
 @inherit_doc
 class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, 
HasPredictionCol,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to