This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 017b7d3  [SPARK-36074][SQL] Add error class for 
StructType.findNestedField
017b7d3 is described below

commit 017b7d3f0b05da5513014e6b2b76acc16bcd006b
Author: Wenchen Fan <[email protected]>
AuthorDate: Tue Jul 13 21:13:58 2021 +0800

    [SPARK-36074][SQL] Add error class for StructType.findNestedField
    
    ### What changes were proposed in this pull request?
    
    This PR adds an INVALID_FIELD_NAME error class for the errors in 
`StructType.findNestedField`. It also cleans up the code there and adds UT for 
this method.
    
    ### Why are the changes needed?
    
    follow the new error message framework
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    existing tests
    
    Closes #33282 from cloud-fan/error.
    
    Authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 core/src/main/resources/error/error-classes.json   |   8 ++
 .../org/apache/spark/sql/AnalysisException.scala   |   8 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   4 +-
 .../spark/sql/catalyst/analysis/package.scala      |   3 +-
 .../spark/sql/errors/QueryCompilationErrors.scala  |  15 ++-
 .../org/apache/spark/sql/types/StructType.scala    |  92 ++++++++---------
 .../apache/spark/sql/types/StructTypeSuite.scala   | 109 +++++++++++++++++++++
 7 files changed, 183 insertions(+), 56 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index 02feb9d..6ab113b 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -1,4 +1,8 @@
 {
+  "AMBIGUOUS_FIELD_NAME" : {
+    "message" : [ "Field name %s is ambiguous and has %s matching fields in 
the struct." ],
+    "sqlState" : "42000"
+  },
   "DIVIDE_BY_ZERO" : {
     "message" : [ "divide by zero" ],
     "sqlState" : "22012"
@@ -7,6 +11,10 @@
     "message" : [ "Found duplicate keys '%s'" ],
     "sqlState" : "23000"
   },
+  "INVALID_FIELD_NAME" : {
+    "message" : [ "Field name %s is invalid: %s is not a struct." ],
+    "sqlState" : "42000"
+  },
   "MISSING_COLUMN" : {
     "message" : [ "cannot resolve '%s' given input columns: [%s]" ],
     "sqlState" : "42000"
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 6299431..d0a3a71 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import org.apache.spark.{SparkThrowable, SparkThrowableHelper}
 import org.apache.spark.annotation.Stable
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
 
 /**
  * Thrown when a query fails to analyze, usually because the query itself is 
invalid.
@@ -48,12 +49,11 @@ class AnalysisException protected[sql] (
   def this(
       errorClass: String,
       messageParameters: Array[String],
-      line: Option[Int],
-      startPosition: Option[Int]) =
+      origin: Origin) =
     this(
       SparkThrowableHelper.getMessage(errorClass, messageParameters),
-      line = line,
-      startPosition = startPosition,
+      line = origin.line,
+      startPosition = origin.startPosition,
       errorClass = Some(errorClass),
       messageParameters = messageParameters)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2d747f7..64f6b79 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3613,7 +3613,9 @@ class Analyzer(override val catalogManager: 
CatalogManager)
         table: ResolvedTable,
         fieldName: Seq[String],
         context: Expression): ResolvedFieldName = {
-      table.schema.findNestedField(fieldName, includeCollections = true, 
conf.resolver).map {
+      table.schema.findNestedField(
+        fieldName, includeCollections = true, conf.resolver, context.origin
+      ).map {
         case (path, field) => ResolvedFieldName(path, field)
       }.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, 
table, context))
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index 8ad8706..81683ad 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -51,8 +51,7 @@ package object analysis {
       throw new AnalysisException(
         errorClass = errorClass,
         messageParameters = messageParameters,
-        line = t.origin.line,
-        startPosition = t.origin.startPosition)
+        origin = t.origin)
     }
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index d1dcbbc..6322676 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -29,7 +29,7 @@ import 
org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, AttributeSet, CreateMap, Expression, GroupingID, 
NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, 
WindowSpecDefinition}
 import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, 
LogicalPlan, SerdeInfo, Window}
-import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode}
 import org.apache.spark.sql.catalyst.util.{toPrettySQL, FailFastMode, 
ParseMode, PermissiveMode}
 import org.apache.spark.sql.connector.catalog._
 import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
@@ -1352,9 +1352,12 @@ private[spark] object QueryCompilationErrors {
         s"${evalTypes.mkString(",")}")
   }
 
-  def ambiguousFieldNameError(fieldName: String, names: String): Throwable = {
+  def ambiguousFieldNameError(
+      fieldName: Seq[String], numMatches: Int, context: Origin): Throwable = {
     new AnalysisException(
-      s"Ambiguous field name: $fieldName. Found multiple columns that can 
match: $names")
+      errorClass = "AMBIGUOUS_FIELD_NAME",
+      messageParameters = Array(fieldName.quoted, numMatches.toString),
+      origin = context)
   }
 
   def cannotUseIntervalTypeInTableSchemaError(): Throwable = {
@@ -2359,8 +2362,10 @@ private[spark] object QueryCompilationErrors {
       context.origin.startPosition)
   }
 
-  def invalidFieldName(fieldName: Seq[String], path: Seq[String]): Throwable = 
{
+  def invalidFieldName(fieldName: Seq[String], path: Seq[String], context: 
Origin): Throwable = {
     new AnalysisException(
-      s"Field name ${fieldName.quoted} is invalid, ${path.quoted} is not a 
struct.")
+      errorClass = "INVALID_FIELD_NAME",
+      messageParameters = Array(fieldName.quoted, path.quoted),
+      origin = context)
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index a9ba2b9..87ff4eb5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -27,6 +27,7 @@ import org.apache.spark.annotation.Stable
 import org.apache.spark.sql.catalyst.analysis.Resolver
 import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, InterpretedOrdering}
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
LegacyTypeStringParser}
+import org.apache.spark.sql.catalyst.trees.Origin
 import org.apache.spark.sql.catalyst.util.{truncatedString, StringUtils}
 import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
@@ -317,66 +318,69 @@ case class StructType(fields: Array[StructField]) extends 
DataType with Seq[Stru
   private[sql] def findNestedField(
       fieldNames: Seq[String],
       includeCollections: Boolean = false,
-      resolver: Resolver = _ == _): Option[(Seq[String], StructField)] = {
-    def prettyFieldName(nameParts: Seq[String]): String = {
-      import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
-      nameParts.quoted
-    }
+      resolver: Resolver = _ == _,
+      context: Origin = Origin()): Option[(Seq[String], StructField)] = {
 
     def findField(
         struct: StructType,
         searchPath: Seq[String],
         normalizedPath: Seq[String]): Option[(Seq[String], StructField)] = {
-      searchPath.headOption.flatMap { searchName =>
-        val found = struct.fields.filter(f => resolver(searchName, f.name))
-        if (found.length > 1) {
-          val names = found.map(f => prettyFieldName(normalizedPath :+ f.name))
-            .mkString("[", ", ", " ]")
-          throw QueryCompilationErrors.ambiguousFieldNameError(
-            prettyFieldName(normalizedPath :+ searchName), names)
-        } else if (found.isEmpty) {
-          None
+      assert(searchPath.nonEmpty)
+      val searchName = searchPath.head
+      val found = struct.fields.filter(f => resolver(searchName, f.name))
+      if (found.length > 1) {
+        throw QueryCompilationErrors.ambiguousFieldNameError(fieldNames, 
found.length, context)
+      } else if (found.isEmpty) {
+        None
+      } else {
+        val field = found.head
+        val currentPath = normalizedPath :+ field.name
+        val newSearchPath = searchPath.tail
+        if (newSearchPath.isEmpty) {
+          Some(normalizedPath -> field)
         } else {
-          val field = found.head
-          (searchPath.tail, field.dataType, includeCollections) match {
-            case (Seq(), _, _) =>
-              Some(normalizedPath -> field)
-
-            case (names, struct: StructType, _) =>
-              findField(struct, names, normalizedPath :+ field.name)
-
-            case (_, _, false) =>
-              None // types nested in maps and arrays are not used
+          (newSearchPath, field.dataType) match {
+            case (_, s: StructType) =>
+              findField(s, newSearchPath, currentPath)
 
-            case (Seq("key"), MapType(keyType, _, _), true) =>
-              // return the key type as a struct field to include nullability
-              Some((normalizedPath :+ field.name) -> StructField("key", 
keyType, nullable = false))
+            case _ if !includeCollections =>
+              throw QueryCompilationErrors.invalidFieldName(fieldNames, 
currentPath, context)
 
-            case (Seq("key", names @ _*), MapType(struct: StructType, _, _), 
true) =>
-              findField(struct, names, normalizedPath ++ Seq(field.name, 
"key"))
+            case (Seq("key", rest @ _*), MapType(keyType, _, _)) =>
+              findFieldInCollection(keyType, false, rest, currentPath, "key")
 
-            case (Seq("value"), MapType(_, valueType, isNullable), true) =>
-              // return the value type as a struct field to include nullability
-              Some((normalizedPath :+ field.name) ->
-                StructField("value", valueType, nullable = isNullable))
+            case (Seq("value", rest @ _*), MapType(_, valueType, isNullable)) 
=>
+              findFieldInCollection(valueType, isNullable, rest, currentPath, 
"value")
 
-            case (Seq("value", names @ _*), MapType(_, struct: StructType, _), 
true) =>
-              findField(struct, names, normalizedPath ++ Seq(field.name, 
"value"))
-
-            case (Seq("element"), ArrayType(elementType, isNullable), true) =>
-              // return the element type as a struct field to include 
nullability
-              Some((normalizedPath :+ field.name) ->
-                StructField("element", elementType, nullable = isNullable))
-
-            case (Seq("element", names @ _*), ArrayType(struct: StructType, 
_), true) =>
-              findField(struct, names, normalizedPath ++ Seq(field.name, 
"element"))
+            case (Seq("element", rest @ _*), ArrayType(elementType, 
isNullable)) =>
+              findFieldInCollection(elementType, isNullable, rest, 
currentPath, "element")
 
             case _ =>
-              throw QueryCompilationErrors.invalidFieldName(fieldNames, 
normalizedPath)
+              throw QueryCompilationErrors.invalidFieldName(fieldNames, 
currentPath, context)
           }
         }
       }
     }
+
+    def findFieldInCollection(
+        dt: DataType,
+        nullable: Boolean,
+        searchPath: Seq[String],
+        normalizedPath: Seq[String],
+        collectionFieldName: String): Option[(Seq[String], StructField)] = {
+      if (searchPath.isEmpty) {
+        Some(normalizedPath -> StructField(collectionFieldName, dt, nullable))
+      } else {
+        val newPath = normalizedPath :+ collectionFieldName
+        dt match {
+          case s: StructType =>
+            findField(s, searchPath, newPath)
+          case _ =>
+            throw QueryCompilationErrors.invalidFieldName(fieldNames, newPath, 
context)
+        }
+      }
+    }
+
     findField(this, fieldNames, Nil)
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
index 18821b8..8db3831 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
@@ -18,6 +18,8 @@
 package org.apache.spark.sql.types
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, 
caseSensitiveResolution}
 import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.catalyst.plans.SQLHelper
 import org.apache.spark.sql.internal.SQLConf
@@ -273,4 +275,111 @@ class StructTypeSuite extends SparkFunSuite with 
SQLHelper {
       checkIntervalDDL(start, end, DT.fieldToString)
     }
   }
+
+  test("findNestedField") {
+    val innerStruct = new StructType()
+      .add("s11", "int")
+      .add("s12", "int")
+    val input = new StructType()
+      .add("s1", innerStruct)
+      .add("s2", new StructType().add("x", "int").add("X", "int"))
+      .add("m1", MapType(IntegerType, IntegerType))
+      .add("m2", MapType(
+        new StructType().add("a", "int"),
+        new StructType().add("b", "int")
+      ))
+      .add("a1", ArrayType(IntegerType))
+      .add("a2", ArrayType(new StructType().add("c", "int")))
+
+    def check(field: Seq[String], expect: Option[(Seq[String], StructField)]): 
Unit = {
+      val res = input.findNestedField(field, resolver = 
caseInsensitiveResolution)
+      assert(res == expect)
+    }
+
+    def caseSensitiveCheck(field: Seq[String], expect: Option[(Seq[String], 
StructField)]): Unit = {
+      val res = input.findNestedField(field, resolver = 
caseSensitiveResolution)
+      assert(res == expect)
+    }
+
+    def checkCollection(field: Seq[String], expect: Option[(Seq[String], 
StructField)]): Unit = {
+      val res = input.findNestedField(field,
+        includeCollections = true, resolver = caseInsensitiveResolution)
+      assert(res == expect)
+    }
+
+    // struct type
+    check(Seq("non_exist"), None)
+    check(Seq("S1"), Some(Nil -> StructField("s1", innerStruct)))
+    caseSensitiveCheck(Seq("S1"), None)
+    check(Seq("s1", "S12"), Some(Seq("s1") -> StructField("s12", IntegerType)))
+    caseSensitiveCheck(Seq("s1", "S12"), None)
+    check(Seq("S1.non_exist"), None)
+    var e = intercept[AnalysisException] {
+      check(Seq("S1", "S12", "S123"), None)
+    }
+    assert(e.getMessage.contains("Field name S1.S12.S123 is invalid: s1.s12 is 
not a struct"))
+
+    // ambiguous name
+    e = intercept[AnalysisException] {
+      check(Seq("S2", "x"), None)
+    }
+    assert(e.getMessage.contains(
+      "Field name S2.x is ambiguous and has 2 matching fields in the struct"))
+    caseSensitiveCheck(Seq("s2", "x"), Some(Seq("s2") -> StructField("x", 
IntegerType)))
+
+    // simple map type
+    e = intercept[AnalysisException] {
+      check(Seq("m1", "key"), None)
+    }
+    assert(e.getMessage.contains("Field name m1.key is invalid: m1 is not a 
struct"))
+    checkCollection(Seq("m1", "key"), Some(Seq("m1") -> StructField("key", 
IntegerType, false)))
+    checkCollection(Seq("M1", "value"), Some(Seq("m1") -> StructField("value", 
IntegerType)))
+    e = intercept[AnalysisException] {
+      checkCollection(Seq("M1", "key", "name"), None)
+    }
+    assert(e.getMessage.contains("Field name M1.key.name is invalid: m1.key is 
not a struct"))
+    e = intercept[AnalysisException] {
+      checkCollection(Seq("M1", "value", "name"), None)
+    }
+    assert(e.getMessage.contains("Field name M1.value.name is invalid: 
m1.value is not a struct"))
+
+    // map of struct
+    checkCollection(Seq("M2", "key", "A"),
+      Some(Seq("m2", "key") -> StructField("a", IntegerType)))
+    checkCollection(Seq("M2", "key", "non_exist"), None)
+    checkCollection(Seq("M2", "value", "b"),
+      Some(Seq("m2", "value") -> StructField("b", IntegerType)))
+    checkCollection(Seq("M2", "value", "non_exist"), None)
+    e = intercept[AnalysisException] {
+      checkCollection(Seq("m2", "key", "A", "name"), None)
+    }
+    assert(e.getMessage.contains("Field name m2.key.A.name is invalid: 
m2.key.a is not a struct"))
+    e = intercept[AnalysisException] {
+      checkCollection(Seq("M2", "value", "b", "name"), None)
+    }
+    assert(e.getMessage.contains(
+      "Field name M2.value.b.name is invalid: m2.value.b is not a struct"))
+
+    // simple array type
+    e = intercept[AnalysisException] {
+      check(Seq("A1", "element"), None)
+    }
+    assert(e.getMessage.contains("Field name A1.element is invalid: a1 is not 
a struct"))
+    checkCollection(Seq("A1", "element"), Some(Seq("a1") -> 
StructField("element", IntegerType)))
+    e = intercept[AnalysisException] {
+      checkCollection(Seq("A1", "element", "name"), None)
+    }
+    assert(e.getMessage.contains(
+      "Field name A1.element.name is invalid: a1.element is not a struct"))
+
+    // array of struct
+    checkCollection(Seq("A2", "element", "C"),
+      Some(Seq("a2", "element") -> StructField("c", IntegerType)))
+    checkCollection(Seq("A2", "element", "non_exist"), None)
+    e = intercept[AnalysisException] {
+      checkCollection(Seq("a2", "element", "C", "name"), None)
+    }
+    assert(e.getMessage.contains(
+      "Field name a2.element.C.name is invalid: a2.element.c is not a struct"))
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to