Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19208#discussion_r138391134
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala ---
    @@ -150,20 +150,14 @@ private[ml] object ValidatorParams {
           }.toSeq
         ))
     
    -    val validatorSpecificParams = instance match {
    -      case cv: CrossValidatorParams =>
    -        List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
    -      case tvs: TrainValidationSplitParams =>
    -        List("trainRatio" -> 
parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
    -      case _ =>
    -        // This should not happen.
    -        throw new NotImplementedError("ValidatorParams.saveImpl does not 
handle type: " +
    -          instance.getClass.getCanonicalName)
    -    }
    -
    -    val jsonParams = validatorSpecificParams ++ List(
    -      "estimatorParamMaps" -> parse(estimatorParamMapsJson),
    -      "seed" -> parse(instance.seed.jsonEncode(instance.getSeed)))
    +    val params = instance.extractParamMap().toSeq
    +    val skipParams = List("estimator", "evaluator", "estimatorParamMaps")
    +    val jsonParams = render(params
    +      .filter { case ParamPair(p, v) => !skipParams.contains(p.name)}
    +      .map { case ParamPair(p, v) =>
    +        p.name -> parse(p.jsonEncode(v))
    +      }.toList ++ List("estimatorParamMaps" -> 
parse(estimatorParamMapsJson))
    +    )
    --- End diff --
    
    Improve code here. So that we don't need to add code for each parameter. 
Now we have 3 new added parameter: (parallelism, collectSubModels, 
persistSubModelPath), all added only in CV/TVS estimator. The old code here is 
easy to cause bugs if we forgot to update it when we add new params.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to