This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 017b7d3 [SPARK-36074][SQL] Add error class for
StructType.findNestedField
017b7d3 is described below
commit 017b7d3f0b05da5513014e6b2b76acc16bcd006b
Author: Wenchen Fan <[email protected]>
AuthorDate: Tue Jul 13 21:13:58 2021 +0800
[SPARK-36074][SQL] Add error class for StructType.findNestedField
### What changes were proposed in this pull request?
This PR adds an INVALID_FIELD_NAME error class for the errors in
`StructType.findNestedField`. It also cleans up the code there and adds UT for
this method.
### Why are the changes needed?
follow the new error message framework
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests
Closes #33282 from cloud-fan/error.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
core/src/main/resources/error/error-classes.json | 8 ++
.../org/apache/spark/sql/AnalysisException.scala | 8 +-
.../spark/sql/catalyst/analysis/Analyzer.scala | 4 +-
.../spark/sql/catalyst/analysis/package.scala | 3 +-
.../spark/sql/errors/QueryCompilationErrors.scala | 15 ++-
.../org/apache/spark/sql/types/StructType.scala | 92 ++++++++---------
.../apache/spark/sql/types/StructTypeSuite.scala | 109 +++++++++++++++++++++
7 files changed, 183 insertions(+), 56 deletions(-)
diff --git a/core/src/main/resources/error/error-classes.json
b/core/src/main/resources/error/error-classes.json
index 02feb9d..6ab113b 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -1,4 +1,8 @@
{
+ "AMBIGUOUS_FIELD_NAME" : {
+ "message" : [ "Field name %s is ambiguous and has %s matching fields in
the struct." ],
+ "sqlState" : "42000"
+ },
"DIVIDE_BY_ZERO" : {
"message" : [ "divide by zero" ],
"sqlState" : "22012"
@@ -7,6 +11,10 @@
"message" : [ "Found duplicate keys '%s'" ],
"sqlState" : "23000"
},
+ "INVALID_FIELD_NAME" : {
+ "message" : [ "Field name %s is invalid: %s is not a struct." ],
+ "sqlState" : "42000"
+ },
"MISSING_COLUMN" : {
"message" : [ "cannot resolve '%s' given input columns: [%s]" ],
"sqlState" : "42000"
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 6299431..d0a3a71 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.{SparkThrowable, SparkThrowableHelper}
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
/**
* Thrown when a query fails to analyze, usually because the query itself is
invalid.
@@ -48,12 +49,11 @@ class AnalysisException protected[sql] (
def this(
errorClass: String,
messageParameters: Array[String],
- line: Option[Int],
- startPosition: Option[Int]) =
+ origin: Origin) =
this(
SparkThrowableHelper.getMessage(errorClass, messageParameters),
- line = line,
- startPosition = startPosition,
+ line = origin.line,
+ startPosition = origin.startPosition,
errorClass = Some(errorClass),
messageParameters = messageParameters)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2d747f7..64f6b79 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3613,7 +3613,9 @@ class Analyzer(override val catalogManager:
CatalogManager)
table: ResolvedTable,
fieldName: Seq[String],
context: Expression): ResolvedFieldName = {
- table.schema.findNestedField(fieldName, includeCollections = true,
conf.resolver).map {
+ table.schema.findNestedField(
+ fieldName, includeCollections = true, conf.resolver, context.origin
+ ).map {
case (path, field) => ResolvedFieldName(path, field)
}.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName,
table, context))
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index 8ad8706..81683ad 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -51,8 +51,7 @@ package object analysis {
throw new AnalysisException(
errorClass = errorClass,
messageParameters = messageParameters,
- line = t.origin.line,
- startPosition = t.origin.startPosition)
+ origin = t.origin)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index d1dcbbc..6322676 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -29,7 +29,7 @@ import
org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
AttributeReference, AttributeSet, CreateMap, Expression, GroupingID,
NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction,
WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join,
LogicalPlan, SerdeInfo, Window}
-import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode}
import org.apache.spark.sql.catalyst.util.{toPrettySQL, FailFastMode,
ParseMode, PermissiveMode}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
@@ -1352,9 +1352,12 @@ private[spark] object QueryCompilationErrors {
s"${evalTypes.mkString(",")}")
}
- def ambiguousFieldNameError(fieldName: String, names: String): Throwable = {
+ def ambiguousFieldNameError(
+ fieldName: Seq[String], numMatches: Int, context: Origin): Throwable = {
new AnalysisException(
- s"Ambiguous field name: $fieldName. Found multiple columns that can
match: $names")
+ errorClass = "AMBIGUOUS_FIELD_NAME",
+ messageParameters = Array(fieldName.quoted, numMatches.toString),
+ origin = context)
}
def cannotUseIntervalTypeInTableSchemaError(): Throwable = {
@@ -2359,8 +2362,10 @@ private[spark] object QueryCompilationErrors {
context.origin.startPosition)
}
- def invalidFieldName(fieldName: Seq[String], path: Seq[String]): Throwable =
{
+ def invalidFieldName(fieldName: Seq[String], path: Seq[String], context:
Origin): Throwable = {
new AnalysisException(
- s"Field name ${fieldName.quoted} is invalid, ${path.quoted} is not a
struct.")
+ errorClass = "INVALID_FIELD_NAME",
+ messageParameters = Array(fieldName.quoted, path.quoted),
+ origin = context)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index a9ba2b9..87ff4eb5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -27,6 +27,7 @@ import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, InterpretedOrdering}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
LegacyTypeStringParser}
+import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.catalyst.util.{truncatedString, StringUtils}
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
@@ -317,66 +318,69 @@ case class StructType(fields: Array[StructField]) extends
DataType with Seq[Stru
private[sql] def findNestedField(
fieldNames: Seq[String],
includeCollections: Boolean = false,
- resolver: Resolver = _ == _): Option[(Seq[String], StructField)] = {
- def prettyFieldName(nameParts: Seq[String]): String = {
- import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
- nameParts.quoted
- }
+ resolver: Resolver = _ == _,
+ context: Origin = Origin()): Option[(Seq[String], StructField)] = {
def findField(
struct: StructType,
searchPath: Seq[String],
normalizedPath: Seq[String]): Option[(Seq[String], StructField)] = {
- searchPath.headOption.flatMap { searchName =>
- val found = struct.fields.filter(f => resolver(searchName, f.name))
- if (found.length > 1) {
- val names = found.map(f => prettyFieldName(normalizedPath :+ f.name))
- .mkString("[", ", ", " ]")
- throw QueryCompilationErrors.ambiguousFieldNameError(
- prettyFieldName(normalizedPath :+ searchName), names)
- } else if (found.isEmpty) {
- None
+ assert(searchPath.nonEmpty)
+ val searchName = searchPath.head
+ val found = struct.fields.filter(f => resolver(searchName, f.name))
+ if (found.length > 1) {
+ throw QueryCompilationErrors.ambiguousFieldNameError(fieldNames,
found.length, context)
+ } else if (found.isEmpty) {
+ None
+ } else {
+ val field = found.head
+ val currentPath = normalizedPath :+ field.name
+ val newSearchPath = searchPath.tail
+ if (newSearchPath.isEmpty) {
+ Some(normalizedPath -> field)
} else {
- val field = found.head
- (searchPath.tail, field.dataType, includeCollections) match {
- case (Seq(), _, _) =>
- Some(normalizedPath -> field)
-
- case (names, struct: StructType, _) =>
- findField(struct, names, normalizedPath :+ field.name)
-
- case (_, _, false) =>
- None // types nested in maps and arrays are not used
+ (newSearchPath, field.dataType) match {
+ case (_, s: StructType) =>
+ findField(s, newSearchPath, currentPath)
- case (Seq("key"), MapType(keyType, _, _), true) =>
- // return the key type as a struct field to include nullability
- Some((normalizedPath :+ field.name) -> StructField("key",
keyType, nullable = false))
+ case _ if !includeCollections =>
+ throw QueryCompilationErrors.invalidFieldName(fieldNames,
currentPath, context)
- case (Seq("key", names @ _*), MapType(struct: StructType, _, _),
true) =>
- findField(struct, names, normalizedPath ++ Seq(field.name,
"key"))
+ case (Seq("key", rest @ _*), MapType(keyType, _, _)) =>
+ findFieldInCollection(keyType, false, rest, currentPath, "key")
- case (Seq("value"), MapType(_, valueType, isNullable), true) =>
- // return the value type as a struct field to include nullability
- Some((normalizedPath :+ field.name) ->
- StructField("value", valueType, nullable = isNullable))
+ case (Seq("value", rest @ _*), MapType(_, valueType, isNullable))
=>
+ findFieldInCollection(valueType, isNullable, rest, currentPath,
"value")
- case (Seq("value", names @ _*), MapType(_, struct: StructType, _),
true) =>
- findField(struct, names, normalizedPath ++ Seq(field.name,
"value"))
-
- case (Seq("element"), ArrayType(elementType, isNullable), true) =>
- // return the element type as a struct field to include
nullability
- Some((normalizedPath :+ field.name) ->
- StructField("element", elementType, nullable = isNullable))
-
- case (Seq("element", names @ _*), ArrayType(struct: StructType,
_), true) =>
- findField(struct, names, normalizedPath ++ Seq(field.name,
"element"))
+ case (Seq("element", rest @ _*), ArrayType(elementType,
isNullable)) =>
+ findFieldInCollection(elementType, isNullable, rest,
currentPath, "element")
case _ =>
- throw QueryCompilationErrors.invalidFieldName(fieldNames,
normalizedPath)
+ throw QueryCompilationErrors.invalidFieldName(fieldNames,
currentPath, context)
}
}
}
}
+
+ def findFieldInCollection(
+ dt: DataType,
+ nullable: Boolean,
+ searchPath: Seq[String],
+ normalizedPath: Seq[String],
+ collectionFieldName: String): Option[(Seq[String], StructField)] = {
+ if (searchPath.isEmpty) {
+ Some(normalizedPath -> StructField(collectionFieldName, dt, nullable))
+ } else {
+ val newPath = normalizedPath :+ collectionFieldName
+ dt match {
+ case s: StructType =>
+ findField(s, searchPath, newPath)
+ case _ =>
+ throw QueryCompilationErrors.invalidFieldName(fieldNames, newPath,
context)
+ }
+ }
+ }
+
findField(this, fieldNames, Nil)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
index 18821b8..8db3831 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql.types
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution,
caseSensitiveResolution}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
@@ -273,4 +275,111 @@ class StructTypeSuite extends SparkFunSuite with
SQLHelper {
checkIntervalDDL(start, end, DT.fieldToString)
}
}
+
+ test("findNestedField") {
+ val innerStruct = new StructType()
+ .add("s11", "int")
+ .add("s12", "int")
+ val input = new StructType()
+ .add("s1", innerStruct)
+ .add("s2", new StructType().add("x", "int").add("X", "int"))
+ .add("m1", MapType(IntegerType, IntegerType))
+ .add("m2", MapType(
+ new StructType().add("a", "int"),
+ new StructType().add("b", "int")
+ ))
+ .add("a1", ArrayType(IntegerType))
+ .add("a2", ArrayType(new StructType().add("c", "int")))
+
+ def check(field: Seq[String], expect: Option[(Seq[String], StructField)]):
Unit = {
+ val res = input.findNestedField(field, resolver =
caseInsensitiveResolution)
+ assert(res == expect)
+ }
+
+ def caseSensitiveCheck(field: Seq[String], expect: Option[(Seq[String],
StructField)]): Unit = {
+ val res = input.findNestedField(field, resolver =
caseSensitiveResolution)
+ assert(res == expect)
+ }
+
+ def checkCollection(field: Seq[String], expect: Option[(Seq[String],
StructField)]): Unit = {
+ val res = input.findNestedField(field,
+ includeCollections = true, resolver = caseInsensitiveResolution)
+ assert(res == expect)
+ }
+
+ // struct type
+ check(Seq("non_exist"), None)
+ check(Seq("S1"), Some(Nil -> StructField("s1", innerStruct)))
+ caseSensitiveCheck(Seq("S1"), None)
+ check(Seq("s1", "S12"), Some(Seq("s1") -> StructField("s12", IntegerType)))
+ caseSensitiveCheck(Seq("s1", "S12"), None)
+ check(Seq("S1.non_exist"), None)
+ var e = intercept[AnalysisException] {
+ check(Seq("S1", "S12", "S123"), None)
+ }
+ assert(e.getMessage.contains("Field name S1.S12.S123 is invalid: s1.s12 is
not a struct"))
+
+ // ambiguous name
+ e = intercept[AnalysisException] {
+ check(Seq("S2", "x"), None)
+ }
+ assert(e.getMessage.contains(
+ "Field name S2.x is ambiguous and has 2 matching fields in the struct"))
+ caseSensitiveCheck(Seq("s2", "x"), Some(Seq("s2") -> StructField("x",
IntegerType)))
+
+ // simple map type
+ e = intercept[AnalysisException] {
+ check(Seq("m1", "key"), None)
+ }
+ assert(e.getMessage.contains("Field name m1.key is invalid: m1 is not a
struct"))
+ checkCollection(Seq("m1", "key"), Some(Seq("m1") -> StructField("key",
IntegerType, false)))
+ checkCollection(Seq("M1", "value"), Some(Seq("m1") -> StructField("value",
IntegerType)))
+ e = intercept[AnalysisException] {
+ checkCollection(Seq("M1", "key", "name"), None)
+ }
+ assert(e.getMessage.contains("Field name M1.key.name is invalid: m1.key is
not a struct"))
+ e = intercept[AnalysisException] {
+ checkCollection(Seq("M1", "value", "name"), None)
+ }
+ assert(e.getMessage.contains("Field name M1.value.name is invalid:
m1.value is not a struct"))
+
+ // map of struct
+ checkCollection(Seq("M2", "key", "A"),
+ Some(Seq("m2", "key") -> StructField("a", IntegerType)))
+ checkCollection(Seq("M2", "key", "non_exist"), None)
+ checkCollection(Seq("M2", "value", "b"),
+ Some(Seq("m2", "value") -> StructField("b", IntegerType)))
+ checkCollection(Seq("M2", "value", "non_exist"), None)
+ e = intercept[AnalysisException] {
+ checkCollection(Seq("m2", "key", "A", "name"), None)
+ }
+ assert(e.getMessage.contains("Field name m2.key.A.name is invalid:
m2.key.a is not a struct"))
+ e = intercept[AnalysisException] {
+ checkCollection(Seq("M2", "value", "b", "name"), None)
+ }
+ assert(e.getMessage.contains(
+ "Field name M2.value.b.name is invalid: m2.value.b is not a struct"))
+
+ // simple array type
+ e = intercept[AnalysisException] {
+ check(Seq("A1", "element"), None)
+ }
+ assert(e.getMessage.contains("Field name A1.element is invalid: a1 is not
a struct"))
+ checkCollection(Seq("A1", "element"), Some(Seq("a1") ->
StructField("element", IntegerType)))
+ e = intercept[AnalysisException] {
+ checkCollection(Seq("A1", "element", "name"), None)
+ }
+ assert(e.getMessage.contains(
+ "Field name A1.element.name is invalid: a1.element is not a struct"))
+
+ // array of struct
+ checkCollection(Seq("A2", "element", "C"),
+ Some(Seq("a2", "element") -> StructField("c", IntegerType)))
+ checkCollection(Seq("A2", "element", "non_exist"), None)
+ e = intercept[AnalysisException] {
+ checkCollection(Seq("a2", "element", "C", "name"), None)
+ }
+ assert(e.getMessage.contains(
+ "Field name a2.element.C.name is invalid: a2.element.c is not a struct"))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]