This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a010a1a1984 [SPARK-52035][ML] Decouple 
LinearRegressionTrainingSummary and LinearRegressionModel
1a010a1a1984 is described below

commit 1a010a1a1984cf0427022bb38ff8d9e28460f974
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu May 8 18:22:19 2025 +0800

    [SPARK-52035][ML] Decouple LinearRegressionTrainingSummary and 
LinearRegressionModel
    
    ### What changes were proposed in this pull request?
    Decouple LinearRegressionTrainingSummary and LinearRegressionModel
    
    ### Why are the changes needed?
    LinearRegressionTrainingSummary holds a reference to the model, making it 
hard to handle in connect server
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #50825 from zhengruifeng/ml_refactor_lir_summary.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../spark/ml/regression/LinearRegression.scala     | 105 +++++++++++----------
 1 file changed, 57 insertions(+), 48 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 29fa3d5e123b..f313786cf598 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -435,10 +435,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
     if (SummaryUtils.enableTrainingSummary) {
       // Handle possible missing or invalid prediction columns
       val (summaryModel, predictionColName) = 
model.findSummaryModelAndPredictionCol()
-
       val trainingSummary = new LinearRegressionTrainingSummary(
         summaryModel.transform(dataset), predictionColName, $(labelCol), 
$(featuresCol),
-        model, Array(0.0), objectiveHistory)
+        summaryModel.get(summaryModel.weightCol).getOrElse(""),
+        summaryModel.numFeatures, summaryModel.getFitIntercept,
+        Array(0.0), objectiveHistory)
       model.setSummary(Some(trainingSummary))
     }
     model
@@ -460,9 +461,17 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
     val lrModel = copyValues(new LinearRegressionModel(
       uid, model.coefficients.compressed, model.intercept))
     val (summaryModel, predictionColName) = 
lrModel.findSummaryModelAndPredictionCol()
+
+    val coefficientArray = if (summaryModel.getFitIntercept) {
+      summaryModel.coefficients.toArray ++ Array(summaryModel.intercept)
+    } else {
+      summaryModel.coefficients.toArray
+    }
     val trainingSummary = new LinearRegressionTrainingSummary(
       summaryModel.transform(dataset), predictionColName, $(labelCol), 
$(featuresCol),
-      summaryModel, model.diagInvAtWA.toArray, model.objectiveHistory)
+      summaryModel.get(summaryModel.weightCol).getOrElse(""),
+      summaryModel.numFeatures, summaryModel.getFitIntercept,
+      model.diagInvAtWA.toArray, model.objectiveHistory, coefficientArray)
 
     lrModel.setSummary(Some(trainingSummary))
   }
@@ -494,7 +503,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
 
     val trainingSummary = new LinearRegressionTrainingSummary(
       summaryModel.transform(dataset), predictionColName, $(labelCol), 
$(featuresCol),
-      model, Array(0.0), Array(0.0))
+      summaryModel.get(summaryModel.weightCol).getOrElse(""),
+      summaryModel.numFeatures, summaryModel.getFitIntercept,
+      Array(0.0), Array(0.0))
 
     model.setSummary(Some(trainingSummary))
   }
