Repository: spark Updated Branches: refs/heads/master a9b93e073 -> 208fff3ac
[SPARK-14164][MLLIB] Improve input layer validation of MultilayerPerceptronClassifier ## What changes were proposed in this pull request? This issue improves an input layer validation and adds related testcases to MultilayerPerceptronClassifier. ```scala - // TODO: how to check ALSO that all elements are greater than 0? - ParamValidators.arrayLengthGt(1) + (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1 ``` ## How was this patch tested? Pass the Jenkins tests including the new testcases. Author: Dongjoon Hyun <[email protected]> Closes #11964 from dongjoon-hyun/SPARK-14164. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/208fff3a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/208fff3a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/208fff3a Branch: refs/heads/master Commit: 208fff3ac87f200fd4e6f0407d70bf81cf8c556f Parents: a9b93e0 Author: Dongjoon Hyun <[email protected]> Authored: Thu Mar 31 09:39:15 2016 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Thu Mar 31 09:39:15 2016 -0700 ---------------------------------------------------------------------- .../MultilayerPerceptronClassifier.scala | 3 +-- .../MultilayerPerceptronClassifierSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/208fff3a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index f6de5f2..7ce3ec6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -43,8 +43,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams "Sizes of layers from input layer to output layer" + " E.g., Array(780, 100, 10) means 780 inputs, " + "one hidden layer with 100 neurons and output layer of 10 neurons.", - // TODO: how to check ALSO that all elements are greater than 0? - ParamValidators.arrayLengthGt(1) + (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1 ) /** @group getParam */ http://git-wip-us.apache.org/repos/asf/spark/blob/208fff3a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 5df8e6a..53c7a55 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -43,6 +43,23 @@ class MultilayerPerceptronClassifierSuite ).toDF("features", "label") } + test("Input Validation") { + val mlpc = new MultilayerPerceptronClassifier() + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int]()) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](1)) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](0, 1)) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](1, 0)) + } + mlpc.setLayers(Array[Int](1, 1)) + } + test("XOR function learning as binary classification problem with two outputs.") { val layers = Array[Int](2, 5, 2) val trainer = new MultilayerPerceptronClassifier() --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
