zhengruifeng commented on a change in pull request #27979:
[SPARK-31138][ML][FOLLOWUP] ANOVA optimization
URL: https://github.com/apache/spark/pull/27979#discussion_r396079565
##########
File path: mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala
##########
@@ -75,81 +74,68 @@ object ANOVATest {
dataset: Dataset[_],
featuresCol: String,
labelCol: String): Array[SelectionTestResult] = {
-
val spark = dataset.sparkSession
import spark.implicits._
SchemaUtils.checkColumnType(dataset.schema, featuresCol, new VectorUDT)
SchemaUtils.checkNumericType(dataset.schema, labelCol)
- val numFeatures = MetadataUtils.getNumFeatures(dataset, featuresCol)
- val Row(numSamples: Long, numClasses: Long) =
- dataset.select(count(labelCol), countDistinct(labelCol)).head
-
dataset.select(col(labelCol).cast("double"), col(featuresCol))
.as[(Double, Vector)]
.rdd
.flatMap { case (label, features) =>
- features.iterator.map { case (col, value) => (col, (label, value,
value * value)) }
+ features.iterator.map { case (col, value) => (col, (label, value)) }
}.aggregateByKey[(Double, Double, OpenHashMap[Double, Double],
OpenHashMap[Double, Long])](
(0.0, 0.0, new OpenHashMap[Double, Double], new OpenHashMap[Double,
Long]))(
seqOp = {
- case (
+ case ((sum, sumOfSq, sums, counts), (label, value)) =>
// sums: mapOfSumPerClass (key: label, value: sum of features for
each label)
// counts: mapOfCountPerClass key: label, value: count of features
for each label
- (sum: Double, sumOfSq: Double, sums, counts),
- (label, feature, featureSq)
- ) =>
- sums.changeValue(label, feature, _ + feature)
+ sums.changeValue(label, value, _ + value)
counts.changeValue(label, 1L, _ + 1L)
- (sum + feature, sumOfSq + featureSq, sums, counts)
+ (sum + value, sumOfSq + value * value, sums, counts)
},
combOp = {
- case (
- (sum1, sumOfSq1, sums1, counts1),
- (sum2, sumOfSq2, sums2, counts2)
- ) =>
- sums2.foreach { case (v, w) =>
- sums1.changeValue(v, w, _ + w)
- }
- counts2.foreach { case (v, w) =>
- counts1.changeValue(v, w, _ + w)
- }
+ case ((sum1, sumOfSq1, sums1, counts1), (sum2, sumOfSq2, sums2,
counts2)) =>
+ sums2.foreach { case (v, w) => sums1.changeValue(v, w, _ + w) }
+ counts2.foreach { case (v, w) => counts1.changeValue(v, w, _ + w) }
(sum1 + sum2, sumOfSq1 + sumOfSq2, sums1, counts1)
}
- ).map {
- case (col, (sum, sumOfSq, sums, counts)) =>
- // e.g. features are [3.3, 2.5, 1.0, 3.0, 2.0] and labels are [1,
2, 1, 3, 3]
- // sum: sum of all the features (3.3+2.5+1.0+3.0+2.0)
- // sumOfSq: sum of squares of all the features
(3.3^2+2.5^2+1.0^2+3.0^2+2.0^2)
- // sums: mapOfSumPerClass (key: label, value: sum of features for
each label)
- // ( 1 -> 3.3 + 1.0, 2 ->
2.5, 3 -> 3.0 + 2.0 )
- // counts: mapOfCountPerClass (key: label, value: count of
features for each label)
- // ( 1 -> 2, 2 -> 2, 3 ->
2 )
- // sqSum: square of sum of all data ((3.3+2.5+1.0+3.0+2.0)^2)
- val sqSum = sum * sum
- val ssTot = sumOfSq - sqSum / numSamples
+ ).map { case (col, (sum, sumOfSq, sums, counts)) =>
+ val numSamples = counts.iterator.map(_._2).sum
+ val numClasses = counts.size
Review comment:
directly get `numSamples` and `numClasses` here
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]