Wayne Zhang created SPARK-18701:
-----------------------------------

             Summary: Poisson GLM fails due to wrong initialization
                 Key: SPARK-18701
                 URL: https://issues.apache.org/jira/browse/SPARK-18701
             Project: Spark
          Issue Type: New Feature
          Components: ML
    Affects Versions: 2.0.2
            Reporter: Wayne Zhang
            Priority: Critical
             Fix For: 2.2.0


Poisson GLM fails for many standard data sets. The issue is incorrect 
initialization leading to almost zero probability and weights. The following 
simple example reproduces the error. 

{code:borderStyle=solid}
val datasetPoissonLogWithZero = Seq(
      LabeledPoint(0.0, Vectors.dense(18, 1.0)),
      LabeledPoint(1.0, Vectors.dense(12, 0.0)),
      LabeledPoint(0.0, Vectors.dense(15, 0.0)),
      LabeledPoint(0.0, Vectors.dense(13, 2.0)),
      LabeledPoint(0.0, Vectors.dense(15, 1.0)),
      LabeledPoint(1.0, Vectors.dense(16, 1.0)),
      LabeledPoint(0.0, Vectors.dense(10, 0.0)),
      LabeledPoint(0.0, Vectors.dense(15, 0.0)),
      LabeledPoint(0.0, Vectors.dense(12, 2.0)),
      LabeledPoint(0.0, Vectors.dense(13, 0.0)),
      LabeledPoint(1.0, Vectors.dense(15, 0.0)),
      LabeledPoint(1.0, Vectors.dense(15, 0.0)),
      LabeledPoint(0.0, Vectors.dense(15, 0.0)),
      LabeledPoint(0.0, Vectors.dense(12, 2.0)),
      LabeledPoint(1.0, Vectors.dense(12, 2.0))
    ).toDF()
    
val glr = new GeneralizedLinearRegression()
  .setFamily("poisson")
  .setLink("log")
  .setMaxIter(20)
  .setRegParam(0)

val model = glr.fit(datasetPoissonLogWithZero)
{code}

The issue is in the initialization:  the mean is initialized as the response, 
which could be zero. Applying the log link results in very negative numbers 
(protected against -Inf), which again leads to close to zero probability and 
weights in the weighted least squares. The fix is easy: just add a small 
constant, highlighted in red below. 
 

    override def initialize(y: Double, weight: Double): Double = {
      require(y >= 0.0, "The response variable of Poisson family " +
        s"should be non-negative, but got $y")
      y {color:red}+ 0.1 {color}
    }

I already have a fix and test code. Will create a PR. 



--
This message was sent by Atlassian JIRA
(v6.3.4#6332)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to