Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/19208#discussion_r138389265
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
@@ -261,17 +290,40 @@ class CrossValidatorModel private[ml] (
val copied = new CrossValidatorModel(
uid,
bestModel.copy(extra).asInstanceOf[Model[_]],
- avgMetrics.clone())
+ avgMetrics.clone(),
+ CrossValidatorModel.copySubModels(subModels))
copyValues(copied, extra).setParent(parent)
}
@Since("1.6.0")
override def write: MLWriter = new
CrossValidatorModel.CrossValidatorModelWriter(this)
+
+ @Since("2.3.0")
+ @throws[IOException]("If the input path already exists but overwrite is
not enabled.")
+ def save(path: String, persistSubModels: Boolean): Unit = {
+ write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter]
+ .persistSubModels(persistSubModels).save(path)
+ }
--- End diff --
I add this method because the `CrossValidatorModelWriter` is private. User
cannot use it. But I don't know whether there is better solution.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]