Repository: spark Updated Branches: refs/heads/master 3e6a714c9 -> f180b6534
[SPARK-22060][ML] Fix CrossValidator/TrainValidationSplit param persist/load bug ## What changes were proposed in this pull request? Currently the param of CrossValidator/TrainValidationSplit persist/loading is hardcoding, which is different with other ML estimators. This cause persist bug for new added `parallelism` param. I refactor related code, avoid hardcoding persist/load param. And in the same time, it solve the `parallelism` persisting bug. This refactoring is very useful because we will add more new params in #19208 , hardcoding param persisting/loading making the thing adding new params very troublesome. ## How was this patch tested? Test added. Author: WeichenXu <weichen...@databricks.com> Closes #19278 from WeichenXu123/fix-tuning-param-bug. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f180b653 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f180b653 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f180b653 Branch: refs/heads/master Commit: f180b65343e706c60b995a3d46d0391612bda966 Parents: 3e6a714 Author: WeichenXu <weichen...@databricks.com> Authored: Fri Sep 22 18:15:01 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Fri Sep 22 18:15:01 2017 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/tuning/CrossValidator.scala | 17 +++++++-------- .../spark/ml/tuning/TrainValidationSplit.scala | 18 ++++++++-------- .../spark/ml/tuning/ValidatorParams.scala | 22 +++++++------------- .../org/apache/spark/ml/util/ReadWrite.scala | 20 +++++++++++++----- .../spark/ml/tuning/CrossValidatorSuite.scala | 3 +++ .../ml/tuning/TrainValidationSplitSuite.scala | 4 +++- 6 files changed, 46 insertions(+), 38 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index ce2a3a2..7c81cb9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -212,14 +212,13 @@ object CrossValidator extends MLReadable[CrossValidator] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val numFolds = (metadata.params \ "numFolds").extract[Int] - val seed = (metadata.params \ "seed").extract[Long] - new CrossValidator(metadata.uid) + val cv = new CrossValidator(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - .setNumFolds(numFolds) - .setSeed(seed) + DefaultParamsReader.getAndSetParams(cv, metadata, + skipParams = Option(List("estimatorParamMaps"))) + cv } } } @@ -302,17 +301,17 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val numFolds = (metadata.params \ "numFolds").extract[Int] - val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - .set(model.numFolds, numFolds) - .set(model.seed, seed) + DefaultParamsReader.getAndSetParams(model, metadata, + skipParams = Option(List("estimatorParamMaps"))) + model } } } http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 16db0f5..6e3ad40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tuning +import java.io.IOException import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -207,14 +208,13 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val trainRatio = (metadata.params \ "trainRatio").extract[Double] - val seed = (metadata.params \ "seed").extract[Long] - new TrainValidationSplit(metadata.uid) + val tvs = new TrainValidationSplit(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - .setTrainRatio(trainRatio) - .setSeed(seed) + DefaultParamsReader.getAndSetParams(tvs, metadata, + skipParams = Option(List("estimatorParamMaps"))) + tvs } } } @@ -295,17 +295,17 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val trainRatio = (metadata.params \ "trainRatio").extract[Double] - val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray + val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - .set(model.trainRatio, trainRatio) - .set(model.seed, seed) + DefaultParamsReader.getAndSetParams(model, metadata, + skipParams = Option(List("estimatorParamMaps"))) + model } } } http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 0ab6eed..363304e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -150,20 +150,14 @@ private[ml] object ValidatorParams { }.toSeq )) - val validatorSpecificParams = instance match { - case cv: CrossValidatorParams => - List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds))) - case tvs: TrainValidationSplitParams => - List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio))) - case _ => - // This should not happen. - throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " + - instance.getClass.getCanonicalName) - } - - val jsonParams = validatorSpecificParams ++ List( - "estimatorParamMaps" -> parse(estimatorParamMapsJson), - "seed" -> parse(instance.seed.jsonEncode(instance.getSeed))) + val params = instance.extractParamMap().toSeq + val skipParams = List("estimator", "evaluator", "estimatorParamMaps") + val jsonParams = render(params + .filter { case ParamPair(p, v) => !skipParams.contains(p.name)} + .map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson)) + ) DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 65f142c..7188da3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -396,17 +396,27 @@ private[ml] object DefaultParamsReader { /** * Extract Params from metadata, and set them in the instance. - * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * This works if all Params (except params included by `skipParams` list) implement + * [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * @param skipParams The params included in `skipParams` won't be set. This is useful if some + * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] + * and need special handling. * TODO: Move to [[Metadata]] method */ - def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + def getAndSetParams( + instance: Params, + metadata: Metadata, + skipParams: Option[List[String]] = None): Unit = { implicit val format = DefaultFormats metadata.params match { case JObject(pairs) => pairs.foreach { case (paramName, jsonValue) => - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) + if (skipParams == None || !skipParams.get.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } } case _ => throw new IllegalArgumentException( http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index a8d4377..a01744f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -159,12 +159,15 @@ class CrossValidatorSuite .setEvaluator(evaluator) .setNumFolds(20) .setEstimatorParamMaps(paramMaps) + .setSeed(42L) + .setParallelism(2) val cv2 = testDefaultReadWrite(cv, testParams = false) assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) assert(cv.getSeed === cv2.getSeed) + assert(cv.getParallelism === cv2.getParallelism) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] http://git-wip-us.apache.org/repos/asf/spark/blob/f180b653/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 7480173..2ed4fbb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.param.{ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -160,11 +160,13 @@ class TrainValidationSplitSuite .setTrainRatio(0.5) .setEstimatorParamMaps(paramMaps) .setSeed(42L) + .setParallelism(2) val tvs2 = testDefaultReadWrite(tvs, testParams = false) assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.getSeed === tvs2.getSeed) + assert(tvs.getParallelism === tvs2.getParallelism) ValidatorParamsSuiteHelpers .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org