This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 568f92017fe1 [SPARK-52051][ML][CONNECT] Enable model summary when memory control is enabled 568f92017fe1 is described below commit 568f92017fe1b96a14b3a86b4f07154e2cffa159 Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Fri May 9 22:33:41 2025 +0800 [SPARK-52051][ML][CONNECT] Enable model summary when memory control is enabled ### What changes were proposed in this pull request? Enable model summary in SparkConnect when memory control is enabled. ### Why are the changes needed? Motivation: model summary is necessary in many use-cases. although it hasn't support offloading, we can still enable it. User can use the summary object within the offloading timeout. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50843 from WeichenXu123/spark-connect-enable-summary. Lead-authored-by: Weichen Xu <weichen...@databricks.com> Co-authored-by: WeichenXu <weichen...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> --- .../src/main/resources/error/error-conditions.json | 6 ++- .../spark/ml/classification/FMClassifier.scala | 21 +++++------ .../apache/spark/ml/classification/LinearSVC.scala | 21 +++++------ .../ml/classification/LogisticRegression.scala | 43 ++++++++++------------ .../MultilayerPerceptronClassifier.scala | 19 ++++------ .../ml/classification/RandomForestClassifier.scala | 35 ++++++++---------- .../spark/ml/clustering/BisectingKMeans.scala | 23 +++++------- .../spark/ml/clustering/GaussianMixture.scala | 13 +++---- .../org/apache/spark/ml/clustering/KMeans.scala | 22 +++++------ .../regression/GeneralizedLinearRegression.scala | 18 +++------ .../spark/ml/regression/LinearRegression.scala | 19 ++++------ .../apache/spark/ml/util/HasTrainingSummary.scala | 6 --- .../pyspark/ml/tests/connect/test_parity_tuning.py | 10 ----- python/pyspark/testing/connectutils.py | 3 -- .../apache/spark/sql/connect/config/Connect.scala | 10 ++--- .../org/apache/spark/sql/connect/ml/MLCache.scala | 31 ++-------------- .../apache/spark/sql/connect/ml/MLHandler.scala | 13 +------ .../org/apache/spark/sql/connect/ml/MLSuite.scala | 35 +----------------- 18 files changed, 116 insertions(+), 232 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 2b55f8094472..ad36523c7343 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -826,7 +826,9 @@ }, "CACHE_INVALID" : { "message" : [ - "Cannot retrieve <objectName> from the ML cache. It is probably because the entry has been evicted." + "Cannot retrieve Summary object <objectName> from the ML cache.", + "The Summary object is evicted if it hasn't been used a specified period of time.", + "You can configure the timeout by setting Spark cluster configure 'spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout'." ] }, "ML_CACHE_SIZE_OVERFLOW_EXCEPTION" : { @@ -834,7 +836,7 @@ "The model cache size in current session is about to exceed", "<mlCacheMaxSize> bytes.", "Please delete existing cached model by executing 'del model' in python client before fitting new model or loading new model,", - "or increase Spark config 'spark.connect.session.connectML.mlCache.memoryControl.maxSize'." + "or increase Spark config 'spark.connect.session.connectML.mlCache.memoryControl.maxStorageSize'." ] }, "MODEL_SIZE_OVERFLOW_EXCEPTION" : { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index c815174b8c42..222cfbb80c3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -224,18 +224,15 @@ class FMClassifier @Since("3.0.0") ( val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors)) val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - if (SummaryUtils.enableTrainingSummary) { - val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() - val summary = new FMClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - weightColName, - objectiveHistory) - model.setSummary(Some(summary)) - } - model + val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() + val summary = new FMClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + weightColName, + objectiveHistory) + model.setSummary(Some(summary)) } @Since("3.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 6a3658537c3c..c5d1170318f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -277,18 +277,15 @@ class LinearSVC @Since("2.2.0") ( val model = copyValues(new LinearSVCModel(uid, coefficients, intercept)) val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - if (SummaryUtils.enableTrainingSummary) { - val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel() - val summary = new LinearSVCTrainingSummaryImpl( - summaryModel.transform(dataset), - rawPredictionColName, - predictionColName, - $(labelCol), - weightColName, - objectiveHistory) - model.setSummary(Some(summary)) - } - model + val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel() + val summary = new LinearSVCTrainingSummaryImpl( + summaryModel.transform(dataset), + rawPredictionColName, + predictionColName, + $(labelCol), + weightColName, + objectiveHistory) + model.setSummary(Some(summary)) } private def trainImpl( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index fd9bc6f0131f..d09cacf3fb5b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -711,30 +711,27 @@ class LogisticRegression @Since("1.2.0") ( numClasses, checkMultinomial(numClasses))) val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - if (SummaryUtils.enableTrainingSummary) { - val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() - val logRegSummary = if (numClasses <= 2) { - new BinaryLogisticRegressionTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - $(featuresCol), - weightColName, - objectiveHistory) - } else { - new LogisticRegressionTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - $(featuresCol), - weightColName, - objectiveHistory) - } - model.setSummary(Some(logRegSummary)) + val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() + val logRegSummary = if (numClasses <= 2) { + new BinaryLogisticRegressionTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + $(featuresCol), + weightColName, + objectiveHistory) + } else { + new LogisticRegressionTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + $(featuresCol), + weightColName, + objectiveHistory) } - model + model.setSummary(Some(logRegSummary)) } private def createBounds( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 5f163c43b402..2359749f8b48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -249,17 +249,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( objectiveHistory: Array[Double]): MultilayerPerceptronClassificationModel = { val model = copyValues(new MultilayerPerceptronClassificationModel(uid, weights)) - if (SummaryUtils.enableTrainingSummary) { - val (summaryModel, _, predictionColName) = model.findSummaryModel() - val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - "", - objectiveHistory) - model.setSummary(Some(summary)) - } - model + val (summaryModel, _, predictionColName) = model.findSummaryModel() + val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + "", + objectiveHistory) + model.setSummary(Some(summary)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 391d61a9793f..f64e2a6d4efc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -185,26 +185,23 @@ class RandomForestClassifier @Since("1.4.0") ( val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() - if (SummaryUtils.enableTrainingSummary) { - val rfSummary = if (numClasses <= 2) { - new BinaryRandomForestClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - weightColName, - Array(0.0)) - } else { - new RandomForestClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - weightColName, - Array(0.0)) - } - model.setSummary(Some(rfSummary)) + val rfSummary = if (numClasses <= 2) { + new BinaryRandomForestClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + weightColName, + Array(0.0)) + } else { + new RandomForestClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + weightColName, + Array(0.0)) } - model + model.setSummary(Some(rfSummary)) } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index d63a0663690e..f1cd126a6406 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -303,19 +303,16 @@ class BisectingKMeans @Since("2.0.0") ( val parentModel = bkm.runWithWeight(instances, handlePersistence, Some(instr)) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) - if (SummaryUtils.enableTrainingSummary) { - val summary = new BisectingKMeansSummary( - model.transform(dataset), - $(predictionCol), - $(featuresCol), - $(k), - $(maxIter), - parentModel.trainingCost) - instr.logNamedValue("clusterSizes", summary.clusterSizes) - instr.logNumFeatures(model.clusterCenters.head.size) - model.setSummary(Some(summary)) - } - model + val summary = new BisectingKMeansSummary( + model.transform(dataset), + $(predictionCol), + $(featuresCol), + $(k), + $(maxIter), + parentModel.trainingCost) + instr.logNamedValue("clusterSizes", summary.clusterSizes) + instr.logNumFeatures(model.clusterCenters.head.size) + model.setSummary(Some(summary)) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index e937313af95d..ee0b19f8129d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -430,14 +430,11 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)) .setParent(this) - if (SummaryUtils.enableTrainingSummary) { - val summary = new GaussianMixtureSummary(model.transform(dataset), - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration) - instr.logNamedValue("logLikelihood", logLikelihood) - instr.logNamedValue("clusterSizes", summary.clusterSizes) - model.setSummary(Some(summary)) - } - model + val summary = new GaussianMixtureSummary(model.transform(dataset), + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration) + instr.logNamedValue("logLikelihood", logLikelihood) + instr.logNamedValue("clusterSizes", summary.clusterSizes) + model.setSummary(Some(summary)) } private def trainImpl( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 9dcb43a36120..e87dc9eb040b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -391,18 +391,16 @@ class KMeans @Since("1.5.0") ( } val model = copyValues(new KMeansModel(uid, oldModel).setParent(this)) - if (SummaryUtils.enableTrainingSummary) { - val summary = new KMeansSummary( - model.transform(dataset), - $(predictionCol), - $(featuresCol), - $(k), - oldModel.numIter, - oldModel.trainingCost) - - model.setSummary(Some(summary)) - instr.logNamedValue("clusterSizes", summary.clusterSizes) - } + val summary = new KMeansSummary( + model.transform(dataset), + $(predictionCol), + $(featuresCol), + $(k), + oldModel.numIter, + oldModel.trainingCost) + + model.setSummary(Some(summary)) + instr.logNamedValue("clusterSizes", summary.clusterSizes) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index d1def5952396..777b70e7d021 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -418,12 +418,9 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) - if (SummaryUtils.enableTrainingSummary) { - val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, - wlsModel.diagInvAtWA.toArray, 1, getSolver) - model.setSummary(Some(trainingSummary)) - } - model + val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, + wlsModel.diagInvAtWA.toArray, 1, getSolver) + model.setSummary(Some(trainingSummary)) } else { val instances = validated.rdd.map { case Row(label: Double, weight: Double, offset: Double, features: Vector) => @@ -438,12 +435,9 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) - if (SummaryUtils.enableTrainingSummary) { - val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, - irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) - model.setSummary(Some(trainingSummary)) - } - model + val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, + irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) + model.setSummary(Some(trainingSummary)) } model diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f313786cf598..847115eb02b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -432,17 +432,14 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val model = createModel(parameters, yMean, yStd, featuresMean, featuresStd) - if (SummaryUtils.enableTrainingSummary) { - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), - summaryModel.get(summaryModel.weightCol).getOrElse(""), - summaryModel.numFeatures, summaryModel.getFitIntercept, - Array(0.0), objectiveHistory) - model.setSummary(Some(trainingSummary)) - } - model + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), + summaryModel.get(summaryModel.weightCol).getOrElse(""), + summaryModel.numFeatures, summaryModel.getFitIntercept, + Array(0.0), objectiveHistory) + model.setSummary(Some(trainingSummary)) } private def trainWithNormal( diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala index a6d09c2e9142..0ba8ce072ab4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala @@ -50,9 +50,3 @@ private[spark] trait HasTrainingSummary[T] { this } } - -private[spark] object SummaryUtils { - - // This flag is only used by Spark Connect - private[spark] var enableTrainingSummary: Boolean = true -} diff --git a/python/pyspark/ml/tests/connect/test_parity_tuning.py b/python/pyspark/ml/tests/connect/test_parity_tuning.py index 7e098b9b909e..2d21644ceed5 100644 --- a/python/pyspark/ml/tests/connect/test_parity_tuning.py +++ b/python/pyspark/ml/tests/connect/test_parity_tuning.py @@ -25,16 +25,6 @@ class TuningParityTests(TuningTestsMixin, ReusedConnectTestCase): pass -class TuningParityWithMLCacheOffloadingEnabledTests(TuningTestsMixin, ReusedConnectTestCase): - @classmethod - def conf(cls): - conf = super().conf() - conf.set("spark.connect.session.connectML.mlCache.memoryControl.enabled", "true") - conf.set("spark.connect.session.connectML.mlCache.memoryControl.maxInMemorySize", "1024") - conf.set("spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout", "1") - return conf - - if __name__ == "__main__": from pyspark.ml.tests.connect.test_parity_tuning import * # noqa: F401 diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 667f2b9ada06..5e2d6ae8724b 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -158,9 +158,6 @@ class ReusedConnectTestCase(unittest.TestCase, SQLTestUtils, PySparkErrorTestUti # Set a static token for all tests so the parallelism doesn't overwrite each # tests' environment variables conf.set("spark.connect.authenticate.token", "deadbeef") - # Disable ml cache offloading, - # offloading hasn't supported APIs like model.summary / model.evaluate - conf.set("spark.connect.session.connectML.mlCache.memoryControl.enabled", "false") return conf @classmethod diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 33feca5b0c94..5fe62295d1a5 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -352,8 +352,8 @@ object Connect { .version("4.1.0") .internal() .bytesConf(ByteUnit.BYTE) - // By default, 1/3 of total designated memory (the configured -Xmx). - .createWithDefault(Runtime.getRuntime.maxMemory() / 3) + // By default, 1/4 of total designated memory (the configured -Xmx). + .createWithDefault(Runtime.getRuntime.maxMemory() / 4) val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_OFFLOADING_TIMEOUT = buildConf("spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout") @@ -363,7 +363,7 @@ object Connect { .version("4.1.0") .internal() .timeConf(TimeUnit.MINUTES) - .createWithDefault(5) + .createWithDefault(15) val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE = buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxModelSize") @@ -373,8 +373,8 @@ object Connect { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("1g") - val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_SIZE = - buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxSize") + val CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_STORAGE_SIZE = + buildConf("spark.connect.session.connectML.mlCache.memoryControl.maxStorageSize") .doc("Maximum total size (including in-memory and offloaded data) of the ml cache. " + "The size is in bytes.") .version("4.1.0") diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index d59525a6190c..e2d6d0a9f6dd 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -85,14 +85,10 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } } - private[ml] val cachedSummary: ConcurrentMap[String, Summary] = { - new ConcurrentHashMap[String, Summary]() - } - private[ml] val totalMLCacheSizeBytes: AtomicLong = new AtomicLong(0) private[spark] def getMLCacheMaxSize: Long = { sessionHolder.session.conf.get( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_SIZE) + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_STORAGE_SIZE) } private[spark] def getModelMaxSize: Long = { sessionHolder.session.conf.get( @@ -118,17 +114,6 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } } - private[spark] def checkSummaryAvail(): Unit = { - if (getMemoryControlEnabled) { - throw MlUnsupportedException( - "SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " + - "Spark Connect session ML cache offloading is enabled. You can use APIs in " + - "'pyspark.ml.evaluation' instead, or you can set Spark config " + - "'spark.connect.session.connectML.mlCache.memoryControl.enabled' to 'false' to " + - "disable Spark Connect session ML cache offloading.") - } - } - /** * Cache an object into a map of MLCache, and return its key * @param obj @@ -140,8 +125,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { val objectId = UUID.randomUUID().toString if (obj.isInstanceOf[Summary]) { - checkSummaryAvail() - cachedSummary.put(objectId, obj.asInstanceOf[Summary]) + cachedModel.put(objectId, CacheItem(obj, 0)) } else if (obj.isInstanceOf[Model[_]]) { val sizeBytes = if (getMemoryControlEnabled) { val _sizeBytes = estimateObjectSize(obj) @@ -175,8 +159,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { if (refId == helperID) { helper } else { - var obj: Object = - Option(cachedModel.get(refId)).map(_.obj).getOrElse(cachedSummary.get(refId)) + var obj: Object = Option(cachedModel.get(refId)).map(_.obj).getOrElse(null) if (obj == null && getMemoryControlEnabled) { val loadPath = offloadedModelsDir.resolve(refId) if (Files.isDirectory(loadPath)) { @@ -221,12 +204,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { def remove(refId: String): Boolean = { val modelIsRemoved = _removeModel(refId) - if (modelIsRemoved) { - true - } else { - val removedSummary = cachedSummary.remove(refId) - removedSummary != null - } + modelIsRemoved } /** @@ -235,7 +213,6 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { def clear(): Int = { val size = cachedModel.size() cachedModel.clear() - cachedSummary.clear() if (getMemoryControlEnabled) { FileUtils.cleanDirectory(new File(offloadedModelsDir.toString)) } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 3afc15c97a66..204c874060cc 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.Model import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.ml.tree.TreeConfig -import org.apache.spark.ml.util.{MLWritable, Summary, SummaryUtils} +import org.apache.spark.ml.util.{MLWritable, Summary} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connect.common.LiteralValueProtoConverter import org.apache.spark.sql.connect.ml.Serializer.deserializeMethodArguments @@ -54,9 +54,6 @@ private class AttributeHelper( def getAttribute: Any = { assert(methods.length >= 1) methods.foldLeft(instance()) { (obj, m) => - if (obj.isInstanceOf[Summary]) { - sessionHolder.mlCache.checkSummaryAvail() - } if (m.argValues.isEmpty) { MLUtils.invokeMethodAllowed(obj, m.name) } else { @@ -122,11 +119,6 @@ private[connect] object MLHandler extends Logging { val mlCache = sessionHolder.mlCache val memoryControlEnabled = sessionHolder.mlCache.getMemoryControlEnabled - // Disable model training summary when memory control is enabled - // because training summary can't support - // size estimation and offloading. - SummaryUtils.enableTrainingSummary = !memoryControlEnabled - if (memoryControlEnabled) { val maxModelSize = sessionHolder.mlCache.getModelMaxSize @@ -249,9 +241,6 @@ private[connect] object MLHandler extends Logging { case proto.MlCommand.Write.TypeCase.OBJ_REF => // save a model val objId = mlCommand.getWrite.getObjRef.getId val model = mlCache.get(objId).asInstanceOf[Model[_]] - if (model == null) { - throw MLCacheInvalidException(s"model $objId") - } val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]] MLUtils.setInstanceParams(copiedModel, mlCommand.getWrite.getParams) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index fcb721b51e64..bdc094e7b2b7 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -135,8 +135,6 @@ class MLSuite extends MLHelper { // Estimator/Model works test("LogisticRegression works") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) - sessionHolder.session.conf - .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED.key, "false") // estimator read/write val ret = readWrite(sessionHolder, getLogisticRegression, getMaxIter) @@ -259,37 +257,6 @@ class MLSuite extends MLHelper { } } - test("Exception: cannot retrieve object") { - val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) - sessionHolder.session.conf - .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_ENABLED.key, "false") - val modelId = trainLogisticRegressionModel(sessionHolder) - - // Fetch summary attribute - val accuracyCommand = proto.MlCommand - .newBuilder() - .setFetch( - proto.Fetch - .newBuilder() - .setObjRef(proto.ObjectRef.newBuilder().setId(modelId)) - .addMethods(proto.Fetch.Method.newBuilder().setMethod("summary")) - .addMethods(proto.Fetch.Method.newBuilder().setMethod("accuracy"))) - .build() - - // Successfully fetch summary.accuracy from the cached model - MLHandler.handleMlCommand(sessionHolder, accuracyCommand) - - // Remove the model from cache - sessionHolder.mlCache.clear() - - // No longer able to retrieve the model from cache - val e = intercept[MLCacheInvalidException] { - MLHandler.handleMlCommand(sessionHolder, accuracyCommand) - } - val msg = e.getMessage - assert(msg.contains(s"$modelId from the ML cache")) - } - test("access the attribute which is not in allowed list") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) val modelId = trainLogisticRegressionModel(sessionHolder) @@ -469,7 +436,7 @@ class MLSuite extends MLHelper { sessionHolder.session.conf .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_MODEL_SIZE.key, "8000") sessionHolder.session.conf - .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_SIZE.key, "10000") + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MEMORY_CONTROL_MAX_STORAGE_SIZE.key, "10000") trainLogisticRegressionModel(sessionHolder) intercept[MLCacheSizeOverflowException] { trainLogisticRegressionModel(sessionHolder) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org