Repository: spark Updated Branches: refs/heads/master 7c36ee46d -> 6ad8d4c37
[SPARK-25289][ML] Avoid exception in ChiSqSelector with FDR when no feature is selected ## What changes were proposed in this pull request? Currently, when FDR is used for `ChiSqSelector` and no feature is selected an exception is thrown because the max operation fails. The PR fixes the problem by handling this case and returning an empty array in that case, as sklearn (which was the reference for the initial implementation of FDR) does. ## How was this patch tested? added UT Closes #22303 from mgaido91/SPARK-25289. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Sean Owen <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ad8d4c3 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ad8d4c3 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ad8d4c3 Branch: refs/heads/master Commit: 6ad8d4c375772c0c907c25837de762b5b9266a8e Parents: 7c36ee4 Author: Marco Gaido <[email protected]> Authored: Sat Sep 1 08:41:07 2018 -0500 Committer: Sean Owen <[email protected]> Committed: Sat Sep 1 08:41:07 2018 -0500 ---------------------------------------------------------------------- .../org/apache/spark/mllib/feature/ChiSqSelector.scala | 12 ++++++++---- .../apache/spark/ml/feature/ChiSqSelectorSuite.scala | 11 +++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6ad8d4c3/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f923be8..aa78e91 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.test.ChiSqTestResult import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} @@ -272,13 +273,16 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { // https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure val tempRes = chiSqTestResult .sortBy { case (res, _) => res.pValue } - val maxIndex = tempRes + val selected = tempRes .zipWithIndex .filter { case ((res, _), index) => res.pValue <= fdr * (index + 1) / chiSqTestResult.length } - .map { case (_, index) => index } - .max - tempRes.take(maxIndex + 1) + if (selected.isEmpty) { + Array.empty[(ChiSqTestResult, Int)] + } else { + val maxIndex = selected.map(_._2).max + tempRes.take(maxIndex + 1) + } case ChiSqSelector.FWE => chiSqTestResult .filter { case (res, _) => res.pValue < fwe / chiSqTestResult.length } http://git-wip-us.apache.org/repos/asf/spark/blob/6ad8d4c3/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index c843df9..80499e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -163,6 +163,17 @@ class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest { } } + test("SPARK-25289: ChiSqSelector should not fail when selecting no features with FDR") { + val labeledPoints = (0 to 1).map { n => + val v = Vectors.dense((1 to 3).map(_ => n * 1.0).toArray) + (n.toDouble, v) + } + val inputDF = spark.createDataFrame(labeledPoints).toDF("label", "features") + val selector = new ChiSqSelector().setSelectorType("fdr").setFdr(0.05) + val model = selector.fit(inputDF) + assert(model.selectedFeatures.isEmpty) + } + private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = { val selectorModel = selector.fit(data) testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel, --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
