viirya commented on a change in pull request #20146: [SPARK-11215][ML] Add
multiple columns support to StringIndexer
URL: https://github.com/apache/spark/pull/20146#discussion_r244529348
##########
File path: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
##########
@@ -421,3 +603,49 @@ object IndexToString extends
DefaultParamsReadable[IndexToString] {
@Since("1.6.0")
override def load(path: String): IndexToString = super.load(path)
}
+
+/**
+ * A SQL `Aggregator` used by `StringIndexer` to count labels in string
columns during fitting.
+ */
+private class StringIndexerAggregator(numColumns: Int, inputColTypes:
Seq[DataType])
+ extends Aggregator[Row, Array[OpenHashMap[String, Long]],
Array[OpenHashMap[String, Long]]] {
+
+ override def zero: Array[OpenHashMap[String, Long]] =
+ Array.fill(numColumns)(new OpenHashMap[String, Long]())
+
+ def reduce(
+ array: Array[OpenHashMap[String, Long]],
+ row: Row): Array[OpenHashMap[String, Long]] = {
+ for (i <- 0 until numColumns) {
+ val stringValue = row.getString(i)
+ // We don't count for null and NaN values.
Review comment:
You're right and doing the filter after the string conversion is unreliable
after rethinking. I changed it now so NaNs are replaced with null before
aggregation. The aggregator can skip null so I think it should be better.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]