@@ -737,7 +748,9 @@ class LinearRegressionModel private[ml] (
     // Handle possible missing or invalid prediction columns
     val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
     new LinearRegressionSummary(summaryModel.transform(dataset), 
predictionColName,
-      $(labelCol), $(featuresCol), summaryModel, Array(0.0))
+      $(labelCol), $(featuresCol),
+      summaryModel.get(summaryModel.weightCol).getOrElse(""),
+      summaryModel.numFeatures, summaryModel.getFitIntercept, Array(0.0))
   }
 
   /**
@@ -879,6 +892,8 @@ object LinearRegressionModel extends 
MLReadable[LinearRegressionModel] {
  *
  * @param predictions predictions output by the model's `transform` method.
  * @param objectiveHistory objective function (scaled loss + regularization) 
at each iteration.
+ * @param coefficientArray Coefficients of the linear regression model, only 
necessary when
+ *                         diagInvAtWA is not Array(0).
  */
 @Since("1.5.0")
 class LinearRegressionTrainingSummary private[regression] (
@@ -886,16 +901,22 @@ class LinearRegressionTrainingSummary private[regression] 
(
     predictionCol: String,
     labelCol: String,
     featuresCol: String,
-    model: LinearRegressionModel,
+    private val weightCol: String,
+    private val numFeatures: Int,
+    private val fitIntercept: Boolean,
     diagInvAtWA: Array[Double],
-    val objectiveHistory: Array[Double])
+    val objectiveHistory: Array[Double],
+    private val coefficientArray: Array[Double] = Array.emptyDoubleArray)
   extends LinearRegressionSummary(
     predictions,
     predictionCol,
     labelCol,
     featuresCol,
-    model,
-    diagInvAtWA) {
+    weightCol,
+    numFeatures,
+    fitIntercept,
+    diagInvAtWA,
+    coefficientArray) {
 
   /**
    * Number of training iterations until termination
@@ -919,6 +940,8 @@ class LinearRegressionTrainingSummary private[regression] (
  *                      each instance.
  * @param labelCol Field in "predictions" which gives the true label of each 
instance.
  * @param featuresCol Field in "predictions" which gives the features of each 
instance as a vector.
+ * @param coefficientArray Coefficients of the linear regression model, only 
necessary when
+ *                         diagInvAtWA is not Array(0).
  */
 @Since("1.5.0")
 class LinearRegressionSummary private[regression] (
@@ -926,23 +949,21 @@ class LinearRegressionSummary private[regression] (
     val predictionCol: String,
     val labelCol: String,
     val featuresCol: String,
-    private val privateModel: LinearRegressionModel,
-    private val diagInvAtWA: Array[Double]) extends Summary with Serializable {
+    private val weightCol: String,
+    private val numFeatures: Int,
+    private val fitIntercept: Boolean,
+    private val diagInvAtWA: Array[Double],
+    private val coefficientArray: Array[Double] = Array.emptyDoubleArray)
+  extends Summary with Serializable {
 
   @transient private val metrics = {
-    val weightCol =
-      if (!privateModel.isDefined(privateModel.weightCol) || 
privateModel.getWeightCol.isEmpty) {
-        lit(1.0)
-      } else {
-        col(privateModel.getWeightCol).cast(DoubleType)
-      }
-
+    val w = if (weightCol.isEmpty) lit(1.0) else 
col(weightCol).cast(DoubleType)
     new RegressionMetrics(
       predictions
-        .select(col(predictionCol), col(labelCol).cast(DoubleType), weightCol)
+        .select(col(predictionCol), col(labelCol).cast(DoubleType), w)
         .rdd
         .map { case Row(pred: Double, label: Double, weight: Double) => (pred, 
label, weight) },
-      !privateModel.getFitIntercept)
+      !fitIntercept)
   }
 
   /**
@@ -990,9 +1011,9 @@ class LinearRegressionSummary private[regression] (
    */
   @Since("2.3.0")
   val r2adj: Double = {
-    val interceptDOF = if (privateModel.getFitIntercept) 1 else 0
+    val interceptDOF = if (fitIntercept) 1 else 0
     1 - (1 - r2) * (numInstances - interceptDOF) /
-      (numInstances - privateModel.coefficients.size - interceptDOF)
+      (numInstances - numFeatures - interceptDOF)
   }
 
   /** Residuals (label - predicted value) */
@@ -1007,10 +1028,10 @@ class LinearRegressionSummary private[regression] (
 
   /** Degrees of freedom */
   @Since("2.2.0")
-  val degreesOfFreedom: Long = if (privateModel.getFitIntercept) {
-    numInstances - privateModel.coefficients.size - 1
+  val degreesOfFreedom: Long = if (fitIntercept) {
+    numInstances - numFeatures - 1
   } else {
-    numInstances - privateModel.coefficients.size
+    numInstances - numFeatures
   }
 
   /**
@@ -1018,15 +1039,10 @@ class LinearRegressionSummary private[regression] (
    * the square root of the instance weights.
    */
   lazy val devianceResiduals: Array[Double] = {
-    val weighted =
-      if (!privateModel.isDefined(privateModel.weightCol) || 
privateModel.getWeightCol.isEmpty) {
-        lit(1.0)
-      } else {
-        sqrt(col(privateModel.getWeightCol))
-      }
+    val w = if (weightCol.isEmpty) lit(1.0) else sqrt(col(weightCol))
     val dr = predictions
-      
.select(col(privateModel.getLabelCol).minus(col(privateModel.getPredictionCol))
-        .multiply(weighted).as("weightedResiduals"))
+      .select(col(labelCol).minus(col(predictionCol))
+        .multiply(w).as("weightedResiduals"))
       .select(min(col("weightedResiduals")).as("min"), 
max(col("weightedResiduals")).as("max"))
       .first()
     Array(dr.getDouble(0), dr.getDouble(1))
@@ -1046,15 +1062,14 @@ class LinearRegressionSummary private[regression] (
       throw new UnsupportedOperationException(
         "No Std. Error of coefficients available for this 
LinearRegressionModel")
     } else {
-      val rss =
-        if (!privateModel.isDefined(privateModel.weightCol) || 
privateModel.getWeightCol.isEmpty) {
+      val rss = if (weightCol.isEmpty) {
           meanSquaredError * numInstances
-        } else {
-          val t = udf { (pred: Double, label: Double, weight: Double) =>
-            math.pow(label - pred, 2.0) * weight }
-          predictions.select(t(col(privateModel.getPredictionCol), 
col(privateModel.getLabelCol),
-            
col(privateModel.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0)
-        }
+      } else {
+        val t = udf { (pred: Double, label: Double, weight: Double) =>
+          math.pow(label - pred, 2.0) * weight }
+        predictions.select(t(col(predictionCol), col(labelCol),
+          col(weightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0)
+      }
       val sigma2 = rss / degreesOfFreedom
       diagInvAtWA.map(_ * sigma2).map(math.sqrt)
     }
@@ -1074,12 +1089,7 @@ class LinearRegressionSummary private[regression] (
       throw new UnsupportedOperationException(
         "No t-statistic available for this LinearRegressionModel")
     } else {
-      val estimate = if (privateModel.getFitIntercept) {
-        Array.concat(privateModel.coefficients.toArray, 
Array(privateModel.intercept))
-      } else {
-        privateModel.coefficients.toArray
-      }
-      estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 }
+      coefficientArray.zip(coefficientStandardErrors).map { x => x._1 / x._2 }
     }
   }
 
@@ -1100,6 +1110,5 @@ class LinearRegressionSummary private[regression] (
       tValues.map { x => 2.0 * (1.0 - 
StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) }
     }
   }
-
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to