Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20633#discussion_r175965272
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
    @@ -351,27 +359,88 @@ private[ml] object DefaultParamsReader {
           timestamp: Long,
           sparkVersion: String,
           params: JValue,
    +      defaultParams: JValue,
           metadata: JValue,
           metadataJson: String) {
     
    +
    +    private def getValueFromParams(params: JValue): Seq[(String, JValue)] 
= {
    +      params match {
    +        case JObject(pairs) => pairs
    +        case _ =>
    +          throw new IllegalArgumentException(
    +            s"Cannot recognize JSON metadata: $metadataJson.")
    +      }
    +    }
    +
         /**
          * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of 
the given name.
          * This can be useful for getting a Param value before an instance of 
`Params`
    -     * is available.
    +     * is available. This will look up `params` first, if not existing 
then looking up
    +     * `defaultParams`.
          */
         def getParamValue(paramName: String): JValue = {
           implicit val format = DefaultFormats
    -      params match {
    +
    +      // Looking up for `params` first.
    +      var pairs = getValueFromParams(params)
    +      var foundPairs = pairs.filter { case (pName, jsonValue) =>
    +        pName == paramName
    +      }
    +      if (foundPairs.length == 0) {
    +        // Looking up for `defaultParams` then.
    +        pairs = getValueFromParams(defaultParams)
    +        foundPairs = pairs.filter { case (pName, jsonValue) =>
    +          pName == paramName
    +        }
    +      }
    +      assert(foundPairs.length == 1, s"Expected one instance of Param 
'$paramName' but found" +
    +        s" ${foundPairs.length} in JSON Params: " + 
pairs.map(_.toString).mkString(", "))
    +
    +      foundPairs.map(_._2).head
    +    }
    +
    +    /**
    +     * Extract Params from metadata, and set them in the instance.
    +     * This works if all Params (except params included by `skipParams` 
list) implement
    +     * [[org.apache.spark.ml.param.Param.jsonDecode()]].
    +     *
    +     * @param skipParams The params included in `skipParams` won't be set. 
This is useful if some
    +     *                   params don't implement 
[[org.apache.spark.ml.param.Param.jsonDecode()]]
    +     *                   and need special handling.
    +     */
    +    def getAndSetParams(
    +        instance: Params,
    +        skipParams: Option[List[String]] = None): Unit = {
    +      setParams(instance, false, skipParams)
    +      setParams(instance, true, skipParams)
    +    }
    +
    +    private def setParams(
    +        instance: Params,
    +        isDefault: Boolean,
    +        skipParams: Option[List[String]]): Unit = {
    +      implicit val format = DefaultFormats
    +      val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion)
    +      val paramsToSet = if (isDefault) defaultParams else params
    +      paramsToSet match {
             case JObject(pairs) =>
    -          val values = pairs.filter { case (pName, jsonValue) =>
    -            pName == paramName
    -          }.map(_._2)
    -          assert(values.length == 1, s"Expected one instance of Param 
'$paramName' but found" +
    -            s" ${values.length} in JSON Params: " + 
pairs.map(_.toString).mkString(", "))
    -          values.head
    +          pairs.foreach { case (paramName, jsonValue) =>
    +            if (skipParams == None || !skipParams.get.contains(paramName)) 
{
    +              val param = instance.getParam(paramName)
    +              val value = param.jsonDecode(compact(render(jsonValue)))
    +              if (isDefault) {
    +                instance.setDefault(param, value)
    +              } else {
    +                instance.set(param, value)
    +              }
    +            }
    +          }
    +        // For metadata file prior to Spark 2.4, there is no default 
section.
    +        case JNothing if isDefault && (major == 2 && minor < 4 || major < 
2) =>
    --- End diff --
    
    This logic would be simpler if this check were put in the getAndSetParams 
method, which could just skip calling setParams(instance, true, skipParams) for 
Spark 2.3-.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to