This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 568f92017fe1 [SPARK-52051][ML][CONNECT] Enable model summary when 
memory control is enabled
568f92017fe1 is described below

commit 568f92017fe1b96a14b3a86b4f07154e2cffa159
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Fri May 9 22:33:41 2025 +0800

    [SPARK-52051][ML][CONNECT] Enable model summary when memory control is 
enabled
    
    ### What changes were proposed in this pull request?
    
    Enable model summary in SparkConnect when memory control is enabled.
    
    ### Why are the changes needed?
    
    Motivation:
    model summary is necessary in many use-cases.
    
    although it hasn't support offloading, we can still enable it. User can use 
the summary object within the offloading timeout.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50843 from WeichenXu123/spark-connect-enable-summary.
    
    Lead-authored-by: Weichen Xu <weichen...@databricks.com>
    Co-authored-by: WeichenXu <weichen...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 .../src/main/resources/error/error-conditions.json |  6 ++-
 .../spark/ml/classification/FMClassifier.scala     | 21 +++++------
 .../apache/spark/ml/classification/LinearSVC.scala | 21 +++++------
 .../ml/classification/LogisticRegression.scala     | 43 ++++++++++------------
 .../MultilayerPerceptronClassifier.scala           | 19 ++++------
 .../ml/classification/RandomForestClassifier.scala | 35 ++++++++----------
 .../spark/ml/clustering/BisectingKMeans.scala      | 23 +++++-------
 .../spark/ml/clustering/GaussianMixture.scala      | 13 +++----
 .../org/apache/spark/ml/clustering/KMeans.scala    | 22 +++++------
 .../regression/GeneralizedLinearRegression.scala   | 18 +++------
 .../spark/ml/regression/LinearRegression.scala     | 19 ++++------
 .../apache/spark/ml/util/HasTrainingSummary.scala  |  6 ---
 .../pyspark/ml/tests/connect/test_parity_tuning.py | 10 -----
 python/pyspark/testing/connectutils.py             |  3 --
 .../apache/spark/sql/connect/config/Connect.scala  | 10 ++---
 .../org/apache/spark/sql/connect/ml/MLCache.scala  | 31 ++--------------
 .../apache/spark/sql/connect/ml/MLHandler.scala    | 13 +------
 .../org/apache/spark/sql/connect/ml/MLSuite.scala  | 35 +-----------------
 18 files changed, 116 insertions(+), 232 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 2b55f8094472..ad36523c7343 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -826,7 +826,9 @@
       },
       "CACHE_INVALID" : {
         "message" : [
-          "Cannot retrieve <objectName> from the ML cache. It is probably 
because the entry has been evicted."
+          "Cannot retrieve Summary object <objectName> from the ML cache.",
+          "The Summary object is evicted if it hasn't been used a specified 
period of time.",
+          "You can configure the timeout by setting Spark cluster configure 
'spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout'."
         ]
       },
       "ML_CACHE_SIZE_OVERFLOW_EXCEPTION" : {
@@ -834,7 +836,7 @@
           "The model cache size in current session is about to exceed",
           "<mlCacheMaxSize> bytes.",
           "Please delete existing cached model by executing 'del model' in 
python client before fitting new model or loading new model,",
-          "or increase Spark config 
'spark.connect.session.connectML.mlCache.memoryControl.maxSize'."
+          "or increase Spark config 
'spark.connect.session.connectML.mlCache.memoryControl.maxStorageSize'."
         ]
       },
       "MODEL_SIZE_OVERFLOW_EXCEPTION" : {
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
index c815174b8c42..222cfbb80c3d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
@@ -224,18 +224,15 @@ class FMClassifier @Since("3.0.0") (
     val model = copyValues(new FMClassificationModel(uid, intercept, linear, 
factors))
     val weightColName = if (!isDefined(weightCol)) "weightCol" else 
$(weightCol)
 
-    if (SummaryUtils.enableTrainingSummary) {
-      val (summaryModel, probabilityColName, predictionColName) = 
model.findSummaryModel()
-      val summary = new FMClassificationTrainingSummaryImpl(
-        summaryModel.transform(dataset),
-        probabilityColName,
-        predictionColName,
-        $(labelCol),
-        weightColName,
-        objectiveHistory)
-      model.setSummary(Some(summary))
-    }
-    model
+    val (summaryModel, probabilityColName, predictionColName) = 
model.findSummaryModel()
+    val summary = new FMClassificationTrainingSummaryImpl(
+      summaryModel.transform(dataset),
+      probabilityColName,
+      predictionColName,
+      $(labelCol),
+      weightColName,
+      objectiveHistory)
+    model.setSummary(Some(summary))
   }
 
   @Since("3.0.0")
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 6a3658537c3c..c5d1170318f7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -277,18 +277,15 @@ class LinearSVC @Since("2.2.0") (
     val model = copyValues(new LinearSVCModel(uid, coefficients, intercept))
     val weightColName = if (!isDefined(weightCol)) "weightCol" else 
$(weightCol)
 
-    if (SummaryUtils.enableTrainingSummary) {
-      val (summaryModel, rawPredictionColName, predictionColName) = 
model.findSummaryModel()
-      val summary = new LinearSVCTrainingSummaryImpl(
-        summaryModel.transform(dataset),
-        rawPredictionColName,
-        predictionColName,
-        $(labelCol),
-        weightColName,
-        objectiveHistory)
-      model.setSummary(Some(summary))
-    }
-    model
+    val (summaryModel, rawPredictionColName, predictionColName) = 
model.findSummaryModel()
+    val summary = new LinearSVCTrainingSummaryImpl(
+      summaryModel.transform(dataset),
+      rawPredictionColName,
+      predictionColName,
+      $(labelCol),
+      weightColName,
+      objectiveHistory)
+    model.setSummary(Some(summary))
   }
 
   private def trainImpl(
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index fd9bc6f0131f..d09cacf3fb5b 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -711,30 +711,27 @@ class LogisticRegression @Since("1.2.0") (
       numClasses, checkMultinomial(numClasses)))
     val weightColName = if (!isDefined(weightCol)) "weightCol" else 
$(weightCol)
 
-    if (SummaryUtils.enableTrainingSummary) {
-      val (summaryModel, probabilityColName, predictionColName) = 
model.findSummaryModel()
-      val logRegSummary = if (numClasses <= 2) {
-        new BinaryLogisticRegressionTrainingSummaryImpl(
-          summaryModel.transform(dataset),
-          probabilityColName,
-          predictionColName,
-          $(labelCol),
-          $(featuresCol),
-          weightColName,
-          objectiveHistory)
-      } else {
-        new LogisticRegressionTrainingSummaryImpl(
-          summaryModel.transform(dataset),
-          probabilityColName,
-          predictionColName,
-          $(labelCol),
-          $(featuresCol),
-          weightColName,
-          objectiveHistory)
-      }
-      model.setSummary(Some(logRegSummary))
+    val (summaryModel, probabilityColName, predictionColName) = 
model.findSummaryModel()
+    val logRegSummary = if (numClasses <= 2) {
+      new BinaryLogisticRegressionTrainingSummaryImpl(
+        summaryModel.transform(dataset),
+        probabilityColName,
+        predictionColName,
+        $(labelCol),
+        $(featuresCol),
+        weightColName,
+        objectiveHistory)
+    } else {
+      new LogisticRegressionTrainingSummaryImpl(
+        summaryModel.transform(dataset),
+        probabilityColName,
+        predictionColName,
+        $(labelCol),
+        $(featuresCol),
+        weightColName,
+        objectiveHistory)
     }
-    model
+    model.setSummary(Some(logRegSummary))
   }
 
   private def createBounds(
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 5f163c43b402..2359749f8b48 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -249,17 +249,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
       objectiveHistory: Array[Double]): 
MultilayerPerceptronClassificationModel = {
     val model = copyValues(new MultilayerPerceptronClassificationModel(uid, 
weights))
 
-    if (SummaryUtils.enableTrainingSummary) {
-      val (summaryModel, _, predictionColName) = model.findSummaryModel()
-      val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
-        summaryModel.transform(dataset),
-        predictionColName,
-        $(labelCol),
-        "",
-        objectiveHistory)
-      model.setSummary(Some(summary))
-    }
-    model
+    val (summaryModel, _, predictionColName) = model.findSummaryModel()
+    val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
+      summaryModel.transform(dataset),
+      predictionColName,
+      $(labelCol),
+      "",
+      objectiveHistory)
+    model.setSummary(Some(summary))
   }
 }
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 391d61a9793f..f64e2a6d4efc 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -185,26 +185,23 @@ class RandomForestClassifier @Since("1.4.0") (
     val weightColName = if (!isDefined(weightCol)) "weightCol" else 
$(weightCol)
 
     val (summaryModel, probabilityColName, predictionColName) = 
model.findSummaryModel()
-    if (SummaryUtils.enableTrainingSummary) {
-      val rfSummary = if (numClasses <= 2) {
-        new BinaryRandomForestClassificationTrainingSummaryImpl(
-          summaryModel.transform(dataset),
-          probabilityColName,
-          predictionColName,
-          $(labelCol),
-          weightColName,
-          Array(0.0))
-      } else {
-        new RandomForestClassificationTrainingSummaryImpl(
-          summaryModel.transform(dataset),
-          predictionColName,
-          $(labelCol),
-          weightColName,
-          Array(0.0))
-      }
-      model.setSummary(Some(rfSummary))
+    val rfSummary = if (numClasses <= 2) {
+      new BinaryRandomForestClassificationTrainingSummaryImpl(
+        summaryModel.transform(dataset),
+        probabilityColName,
+        predictionColName,
+        $(labelCol),
+        weightColName,
+        Array(0.0))
+    } else {
+      new RandomForestClassificationTrainingSummaryImpl(
+        summaryModel.transform(dataset),
+        predictionColName,
+        $(labelCol),
+        weightColName,
+        Array(0.0))
     }
-    model
+    model.setSummary(Some(rfSummary))
   }
 
   @Since("1.4.1")
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index d63a0663690e..f1cd126a6406 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -303,19 +303,16 @@ class BisectingKMeans @Since("2.0.0") (
     val parentModel = bkm.runWithWeight(instances, handlePersistence, 
Some(instr))
     val model = copyValues(new BisectingKMeansModel(uid, 
parentModel).setParent(this))
 
-    if (SummaryUtils.enableTrainingSummary) {
-      val summary = new BisectingKMeansSummary(
-        model.transform(dataset),
-        $(predictionCol),
-        $(featuresCol),
-        $(k),
-        $(maxIter),
-        parentModel.trainingCost)
-      instr.logNamedValue("clusterSizes", summary.clusterSizes)
-      instr.logNumFeatures(model.clusterCenters.head.size)
-      model.setSummary(Some(summary))
-    }
-    model
+    val summary = new BisectingKMeansSummary(
+      model.transform(dataset),
+      $(predictionCol),
+      $(featuresCol),
+      $(k),
+      $(maxIter),
+      parentModel.trainingCost)
+    instr.logNamedValue("clusterSizes", summary.clusterSizes)
+    instr.logNumFeatures(model.clusterCenters.head.size)
+    model.setSummary(Some(summary))
   }
 
   @Since("2.0.0")
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index e937313af95d..ee0b19f8129d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -430,14 +430,11 @@ class GaussianMixture @Since("2.0.0") (
 
     val model = copyValues(new GaussianMixtureModel(uid, weights, 
gaussianDists))
       .setParent(this)
-    if (SummaryUtils.enableTrainingSummary) {
-      val summary = new GaussianMixtureSummary(model.transform(dataset),
-        $(predictionCol), $(probabilityCol), $(featuresCol), $(k), 
logLikelihood, iteration)
-      instr.logNamedValue("logLikelihood", logLikelihood)
-      instr.logNamedValue("clusterSizes", summary.clusterSizes)
-      model.setSummary(Some(summary))
-    }
-    model
+    val summary = new GaussianMixtureSummary(model.transform(dataset),
+      $(predictionCol), $(probabilityCol), $(featuresCol), $(k), 
logLikelihood, iteration)
+    instr.logNamedValue("logLikelihood", logLikelihood)
+    instr.logNamedValue("clusterSizes", summary.clusterSizes)
+    model.setSummary(Some(summary))
   }
 
   private def trainImpl(
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 9dcb43a36120..e87dc9eb040b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -391,18 +391,16 @@ class KMeans @Since("1.5.0") (
     }
 
     val model = copyValues(new KMeansModel(uid, oldModel).setParent(this))
-    if (SummaryUtils.enableTrainingSummary) {
-      val summary = new KMeansSummary(
-        model.transform(dataset),
-        $(predictionCol),
-        $(featuresCol),
-        $(k),
-        oldModel.numIter,
-        oldModel.trainingCost)
-
-      model.setSummary(Some(summary))
-      instr.logNamedValue("clusterSizes", summary.clusterSizes)
-    }
+    val summary = new KMeansSummary(
+      model.transform(dataset),
+      $(predictionCol),
+      $(featuresCol),
+      $(k),
+      oldModel.numIter,
+      oldModel.trainingCost)
+
+    model.setSummary(Some(summary))
+    instr.logNamedValue("clusterSizes", summary.clusterSizes)
     model
   }
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index d1def5952396..777b70e7d021 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -418,12 +418,9 @@ class GeneralizedLinearRegression @Since("2.0.0") 
(@Since("2.0.0") override val
       val model = copyValues(
         new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, 
wlsModel.intercept)
           .setParent(this))
-      if (SummaryUtils.enableTrainingSummary) {
-        val trainingSummary = new 
GeneralizedLinearRegressionTrainingSummary(dataset, model,
-          wlsModel.diagInvAtWA.toArray, 1, getSolver)
-        model.setSummary(Some(trainingSummary))
-      }
-      model
+      val trainingSummary = new 
GeneralizedLinearRegressionTrainingSummary(dataset, model,
+        wlsModel.diagInvAtWA.toArray, 1, getSolver)
+      model.setSummary(Some(trainingSummary))
     } else {
       val instances = validated.rdd.map {
         case Row(label: Double, weight: Double, offset: Double, features: 
Vector) =>
@@ -438,12 +435,9 @@ class GeneralizedLinearRegression @Since("2.0.0") 
(@Since("2.0.0") override val
       val model = copyValues(
         new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, 
irlsModel.intercept)
           .setParent(this))
-      if (SummaryUtils.enableTrainingSummary) {
-        val trainingSummary = new 
GeneralizedLinearRegressionTrainingSummary(dataset, model,
-          irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
-        model.setSummary(Some(trainingSummary))
-      }
-      model
+      val trainingSummary = new 
GeneralizedLinearRegressionTrainingSummary(dataset, model,
+        irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
+      model.setSummary(Some(trainingSummary))
     }
 
     model
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index f313786cf598..847115eb02b1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -432,17 +432,14 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
 
     val model = createModel(parameters, yMean, yStd, featuresMean, featuresStd)
 
-    if (SummaryUtils.enableTrainingSummary) {
-      // Handle possible missing or invalid prediction columns
-      val (summaryModel, predictionColName) = 
model.findSummaryModelAndPredictionCol()
-      val trainingSummary = new LinearRegressionTrainingSummary(
-        summaryModel.transform(dataset), predictionColName, $(labelCol), 
$(featuresCol),
-        summaryModel.get(summaryModel.weightCol).getOrElse(""),
-        summaryModel.numFeatures, summaryModel.getFitIntercept,
-        Array(0.0), objectiveHistory)
-      model.setSummary(Some(trainingSummary))
-    }
-    model
+    // Handle possible missing or invalid prediction columns
+    val (summaryModel, predictionColName) = 
model.findSummaryModelAndPredictionCol()
+    val trainingSummary = new LinearRegressionTrainingSummary(
+      summaryModel.transform(dataset), predictionColName, $(labelCol), 
$(featuresCol),
+      summaryModel.get(summaryModel.weightCol).getOrElse(""),
+      summaryModel.numFeatures, summaryModel.getFitIntercept,
+      Array(0.0), objectiveHistory)
+    model.setSummary(Some(trainingSummary))
   }
 
   private def trainWithNormal(
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala
index a6d09c2e9142..0ba8ce072ab4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala
@@ -50,9 +50,3 @@ private[spark] trait HasTrainingSummary[T] {
     this
   }
 }
-
-private[spark] object SummaryUtils {
-
-  // This flag is only used by Spark Connect
-  private[spark] var enableTrainingSummary: Boolean = true
-}
diff --git a/python/pyspark/ml/tests/connect/test_parity_tuning.py 
b/python/pyspark/ml/tests/connect/test_parity_tuning.py
index 7e098b9b909e..2d21644ceed5 100644
--- a/python/pyspark/ml/tests/connect/test_parity_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_parity_tuning.py
@@ -25,16 +25,6 @@ class TuningParityTests(TuningTestsMixin, 
ReusedConnectTestCase):
     pass
 
 
-class TuningParityWithMLCacheOffloadingEnabledTests(TuningTestsMixin, 
ReusedConnectTestCase):
-    @classmethod
-    def conf(cls):
-        conf = super().conf()
-        
conf.set("spark.connect.session.connectML.mlCache.memoryControl.enabled", 
"true")
-        
conf.set("spark.connect.session.connectML.mlCache.memoryControl.maxInMemorySize",
 "1024")
-        
conf.set("spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout",
 "1")
-        return conf
-
-
 if __name__ == "__main__":
     from pyspark.ml.tests.connect.test_parity_tuning import *  # noqa: F401
 
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index 667f2b9ada06..5e2d6ae8724b 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -158,9 +158,6 @@ class ReusedConnectTestCase(unittest.TestCase, 
SQLTestUtils, PySparkErrorTestUti
         # Set a static token for all tests so the parallelism doesn't 
overwrite each
         # tests' environment variables
         conf.set("spark.connect.authenticate.token", "deadbeef")
-        # Disable ml cache offloading,
-        # offloading hasn't supported APIs like model.summary / model.evaluate
-        
conf.set("spark.connect.session.connectML.mlCache.memoryControl.enabled", 
"false")
         return conf
 
     @classmethod
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 33feca5b0c94..5fe62295d1a5 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -352,8 +352,8 @@ object Connect {
       .version("4.1.0")
       .internal()
       .bytesConf(ByteUnit.BYTE)
-      // By default, 1/3 of total designated memory (the configured -Xmx).
-      .createWithDefault(Runtime.getRuntime.maxMemory() / 3)
+      // By default, 1/4 of total designated memory (the configured -Xmx).
+      .createWithDefault(Runtime.getRuntime.maxMemory() / 4)
 
   val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_OFFLOADING_TIMEOUT =
     
buildConf("spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout")
@@ -363,7 +363,7 @@ object Connect {
       .version("4.1.0")
       .internal()
       .timeConf(TimeUnit.MINUTES)
-      .createWithDefault(5)
+      .createWithDefault(15)
 
   val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE =
     
buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxModelSize")
@@ -373,8 +373,8 @@ object Connect {
       .bytesConf(ByteUnit.BYTE)
       .createWithDefaultString("1g")
 
-  val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_SIZE =
-    buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxSize")
+  val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_STORAGE_SIZE =
+    
buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxStorageSize")
       .doc("Maximum total size (including in-memory and offloaded data) of the 
ml cache. " +
         "The size is in bytes.")
       .version("4.1.0")
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
index d59525a6190c..e2d6d0a9f6dd 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
@@ -85,14 +85,10 @@ private[connect] class MLCache(sessionHolder: 
SessionHolder) extends Logging {
     }
   }
 
-  private[ml] val cachedSummary: ConcurrentMap[String, Summary] = {
-    new ConcurrentHashMap[String, Summary]()
-  }
-
   private[ml] val totalMLCacheSizeBytes: AtomicLong = new AtomicLong(0)
   private[spark] def getMLCacheMaxSize: Long = {
     sessionHolder.session.conf.get(
-      Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_SIZE)
+      Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_STORAGE_SIZE)
   }
   private[spark] def getModelMaxSize: Long = {
     sessionHolder.session.conf.get(
@@ -118,17 +114,6 @@ private[connect] class MLCache(sessionHolder: 
SessionHolder) extends Logging {
     }
   }
 
-  private[spark] def checkSummaryAvail(): Unit = {
-    if (getMemoryControlEnabled) {
-      throw MlUnsupportedException(
-        "SparkML 'model.summary' and 'model.evaluate' APIs are not supported' 
when " +
-          "Spark Connect session ML cache offloading is enabled. You can use 
APIs in " +
-          "'pyspark.ml.evaluation' instead, or you can set Spark config " +
-          "'spark.connect.session.connectML.mlCache.memoryControl.enabled' to 
'false' to " +
-          "disable Spark Connect session ML cache offloading.")
-    }
-  }
-
   /**
    * Cache an object into a map of MLCache, and return its key
    * @param obj
@@ -140,8 +125,7 @@ private[connect] class MLCache(sessionHolder: 
SessionHolder) extends Logging {
     val objectId = UUID.randomUUID().toString
 
     if (obj.isInstanceOf[Summary]) {
-      checkSummaryAvail()
-      cachedSummary.put(objectId, obj.asInstanceOf[Summary])
+      cachedModel.put(objectId, CacheItem(obj, 0))
     } else if (obj.isInstanceOf[Model[_]]) {
       val sizeBytes = if (getMemoryControlEnabled) {
         val _sizeBytes = estimateObjectSize(obj)
@@ -175,8 +159,7 @@ private[connect] class MLCache(sessionHolder: 
SessionHolder) extends Logging {
     if (refId == helperID) {
       helper
     } else {
-      var obj: Object =
-        
Option(cachedModel.get(refId)).map(_.obj).getOrElse(cachedSummary.get(refId))
+      var obj: Object = 
Option(cachedModel.get(refId)).map(_.obj).getOrElse(null)
       if (obj == null && getMemoryControlEnabled) {
         val loadPath = offloadedModelsDir.resolve(refId)
         if (Files.isDirectory(loadPath)) {
@@ -221,12 +204,7 @@ private[connect] class MLCache(sessionHolder: 
SessionHolder) extends Logging {
   def remove(refId: String): Boolean = {
     val modelIsRemoved = _removeModel(refId)
 
-    if (modelIsRemoved) {
-      true
-    } else {
-      val removedSummary = cachedSummary.remove(refId)
-      removedSummary != null
-    }
+    modelIsRemoved
   }
 
   /**
@@ -235,7 +213,6 @@ private[connect] class MLCache(sessionHolder: 
SessionHolder) extends Logging {
   def clear(): Int = {
     val size = cachedModel.size()
     cachedModel.clear()
-    cachedSummary.clear()
     if (getMemoryControlEnabled) {
       FileUtils.cleanDirectory(new File(offloadedModelsDir.toString))
     }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 3afc15c97a66..204c874060cc 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.Model
 import org.apache.spark.ml.param.{ParamMap, Params}
 import org.apache.spark.ml.tree.TreeConfig
-import org.apache.spark.ml.util.{MLWritable, Summary, SummaryUtils}
+import org.apache.spark.ml.util.{MLWritable, Summary}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
 import org.apache.spark.sql.connect.ml.Serializer.deserializeMethodArguments
@@ -54,9 +54,6 @@ private class AttributeHelper(
   def getAttribute: Any = {
     assert(methods.length >= 1)
     methods.foldLeft(instance()) { (obj, m) =>
-      if (obj.isInstanceOf[Summary]) {
-        sessionHolder.mlCache.checkSummaryAvail()
-      }
       if (m.argValues.isEmpty) {
         MLUtils.invokeMethodAllowed(obj, m.name)
       } else {
@@ -122,11 +119,6 @@ private[connect] object MLHandler extends Logging {
     val mlCache = sessionHolder.mlCache
     val memoryControlEnabled = sessionHolder.mlCache.getMemoryControlEnabled
 
-    // Disable model training summary when memory control is enabled
-    // because training summary can't support
-    // size estimation and offloading.
-    SummaryUtils.enableTrainingSummary = !memoryControlEnabled
-
     if (memoryControlEnabled) {
       val maxModelSize = sessionHolder.mlCache.getModelMaxSize
 
@@ -249,9 +241,6 @@ private[connect] object MLHandler extends Logging {
           case proto.MlCommand.Write.TypeCase.OBJ_REF => // save a model
             val objId = mlCommand.getWrite.getObjRef.getId
             val model = mlCache.get(objId).asInstanceOf[Model[_]]
-            if (model == null) {
-              throw MLCacheInvalidException(s"model $objId")
-            }
             val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
             MLUtils.setInstanceParams(copiedModel, 
mlCommand.getWrite.getParams)
 
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index fcb721b51e64..bdc094e7b2b7 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -135,8 +135,6 @@ class MLSuite extends MLHelper {
   // Estimator/Model works
   test("LogisticRegression works") {
     val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
-    sessionHolder.session.conf
-      
.set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED.key, 
"false")
 
     // estimator read/write
     val ret = readWrite(sessionHolder, getLogisticRegression, getMaxIter)
@@ -259,37 +257,6 @@ class MLSuite extends MLHelper {
     }
   }
 
-  test("Exception: cannot retrieve object") {
-    val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
-    sessionHolder.session.conf
-      
.set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED.key, 
"false")
-    val modelId = trainLogisticRegressionModel(sessionHolder)
-
-    // Fetch summary attribute
-    val accuracyCommand = proto.MlCommand
-      .newBuilder()
-      .setFetch(
-        proto.Fetch
-          .newBuilder()
-          .setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
-          .addMethods(proto.Fetch.Method.newBuilder().setMethod("summary"))
-          .addMethods(proto.Fetch.Method.newBuilder().setMethod("accuracy")))
-      .build()
-
-    // Successfully fetch summary.accuracy from the cached model
-    MLHandler.handleMlCommand(sessionHolder, accuracyCommand)
-
-    // Remove the model from cache
-    sessionHolder.mlCache.clear()
-
-    // No longer able to retrieve the model from cache
-    val e = intercept[MLCacheInvalidException] {
-      MLHandler.handleMlCommand(sessionHolder, accuracyCommand)
-    }
-    val msg = e.getMessage
-    assert(msg.contains(s"$modelId from the ML cache"))
-  }
-
   test("access the attribute which is not in allowed list") {
     val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
     val modelId = trainLogisticRegressionModel(sessionHolder)
@@ -469,7 +436,7 @@ class MLSuite extends MLHelper {
     sessionHolder.session.conf
       
.set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE.key,
 "8000")
     sessionHolder.session.conf
-      
.set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_SIZE.key, 
"10000")
+      
.set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_STORAGE_SIZE.key,
 "10000")
     trainLogisticRegressionModel(sessionHolder)
     intercept[MLCacheSizeOverflowException] {
       trainLogisticRegressionModel(sessionHolder)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to