Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20633#discussion_r175965272
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -351,27 +359,88 @@ private[ml] object DefaultParamsReader {
timestamp: Long,
sparkVersion: String,
params: JValue,
+ defaultParams: JValue,
metadata: JValue,
metadataJson: String) {
+
+ private def getValueFromParams(params: JValue): Seq[(String, JValue)]
= {
+ params match {
+ case JObject(pairs) => pairs
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Cannot recognize JSON metadata: $metadataJson.")
+ }
+ }
+
/**
* Get the JSON value of the [[org.apache.spark.ml.param.Param]] of
the given name.
* This can be useful for getting a Param value before an instance of
`Params`
- * is available.
+ * is available. This will look up `params` first, if not existing
then looking up
+ * `defaultParams`.
*/
def getParamValue(paramName: String): JValue = {
implicit val format = DefaultFormats
- params match {
+
+ // Looking up for `params` first.
+ var pairs = getValueFromParams(params)
+ var foundPairs = pairs.filter { case (pName, jsonValue) =>
+ pName == paramName
+ }
+ if (foundPairs.length == 0) {
+ // Looking up for `defaultParams` then.
+ pairs = getValueFromParams(defaultParams)
+ foundPairs = pairs.filter { case (pName, jsonValue) =>
+ pName == paramName
+ }
+ }
+ assert(foundPairs.length == 1, s"Expected one instance of Param
'$paramName' but found" +
+ s" ${foundPairs.length} in JSON Params: " +
pairs.map(_.toString).mkString(", "))
+
+ foundPairs.map(_._2).head
+ }
+
+ /**
+ * Extract Params from metadata, and set them in the instance.
+ * This works if all Params (except params included by `skipParams`
list) implement
+ * [[org.apache.spark.ml.param.Param.jsonDecode()]].
+ *
+ * @param skipParams The params included in `skipParams` won't be set.
This is useful if some
+ * params don't implement
[[org.apache.spark.ml.param.Param.jsonDecode()]]
+ * and need special handling.
+ */
+ def getAndSetParams(
+ instance: Params,
+ skipParams: Option[List[String]] = None): Unit = {
+ setParams(instance, false, skipParams)
+ setParams(instance, true, skipParams)
+ }
+
+ private def setParams(
+ instance: Params,
+ isDefault: Boolean,
+ skipParams: Option[List[String]]): Unit = {
+ implicit val format = DefaultFormats
+ val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion)
+ val paramsToSet = if (isDefault) defaultParams else params
+ paramsToSet match {
case JObject(pairs) =>
- val values = pairs.filter { case (pName, jsonValue) =>
- pName == paramName
- }.map(_._2)
- assert(values.length == 1, s"Expected one instance of Param
'$paramName' but found" +
- s" ${values.length} in JSON Params: " +
pairs.map(_.toString).mkString(", "))
- values.head
+ pairs.foreach { case (paramName, jsonValue) =>
+ if (skipParams == None || !skipParams.get.contains(paramName))
{
+ val param = instance.getParam(paramName)
+ val value = param.jsonDecode(compact(render(jsonValue)))
+ if (isDefault) {
+ instance.setDefault(param, value)
+ } else {
+ instance.set(param, value)
+ }
+ }
+ }
+ // For metadata file prior to Spark 2.4, there is no default
section.
+ case JNothing if isDefault && (major == 2 && minor < 4 || major <
2) =>
--- End diff --
This logic would be simpler if this check were put in the getAndSetParams
method, which could just skip calling setParams(instance, true, skipParams) for
Spark 2.3-.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]