Github user smurching commented on a diff in the pull request:
https://github.com/apache/spark/pull/19208#discussion_r139578700
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala ---
@@ -276,12 +315,32 @@ object TrainValidationSplitModel extends
MLReadable[TrainValidationSplitModel] {
ValidatorParams.validateParams(instance)
+ protected var shouldPersistSubModels: Boolean = false
+
+ /**
+ * Set option for persist sub models.
+ */
+ @Since("2.3.0")
+ def persistSubModels(persist: Boolean): this.type = {
+ shouldPersistSubModels = persist
+ this
+ }
+
override protected def saveImpl(path: String): Unit = {
import org.json4s.JsonDSL._
- val extraMetadata = "validationMetrics" ->
instance.validationMetrics.toSeq
+ val extraMetadata = ("validationMetrics" ->
instance.validationMetrics.toSeq) ~
+ ("shouldPersistSubModels" -> shouldPersistSubModels)
ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+ if (shouldPersistSubModels) {
+ require(instance.subModels != null, "Cannot get sub models to
persist.")
+ val subModelsPath = new Path(path, "subModels")
+ for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
+ val modelPath = new Path(subModelsPath,
paramIndex.toString).toString
+
instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath)
--- End diff --
Should we clean up/remove the partially-persisted `subModels` if any of
these `save()` calls fail? E.g. let's say we have four subModels and the first
three `save()` calls succeed but the fourth fails - should we delete the
folders for the first three submodels?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]