zhengruifeng commented on a change in pull request #27461: [SPARK-30736][ML] 
One-Pass ChiSquareTest
URL: https://github.com/apache/spark/pull/27461#discussion_r376918926
 
 

 ##########
 File path: 
mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
 ##########
 @@ -83,66 +81,161 @@ private[spark] object ChiSqTest extends Logging {
    */
   def chiSquaredFeatures(data: RDD[LabeledPoint],
       methodName: String = PEARSON.name): Array[ChiSqTestResult] = {
-    val numCols = data.first().features.size
-    val results = new Array[ChiSqTestResult](numCols)
-    var labels: Map[Double, Int] = null
-    // at most 1000 columns at a time
-    val batchSize = 1000
-    var batch = 0
-    while (batch * batchSize < numCols) {
-      // The following block of code can be cleaned up and made public as
-      // chiSquared(data: RDD[(V1, V2)])
-      val startCol = batch * batchSize
-      val endCol = startCol + math.min(batchSize, numCols - startCol)
-      val pairCounts = data.mapPartitions { iter =>
-        val distinctLabels = mutable.HashSet.empty[Double]
-        val allDistinctFeatures: Map[Int, mutable.HashSet[Double]] =
-          Map((startCol until endCol).map(col => (col, 
mutable.HashSet.empty[Double])): _*)
-        var i = 1
-        iter.flatMap { case LabeledPoint(label, features) =>
-          if (i % 1000 == 0) {
-            if (distinctLabels.size > maxCategories) {
-              throw new SparkException(s"Chi-square test expect factors 
(categorical values) but "
-                + s"found more than $maxCategories distinct label values.")
-            }
-            allDistinctFeatures.foreach { case (col, distinctFeatures) =>
-              if (distinctFeatures.size > maxCategories) {
-                throw new SparkException(s"Chi-square test expect factors 
(categorical values) but "
-                  + s"found more than $maxCategories distinct values in column 
$col.")
-              }
-            }
-          }
-          i += 1
-          distinctLabels += label
-          val brzFeatures = features.asBreeze
-          (startCol until endCol).map { col =>
-            val feature = brzFeatures(col)
-            allDistinctFeatures(col) += feature
-            (col, feature, label)
-          }
+    data.first().features match {
+      case dv: DenseVector =>
+        chiSquaredDenseFeatures(data, dv.size, methodName)
+      case sv: SparseVector =>
+        chiSquaredSparseFeatures(data, sv.size, methodName)
+    }
+  }
+
+  private def chiSquaredDenseFeatures(data: RDD[LabeledPoint],
+      numFeatures: Int,
+      methodName: String = PEARSON.name): Array[ChiSqTestResult] = {
+    data.flatMap { case LabeledPoint(label, features) =>
+      require(features.size == numFeatures)
+      features.iterator.map { case (col, value) =>
+        (col, (value, label))
+      }
+    }.aggregateByKey(new OpenHashMap[(Double, Double), Long])(
+      seqOp = { case (count, t) =>
+        count.changeValue(t, 1L, _ + 1L)
+        count
+      },
+      combOp = { case (count1, count2) =>
+        count2.iterator.foreach { case (t, c) =>
+          count1.changeValue(t, c, _ + c)
         }
-      }.countByValue()
+        count1
+      }
+    ).map { case (col, count) =>
+      val label2Index = 
count.iterator.map(_._1._2).toArray.distinct.sorted.zipWithIndex.toMap
+      val numLabels = label2Index.size
+      if (numLabels > maxCategories) {
+        throw new SparkException(s"Chi-square test expect factors (categorical 
values) but "
+          + s"found more than $maxCategories distinct label values.")
+      }
+
+      val value2Index = 
count.iterator.map(_._1._1).toArray.distinct.sorted.zipWithIndex.toMap
+      val numValues = value2Index.size
+      if (numValues > maxCategories) {
+        throw new SparkException(s"Chi-square test expect factors (categorical 
values) but "
+          + s"found more than $maxCategories distinct values in column $col.")
+      }
 
-      if (labels == null) {
-        // Do this only once for the first column since labels are invariant 
across features.
-        labels =
-          pairCounts.keys.filter(_._1 == 
startCol).map(_._3).toArray.distinct.zipWithIndex.toMap
+      val contingency = new DenseMatrix(numValues, numLabels,
+        Array.ofDim[Double](numValues * numLabels))
+      count.foreach { case ((value, label), c) =>
+        val i = value2Index(value)
+        val j = label2Index(label)
+        contingency.update(i, j, c)
       }
-      val numLabels = labels.size
-      pairCounts.keys.groupBy(_._1).foreach { case (col, keys) =>
-        val features = keys.map(_._2).toArray.distinct.zipWithIndex.toMap
-        val numRows = features.size
-        val contingency = new BDM(numRows, numLabels, new 
Array[Double](numRows * numLabels))
-        keys.foreach { case (_, feature, label) =>
-          val i = features(feature)
-          val j = labels(label)
-          contingency(i, j) += pairCounts((col, feature, label))
+
+      val result = ChiSqTest.chiSquaredMatrix(contingency, methodName)
+      (col, result.pValue, result.degreesOfFreedom, result.statistic, 
result.nullHypothesis)
+    }.collect().sortBy(_._1).map {
+      case (_, pValue, degreesOfFreedom, statistic, nullHypothesis) =>
+        new ChiSqTestResult(pValue, degreesOfFreedom, statistic, methodName, 
nullHypothesis)
+    }
+  }
+
+  private def chiSquaredSparseFeatures(data: RDD[LabeledPoint],
+      numFeatures: Int,
+      methodName: String = PEARSON.name): Array[ChiSqTestResult] = {
+    val labelCounts = data.map(_.label).countByValue()
+    val numLabels = labelCounts.size
+    if (numLabels > maxCategories) {
+      throw new SparkException(s"Chi-square test expect factors (categorical 
values) but "
+        + s"found more than $maxCategories distinct label values.")
+    }
+
+    val numInstances = labelCounts.valuesIterator.sum
+    val label2Index = labelCounts.keys.toArray.sorted.zipWithIndex.toMap
+
+    val sc = data.sparkContext
+    val bcLabels = sc.broadcast((labelCounts, label2Index))
+
+    val results = data.flatMap { case LabeledPoint(label, features) =>
+      require(features.size == numFeatures)
+      features.nonZeroIterator.map { case (col, value) =>
+        (col, (value, label))
+      }
+    }.aggregateByKey(new OpenHashMap[(Double, Double), Long])(
+      seqOp = { case (count, t) =>
+        count.changeValue(t, 1L, _ + 1L)
+        count
+      },
+      combOp = { case (count1, count2) =>
+        count2.iterator.foreach { case (t, c) =>
+          count1.changeValue(t, c, _ + c)
+        }
+        count1
+      }
+    ).map { case (col, count) =>
+      val (labelCounts, label2Index) = bcLabels.value
+      val nnz = count.iterator.map(_._2).sum
+      require(numInstances >= nnz)
+
+      val value2Index = if (numInstances == nnz) {
+        count.iterator.map(_._1._1).toArray.distinct.sorted.zipWithIndex.toMap
+      } else {
+        (count.iterator.map(_._1._1).toArray :+ 
0.0).distinct.sorted.zipWithIndex.toMap
+      }
+      val numValues = value2Index.size
+      if (numValues > maxCategories) {
+        throw new SparkException(s"Chi-square test expect factors (categorical 
values) but "
+          + s"found more than $maxCategories distinct values in column $col.")
+      }
+
+      val contingency = new DenseMatrix(numValues, numLabels,
+        Array.ofDim[Double](numValues * numLabels))
+      count.foreach { case ((value, label), c) =>
 
 Review comment:
   Yes

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to