This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 98dd2b024003 [SPARK-49615][ML] Make all ML feature transformers 
dataset schema validation conforming "spark.sql.caseSensitive" config
98dd2b024003 is described below

commit 98dd2b0240033b32f200016da0da4e2ec6af4386
Author: Weichen Xu <[email protected]>
AuthorDate: Mon Nov 4 13:41:05 2024 -0800

    [SPARK-49615][ML] Make all ML feature transformers dataset schema 
validation conforming "spark.sql.caseSensitive" config
    
    ### What changes were proposed in this pull request?
    
    Make all ML feature transformers dataset schema validation conforming 
"spark.sql.caseSensitive" config.
    
    ### Why are the changes needed?
    
    Bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48747 from WeichenXu123/SPARK-49615-2.
    
    Authored-by: Weichen Xu <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../scala/org/apache/spark/ml/Transformer.scala    |  3 +-
 .../org/apache/spark/ml/feature/Binarizer.scala    |  4 ++-
 .../scala/org/apache/spark/ml/feature/DCT.scala    |  4 ++-
 .../org/apache/spark/ml/feature/HashingTF.scala    |  2 +-
 .../org/apache/spark/ml/feature/Normalizer.scala   |  4 ++-
 .../org/apache/spark/ml/feature/RFormula.scala     | 33 +++++++++++++---------
 .../org/apache/spark/ml/feature/Selector.scala     |  4 ++-
 .../apache/spark/ml/feature/StringIndexer.scala    |  4 +--
 .../ml/feature/UnivariateFeatureSelector.scala     |  4 ++-
 .../apache/spark/ml/feature/VectorIndexer.scala    |  8 ++++--
 .../apache/spark/ml/feature/VectorSizeHint.scala   |  8 ++++--
 .../org/apache/spark/ml/feature/VectorSlicer.scala |  8 ++++--
 .../org/apache/spark/ml/util/SchemaUtils.scala     | 13 ++++++++-
 13 files changed, 69 insertions(+), 30 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 3b60b5ae294e..1d74f3c8a969 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -24,6 +24,7 @@ import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
