This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9bff2c8bc505 [SPARK-48463][ML] Make StringIndexer supporting nested
input columns
9bff2c8bc505 is described below
commit 9bff2c8bc5059f5be0dc6e8105c11403942a0b9f
Author: Weichen Xu <[email protected]>
AuthorDate: Mon Jul 15 15:19:59 2024 +0800
[SPARK-48463][ML] Make StringIndexer supporting nested input columns
### What changes were proposed in this pull request?
Make StringIndexer supporting nested input columns
### Why are the changes needed?
User demand.
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
Closes #47283 from WeichenXu123/SPARK-48463.
Lead-authored-by: Weichen Xu <[email protected]>
Co-authored-by: WeichenXu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
.../apache/spark/ml/feature/StringIndexer.scala | 37 +++++++++++------
.../spark/ml/feature/StringIndexerSuite.scala | 47 +++++++++++++++++++++-
2 files changed, 71 insertions(+), 13 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 60dc4d024071..34f77f029395 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import java.util.ArrayList
+
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
@@ -27,7 +29,7 @@ import org.apache.spark.ml.attribute.{Attribute,
NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
-import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders,
Row}
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset,
Encoder, Encoders, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
@@ -103,8 +105,8 @@ private[feature] trait StringIndexerBase extends Params
with HasHandleInvalid wi
private def validateAndTransformField(
schema: StructType,
inputColName: String,
+ inputDataType: DataType,
outputColName: String): StructField = {
- val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType ||
inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be either string type or numeric
type, " +
s"but got $inputDataType.")
@@ -122,12 +124,22 @@ private[feature] trait StringIndexerBase extends Params
with HasHandleInvalid wi
require(outputColNames.distinct.length == outputColNames.length,
s"Output columns should not be duplicate.")
+ val sparkSession = SparkSession.getActiveSession.get
+ val transformDataset = sparkSession.createDataFrame(new ArrayList[Row](),
schema = schema)
val outputFields = inputColNames.zip(outputColNames).flatMap {
case (inputColName, outputColName) =>
- schema.fieldNames.contains(inputColName) match {
- case true => Some(validateAndTransformField(schema, inputColName,
outputColName))
- case false if skipNonExistsCol => None
- case _ => throw new SparkException(s"Input column $inputColName does
not exist.")
+ try {
+ val dtype = transformDataset.col(inputColName).expr.dataType
+ Some(
+ validateAndTransformField(schema, inputColName, dtype,
outputColName)
+ )
+ } catch {
+ case _: AnalysisException =>
+ if (skipNonExistsCol) {
+ None
+ } else {
+ throw new SparkException(s"Input column $inputColName does not
exist.")
+ }
}
}
StructType(schema.fields ++ outputFields)
@@ -431,11 +443,8 @@ class StringIndexerModel (
val labelToIndex = labelsToIndexArray(i)
val labels = labelsArray(i)
- if (!dataset.schema.fieldNames.contains(inputColName)) {
- logWarning(log"Input column ${MDC(LogKeys.COLUMN_NAME, inputColName)}
does not exist " +
- log"during transformation. Skip StringIndexerModel for this column.")
- outputColNames(i) = null
- } else {
+ try {
+ dataset.col(inputColName)
val filteredLabels = getHandleInvalid match {
case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
case _ => labels
@@ -449,9 +458,13 @@ class StringIndexerModel (
outputColumns(i) = indexer(dataset(inputColName).cast(StringType))
.as(outputColName, metadata)
+ } catch {
+ case _: AnalysisException =>
+ logWarning(log"Input column ${MDC(LogKeys.COLUMN_NAME,
inputColName)} does not exist " +
+ log"during transformation. Skip StringIndexerModel for this
column.")
+ outputColNames(i) = null
}
}
-
val filteredOutputColNames = outputColNames.filter(_ != null)
val filteredOutputColumns = outputColumns.filter(_ != null)
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 99f12eab7d69..8f3750959d2b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -21,7 +21,8 @@ import org.apache.spark.ml.attribute.{Attribute,
NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.sql.Row
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.catalyst.parser.DataTypeParser
+import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.types.{DoubleType, StringType, StructField,
StructType}
class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
@@ -113,6 +114,50 @@ class StringIndexerSuite extends MLTest with
DefaultReadWriteTest {
assert(outSchema("output2").dataType === DoubleType)
}
+ test("StringIndexer.transformSchema nested col") {
+ val outputCols = Array("output", "output2", "output3", "output4",
"output5")
+ val idxToStr = new StringIndexer().setInputCols(
+ Array("input1.a.f1", "input1.a.f2", "input2.b1", "input2.b2", "input3")
+ ).setOutputCols(outputCols)
+
+ val inSchema = DataTypeParser.parseTableSchema(
+ "input1 struct<a struct<f1 string, f2 string>>, " +
+ "input2 struct<b1 string, b2 string>, input3 string"
+ )
+ val outSchema = idxToStr.transformSchema(inSchema)
+
+ for (outputCol <- outputCols) {
+ assert(outSchema(outputCol).dataType === DoubleType)
+ }
+ }
+
+ test("StringIndexer nested input cols") {
+ val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
+ val df = data.toDF("id", "label")
+ .select(col("id"), struct(col("label")).alias("c1"))
+ val indexer = new StringIndexer()
+ .setInputCol("c1.label")
+ .setOutputCol("labelIndex")
+ val indexerModel = indexer.fit(df)
+ MLTestingUtils.checkCopyAndUids(indexer, indexerModel)
+ // a -> 0, b -> 2, c -> 1
+ val expected = Seq(
+ (0, 0.0),
+ (1, 2.0),
+ (2, 1.0),
+ (3, 0.0),
+ (4, 0.0),
+ (5, 1.0)
+ ).toDF("id", "labelIndex")
+
+ val dfOutput = indexerModel.transform(df)
+ val outputs = dfOutput.select("id", "labelIndex").collect().toSeq
+ val attr = Attribute.fromStructField(outputs.head.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attr.values.get === Array("a", "c", "b"))
+ assert(outputs === expected.collect().toSeq)
+ }
+
test("StringIndexerUnseen") {
val data = Seq((0, "a"), (1, "b"), (4, "b"))
val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]