This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 09abea00cb0 [SPARK-42611][SQL] Insert char/varchar length checks for inner fields during resolution 09abea00cb0 is described below commit 09abea00cb0d67336413ac8892617ca824429042 Author: aokolnychyi <aokolnyc...@apple.com> AuthorDate: Wed Mar 1 15:50:10 2023 +0800 [SPARK-42611][SQL] Insert char/varchar length checks for inner fields during resolution ### What changes were proposed in this pull request? This PR adds char/varchar length checks for inner fields during resolution when struct fields are reordered. ### Why are the changes needed? These checks are needed to handle nested char/varchar columns correctly. Prior to this change, we would lose the raw type information when constructing nested attributes. As a result, we will not insert proper char/varchar length checks. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR comes with tests that would previously fail. Closes #40206 from aokolnychyi/spark-42611. Authored-by: aokolnychyi <aokolnyc...@apple.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit d7d8af0dbb47e152b280226a7afcf0771b5a5ae8) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../catalyst/analysis/TableOutputResolver.scala | 62 ++++++++++++++-------- .../apache/spark/sql/CharVarcharTestSuite.scala | 52 ++++++++++++++++++ 2 files changed, 93 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 61d24964d60..e1ee0defa23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -36,20 +36,25 @@ object TableOutputResolver { byName: Boolean, conf: SQLConf): LogicalPlan = { - if (expected.size < query.output.size) { - throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError(tableName, expected, query) + val actualExpectedCols = expected.map { attr => + attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)) + } + + if (actualExpectedCols.size < query.output.size) { + throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError( + tableName, actualExpectedCols, query) } val errors = new mutable.ArrayBuffer[String]() val resolved: Seq[NamedExpression] = if (byName) { - reorderColumnsByName(query.output, expected, conf, errors += _) + reorderColumnsByName(query.output, actualExpectedCols, conf, errors += _) } else { - if (expected.size > query.output.size) { + if (actualExpectedCols.size > query.output.size) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( - tableName, expected, query) + tableName, actualExpectedCols, query) } - query.output.zip(expected).flatMap { + query.output.zip(actualExpectedCols).flatMap { case (queryExpr, tableAttr) => checkField(tableAttr, queryExpr, byName, conf, err => errors += err, Seq(tableAttr.name)) } @@ -254,28 +259,23 @@ object TableOutputResolver { addError: String => Unit, colPath: Seq[String]): Option[NamedExpression] = { + val attrTypeHasCharVarchar = CharVarcharUtils.hasCharVarchar(tableAttr.dataType) + val attrTypeWithoutCharVarchar = if (attrTypeHasCharVarchar) { + CharVarcharUtils.replaceCharVarcharWithString(tableAttr.dataType) + } else { + tableAttr.dataType + } val storeAssignmentPolicy = conf.storeAssignmentPolicy lazy val outputField = if (tableAttr.dataType.sameType(queryExpr.dataType) && tableAttr.name == queryExpr.name && tableAttr.metadata == queryExpr.metadata) { Some(queryExpr) } else { - val casted = storeAssignmentPolicy match { - case StoreAssignmentPolicy.ANSI => - val cast = Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone), - ansiEnabled = true) - cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) - checkCastOverflowInTableInsert(cast, colPath.quoted) - case StoreAssignmentPolicy.LEGACY => - Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone), - ansiEnabled = false) - case _ => - Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) - } - val exprWithStrLenCheck = if (conf.charVarcharAsString) { + val casted = cast(queryExpr, attrTypeWithoutCharVarchar, conf, colPath.quoted) + val exprWithStrLenCheck = if (conf.charVarcharAsString || !attrTypeHasCharVarchar) { casted } else { - CharVarcharUtils.stringLengthCheck(casted, tableAttr) + CharVarcharUtils.stringLengthCheck(casted, tableAttr.dataType) } // Renaming is needed for handling the following cases like // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 @@ -290,7 +290,7 @@ object TableOutputResolver { case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI => // run the type check first to ensure type errors are present val canWrite = DataType.canWrite( - queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, colPath.quoted, + queryExpr.dataType, attrTypeWithoutCharVarchar, byName, conf.resolver, colPath.quoted, storeAssignmentPolicy, addError) if (queryExpr.nullable && !tableAttr.nullable) { addError(s"Cannot write nullable values to non-null column '${colPath.quoted}'") @@ -304,4 +304,24 @@ object TableOutputResolver { } } } + + private def cast( + expr: Expression, + expectedType: DataType, + conf: SQLConf, + colName: String): Expression = { + + conf.storeAssignmentPolicy match { + case StoreAssignmentPolicy.ANSI => + val cast = Cast(expr, expectedType, Option(conf.sessionLocalTimeZone), ansiEnabled = true) + cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) + checkCastOverflowInTableInsert(cast, colName) + + case StoreAssignmentPolicy.LEGACY => + Cast(expr, expectedType, Option(conf.sessionLocalTimeZone), ansiEnabled = false) + + case _ => + Cast(expr, expectedType, Option(conf.sessionLocalTimeZone)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index c0ceebaa9a6..a6c310cd925 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -926,4 +926,56 @@ class DSV2CharVarcharTestSuite extends CharVarcharTestSuite } } } + + test("SPARK-42611: check char/varchar length in reordered nested structs") { + Seq("CHAR(5)", "VARCHAR(5)").foreach { typ => + withTable("t") { + sql(s"CREATE TABLE t(s STRUCT<n_c: $typ, n_i: INT>) USING $format") + + val inputDF = sql("SELECT named_struct('n_i', 1, 'n_c', '123456') AS s") + + val e = intercept[RuntimeException](inputDF.writeTo("t").append()) + assert(e.getMessage.contains("Exceeds char/varchar type length limitation: 5")) + } + } + } + + test("SPARK-42611: check char/varchar length in reordered structs within arrays") { + Seq("CHAR(5)", "VARCHAR(5)").foreach { typ => + withTable("t") { + sql(s"CREATE TABLE t(a ARRAY<STRUCT<n_c: $typ, n_i: INT>>) USING $format") + + val inputDF = sql("SELECT array(named_struct('n_i', 1, 'n_c', '123456')) AS a") + + val e = intercept[SparkException](inputDF.writeTo("t").append()) + assert(e.getCause.getMessage.contains("Exceeds char/varchar type length limitation: 5")) + } + } + } + + test("SPARK-42611: check char/varchar length in reordered structs within map keys") { + Seq("CHAR(5)", "VARCHAR(5)").foreach { typ => + withTable("t") { + sql(s"CREATE TABLE t(m MAP<STRUCT<n_c: $typ, n_i: INT>, INT>) USING $format") + + val inputDF = sql("SELECT map(named_struct('n_i', 1, 'n_c', '123456'), 1) AS m") + + val e = intercept[SparkException](inputDF.writeTo("t").append()) + assert(e.getCause.getMessage.contains("Exceeds char/varchar type length limitation: 5")) + } + } + } + + test("SPARK-42611: check char/varchar length in reordered structs within map values") { + Seq("CHAR(5)", "VARCHAR(5)").foreach { typ => + withTable("t") { + sql(s"CREATE TABLE t(m MAP<INT, STRUCT<n_c: $typ, n_i: INT>>) USING $format") + + val inputDF = sql("SELECT map(1, named_struct('n_i', 1, 'n_c', '123456')) AS m") + + val e = intercept[SparkException](inputDF.writeTo("t").append()) + assert(e.getCause.getMessage.contains("Exceeds char/varchar type length limitation: 5")) + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org