This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1a010a1a1984 [SPARK-52035][ML] Decouple LinearRegressionTrainingSummary and LinearRegressionModel 1a010a1a1984 is described below commit 1a010a1a1984cf0427022bb38ff8d9e28460f974 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu May 8 18:22:19 2025 +0800 [SPARK-52035][ML] Decouple LinearRegressionTrainingSummary and LinearRegressionModel ### What changes were proposed in this pull request? Decouple LinearRegressionTrainingSummary and LinearRegressionModel ### Why are the changes needed? LinearRegressionTrainingSummary holds a reference to the model, making it hard to handle in connect server ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? No Closes #50825 from zhengruifeng/ml_refactor_lir_summary. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../spark/ml/regression/LinearRegression.scala | 105 +++++++++++---------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 29fa3d5e123b..f313786cf598 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -435,10 +435,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String if (SummaryUtils.enableTrainingSummary) { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - val trainingSummary = new LinearRegressionTrainingSummary( summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), - model, Array(0.0), objectiveHistory) + summaryModel.get(summaryModel.weightCol).getOrElse(""), + summaryModel.numFeatures, summaryModel.getFitIntercept, + Array(0.0), objectiveHistory) model.setSummary(Some(trainingSummary)) } model @@ -460,9 +461,17 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val lrModel = copyValues(new LinearRegressionModel( uid, model.coefficients.compressed, model.intercept)) val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol() + + val coefficientArray = if (summaryModel.getFitIntercept) { + summaryModel.coefficients.toArray ++ Array(summaryModel.intercept) + } else { + summaryModel.coefficients.toArray + } val trainingSummary = new LinearRegressionTrainingSummary( summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), - summaryModel, model.diagInvAtWA.toArray, model.objectiveHistory) + summaryModel.get(summaryModel.weightCol).getOrElse(""), + summaryModel.numFeatures, summaryModel.getFitIntercept, + model.diagInvAtWA.toArray, model.objectiveHistory, coefficientArray) lrModel.setSummary(Some(trainingSummary)) } @@ -494,7 +503,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val trainingSummary = new LinearRegressionTrainingSummary( summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), - model, Array(0.0), Array(0.0)) + summaryModel.get(summaryModel.weightCol).getOrElse(""), + summaryModel.numFeatures, summaryModel.getFitIntercept, + Array(0.0), Array(0.0)) model.setSummary(Some(trainingSummary)) } @@ -737,7 +748,9 @@ class LinearRegressionModel private[ml] ( // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, - $(labelCol), $(featuresCol), summaryModel, Array(0.0)) + $(labelCol), $(featuresCol), + summaryModel.get(summaryModel.weightCol).getOrElse(""), + summaryModel.numFeatures, summaryModel.getFitIntercept, Array(0.0)) } /** @@ -879,6 +892,8 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { * * @param predictions predictions output by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + * @param coefficientArray Coefficients of the linear regression model, only necessary when + * diagInvAtWA is not Array(0). */ @Since("1.5.0") class LinearRegressionTrainingSummary private[regression] ( @@ -886,16 +901,22 @@ class LinearRegressionTrainingSummary private[regression] ( predictionCol: String, labelCol: String, featuresCol: String, - model: LinearRegressionModel, + private val weightCol: String, + private val numFeatures: Int, + private val fitIntercept: Boolean, diagInvAtWA: Array[Double], - val objectiveHistory: Array[Double]) + val objectiveHistory: Array[Double], + private val coefficientArray: Array[Double] = Array.emptyDoubleArray) extends LinearRegressionSummary( predictions, predictionCol, labelCol, featuresCol, - model, - diagInvAtWA) { + weightCol, + numFeatures, + fitIntercept, + diagInvAtWA, + coefficientArray) { /** * Number of training iterations until termination @@ -919,6 +940,8 @@ class LinearRegressionTrainingSummary private[regression] ( * each instance. * @param labelCol Field in "predictions" which gives the true label of each instance. * @param featuresCol Field in "predictions" which gives the features of each instance as a vector. + * @param coefficientArray Coefficients of the linear regression model, only necessary when + * diagInvAtWA is not Array(0). */ @Since("1.5.0") class LinearRegressionSummary private[regression] ( @@ -926,23 +949,21 @@ class LinearRegressionSummary private[regression] ( val predictionCol: String, val labelCol: String, val featuresCol: String, - private val privateModel: LinearRegressionModel, - private val diagInvAtWA: Array[Double]) extends Summary with Serializable { + private val weightCol: String, + private val numFeatures: Int, + private val fitIntercept: Boolean, + private val diagInvAtWA: Array[Double], + private val coefficientArray: Array[Double] = Array.emptyDoubleArray) + extends Summary with Serializable { @transient private val metrics = { - val weightCol = - if (!privateModel.isDefined(privateModel.weightCol) || privateModel.getWeightCol.isEmpty) { - lit(1.0) - } else { - col(privateModel.getWeightCol).cast(DoubleType) - } - + val w = if (weightCol.isEmpty) lit(1.0) else col(weightCol).cast(DoubleType) new RegressionMetrics( predictions - .select(col(predictionCol), col(labelCol).cast(DoubleType), weightCol) + .select(col(predictionCol), col(labelCol).cast(DoubleType), w) .rdd .map { case Row(pred: Double, label: Double, weight: Double) => (pred, label, weight) }, - !privateModel.getFitIntercept) + !fitIntercept) } /** @@ -990,9 +1011,9 @@ class LinearRegressionSummary private[regression] ( */ @Since("2.3.0") val r2adj: Double = { - val interceptDOF = if (privateModel.getFitIntercept) 1 else 0 + val interceptDOF = if (fitIntercept) 1 else 0 1 - (1 - r2) * (numInstances - interceptDOF) / - (numInstances - privateModel.coefficients.size - interceptDOF) + (numInstances - numFeatures - interceptDOF) } /** Residuals (label - predicted value) */ @@ -1007,10 +1028,10 @@ class LinearRegressionSummary private[regression] ( /** Degrees of freedom */ @Since("2.2.0") - val degreesOfFreedom: Long = if (privateModel.getFitIntercept) { - numInstances - privateModel.coefficients.size - 1 + val degreesOfFreedom: Long = if (fitIntercept) { + numInstances - numFeatures - 1 } else { - numInstances - privateModel.coefficients.size + numInstances - numFeatures } /** @@ -1018,15 +1039,10 @@ class LinearRegressionSummary private[regression] ( * the square root of the instance weights. */ lazy val devianceResiduals: Array[Double] = { - val weighted = - if (!privateModel.isDefined(privateModel.weightCol) || privateModel.getWeightCol.isEmpty) { - lit(1.0) - } else { - sqrt(col(privateModel.getWeightCol)) - } + val w = if (weightCol.isEmpty) lit(1.0) else sqrt(col(weightCol)) val dr = predictions - .select(col(privateModel.getLabelCol).minus(col(privateModel.getPredictionCol)) - .multiply(weighted).as("weightedResiduals")) + .select(col(labelCol).minus(col(predictionCol)) + .multiply(w).as("weightedResiduals")) .select(min(col("weightedResiduals")).as("min"), max(col("weightedResiduals")).as("max")) .first() Array(dr.getDouble(0), dr.getDouble(1)) @@ -1046,15 +1062,14 @@ class LinearRegressionSummary private[regression] ( throw new UnsupportedOperationException( "No Std. Error of coefficients available for this LinearRegressionModel") } else { - val rss = - if (!privateModel.isDefined(privateModel.weightCol) || privateModel.getWeightCol.isEmpty) { + val rss = if (weightCol.isEmpty) { meanSquaredError * numInstances - } else { - val t = udf { (pred: Double, label: Double, weight: Double) => - math.pow(label - pred, 2.0) * weight } - predictions.select(t(col(privateModel.getPredictionCol), col(privateModel.getLabelCol), - col(privateModel.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0) - } + } else { + val t = udf { (pred: Double, label: Double, weight: Double) => + math.pow(label - pred, 2.0) * weight } + predictions.select(t(col(predictionCol), col(labelCol), + col(weightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0) + } val sigma2 = rss / degreesOfFreedom diagInvAtWA.map(_ * sigma2).map(math.sqrt) } @@ -1074,12 +1089,7 @@ class LinearRegressionSummary private[regression] ( throw new UnsupportedOperationException( "No t-statistic available for this LinearRegressionModel") } else { - val estimate = if (privateModel.getFitIntercept) { - Array.concat(privateModel.coefficients.toArray, Array(privateModel.intercept)) - } else { - privateModel.coefficients.toArray - } - estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 } + coefficientArray.zip(coefficientStandardErrors).map { x => x._1 / x._2 } } } @@ -1100,6 +1110,5 @@ class LinearRegressionSummary private[regression] ( tValues.map { x => 2.0 * (1.0 - StudentsT(degreesOfFreedom.toDouble).cdf(math.abs(x))) } } } - } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org