@@ -103,7 +104,7 @@ abstract class UnaryTransformer[IN: TypeTag, OUT: TypeTag, 
T <: UnaryTransformer
   protected def validateInputType(inputType: DataType): Unit = {}
 
   override def transformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
+    val inputType = SchemaUtils.getSchemaFieldType(schema, $(inputCol))
     validateInputType(inputType)
     if (schema.fieldNames.contains($(outputCol))) {
       throw new IllegalArgumentException(s"Output column ${$(outputCol)} 
already exists.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 5486c39034fd..8123438fd887 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -214,7 +214,9 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
         case DoubleType =>
           BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
         case _: VectorUDT =>
-          val size = AttributeGroup.fromStructField(schema(inputColName)).size
+          val size = AttributeGroup.fromStructField(
+            SchemaUtils.getSchemaField(schema, inputColName)
+          ).size
           if (size < 0) {
             StructField(outputColName, new VectorUDT)
           } else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
index d057e5a62e50..9a8bfb195666 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
@@ -79,7 +79,9 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: 
String)
   override def transformSchema(schema: StructType): StructType = {
     var outputSchema = super.transformSchema(schema)
     if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) {
-      val size = AttributeGroup.fromStructField(schema($(inputCol))).size
+      val size = AttributeGroup.fromStructField(
+        SchemaUtils.getSchemaField(schema, $(inputCol))
+      ).size
       if (size >= 0) {
         outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
           $(outputCol), size)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 3b42105958c7..dab0a6494fdb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -105,7 +105,7 @@ class HashingTF @Since("3.0.0") private[ml] (
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
+    val inputType = SchemaUtils.getSchemaFieldType(schema, $(inputCol))
     require(inputType.isInstanceOf[ArrayType],
       s"The input column must be ${ArrayType.simpleString}, but got 
${inputType.catalogString}.")
     val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index 4c7583b8381d..c7b7164e42f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -71,7 +71,9 @@ class Normalizer @Since("1.4.0") (@Since("1.4.0") override 
val uid: String)
   override def transformSchema(schema: StructType): StructType = {
     var outputSchema = super.transformSchema(schema)
     if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) {
-      val size = AttributeGroup.fromStructField(schema($(inputCol))).size
+      val size = AttributeGroup.fromStructField(
+        SchemaUtils.getSchemaField(schema, $(inputCol))
+      ).size
       if (size >= 0) {
         outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
           $(outputCol), size)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 77bd18423ef1..221d70c18d5a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -220,7 +220,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
 
     // First we index each string column referenced by the input terms.
     val indexed = terms.zipWithIndex.map { case (term, i) =>
-      dataset.schema(term).dataType match {
+      val termField = SchemaUtils.getSchemaField(dataset.schema, term)
+      termField.dataType match {
         case _: StringType =>
           val indexCol = tmpColumn("stridx")
           encoderStages += new StringIndexer()
@@ -231,7 +232,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
           prefixesToRewrite(indexCol + "_") = term + "_"
           (term, indexCol)
         case _: VectorUDT =>
-          val group = AttributeGroup.fromStructField(dataset.schema(term))
+          val group = AttributeGroup.fromStructField(termField)
           val size = if (group.size < 0) {
             firstRow.getAs[Vector](i).size
           } else {
@@ -250,7 +251,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
     // Then we handle one-hot encoding and interactions between terms.
     var keepReferenceCategory = false
     val encodedTerms = resolvedFormula.terms.map {
-      case Seq(term) if dataset.schema(term).dataType == StringType =>
+      case Seq(term) if SchemaUtils.getSchemaFieldType(dataset.schema, term) 
== StringType =>
         val encodedCol = tmpColumn("onehot")
         // Formula w/o intercept, one of the categories in the first category 
feature is
         // being used as reference category, we will not drop any category for 
that feature.
@@ -292,7 +293,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
     encoderStages += new ColumnPruner(tempColumns.toSet)
 
     if ((dataset.schema.fieldNames.contains(resolvedFormula.label) &&
-      dataset.schema(resolvedFormula.label).dataType == StringType) || 
$(forceIndexLabel)) {
+      SchemaUtils.getSchemaFieldType(
+        dataset.schema, resolvedFormula.label) == StringType) || 
$(forceIndexLabel)) {
       encoderStages += new StringIndexer()
         .setInputCol(resolvedFormula.label)
         .setOutputCol($(labelCol))
@@ -359,8 +361,8 @@ class RFormulaModel private[feature](
     val withFeatures = pipelineModel.transformSchema(schema)
     if (resolvedFormula.label.isEmpty || hasLabelCol(withFeatures)) {
       withFeatures
-    } else if (schema.exists(_.name == resolvedFormula.label)) {
-      val nullable = schema(resolvedFormula.label).dataType match {
+    } else if (SchemaUtils.checkSchemaFieldExist(schema, 
resolvedFormula.label)) {
+      val nullable = SchemaUtils.getSchemaFieldType(schema, 
resolvedFormula.label) match {
         case _: NumericType | BooleanType => false
         case _ => true
       }
@@ -387,8 +389,8 @@ class RFormulaModel private[feature](
     val labelName = resolvedFormula.label
     if (labelName.isEmpty || hasLabelCol(dataset.schema)) {
       dataset.toDF()
-    } else if (dataset.schema.exists(_.name == labelName)) {
-      dataset.schema(labelName).dataType match {
+    } else if (SchemaUtils.checkSchemaFieldExist(dataset.schema, labelName)) {
+      SchemaUtils.getSchemaFieldType(dataset.schema, labelName) match {
         case _: NumericType | BooleanType =>
           dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
         case other =>
@@ -402,10 +404,12 @@ class RFormulaModel private[feature](
   }
 
   private def checkCanTransform(schema: StructType): Unit = {
-    val columnNames = schema.map(_.name)
-    require(!columnNames.contains($(featuresCol)), "Features column already 
exists.")
     require(
-      !columnNames.contains($(labelCol)) || 
schema($(labelCol)).dataType.isInstanceOf[NumericType],
+      !SchemaUtils.checkSchemaFieldExist(schema, $(featuresCol)), "Features 
column already exists."
+    )
+    require(
+      !SchemaUtils.checkSchemaFieldExist(schema, $(labelCol))
+      || SchemaUtils.getSchemaFieldType(schema, 
$(labelCol)).isInstanceOf[NumericType],
       s"Label column already exists and is not of type 
${NumericType.simpleString}.")
   }
 
@@ -550,7 +554,9 @@ private class VectorAttributeRewriter(
 
   override def transform(dataset: Dataset[_]): DataFrame = {
     val metadata = {
-      val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
+      val group = AttributeGroup.fromStructField(
+        SchemaUtils.getSchemaField(dataset.schema, vectorCol)
+      )
       val attrs = group.attributes.get.map { attr =>
         if (attr.name.isDefined) {
           val name = prefixesToRewrite.foldLeft(attr.name.get) { case 
(curName, (from, to)) =>
@@ -563,7 +569,8 @@ private class VectorAttributeRewriter(
       }
       new AttributeGroup(vectorCol, attrs).toMetadata()
     }
-    val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
+    val vectorColFieldName = SchemaUtils.getSchemaField(dataset.schema, 
vectorCol).name
+    val otherCols = dataset.columns.filter(_ != 
vectorColFieldName).map(dataset.col)
     val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
     import org.apache.spark.util.ArrayImplicits._
     dataset.select((otherCols :+ rewrittenCol).toImmutableArraySeq : _*)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
index 8ff880b7b8aa..dde1068c5b92 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Selector.scala
@@ -341,7 +341,9 @@ private[feature] object SelectorModel {
       featuresCol: String,
       isNumericAttribute: Boolean): StructField = {
     val selector = selectedFeatures.toSet
-    val origAttrGroup = AttributeGroup.fromStructField(schema(featuresCol))
+    val origAttrGroup = AttributeGroup.fromStructField(
+      SchemaUtils.getSchemaField(schema, featuresCol)
+    )
     val featureAttributes: Array[Attribute] = if 
(origAttrGroup.attributes.nonEmpty) {
       origAttrGroup.attributes.get.zipWithIndex.filter(x => 
selector.contains(x._2)).map(_._1)
     } else {
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 20b03edf23c4..1acffa471e9a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -564,7 +564,7 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") 
override val uid: String)
   @Since("1.5.0")
   override def transformSchema(schema: StructType): StructType = {
     val inputColName = $(inputCol)
-    val inputDataType = schema(inputColName).dataType
+    val inputDataType = SchemaUtils.getSchemaFieldType(schema, inputColName)
     require(inputDataType.isInstanceOf[NumericType],
       s"The input column $inputColName must be a numeric type, " +
         s"but got $inputDataType.")
@@ -579,7 +579,7 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") 
override val uid: String)
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
-    val inputColSchema = dataset.schema($(inputCol))
+    val inputColSchema = SchemaUtils.getSchemaField(dataset.schema, 
$(inputCol))
     // If the labels array is empty use column metadata
     val values = if (!isDefined(labels) || $(labels).isEmpty) {
       Attribute.fromStructField(inputColSchema)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
index 9c2033c28430..ea1a8c6438c8 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
@@ -410,7 +410,9 @@ object UnivariateFeatureSelectorModel extends 
MLReadable[UnivariateFeatureSelect
       featuresCol: String,
       isNumericAttribute: Boolean): StructField = {
     val selector = selectedFeatures.toSet
-    val origAttrGroup = AttributeGroup.fromStructField(schema(featuresCol))
+    val origAttrGroup = AttributeGroup.fromStructField(
+      SchemaUtils.getSchemaField(schema, featuresCol)
+    )
     val featureAttributes: Array[Attribute] = if 
(origAttrGroup.attributes.nonEmpty) {
       origAttrGroup.attributes.get.zipWithIndex.filter(x => 
selector.contains(x._2)).map(_._1)
     } else {
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index ff89dee68ea3..b2323d2b706f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -445,7 +445,9 @@ class VectorIndexerModel private[ml] (
     SchemaUtils.checkColumnType(schema, $(inputCol), dataType)
 
     // If the input metadata specifies numFeatures, compare with expected 
numFeatures.
-    val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol)))
+    val origAttrGroup = AttributeGroup.fromStructField(
+      SchemaUtils.getSchemaField(schema, $(inputCol))
+    )
     val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) {
       Some(origAttrGroup.attributes.get.length)
     } else {
@@ -466,7 +468,9 @@ class VectorIndexerModel private[ml] (
    * @return  Output column field.  This field does not contain non-ML 
metadata.
    */
   private def prepOutputField(schema: StructType): StructField = {
-    val origAttrGroup = AttributeGroup.fromStructField(schema($(inputCol)))
+    val origAttrGroup = AttributeGroup.fromStructField(
+      SchemaUtils.getSchemaField(schema, $(inputCol))
+    )
     val featureAttributes: Array[Attribute] = if 
(origAttrGroup.attributes.nonEmpty) {
       // Convert original attributes to modified attributes
       val origAttrs: Array[Attribute] = origAttrGroup.attributes.get
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
index 5c96d07e0ca9..4abb607733e3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.linalg.VectorUDT
 import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol}
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable, SchemaUtils}
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{StringType, StructType}
@@ -98,7 +98,9 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") 
override val uid: String)
     val localSize = getSize
     val localHandleInvalid = getHandleInvalid
 
-    val group = AttributeGroup.fromStructField(dataset.schema(localInputCol))
+    val group = AttributeGroup.fromStructField(
+      SchemaUtils.getSchemaField(dataset.schema, localInputCol)
+    )
     val newGroup = validateSchemaAndSize(dataset.schema, group)
     if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size 
== localSize) {
       dataset.toDF()
@@ -139,7 +141,7 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") 
override val uid: String)
     val localSize = getSize
     val localInputCol = getInputCol
 
-    val inputColType = schema(getInputCol).dataType
+    val inputColType = SchemaUtils.getSchemaFieldType(schema, getInputCol)
     require(
       inputColType.isInstanceOf[VectorUDT],
       s"Input column, $getInputCol must be of Vector type, got $inputColType"
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index 5687ba878634..58a44a41f0e8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -99,7 +99,9 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") 
override val uid: Stri
   override def transform(dataset: Dataset[_]): DataFrame = {
     // Validity checks
     transformSchema(dataset.schema)
-    val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
+    val inputAttr = AttributeGroup.fromStructField(
+      SchemaUtils.getSchemaField(dataset.schema, $(inputCol))
+    )
     if ($(indices).nonEmpty) {
       val size = inputAttr.size
       if (size >= 0) {
@@ -130,7 +132,9 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") 
override val uid: Stri
 
   /** Get the feature indices in order: indices, names */
   private def getSelectedFeatureIndices(schema: StructType): Array[Int] = {
-    val nameFeatures = 
MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names))
+    val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(
+      SchemaUtils.getSchemaField(schema, $(inputCol)), $(names)
+    )
     val indFeatures = $(indices)
     val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length
     lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" 
+
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index ff132e2a29a8..538664183872 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -41,7 +41,7 @@ private[spark] object SchemaUtils {
       colName: String,
       dataType: DataType,
       msg: String = ""): Unit = {
-    val actualDataType = schema(colName).dataType
+    val actualDataType = SchemaUtils.getSchemaField(schema, colName).dataType
     val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
     require(actualDataType.equals(dataType),
       s"Column $colName must be of type 
${dataType.getClass}:${dataType.catalogString} " +
@@ -238,4 +238,15 @@ private[spark] object SchemaUtils {
   def getSchemaFieldType(schema: StructType, colName: String): DataType = {
     getSchemaField(schema, colName).dataType
   }
+
+  /**
+   * Check whether a certain column name exists in the schema.
+   * @param schema input schema
+   * @param colName column name, nested column name is supported.
+   */
+  def checkSchemaFieldExist(schema: StructType, colName: String): Boolean = {
+    val colSplits = AttributeNameParser.parseAttributeName(colName)
+    val fieldOpt = schema.findNestedField(colSplits, resolver = 
SQLConf.get.resolver)
+    fieldOpt.isDefined
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to