Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/19208#discussion_r138393318
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
@@ -212,14 +238,12 @@ object CrossValidator extends
MLReadable[CrossValidator] {
val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
- val numFolds = (metadata.params \ "numFolds").extract[Int]
- val seed = (metadata.params \ "seed").extract[Long]
- new CrossValidator(metadata.uid)
+ val cv = new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
- .setNumFolds(numFolds)
- .setSeed(seed)
+ DefaultParamsReader.getAndSetParams(cv, metadata, skipParams =
List("estimatorParamMaps"))
--- End diff --
Use `getAndSetParams` instead of setting all params manually. This simplify
code, and it can keep read/write compatibility.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]