Github user yanboliang commented on a diff in the pull request:
https://github.com/apache/spark/pull/16630#discussion_r128200260
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
---
@@ -1458,4 +1475,167 @@ class GeneralizedLinearRegressionTrainingSummary
private[regression] (
"No p-value available for this GeneralizedLinearRegressionModel")
}
}
+
+ /**
+ * Coefficient matrix with feature name, coefficient, standard error,
+ * tValue and pValue.
+ */
+ @Since("2.3.0")
+ lazy val coefficientMatrix: Array[(String, Double, Double, Double,
Double)] = {
+ if (isNormalSolver) {
+ var featureNamesLocal = featureNames
+ var coefficients = model.coefficients.toArray
+ var idx = Array.range(0, coefficients.length)
+ if (model.getFitIntercept) {
+ featureNamesLocal = featureNamesLocal :+ "(Intercept)"
+ coefficients = coefficients :+ model.intercept
+ // Reorder so that intercept comes first
+ idx = (coefficients.length - 1) +: idx
+ }
+ val result = for (i <- idx) yield
+ (featureNamesLocal(i), coefficients(i),
coefficientStandardErrors(i),
+ tValues(i), pValues(i))
+ result
+ } else {
+ throw new UnsupportedOperationException(
+ "No summary table available for this
GeneralizedLinearRegressionModel")
+ }
+ }
+
+ private def round(x: Double, digit: Int): String = {
+ BigDecimal(x).setScale(digit,
BigDecimal.RoundingMode.HALF_UP).toString()
+ }
+
+ private[regression] def showString(_numRows: Int, truncate: Int = 20,
+ numDigits: Int = 3): String = {
+ val numRows = _numRows.max(1)
+ val data = coefficientMatrix.take(numRows)
+ val hasMoreData = coefficientMatrix.size > numRows
+
+ val colNames = Array("Feature", "Estimate", "StdError", "TValue",
"PValue")
+ val numCols = colNames.size
+
+ val rows = colNames +: data.map( row => {
+ val mrow = for (cell <- row.productIterator) yield {
+ val str = cell match {
+ case s: String => s
+ case n: Double => round(n, numDigits).toString
+ }
+ if (truncate > 0 && str.length > truncate) {
+ // do not show ellipses for strings shorter than 4 characters.
+ if (truncate < 4) str.substring(0, truncate)
+ else str.substring(0, truncate - 3) + "..."
+ } else {
+ str
+ }
+ }
+ mrow.toArray
+ })
+
+ val sb = new StringBuilder
+ val colWidths = Array.fill(numCols)(3)
+
+ // Compute the width of each column
+ for (row <- rows) {
+ for ((cell, i) <- row.zipWithIndex) {
+ colWidths(i) = math.max(colWidths(i), cell.length)
+ }
+ }
+
+ // Create SeparateLine
+ val sep: String = colWidths.map("-" * _).addString(sb, "+", "+",
"+\n").toString()
+
+ // column names
+ rows.head.zipWithIndex.map { case (cell, i) =>
+ if (truncate > 0) {
+ StringUtils.leftPad(cell, colWidths(i))
+ } else {
+ StringUtils.rightPad(cell, colWidths(i))
+ }
+ }.addString(sb, "|", "|", "|\n")
+ sb.append(sep)
+
+ // data
+ rows.tail.map {
+ _.zipWithIndex.map { case (cell, i) =>
+ if (truncate > 0) {
+ StringUtils.leftPad(cell.toString, colWidths(i))
+ } else {
+ StringUtils.rightPad(cell.toString, colWidths(i))
+ }
+ }.addString(sb, "|", "|", "|\n")
+ }
+
+ // For Data that has more than "numRows" records
+ if (hasMoreData) {
+ sb.append("...\n")
+ sb.append(sep)
+ val rowsString = if (numRows == 1) "row" else "rows"
+ sb.append(s"only showing top $numRows $rowsString\n")
+ } else {
+ sb.append(sep)
+ }
+
+ sb.append("\n")
+ sb.append(s"(Dispersion parameter for ${family.name} family taken to
be " +
+ round(dispersion, numDigits) + ")")
+
+ sb.append("\n")
+ val nd = "Null deviance: " + round(nullDeviance, numDigits) +
+ s" on $degreesOfFreedom degrees of freedom"
+ val rd = "Residual deviance: " + round(deviance, numDigits) +
+ s" on $residualDegreeOfFreedom degrees of freedom"
+ val l = math.max(nd.length, rd.length)
+ sb.append(StringUtils.leftPad(nd, l))
+ sb.append("\n")
+ sb.append(StringUtils.leftPad(rd, l))
+
+ if (family.name != "tweedie") {
+ sb.append("\n")
+ sb.append(s"AIC: " + round(aic, numDigits))
+ }
+
+ sb.toString()
+ }
+
+ /**
+ * Displays the summary of a GeneralizedLinearModel fit.
+ *
+ * @since 2.3.0
+ */
+ def show(): Unit = {
+ val numRows = coefficientMatrix.size
+ show(numRows, true, 3)
+ }
+
+ /**
+ * Displays the top numRows rows of the summary of a
GeneralizedLinearModel fit.
+ *
+ * @param numRows Number of rows to show
+ *
+ * @since 2.3.0
+ */
+ @Since("2.3.0")
+ def show(numRows: Int): Unit = {
+ show(numRows, true, 3)
+ }
+
+ /**
+ * Displays the summary of a GeneralizedLinearModel fit. Strings more
than 20 characters
+ * will be truncated, and all cells will be aligned right.
+ *
+ * @param numRows Number of rows to show
+ * @param truncate Whether truncate long strings. If true, strings more
than 20 characters will
+ * be truncated and all cells will be aligned right
+ * @param numDigits Number of decimal places used to round numerical
values.
+ *
+ * @since 2.3.0
+ */
+ // scalastyle:off println
+ def show(numRows: Int, truncate: Boolean, numDigits: Int): Unit = if
(truncate) {
--- End diff --
I think not all functions are useful for GLM summary, I'd recommend to keep
only one ```show``` function with default setting, such as ```numRows =
coefficientMatrix.size```, ```truncate = 20``` and ```numDigits = 3```. There
has little different compared with ```Dataset.show```, it's not necessary to
provide lots of opinions for users to set, users just want to see the output
like R. Then the code will be more clean.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]