Repository: spark
Updated Branches:
  refs/heads/branch-2.0 554e0f30a -> 97fd9a09c


[SPARK-15316][PYSPARK][ML] Add linkPredictionCol to GeneralizedLinearRegression

## What changes were proposed in this pull request?

Add linkPredictionCol to GeneralizedLinearRegression and fix the PyDoc to 
generate the bullet list

## How was this patch tested?

doctests & built docs locally

Author: Holden Karau <hol...@us.ibm.com>

Closes #13106 from 
holdenk/SPARK-15316-add-linkPredictionCol-toGeneralizedLinearRegression.

(cherry picked from commit e71cd96bf733f0440f818c6efc7a04b68d7cbe45)
Signed-off-by: Nick Pentreath <ni...@za.ibm.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/97fd9a09
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/97fd9a09
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/97fd9a09

Branch: refs/heads/branch-2.0
Commit: 97fd9a09ce1313ad7b9569fc3ca8e944d36d0ce9
Parents: 554e0f3
Author: Holden Karau <hol...@us.ibm.com>
Authored: Thu May 19 20:59:19 2016 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Thu May 19 20:59:42 2016 +0200

----------------------------------------------------------------------
 python/pyspark/ml/regression.py | 46 +++++++++++++++++++++++++++---------
 1 file changed, 35 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/97fd9a09/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index cfcbbfc..25640b1 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1245,10 +1245,14 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
     predictor (link function) and a description of the error distribution 
(family). It supports
     "gaussian", "binomial", "poisson" and "gamma" as family. Valid link 
functions for each family
     is listed below. The first link function of each family is the default one.
-    - "gaussian" -> "identity", "log", "inverse"
-    - "binomial" -> "logit", "probit", "cloglog"
-    - "poisson"  -> "log", "identity", "sqrt"
-    - "gamma"    -> "inverse", "identity", "log"
+
+    * "gaussian" -> "identity", "log", "inverse"
+
+    * "binomial" -> "logit", "probit", "cloglog"
+
+    * "poisson"  -> "log", "identity", "sqrt"
+
+    * "gamma"    -> "inverse", "identity", "log"
 
     .. seealso:: `GLM 
<https://en.wikipedia.org/wiki/Generalized_linear_model>`_
 
@@ -1258,9 +1262,12 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
     ...     (1.0, Vectors.dense(1.0, 2.0)),
     ...     (2.0, Vectors.dense(0.0, 0.0)),
     ...     (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
-    >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
+    >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", 
linkPredictionCol="p")
     >>> model = glr.fit(df)
-    >>> abs(model.transform(df).head().prediction - 1.5) < 0.001
+    >>> transformed = model.transform(df)
+    >>> abs(transformed.head().prediction - 1.5) < 0.001
+    True
+    >>> abs(transformed.head().p - 1.5) < 0.001
     True
     >>> model.coefficients
     DenseVector([1.5..., -1.0...])
@@ -1290,20 +1297,23 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
                  "relationship between the linear predictor and the mean of 
the distribution " +
                  "function. Supported options: identity, log, inverse, logit, 
probit, cloglog " +
                  "and sqrt.", typeConverter=TypeConverters.toString)
+    linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link 
prediction (linear " +
+                              "predictor) column name", 
typeConverter=TypeConverters.toString)
 
     @keyword_only
     def __init__(self, labelCol="label", featuresCol="features", 
predictionCol="prediction",
                  family="gaussian", link=None, fitIntercept=True, maxIter=25, 
tol=1e-6,
-                 regParam=0.0, weightCol=None, solver="irls"):
+                 regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol=""):
         """
         __init__(self, labelCol="label", featuresCol="features", 
predictionCol="prediction", \
                  family="gaussian", link=None, fitIntercept=True, maxIter=25, 
tol=1e-6, \
-                 regParam=0.0, weightCol=None, solver="irls")
+                 regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol="")
         """
         super(GeneralizedLinearRegression, self).__init__()
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.regression.GeneralizedLinearRegression", 
self.uid)
-        self._setDefault(family="gaussian", maxIter=25, tol=1e-6, 
regParam=0.0, solver="irls")
+        self._setDefault(family="gaussian", maxIter=25, tol=1e-6, 
regParam=0.0, solver="irls",
+                         linkPredictionCol="")
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
 
@@ -1311,11 +1321,11 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
     @since("2.0.0")
     def setParams(self, labelCol="label", featuresCol="features", 
predictionCol="prediction",
                   family="gaussian", link=None, fitIntercept=True, maxIter=25, 
tol=1e-6,
-                  regParam=0.0, weightCol=None, solver="irls"):
+                  regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol=""):
         """
         setParams(self, labelCol="label", featuresCol="features", 
predictionCol="prediction", \
                   family="gaussian", link=None, fitIntercept=True, maxIter=25, 
tol=1e-6, \
-                  regParam=0.0, weightCol=None, solver="irls")
+                  regParam=0.0, weightCol=None, solver="irls", 
linkPredictionCol="")
         Sets params for generalized linear regression.
         """
         kwargs = self.setParams._input_kwargs
@@ -1339,6 +1349,20 @@ class GeneralizedLinearRegression(JavaEstimator, 
HasLabelCol, HasFeaturesCol, Ha
         return self.getOrDefault(self.family)
 
     @since("2.0.0")
+    def setLinkPredictionCol(self, value):
+        """
+        Sets the value of :py:attr:`linkPredictionCol`.
+        """
+        return self._set(linkPredictionCol=value)
+
+    @since("2.0.0")
+    def getLinkPredictionCol(self):
+        """
+        Gets the value of linkPredictionCol or its default value.
+        """
+        return self.getOrDefault(self.linkPredictionCol)
+
+    @since("2.0.0")
     def setLink(self, value):
         """
         Sets the value of :py:attr:`link`.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to