Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/19208#discussion_r138391134
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala ---
@@ -150,20 +150,14 @@ private[ml] object ValidatorParams {
}.toSeq
))
- val validatorSpecificParams = instance match {
- case cv: CrossValidatorParams =>
- List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
- case tvs: TrainValidationSplitParams =>
- List("trainRatio" ->
parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
- case _ =>
- // This should not happen.
- throw new NotImplementedError("ValidatorParams.saveImpl does not
handle type: " +
- instance.getClass.getCanonicalName)
- }
-
- val jsonParams = validatorSpecificParams ++ List(
- "estimatorParamMaps" -> parse(estimatorParamMapsJson),
- "seed" -> parse(instance.seed.jsonEncode(instance.getSeed)))
+ val params = instance.extractParamMap().toSeq
+ val skipParams = List("estimator", "evaluator", "estimatorParamMaps")
+ val jsonParams = render(params
+ .filter { case ParamPair(p, v) => !skipParams.contains(p.name)}
+ .map { case ParamPair(p, v) =>
+ p.name -> parse(p.jsonEncode(v))
+ }.toList ++ List("estimatorParamMaps" ->
parse(estimatorParamMapsJson))
+ )
--- End diff --
Improve code here. So that we don't need to add code for each parameter.
Now we have 3 new added parameter: (parallelism, collectSubModels,
persistSubModelPath), all added only in CV/TVS estimator. The old code here is
easy to cause bugs if we forgot to update it when we add new params.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]