zhengruifeng commented on a change in pull request #26803: [SPARK-30178][ML]
RobustScaler support large numFeatures
URL: https://github.com/apache/spark/pull/26803#discussion_r355827074
##########
File path: mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
##########
@@ -147,49 +146,43 @@ class RobustScaler (override val uid: String)
override def fit(dataset: Dataset[_]): RobustScalerModel = {
transformSchema(dataset.schema, logging = true)
- val localRelativeError = $(relativeError)
- val summaries = dataset.select($(inputCol)).rdd.map {
- case Row(vec: Vector) => vec
- }.mapPartitions { iter =>
- var agg: Array[QuantileSummaries] = null
- while (iter.hasNext) {
- val vec = iter.next()
- if (agg == null) {
- agg = Array.fill(vec.size)(
- new QuantileSummaries(QuantileSummaries.defaultCompressThreshold,
localRelativeError))
- }
- require(vec.size == agg.length,
- s"Number of dimensions must be ${agg.length} but got ${vec.size}")
- var i = 0
- while (i < vec.size) {
- agg(i) = agg(i).insert(vec(i))
- i += 1
- }
- }
-
- if (agg == null) {
- Iterator.empty
- } else {
- Iterator.single(agg.map(_.compress))
- }
- }.treeReduce { (agg1, agg2) =>
- require(agg1.length == agg2.length)
- var i = 0
- while (i < agg1.length) {
- agg1(i) = agg1(i).merge(agg2(i))
- i += 1
- }
- agg1
- }
+ val vectors = dataset.select($(inputCol)).rdd.map { case Row(vec: Vector)
=> vec }
+ val numFeatures = MetadataUtils.getNumFeatures(dataset.schema($(inputCol)))
+ .getOrElse(vectors.first().size)
- val (range, median) = summaries.map { s =>
- (s.query($(upper)).get - s.query($(lower)).get,
- s.query(0.5).get)
- }.unzip
+ val localRelativeError = $(relativeError)
+ val localUpper = $(upper)
+ val localLower = $(lower)
+
+ val collected = vectors.flatMap { vec =>
+ require(vec.size == numFeatures,
+ s"Number of dimensions must be $numFeatures but got ${vec.size}")
+ Iterator.range(0, numFeatures).map { i => (i, vec(i)) }
+ }.aggregateByKey(
+ new QuantileSummaries(QuantileSummaries.defaultCompressThreshold,
localRelativeError))(
+ seqOp = (s, v) => s.insert(v),
+ combOp = (s1, s2) => s1.compress.merge(s2.compress)
Review comment:
op `merge` needs the two `QuantileSummaries`s both are compressed.
```scala
scala> var s1 = new
QuantileSummaries(QuantileSummaries.defaultCompressThreshold, 0.0001)
s1: org.apache.spark.sql.catalyst.util.QuantileSummaries =
org.apache.spark.sql.catalyst.util.QuantileSummaries@6bafe3b0
scala> var s2 = new
QuantileSummaries(QuantileSummaries.defaultCompressThreshold, 0.0001)
s2: org.apache.spark.sql.catalyst.util.QuantileSummaries =
org.apache.spark.sql.catalyst.util.QuantileSummaries@17c7cc93
scala> s1 = s1.insert(1.0)
s1: org.apache.spark.sql.catalyst.util.QuantileSummaries =
org.apache.spark.sql.catalyst.util.QuantileSummaries@6bafe3b0
scala> s1 = s1.insert(2.0)
s1: org.apache.spark.sql.catalyst.util.QuantileSummaries =
org.apache.spark.sql.catalyst.util.QuantileSummaries@6bafe3b0
scala> s2 = s2.insert(3.0)
s2: org.apache.spark.sql.catalyst.util.QuantileSummaries =
org.apache.spark.sql.catalyst.util.QuantileSummaries@17c7cc93
scala> s2 = s2.insert(4.0)
s2: org.apache.spark.sql.catalyst.util.QuantileSummaries =
org.apache.spark.sql.catalyst.util.QuantileSummaries@17c7cc93
scala> s1.merge(s2)
java.lang.IllegalArgumentException: requirement failed: Current buffer needs
to be compressed before merge
at scala.Predef$.require(Predef.scala:224)
at
org.apache.spark.sql.catalyst.util.QuantileSummaries.merge(QuantileSummaries.scala:154)
... 49 elided
scala> s1.merge(s2.compress)
compress compressThreshold compressed
scala> s1.merge(s2.compressed)
<console>:29: error: type mismatch;
found : Boolean
required: org.apache.spark.sql.catalyst.util.QuantileSummaries
s1.merge(s2.compressed)
^
scala> s1.merge(s2.compress)
java.lang.IllegalArgumentException: requirement failed: Current buffer needs
to be compressed before merge
at scala.Predef$.require(Predef.scala:224)
at
org.apache.spark.sql.catalyst.util.QuantileSummaries.merge(QuantileSummaries.scala:154)
... 49 elided
scala> s1.compress.merge(s2)
compress compressThreshold compressed
scala> s1.compress.merge(s2)
java.lang.IllegalArgumentException: requirement failed: Other buffer needs
to be compressed before merge
at scala.Predef$.require(Predef.scala:224)
at
org.apache.spark.sql.catalyst.util.QuantileSummaries.merge(QuantileSummaries.scala:155)
... 49 elided
scala> s1.compress.merge(s2.compress)
compress compressThreshold compressed
scala> s1.compress.merge(s2.compress)
res7: org.apache.spark.sql.catalyst.util.QuantileSummaries =
org.apache.spark.sql.catalyst.util.QuantileSummaries@722b4f64
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]