Repository: spark Updated Branches: refs/heads/master 6466d6c8a -> f5ebb18c4
[SPARK-14671][ML] Pipeline setStages should handle subclasses of PipelineStage ## What changes were proposed in this pull request? Pipeline.setStages failed for some code examples which worked in 1.5 but fail in 1.6. This tends to occur when using a mix of transformers from ml.feature. It is because Java Arrays are non-covariant and the addition of MLWritable to some transformers means the stages0/1 arrays above are not of type Array[PipelineStage]. This PR modifies the following to accept subclasses of PipelineStage: * Pipeline.setStages() * Params.w() ## How was this patch tested? Unit test which fails to compile before this fix. Author: Joseph K. Bradley <jos...@databricks.com> Closes #12430 from jkbradley/pipeline-setstages. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f5ebb18c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f5ebb18c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f5ebb18c Branch: refs/heads/master Commit: f5ebb18c45ffdee2756a80f64239cb9158df1a11 Parents: 6466d6c Author: Joseph K. Bradley <jos...@databricks.com> Authored: Wed Apr 27 16:11:12 2016 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Apr 27 16:11:12 2016 -0700 ---------------------------------------------------------------------- mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 5 ++++- .../src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 9 ++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f5ebb18c/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 8206672..b02aea9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -103,7 +103,10 @@ class Pipeline @Since("1.4.0") ( /** @group setParam */ @Since("1.2.0") - def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } + def setStages(value: Array[_ <: PipelineStage]): this.type = { + set(stages, value.asInstanceOf[Array[PipelineStage]]) + this + } // Below, we clone stages so that modifications to the list of stages will not change // the Param value in the Pipeline. http://git-wip-us.apache.org/repos/asf/spark/blob/f5ebb18c/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index a8c4ac6..1de638f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} -import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -201,6 +201,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul pipeline.fit(df) } } + + test("Pipeline.setStages should handle Java Arrays being non-covariant") { + val stages0 = Array(new UnWritableStage("b")) + val stages1 = Array(new WritableStage("a")) + val steps = stages0 ++ stages1 + val p = new Pipeline().setStages(steps) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org