This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 98dd2b024003 [SPARK-49615][ML] Make all ML feature transformers
dataset schema validation conforming "spark.sql.caseSensitive" config
98dd2b024003 is described below
commit 98dd2b0240033b32f200016da0da4e2ec6af4386
Author: Weichen Xu <[email protected]>
AuthorDate: Mon Nov 4 13:41:05 2024 -0800
[SPARK-49615][ML] Make all ML feature transformers dataset schema
validation conforming "spark.sql.caseSensitive" config
### What changes were proposed in this pull request?
Make all ML feature transformers dataset schema validation conforming
"spark.sql.caseSensitive" config.
### Why are the changes needed?
Bug fix.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48747 from WeichenXu123/SPARK-49615-2.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../scala/org/apache/spark/ml/Transformer.scala | 3 +-
.../org/apache/spark/ml/feature/Binarizer.scala | 4 ++-
.../scala/org/apache/spark/ml/feature/DCT.scala | 4 ++-
.../org/apache/spark/ml/feature/HashingTF.scala | 2 +-
.../org/apache/spark/ml/feature/Normalizer.scala | 4 ++-
.../org/apache/spark/ml/feature/RFormula.scala | 33 +++++++++++++---------
.../org/apache/spark/ml/feature/Selector.scala | 4 ++-
.../apache/spark/ml/feature/StringIndexer.scala | 4 +--
.../ml/feature/UnivariateFeatureSelector.scala | 4 ++-
.../apache/spark/ml/feature/VectorIndexer.scala | 8 ++++--
.../apache/spark/ml/feature/VectorSizeHint.scala | 8 ++++--
.../org/apache/spark/ml/feature/VectorSlicer.scala | 8 ++++--
.../org/apache/spark/ml/util/SchemaUtils.scala | 13 ++++++++-
13 files changed, 69 insertions(+), 30 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 3b60b5ae294e..1d74f3c8a969 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -24,6 +24,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -103,7 +104,7 @@ abstract class UnaryTransformer[IN: TypeTag, OUT: TypeTag,
T <: UnaryTransformer
protected def validateInputType(inputType: DataType): Unit = {}
override def transformSchema(schema: StructType): StructType = {
- val inputType = schema($(inputCol)).dataType
+ val inputType = SchemaUtils.getSchemaFieldType(schema, $(inputCol))
validateInputType(inputType)
if (schema.fieldNames.contains($(outputCol))) {
throw new IllegalArgumentException(s"Output column ${$(outputCol)}
already exists.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 5486c39034fd..8123438fd887 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -214,7 +214,9 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0")
override val uid: String)
case DoubleType =>
BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
case _: VectorUDT =>
- val size = AttributeGroup.fromStructField(schema(inputColName)).size
+ val size = AttributeGroup.fromStructField(
+ SchemaUtils.getSchemaField(schema, inputColName)
+ ).size
if (size < 0) {
StructField(outputColName, new VectorUDT)
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
index d057e5a62e50..9a8bfb195666 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
@@ -79,7 +79,9 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid:
String)
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) {
- val size = AttributeGroup.fromStructField(schema($(inputCol))).size
+ val size = AttributeGroup.fromStructField(
+ SchemaUtils.getSchemaField(schema, $(inputCol))
+ ).size
if (size >= 0) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(outputCol), size)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 3b42105958c7..dab0a6494fdb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -105,7 +105,7 @@ class HashingTF @Since("3.0.0") private[ml] (
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
- val inputType = schema($(inputCol)).dataType
+ val inputType = SchemaUtils.getSchemaFieldType(schema, $(inputCol))
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ${ArrayType.simpleString}, but got
${inputType.catalogString}.")
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index 4c7583b8381d..c7b7164e42f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -71,7 +71,9 @@ class Normalizer @Since("1.4.0") (@Since("1.4.0") override
val uid: String)
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) {
- val size = AttributeGroup.fromStructField(schema($(inputCol))).size
+ val size = AttributeGroup.fromStructField(
+ SchemaUtils.getSchemaField(schema, $(inputCol))
+ ).size
if (size >= 0) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(outputCol), size)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 77bd18423ef1..221d70c18d5a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -220,7 +220,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override
val uid: String)
// First we index each string column referenced by the input terms.
val indexed = terms.zipWithIndex.map { case (term, i) =>
- dataset.schema(term).dataType match {
+ val termField = SchemaUtils.getSchemaField(dataset.schema, term)
+ termField.dataType match {
case _: StringType =>
val indexCol = tmpColumn("stridx")
encoderStages += new StringIndexer()
@@ -231,7 +232,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override
val uid: String)
prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol)
case _: VectorUDT =>
- val group = AttributeGroup.fromStructField(dataset.schema(term))
+ val group = AttributeGroup.fromStructField(termField)
val size = if (group.size < 0) {
firstRow.getAs[Vector](i).size
} else {
@@ -250,7 +251,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override
val uid: String)
// Then we handle one-hot encoding and interactions between terms.
var keepReferenceCategory = false
val encodedTerms = resolvedFormula.terms.map {
- case Seq(term) if dataset.schema(term).dataType == StringType =>
+ case Seq(term) if SchemaUtils.getSchemaFieldType(dataset.schema, term)
== StringType =>
val encodedCol = tmpColumn("onehot")
// Formula w/o intercept, one of the categories in the first category
feature is
// being used as reference category, we will not drop any category for
that feature.
@@ -292,7 +293,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override
val uid: String)
encoderStages += new ColumnPruner(tempColumns.toSet)
if ((dataset.schema.fieldNames.contains(resolvedFormula.label) &&
- dataset.schema(resolvedFormula.label).dataType == StringType) ||
$(forceIndexLabel)) {
+ SchemaUtils.getSchemaFieldType(
+ dataset.schema, resolvedFormula.label) == StringType) ||
$(forceIndexLabel)) {
encoderStages += new StringIndexer()
.setInputCol(resolvedFormula.label)
.setOutputCol($(labelCol))
@@ -359,8 +361,8 @@ class RFormulaModel private[feature](
val withFeatures = pipelineModel.transformSchema(schema)
if (resolvedFormula.label.isEmpty || hasLabelCol(withFeatures)) {
withFeatures
- } else if (schema.exists(_.name == resolvedFormula.label)) {
- val nullable = schema(resolvedFormula.label).dataType match {
+ } else if (SchemaUtils.checkSchemaFieldExist(schema,
resolvedFormula.label)) {
+ val nullable = SchemaUtils.getSchemaFieldType(schema,
resolvedFormula.label) match {
case _: NumericType | BooleanType => false
case _ => true
}
@@ -387,8 +389,8 @@ class RFormulaModel private[feature](
val labelName = resolvedFormula.label
if (labelName.isEmpty || hasLabelCol(dataset.schema)) {
dataset.toDF()
- } else if (dataset.schema.exists(_.name == labelName)) {
- dataset.schema(labelName).dataType match {
+ } else if (SchemaUtils.checkSchemaFieldExist(dataset.schema, labelName)) {
+ SchemaUtils.getSchemaFieldType(dataset.schema, labelName) match {
case _: NumericType | BooleanType =>
dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
case other =>
@@ -402,10 +404,12 @@ class RFormulaModel private[feature](
}
private def checkCanTransform(schema: StructType): Unit = {
- val columnNames = schema.map(_.name)
- require(!columnNames.contains($(featuresCol)), "Features column already
exists.")
require(
- !columnNames.contains($(labelCol)) ||
schema($(labelCol)).dataType.isInstanceOf[NumericType],
+ !SchemaUtils.checkSchemaFieldExist(schema, $(featuresCol)), "Features
column already exists."
+ )
+ require(
+ !SchemaUtils.checkSchemaFieldExist(schema, $(labelCol))
+ || SchemaUtils.getSchemaFieldType(schema,
$(labelCol)).isInstanceOf[NumericType],
s"Label column already exists and is not of type
${NumericType.simpleString}.")
}
@@ -550,7 +554,9 @@ private class VectorAttributeRewriter(
override def transform(dataset: Dataset[_]): DataFrame = {
val metadata = {
- val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
+ val group = AttributeGroup.fromStructField(
+ SchemaUtils.getSchemaField(dataset.schema, vectorCol)
+ )
val attrs = group.attributes.get.map { attr =>
if (attr.name.isDefined) {
val name = prefixesToRewrite.foldLeft(attr.name.get) { case
(curName, (from, to)) =>
@@ -563,7 +569,8 @@ private class VectorAttributeRewriter(
}
new AttributeGroup(vectorCol, attrs).toMetadata()
}
- val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
+ val vectorColFieldName = SchemaUtils.getSchemaField(dataset.schema,
vectorCol).name
+ val otherCols = dataset.columns.filter(_ !=
vectorColFieldName).map(dataset.col)
val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
import org.apache.spark.util.ArrayImplicits._
dataset.select((otherCols :+ rewrittenCol).toImmutableArraySeq : _*)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
index 8ff880b7b8aa..dde1068c5b92 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
@@ -341,7 +341,9 @@ private[feature] object SelectorModel {
featuresCol: String,
isNumericAttribute: Boolean): StructField = {
val selector = selectedFeatures.toSet
- val origAttrGroup = AttributeGroup.fromStructField(schema(featuresCol))
+ val origAttrGroup = AttributeGroup.fromStructField(
+ SchemaUtils.getSchemaField(schema, featuresCol)
+ )
val featureAttributes: Array[Attribute] = if
(origAttrGroup.attributes.nonEmpty) {
origAttrGroup.attributes.get.zipWithIndex.filter(x =>
selector.contains(x._2)).map(_._1)
} else {
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 20b03edf23c4..1acffa471e9a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -564,7 +564,7 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0")
override val uid: String)
@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
- val inputDataType = schema(inputColName).dataType
+ val inputDataType = SchemaUtils.getSchemaFieldType(schema, inputColName)
require(inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be a numeric type, " +
s"but got $inputDataType.")
@@ -579,7 +579,7 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0")
override val uid: String)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- val inputColSchema = dataset.schema($(inputCol))
+ val inputColSchema = SchemaUtils.getSchemaField(dataset.schema,
$(inputCol))
// If the labels array is empty use column metadata
val values = if (!isDefined(labels) || $(labels).isEmpty) {
Attribute.fromStructField(inputColSchema)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
index 9c2033c28430..ea1a8c6438c8 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
@@ -410,7 +410,9 @@ object UnivariateFeatureSelectorModel extends
MLReadable[UnivariateFeatureSelect
featuresCol: String,
isNumericAttribute: Boolean): StructField = {
val selector = selectedFeatures.toSet
- val origAttrGroup = AttributeGroup.fromStructField(schema(featuresCol))
+ val origAttrGroup = AttributeGroup.fromStructField(
+ SchemaUtils.getSchemaField(schema, featuresCol)
+ )
val featureAttributes: Array[Attribute] = if
(origAttrGroup.attributes.nonEmpty) {
origAttrGroup.attributes.get.zipWithIndex.filter(x =>
selector.contains(x._2)).map(_._1)
} else {
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index ff89dee68ea3..b2323d2b706f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -445,7 +445,9 @@ class VectorIndexerModel private[ml] (
SchemaUtils.checkColumnType(schema, $(inputCol), dataType)
// If the input metadata specifies numFeatures, compare with expected
numFeatures.
- val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol)))
+ val origAttrGroup = AttributeGroup.fromStructField(
+ SchemaUtils.getSchemaField(schema, $(inputCol))
+ )
val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
Some(origAttrGroup.attributes.get.length)
} else {
@@ -466,7 +468,9 @@ class VectorIndexerModel private[ml] (
* @return Output column field. This field does not contain non-ML
metadata.
*/
private def prepOutputField(schema: StructType): StructField = {
- val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol)))
+ val origAttrGroup = AttributeGroup.fromStructField(
+ SchemaUtils.getSchemaField(schema, $(inputCol))
+ )
val featureAttributes: Array[Attribute] = if
(origAttrGroup.attributes.nonEmpty) {
// Convert original attributes to modified attributes
val origAttrs: Array[Attribute] = origAttrGroup.attributes.get
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
index 5c96d07e0ca9..4abb607733e3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol}
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable,
Identifiable}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable,
Identifiable, SchemaUtils}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructType}
@@ -98,7 +98,9 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0")
override val uid: String)
val localSize = getSize
val localHandleInvalid = getHandleInvalid
- val group = AttributeGroup.fromStructField(dataset.schema(localInputCol))
+ val group = AttributeGroup.fromStructField(
+ SchemaUtils.getSchemaField(dataset.schema, localInputCol)
+ )
val newGroup = validateSchemaAndSize(dataset.schema, group)
if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size
== localSize) {
dataset.toDF()
@@ -139,7 +141,7 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0")
override val uid: String)
val localSize = getSize
val localInputCol = getInputCol
- val inputColType = schema(getInputCol).dataType
+ val inputColType = SchemaUtils.getSchemaFieldType(schema, getInputCol)
require(
inputColType.isInstanceOf[VectorUDT],
s"Input column, $getInputCol must be of Vector type, got $inputColType"
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index 5687ba878634..58a44a41f0e8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -99,7 +99,9 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0")
override val uid: Stri
override def transform(dataset: Dataset[_]): DataFrame = {
// Validity checks
transformSchema(dataset.schema)
- val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
+ val inputAttr = AttributeGroup.fromStructField(
+ SchemaUtils.getSchemaField(dataset.schema, $(inputCol))
+ )
if ($(indices).nonEmpty) {
val size = inputAttr.size
if (size >= 0) {
@@ -130,7 +132,9 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0")
override val uid: Stri
/** Get the feature indices in order: indices, names */
private def getSelectedFeatureIndices(schema: StructType): Array[Int] = {
- val nameFeatures =
MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names))
+ val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(
+ SchemaUtils.getSchemaField(schema, $(inputCol)), $(names)
+ )
val indFeatures = $(indices)
val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length
lazy val errMsg = "VectorSlicer requires indices and names to be disjoint"
+
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index ff132e2a29a8..538664183872 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -41,7 +41,7 @@ private[spark] object SchemaUtils {
colName: String,
dataType: DataType,
msg: String = ""): Unit = {
- val actualDataType = schema(colName).dataType
+ val actualDataType = SchemaUtils.getSchemaField(schema, colName).dataType
val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
require(actualDataType.equals(dataType),
s"Column $colName must be of type
${dataType.getClass}:${dataType.catalogString} " +
@@ -238,4 +238,15 @@ private[spark] object SchemaUtils {
def getSchemaFieldType(schema: StructType, colName: String): DataType = {
getSchemaField(schema, colName).dataType
}
+
+ /**
+ * Check whether a certain column name exists in the schema.
+ * @param schema input schema
+ * @param colName column name, nested column name is supported.
+ */
+ def checkSchemaFieldExist(schema: StructType, colName: String): Boolean = {
+ val colSplits = AttributeNameParser.parseAttributeName(colName)
+ val fieldOpt = schema.findNestedField(colSplits, resolver =
SQLConf.get.resolver)
+ fieldOpt.isDefined
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]