Repository: spark Updated Branches: refs/heads/branch-1.4 a038c5174 -> 91ce13109
[SPARK-7429] [ML] Params cleanups Params.setDefault taking a set of ParamPairs should be annotated with varargs. I thought it would not work before, but it apparently does. CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel CC: mengxr Author: Joseph K. Bradley <[email protected]> Closes #5960 from jkbradley/params-cleanups and squashes the following commits: 118b158 [Joseph K. Bradley] Params.setDefault taking a set of ParamPairs should be annotated with varargs. I thought it would not work before, but it apparently does. CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel (cherry picked from commit 4f87e9562aa0dfe5467d7fbaba9278213106377c) Signed-off-by: Xiangrui Meng <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/91ce1310 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/91ce1310 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/91ce1310 Branch: refs/heads/branch-1.4 Commit: 91ce13109b197dbfd396a8b38d58c332f8c1047d Parents: a038c51 Author: Joseph K. Bradley <[email protected]> Authored: Thu May 7 01:28:44 2015 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Thu May 7 01:28:59 2015 -0700 ---------------------------------------------------------------------- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 4 +--- .../main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 3 ++- .../src/test/java/org/apache/spark/ml/param/JavaTestParams.java | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/91ce1310/mllib/src/main/scala/org/apache/spark/ml/param/params.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 51ce19d..6d09962 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -366,13 +366,11 @@ trait Params extends Identifiable with Serializable { /** * Sets default values for a list of params. * - * Note: Java developers should use the single-parameter [[setDefault()]]. - * Annotating this with varargs causes compilation failures. - * * @param paramPairs a list of param pairs that specify params and their default values to set * respectively. Make sure that the params are initialized before this method * gets called. */ + @varargs protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => setDefault(p.param.asInstanceOf[Param[Any]], p.value) http://git-wip-us.apache.org/repos/asf/spark/blob/91ce1310/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 9208127..ac0d1fe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -105,7 +105,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema - transformSchema(dataset.schema, logging = true) + transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext val est = $(estimator) val eval = $(evaluator) @@ -159,6 +159,7 @@ class CrossValidatorModel private[ml] ( } override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } http://git-wip-us.apache.org/repos/asf/spark/blob/91ce1310/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 8abe575..532eca4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -59,5 +59,6 @@ public class JavaTestParams extends JavaParams { ParamValidators.inArray(validStrings)); setDefault(myIntParam, 1); setDefault(myDoubleParam, 0.5); + setDefault(myIntParam.w(1), myDoubleParam.w(0.5)); } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
