Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/19208#discussion_r148886008
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
@@ -282,12 +328,40 @@ object CrossValidatorModel extends
MLReadable[CrossValidatorModel] {
ValidatorParams.validateParams(instance)
+ protected var shouldPersistSubModels: Boolean = if
(instance.hasSubModels) true else false
+
+ /**
+ * Extra options for CrossValidatorModelWriter, current support
"persistSubModels".
+ * if sub models exsit, the default value for option
"persistSubModels" is "true".
+ */
+ @Since("2.3.0")
+ override def option(key: String, value: String): this.type = {
+ key.toLowerCase(Locale.ROOT) match {
+ case "persistsubmodels" => shouldPersistSubModels = value.toBoolean
+ case _ => throw new IllegalArgumentException(
+ s"Illegal option ${key} for CrossValidatorModelWriter")
+ }
+ this
+ }
+
override protected def saveImpl(path: String): Unit = {
import org.json4s.JsonDSL._
- val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
+ val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~
+ ("shouldPersistSubModels" -> shouldPersistSubModels)
ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+ if (shouldPersistSubModels) {
+ require(instance.hasSubModels, "Cannot get sub models to persist.")
--- End diff --
This error message may be unclear. How about adding: "When persisting
tuning models, you can only set persistSubModels to true if the tuning was done
with collectSubModels set to true. To save the sub-models, try rerunning
fitting with collectSubModels set to true."
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]