This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new cc249781da6 [SPARK-43742][SQL] Refactor default column value resolution cc249781da6 is described below commit cc249781da6497fe1531c8b99a0b9fc4143a2f83 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Sun May 28 16:24:29 2023 -0700 [SPARK-43742][SQL] Refactor default column value resolution ### What changes were proposed in this pull request? This PR refactors the default column value resolution so that we don't need an extra DS v2 API for external v2 sources. The general idea is to split the default column value resolution into two parts: 1. resolve the column "DEFAULT" to the column default expression. This applies to `Project`/`UnresolvedInlineTable` under `InsertIntoStatement`, and assignment expressions in `UpdateTable`/`MergeIntoTable`. 2. fill missing columns with column default values for the input query. This does not apply to UPDATE and non-INSERT action of MERGE as they use the column from the target table as the default value. The first part should be done for all the data sources, as it's part of column resolution. The second part should not be applied to v2 data sources with `ACCEPT_ANY_SCHEMA`, as they are free to define how to handle missing columns. More concretely, this PR: 1. put the column "DEFAULT" resolution logic in the rule `ResolveReferences`, with two new virtual rules. This is to follow https://github.com/apache/spark/pull/38888 2. put the missing column handling in `TableOutputResolver`, which is shared by both the v1 and v2 insertion resolution rule. External v2 data sources can add custom catalyst rules to deal with missing columns for themselves. 3. Remove the old rule `ResolveDefaultColumns`. Note that, with the refactor, we no long need to manually look up the table. We will deal with column default values after the target table of INSERT/UPDATE/MERGE is resolved. 4. Remove the rule `ResolveUserSpecifiedColumns` and merge it to `PreprocessTableInsertion`. These two rules are both to resolve v1 insertion, and it's tricky to reason about their interactions. It's clearer to resolve the insertion with one pass. ### Why are the changes needed? code cleanup and remove unneeded DS v2 API. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? updated tests Closes #41262 from cloud-fan/def-val. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- core/src/main/resources/error/error-classes.json | 27 +- .../connector/write/SupportsCustomSchemaWrite.java | 38 -- .../spark/sql/catalyst/analysis/Analyzer.scala | 142 +---- .../sql/catalyst/analysis/AssignmentUtils.scala | 8 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 23 +- .../catalyst/analysis/ColumnResolutionHelper.scala | 8 + .../analysis/ResolveColumnDefaultInInsert.scala | 164 +++++ .../catalyst/analysis/ResolveDefaultColumns.scala | 681 --------------------- .../catalyst/analysis/ResolveInsertionBase.scala | 70 +++ .../analysis/ResolveReferencesInUpdate.scala | 72 +++ .../catalyst/analysis/TableOutputResolver.scala | 43 +- .../catalyst/util/ResolveDefaultColumnsUtil.scala | 86 +++ .../spark/sql/errors/QueryCompilationErrors.scala | 25 +- .../execution/datasources/DataSourceStrategy.scala | 3 +- .../spark/sql/execution/datasources/rules.scala | 31 +- .../analyzer-results/postgreSQL/numeric.sql.out | 8 +- .../sql-tests/results/postgreSQL/numeric.sql.out | 8 +- .../org/apache/spark/sql/SQLInsertTestSuite.scala | 24 - .../analysis/ResolveDefaultColumnsSuite.scala | 222 ++----- .../spark/sql/connector/DataSourceV2SQLSuite.scala | 33 +- .../command/AlignMergeAssignmentsSuite.scala | 37 +- .../execution/command/PlanResolutionSuite.scala | 153 ++--- .../org/apache/spark/sql/sources/InsertSuite.scala | 325 +++++----- .../org/apache/spark/sql/hive/InsertSuite.scala | 21 +- .../spark/sql/hive/execution/HiveQuerySuite.scala | 8 +- 25 files changed, 855 insertions(+), 1405 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index f7c0879e1a2..07ff6e1c7c2 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -834,7 +834,18 @@ }, "INSERT_COLUMN_ARITY_MISMATCH" : { "message" : [ - "<tableName> requires that the data to be inserted have the same number of columns as the target table: target table has <targetColumns> column(s) but the inserted data has <insertedColumns> column(s), including <staticPartCols> partition column(s) having constant value(s)." + "Cannot write to '<tableName>', <reason>:", + "Table columns: <tableColumns>.", + "Data columns: <dataColumns>." + ], + "sqlState" : "21S01" + }, + "INSERT_PARTITION_COLUMN_ARITY_MISMATCH" : { + "message" : [ + "Cannot write to '<tableName>', <reason>:", + "Table columns: <tableColumns>.", + "Partition columns with static values: <staticPartCols>.", + "Data columns: <dataColumns>." ], "sqlState" : "21S01" }, @@ -3489,20 +3500,6 @@ "Cannot resolve column name \"<colName>\" among (<fieldNames>)." ] }, - "_LEGACY_ERROR_TEMP_1202" : { - "message" : [ - "Cannot write to '<tableName>', too many data columns:", - "Table columns: <tableColumns>.", - "Data columns: <dataColumns>." - ] - }, - "_LEGACY_ERROR_TEMP_1203" : { - "message" : [ - "Cannot write to '<tableName>', not enough data columns:", - "Table columns: <tableColumns>.", - "Data columns: <dataColumns>." - ] - }, "_LEGACY_ERROR_TEMP_1204" : { "message" : [ "Cannot write incompatible data to table '<tableName>':", diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsCustomSchemaWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsCustomSchemaWrite.java deleted file mode 100644 index 9435625a1c4..00000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsCustomSchemaWrite.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.write; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.types.StructType; - -/** - * Trait for tables that support custom schemas for write operations including INSERT INTO commands - * whose target table columns have explicit or implicit default values. - * - * @since 3.4.1 - */ -@Evolving -public interface SupportsCustomSchemaWrite { - /** - * Represents a table with a custom schema to use for resolving DEFAULT column references when - * inserting into the table. For example, this can be useful for excluding hidden pseudocolumns. - * - * @return the new schema to use for this process. - */ - StructType customSchemaForInserts(); -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 820d7df2807..c1ec9728e56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -55,8 +55,7 @@ import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssig import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY -import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} -import org.apache.spark.util.collection.{Utils => CUtils} +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and @@ -280,7 +279,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor KeepLegacyOutputs), Batch("Resolution", fixedPoint, new ResolveCatalogs(catalogManager) :: - ResolveUserSpecifiedColumns :: ResolveInsertInto :: ResolveRelations :: ResolvePartitionSpec :: @@ -313,7 +311,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor TimeWindowing :: SessionWindowing :: ResolveWindowTime :: - ResolveDefaultColumns(ResolveRelations.resolveRelationOrTempView) :: ResolveInlineTables :: ResolveLambdaVariables :: ResolveTimeZone :: @@ -1080,7 +1077,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor def apply(plan: LogicalPlan) : LogicalPlan = plan.resolveOperatorsUpWithPruning(AlwaysProcess.fn, ruleId) { - case i @ InsertIntoStatement(table, _, _, _, _, _) if i.query.resolved => + case i @ InsertIntoStatement(table, _, _, _, _, _) => val relation = table match { case u: UnresolvedRelation if !u.isStreaming => resolveRelation(u).getOrElse(u) @@ -1280,53 +1277,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } /** Handle INSERT INTO for DSv2 */ - object ResolveInsertInto extends Rule[LogicalPlan] { - - /** Add a project to use the table column names for INSERT INTO BY NAME */ - private def createProjectForByNameQuery(i: InsertIntoStatement): LogicalPlan = { - SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver) - - if (i.userSpecifiedCols.size != i.query.output.size) { - throw QueryCompilationErrors.writeTableWithMismatchedColumnsError( - i.userSpecifiedCols.size, i.query.output.size, i.query) - } - val projectByName = i.userSpecifiedCols.zip(i.query.output) - .map { case (userSpecifiedCol, queryOutputCol) => - val resolvedCol = i.table.resolve(Seq(userSpecifiedCol), resolver) - .getOrElse( - throw QueryCompilationErrors.unresolvedAttributeError( - "UNRESOLVED_COLUMN", userSpecifiedCol, i.table.output.map(_.name), i.origin)) - (queryOutputCol.dataType, resolvedCol.dataType) match { - case (input: StructType, expected: StructType) => - // Rename inner fields of the input column to pass the by-name INSERT analysis. - Alias(Cast(queryOutputCol, renameFieldsInStruct(input, expected)), resolvedCol.name)() - case _ => - Alias(queryOutputCol, resolvedCol.name)() - } - } - Project(projectByName, i.query) - } - - private def renameFieldsInStruct(input: StructType, expected: StructType): StructType = { - if (input.length == expected.length) { - val newFields = input.zip(expected).map { case (f1, f2) => - (f1.dataType, f2.dataType) match { - case (s1: StructType, s2: StructType) => - f1.copy(name = f2.name, dataType = renameFieldsInStruct(s1, s2)) - case _ => - f1.copy(name = f2.name) - } - } - StructType(newFields) - } else { - input - } - } - + object ResolveInsertInto extends ResolveInsertionBase { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( AlwaysProcess.fn, ruleId) { - case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) - if i.query.resolved => + case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) if i.query.resolved => // ifPartitionNotExists is append with validation, but validation is not supported if (i.ifPartitionNotExists) { throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name) @@ -1529,6 +1483,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + // Don't wait other rules to resolve the child plans of `InsertIntoStatement` as we need + // to resolve column "DEFAULT" in the child plans so that they must be unresolved. + case i: InsertIntoStatement => ResolveColumnDefaultInInsert(i) + // Wait for other rules to resolve child plans first case p: LogicalPlan if !p.childrenResolved => p @@ -1648,6 +1606,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // implementation and should be resolved based on the table schema. o.copy(deleteExpr = resolveExpressionByPlanOutput(o.deleteExpr, o.table)) + case u: UpdateTable => ResolveReferencesInUpdate(u) + case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _) if !m.resolved && targetTable.resolved && sourceTable.resolved => @@ -1798,7 +1758,18 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case MergeResolvePolicy.SOURCE => Project(Nil, mergeInto.sourceTable) case MergeResolvePolicy.TARGET => Project(Nil, mergeInto.targetTable) } - resolveMergeExprOrFail(c, resolvePlan) + val resolvedExpr = resolveExprInAssignment(c, resolvePlan) + val withDefaultResolved = if (conf.enableDefaultColumns) { + resolveColumnDefaultInAssignmentValue( + resolvedKey, + resolvedExpr, + QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates()) + } else { + resolvedExpr + } + checkResolvedMergeExpr(withDefaultResolved, resolvePlan) + withDefaultResolved case o => o } Assignment(resolvedKey, resolvedValue) @@ -1806,15 +1777,13 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = { - val resolved = resolveExpressionByPlanChildren(e, p) - resolved.references.filter { attribute: Attribute => - !attribute.resolved && - // We exclude attribute references named "DEFAULT" from consideration since they are - // handled exclusively by the ResolveDefaultColumns analysis rule. That rule checks the - // MERGE command for such references and either replaces each one with a corresponding - // value, or returns a custom error message. - normalizeFieldName(attribute.name) != normalizeFieldName(CURRENT_DEFAULT_COLUMN_NAME) - }.foreach { a => + val resolved = resolveExprInAssignment(e, p) + checkResolvedMergeExpr(resolved, p) + resolved + } + + private def checkResolvedMergeExpr(e: Expression, p: LogicalPlan): Unit = { + e.references.filter(!_.resolved).foreach { a => // Note: This will throw error only on unresolved attribute issues, // not other resolution errors like mismatched data types. val cols = p.inputSet.toSeq.map(_.sql).mkString(", ") @@ -1824,10 +1793,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor "sqlExpr" -> a.sql, "cols" -> cols)) } - resolved match { - case Alias(child: ExtractValue, _) => child - case other => other - } } // Expand the star expression using the input plan first. If failed, try resolve @@ -3359,53 +3324,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } - /** - * A special rule to reorder columns for DSv1 when users specify a column list in INSERT INTO. - * DSv2 is handled by [[ResolveInsertInto]] separately. - */ - object ResolveUserSpecifiedColumns extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( - AlwaysProcess.fn, ruleId) { - case i: InsertIntoStatement if !i.table.isInstanceOf[DataSourceV2Relation] && - i.table.resolved && i.query.resolved && i.userSpecifiedCols.nonEmpty => - val resolved = resolveUserSpecifiedColumns(i) - val projection = addColumnListOnQuery(i.table.output, resolved, i.query) - i.copy(userSpecifiedCols = Nil, query = projection) - } - - private def resolveUserSpecifiedColumns(i: InsertIntoStatement): Seq[NamedExpression] = { - SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver) - - i.userSpecifiedCols.map { col => - i.table.resolve(Seq(col), resolver).getOrElse { - val candidates = i.table.output.map(_.qualifiedName) - val orderedCandidates = StringUtils.orderSuggestedIdentifiersBySimilarity(col, candidates) - throw QueryCompilationErrors - .unresolvedAttributeError("UNRESOLVED_COLUMN", col, orderedCandidates, i.origin) - } - } - } - - private def addColumnListOnQuery( - tableOutput: Seq[Attribute], - cols: Seq[NamedExpression], - query: LogicalPlan): LogicalPlan = { - if (cols.size != query.output.size) { - throw QueryCompilationErrors.writeTableWithMismatchedColumnsError( - cols.size, query.output.size, query) - } - val nameToQueryExpr = CUtils.toMap(cols, query.output) - // Static partition columns in the table output should not appear in the column list - // they will be handled in another rule ResolveInsertInto - val reordered = tableOutput.flatMap { nameToQueryExpr.get(_).orElse(None) } - if (reordered == query.output) { - query - } else { - Project(reordered, query) - } - } - } - private def validateStoreAssignmentPolicy(): Unit = { // SPARK-28730: LEGACY store assignment policy is disallowed in data source v2. if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala index 6d8118548fb..069cef6b361 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal} import org.apache.spark.sql.catalyst.plans.logical.Assignment import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, StructType} @@ -103,8 +104,11 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { case assignment if assignment.key.semanticEquals(attr) => assignment } val resolvedValue = if (matchingAssignments.isEmpty) { - errors += s"No assignment for '${attr.name}'" - attr + val defaultExpr = getDefaultValueExprOrNullLit(attr, conf) + if (defaultExpr.isEmpty) { + errors += s"No assignment for '${attr.name}'" + } + defaultExpr.getOrElse(attr) } else if (matchingAssignments.length > 1) { val conflictingValuesStr = matchingAssignments.map(_.value.sql).mkString(", ") errors += s"Multiple assignments for '${attr.name}': $conflictingValuesStr" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index cafabb22d10..98b50e77198 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -161,6 +161,21 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } def checkAnalysis0(plan: LogicalPlan): Unit = { + // The target table is not a child plan of the insert command. We should report errors for table + // not found first, instead of errors in the input query of the insert command, by doing a + // top-down traversal. + plan.foreach { + case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) => + u.tableNotFound(u.multipartIdentifier) + + // TODO (SPARK-27484): handle streaming write commands when we have them. + case write: V2WriteCommand if write.table.isInstanceOf[UnresolvedRelation] => + val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier + write.table.tableNotFound(tblName) + + case _ => + } + // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { @@ -197,14 +212,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass = "_LEGACY_ERROR_TEMP_2313", messageParameters = Map("name" -> u.name)) - case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) => - u.tableNotFound(u.multipartIdentifier) - - // TODO (SPARK-27484): handle streaming write commands when we have them. - case write: V2WriteCommand if write.table.isInstanceOf[UnresolvedRelation] => - val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier - write.table.tableNotFound(tblName) - case command: V2PartitionCommand => command.table match { case r @ ResolvedTable(_, _, table, _) => table match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 318a23c36af..98cbdea72d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -384,6 +384,14 @@ trait ColumnResolutionHelper extends Logging { allowOuter = allowOuter) } + def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): Expression = { + resolveExpressionByPlanChildren(expr, hostPlan) match { + // Assignment key and value does not need the alias when resolving nested columns. + case Alias(child: ExtractValue, _) => child + case other => other + } + } + private def resolveExpressionByPlanId( e: Expression, q: LogicalPlan): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala new file mode 100644 index 00000000000..f7919664926 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{containsExplicitDefaultColumn, getDefaultValueExprOrNullLit, isExplicitDefaultColumn} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructField + +/** + * A virtual rule to resolve column "DEFAULT" in [[Project]] and [[UnresolvedInlineTable]] under + * [[InsertIntoStatement]]. It's only used by the real rule `ResolveReferences`. + * + * This virtual rule is triggered if: + * 1. The column "DEFAULT" can't be resolved normally by `ResolveReferences`. This is guaranteed as + * `ResolveReferences` resolves the query plan bottom up. This means that when we reach here to + * resolve [[InsertIntoStatement]], its child plans have already been resolved by + * `ResolveReferences`. + * 2. The plan nodes between [[Project]] and [[InsertIntoStatement]] are + * all unary nodes that inherit the output columns from its child. + * 3. The plan nodes between [[UnresolvedInlineTable]] and [[InsertIntoStatement]] are either + * [[Project]], or [[Aggregate]], or [[SubqueryAlias]]. + */ +case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolutionHelper { + // TODO (SPARK-43752): support v2 write commands as well. + def apply(plan: LogicalPlan): LogicalPlan = plan match { + case i: InsertIntoStatement if conf.enableDefaultColumns && i.table.resolved && + i.query.containsPattern(UNRESOLVED_ATTRIBUTE) => + val staticPartCols = i.partitionSpec.filter(_._2.isDefined).keySet.map(normalizeFieldName) + // For INSERT with static partitions, such as `INSERT INTO t PARTITION(c=1) SELECT ...`, the + // input query schema should match the table schema excluding columns with static + // partition values. + val expectedQuerySchema = i.table.schema.filter { field => + !staticPartCols.contains(normalizeFieldName(field.name)) + } + // Normally, we should match the query schema with the table schema by position. If the n-th + // column of the query is the DEFAULT column, we should get the default value expression + // defined for the n-th column of the table. However, if the INSERT has a column list, such as + // `INSERT INTO t(b, c, a)`, the matching should be by name. For example, the first column of + // the query should match the column 'b' of the table. + // To simplify the implementation, `resolveColumnDefault` always does by-position match. If + // the INSERT has a column list, we reorder the table schema w.r.t. the column list and pass + // the reordered schema as the expected schema to `resolveColumnDefault`. + if (i.userSpecifiedCols.isEmpty) { + i.withNewChildren(Seq(resolveColumnDefault(i.query, expectedQuerySchema))) + } else { + val colNamesToFields: Map[String, StructField] = expectedQuerySchema.map { field => + normalizeFieldName(field.name) -> field + }.toMap + val reorder = i.userSpecifiedCols.map { col => + colNamesToFields.get(normalizeFieldName(col)) + } + if (reorder.forall(_.isDefined)) { + i.withNewChildren(Seq(resolveColumnDefault(i.query, reorder.flatten))) + } else { + i + } + } + + case _ => plan + } + + /** + * Resolves the column "DEFAULT" in [[Project]] and [[UnresolvedInlineTable]]. A column is a + * "DEFAULT" column if all the following conditions are met: + * 1. The expression inside project list or inline table expressions is a single + * [[UnresolvedAttribute]] with name "DEFAULT". This means `SELECT DEFAULT, ...` is valid but + * `SELECT DEFAULT + 1, ...` is not. + * 2. The project list or inline table expressions have less elements than the expected schema. + * To find the default value definition, we need to find the matching column for expressions + * inside project list or inline table expressions. This matching is by position and it + * doesn't make sense if we have more expressions than the columns of expected schema. + * 3. The plan nodes between [[Project]] and [[InsertIntoStatement]] are + * all unary nodes that inherit the output columns from its child. + * 4. The plan nodes between [[UnresolvedInlineTable]] and [[InsertIntoStatement]] are either + * [[Project]], or [[Aggregate]], or [[SubqueryAlias]]. + */ + private def resolveColumnDefault( + plan: LogicalPlan, + expectedQuerySchema: Seq[StructField], + acceptProject: Boolean = true, + acceptInlineTable: Boolean = true): LogicalPlan = { + plan match { + case _: SubqueryAlias => + plan.mapChildren( + resolveColumnDefault(_, expectedQuerySchema, acceptProject, acceptInlineTable)) + + case _: GlobalLimit | _: LocalLimit | _: Offset | _: Sort if acceptProject => + plan.mapChildren( + resolveColumnDefault(_, expectedQuerySchema, acceptInlineTable = false)) + + case p: Project if acceptProject && p.child.resolved && + p.containsPattern(UNRESOLVED_ATTRIBUTE) && + p.projectList.length <= expectedQuerySchema.length => + val newProjectList = p.projectList.zipWithIndex.map { + case (u: UnresolvedAttribute, i) if isExplicitDefaultColumn(u) => + Alias(getDefaultValueExprOrNullLit(expectedQuerySchema(i)), u.name)() + case (other, _) if containsExplicitDefaultColumn(other) => + throw QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList() + case (other, _) => other + } + val newChild = resolveColumnDefault(p.child, expectedQuerySchema, acceptProject = false) + val newProj = p.copy(projectList = newProjectList, child = newChild) + newProj.copyTagsFrom(p) + newProj + + case _: Project | _: Aggregate if acceptInlineTable => + plan.mapChildren(resolveColumnDefault(_, expectedQuerySchema, acceptProject = false)) + + case inlineTable: UnresolvedInlineTable if acceptInlineTable && + inlineTable.containsPattern(UNRESOLVED_ATTRIBUTE) && + inlineTable.rows.forall(exprs => exprs.length <= expectedQuerySchema.length) => + val newRows = inlineTable.rows.map { exprs => + exprs.zipWithIndex.map { + case (u: UnresolvedAttribute, i) if isExplicitDefaultColumn(u) => + getDefaultValueExprOrNullLit(expectedQuerySchema(i)) + case (other, _) if containsExplicitDefaultColumn(other) => + throw QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList() + case (other, _) => other + } + } + val newInlineTable = inlineTable.copy(rows = newRows) + newInlineTable.copyTagsFrom(inlineTable) + newInlineTable + + case other => other + } + } + + /** + * Normalizes a schema field name suitable for use in looking up into maps keyed by schema field + * names. + * @param str the field name to normalize + * @return the normalized result + */ + private def normalizeFieldName(str: String): String = { + if (SQLConf.get.caseSensitiveAnalysis) { + str + } else { + str.toLowerCase() + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala deleted file mode 100644 index 13e9866645a..00000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala +++ /dev/null @@ -1,681 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.catalog.UnresolvedCatalogRelation -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.connector.write.SupportsCustomSchemaWrite -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ - -/** - * This is a rule to process DEFAULT columns in statements such as CREATE/REPLACE TABLE. - * - * Background: CREATE TABLE and ALTER TABLE invocations support setting column default values for - * later operations. Following INSERT, UPDATE, and MERGE commands may then reference the value - * using the DEFAULT keyword as needed. - * - * Example: - * CREATE TABLE T(a INT DEFAULT 4, b INT NOT NULL DEFAULT 5); - * INSERT INTO T VALUES (1, 2); - * INSERT INTO T VALUES (1, DEFAULT); - * INSERT INTO T VALUES (DEFAULT, 6); - * SELECT * FROM T; - * (1, 2) - * (1, 5) - * (4, 6) - * - * @param resolveRelation function to resolve relations from the catalog. This should generally map - * to the 'resolveRelationOrTempView' method of the ResolveRelations rule. - */ -case class ResolveDefaultColumns( - resolveRelation: UnresolvedRelation => LogicalPlan) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperatorsWithPruning( - (_ => SQLConf.get.enableDefaultColumns), ruleId) { - case i: InsertIntoStatement if insertsFromInlineTable(i) => - resolveDefaultColumnsForInsertFromInlineTable(i) - case i: InsertIntoStatement if insertsFromProject(i).isDefined => - resolveDefaultColumnsForInsertFromProject(i) - case u: UpdateTable => - resolveDefaultColumnsForUpdate(u) - case m: MergeIntoTable => - resolveDefaultColumnsForMerge(m) - } - } - - /** - * Checks if a logical plan is an INSERT INTO command where the inserted data comes from a VALUES - * list, with possible projection(s), aggregate(s), and/or alias(es) in between. - */ - private def insertsFromInlineTable(i: InsertIntoStatement): Boolean = { - var query = i.query - while (query.children.size == 1) { - query match { - case _: Project | _: Aggregate | _: SubqueryAlias => - query = query.children(0) - case _ => - return false - } - } - query match { - case u: UnresolvedInlineTable - if u.rows.nonEmpty && u.rows.forall(_.size == u.rows(0).size) => - true - case r: LocalRelation - if r.data.nonEmpty && r.data.forall(_.numFields == r.data(0).numFields) => - true - case _ => - false - } - } - - /** - * Checks if a logical plan is an INSERT INTO command where the inserted data comes from a SELECT - * list, with possible other unary operators like sorting and/or alias(es) in between. - */ - private def insertsFromProject(i: InsertIntoStatement): Option[Project] = { - var node = i.query - def matches(node: LogicalPlan): Boolean = node match { - case _: GlobalLimit | _: LocalLimit | _: Offset | _: SubqueryAlias | _: Sort => true - case _ => false - } - while (matches(node)) { - node = node.children.head - } - node match { - case p: Project => Some(p) - case _ => None - } - } - - /** - * Resolves DEFAULT column references for an INSERT INTO command satisfying the - * [[insertsFromInlineTable]] method. - */ - private def resolveDefaultColumnsForInsertFromInlineTable(i: InsertIntoStatement): LogicalPlan = { - val children = mutable.Buffer.empty[LogicalPlan] - var node = i.query - while (node.children.size == 1) { - children.append(node) - node = node.children(0) - } - val insertTableSchemaWithoutPartitionColumns: Option[StructType] = - getInsertTableSchemaWithoutPartitionColumns(i) - insertTableSchemaWithoutPartitionColumns.map { schema: StructType => - val regenerated: InsertIntoStatement = - regenerateUserSpecifiedCols(i, schema) - val (expanded: LogicalPlan, addedDefaults: Boolean) = - addMissingDefaultValuesForInsertFromInlineTable(node, schema, i.userSpecifiedCols.size) - val replaced: Option[LogicalPlan] = - replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults) - replaced.map { r: LogicalPlan => - node = r - for (child <- children.reverse) { - node = child.withNewChildren(Seq(node)) - } - regenerated.copy(query = node) - }.getOrElse(i) - }.getOrElse(i) - } - - /** - * Resolves DEFAULT column references for an INSERT INTO command whose query is a general - * projection. - */ - private def resolveDefaultColumnsForInsertFromProject(i: InsertIntoStatement): LogicalPlan = { - val insertTableSchemaWithoutPartitionColumns: Option[StructType] = - getInsertTableSchemaWithoutPartitionColumns(i) - insertTableSchemaWithoutPartitionColumns.map { schema => - val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i, schema) - val project: Project = insertsFromProject(i).get - if (project.projectList.exists(_.isInstanceOf[Star])) { - i - } else { - val (expanded: Project, addedDefaults: Boolean) = - addMissingDefaultValuesForInsertFromProject(project, schema, i.userSpecifiedCols.size) - val replaced: Option[LogicalPlan] = - replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults) - replaced.map { r => - // Replace the INSERT INTO source relation, copying unary operators until we reach the - // original projection which we replace with the new projection with new values. - def replace(plan: LogicalPlan): LogicalPlan = plan match { - case _: Project => r - case u: UnaryNode => u.withNewChildren(Seq(replace(u.child))) - } - regenerated.copy(query = replace(regenerated.query)) - }.getOrElse(i) - } - }.getOrElse(i) - } - - /** - * Resolves DEFAULT column references for an UPDATE command. - */ - private def resolveDefaultColumnsForUpdate(u: UpdateTable): LogicalPlan = { - // Return a more descriptive error message if the user tries to use a DEFAULT column reference - // inside an UPDATE command's WHERE clause; this is not allowed. - u.condition.foreach { c: Expression => - if (c.find(isExplicitDefaultColumn).isDefined) { - throw QueryCompilationErrors.defaultReferencesNotAllowedInUpdateWhereClause() - } - } - val schemaForTargetTable: Option[StructType] = getSchemaForTargetTable(u.table) - schemaForTargetTable.map { schema => - val defaultExpressions: Seq[Expression] = schema.fields.map { - case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "UPDATE") - case _ => Literal(null) - } - // Create a map from each column name in the target table to its DEFAULT expression. - val columnNamesToExpressions: Map[String, Expression] = - mapStructFieldNamesToExpressions(schema, defaultExpressions) - // For each assignment in the UPDATE command's SET clause with a DEFAULT column reference on - // the right-hand side, look up the corresponding expression from the above map. - val newAssignments: Option[Seq[Assignment]] = - replaceExplicitDefaultValuesForUpdateAssignments( - u.assignments, CommandType.Update, columnNamesToExpressions) - newAssignments.map { n => - u.copy(assignments = n) - }.getOrElse(u) - }.getOrElse(u) - } - - /** - * Resolves DEFAULT column references for a MERGE INTO command. - */ - private def resolveDefaultColumnsForMerge(m: MergeIntoTable): LogicalPlan = { - val schema: StructType = getSchemaForTargetTable(m.targetTable).getOrElse(return m) - // Return a more descriptive error message if the user tries to use a DEFAULT column reference - // inside an UPDATE command's WHERE clause; this is not allowed. - m.mergeCondition.foreach { c: Expression => - if (c.find(isExplicitDefaultColumn).isDefined) { - throw QueryCompilationErrors.defaultReferencesNotAllowedInMergeCondition() - } - } - val columnsWithDefaults = ArrayBuffer.empty[String] - val defaultExpressions: Seq[Expression] = schema.fields.map { - case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => - columnsWithDefaults.append(normalizeFieldName(f.name)) - analyze(f, "MERGE") - case _ => Literal(null) - } - val columnNamesToExpressions: Map[String, Expression] = - mapStructFieldNamesToExpressions(schema, defaultExpressions) - var replaced = false - val newMatchedActions: Seq[MergeAction] = m.matchedActions.map { action: MergeAction => - replaceExplicitDefaultValuesInMergeAction(action, columnNamesToExpressions).map { r => - replaced = true - r - }.getOrElse(action) - } - val newNotMatchedActions: Seq[MergeAction] = m.notMatchedActions.map { action: MergeAction => - val expanded = addMissingDefaultValuesForMergeAction(action, m, columnsWithDefaults.toSeq) - replaceExplicitDefaultValuesInMergeAction(expanded, columnNamesToExpressions).map { r => - replaced = true - r - }.getOrElse(action) - } - val newNotMatchedBySourceActions: Seq[MergeAction] = - m.notMatchedBySourceActions.map { action: MergeAction => - replaceExplicitDefaultValuesInMergeAction(action, columnNamesToExpressions).map { r => - replaced = true - r - }.getOrElse(action) - } - if (replaced) { - m.copy(matchedActions = newMatchedActions, - notMatchedActions = newNotMatchedActions, - notMatchedBySourceActions = newNotMatchedBySourceActions) - } else { - m - } - } - - /** Adds a new expressions to a merge action to generate missing default column values. */ - def addMissingDefaultValuesForMergeAction( - action: MergeAction, - m: MergeIntoTable, - columnNamesWithDefaults: Seq[String]): MergeAction = { - action match { - case i: InsertAction => - val targetColumns: Set[String] = i.assignments.map(_.key).flatMap { expr => - expr match { - case a: AttributeReference => Seq(normalizeFieldName(a.name)) - case u: UnresolvedAttribute => Seq(u.nameParts.map(normalizeFieldName).mkString(".")) - case _ => Seq() - } - }.toSet - val targetTable: String = m.targetTable match { - case SubqueryAlias(id, _) => id.name - case d: DataSourceV2Relation => d.name - } - val missingColumnNamesWithDefaults = columnNamesWithDefaults.filter { name => - !targetColumns.contains(normalizeFieldName(name)) && - !targetColumns.contains( - s"${normalizeFieldName(targetTable)}.${normalizeFieldName(name)}") - } - val newAssignments: Seq[Assignment] = missingColumnNamesWithDefaults.map { key => - Assignment(UnresolvedAttribute(key), UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME)) - } - i.copy(assignments = i.assignments ++ newAssignments) - case _ => - action - } - } - - /** - * Replaces unresolved DEFAULT column references with corresponding values in one action of a - * MERGE INTO command. - */ - private def replaceExplicitDefaultValuesInMergeAction( - action: MergeAction, - columnNamesToExpressions: Map[String, Expression]): Option[MergeAction] = { - action match { - case u: UpdateAction => - val replaced: Option[Seq[Assignment]] = - replaceExplicitDefaultValuesForUpdateAssignments( - u.assignments, CommandType.Merge, columnNamesToExpressions) - replaced.map { r => - Some(u.copy(assignments = r)) - }.getOrElse(None) - case i: InsertAction => - val replaced: Option[Seq[Assignment]] = - replaceExplicitDefaultValuesForUpdateAssignments( - i.assignments, CommandType.Merge, columnNamesToExpressions) - replaced.map { r => - Some(i.copy(assignments = r)) - }.getOrElse(None) - case _ => Some(action) - } - } - - /** - * Regenerates user-specified columns of an InsertIntoStatement based on the names in the - * insertTableSchemaWithoutPartitionColumns field of this class. - */ - private def regenerateUserSpecifiedCols( - i: InsertIntoStatement, - insertTableSchemaWithoutPartitionColumns: StructType): InsertIntoStatement = { - if (i.userSpecifiedCols.nonEmpty) { - i.copy( - userSpecifiedCols = insertTableSchemaWithoutPartitionColumns.fields.map(_.name)) - } else { - i - } - } - - /** - * Returns true if an expression is an explicit DEFAULT column reference. - */ - private def isExplicitDefaultColumn(expr: Expression): Boolean = expr match { - case u: UnresolvedAttribute if u.name.equalsIgnoreCase(CURRENT_DEFAULT_COLUMN_NAME) => true - case _ => false - } - - /** - * Updates an inline table to generate missing default column values. - * Returns the resulting plan plus a boolean indicating whether such values were added. - */ - def addMissingDefaultValuesForInsertFromInlineTable( - node: LogicalPlan, - insertTableSchemaWithoutPartitionColumns: StructType, - numUserSpecifiedColumns: Int): (LogicalPlan, Boolean) = { - val schema = insertTableSchemaWithoutPartitionColumns - val newDefaultExpressions: Seq[UnresolvedAttribute] = - getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, node.output.size) - val newNames: Seq[String] = schema.fields.map(_.name) - val resultPlan: LogicalPlan = node match { - case _ if newDefaultExpressions.isEmpty => - node - case table: UnresolvedInlineTable => - table.copy( - names = newNames, - rows = table.rows.map { row => row ++ newDefaultExpressions }) - case local: LocalRelation => - val newDefaultExpressionsRow = new GenericInternalRow( - // Note that this code path only runs when there is a user-specified column list of fewer - // column than the target table; otherwise, the above 'newDefaultExpressions' is empty and - // we match the first case in this list instead. - schema.fields.drop(local.output.size).map { - case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => - analyze(f, "INSERT") match { - case lit: Literal => lit.value - case _ => null - } - case _ => null - }) - LocalRelation( - output = schema.toAttributes, - data = local.data.map { row => - new JoinedRow(row, newDefaultExpressionsRow) - }) - case _ => - node - } - (resultPlan, newDefaultExpressions.nonEmpty) - } - - /** - * Adds a new expressions to a projection to generate missing default column values. - * Returns the logical plan plus a boolean indicating if such defaults were added. - */ - private def addMissingDefaultValuesForInsertFromProject( - project: Project, - insertTableSchemaWithoutPartitionColumns: StructType, - numUserSpecifiedColumns: Int): (Project, Boolean) = { - val schema = insertTableSchemaWithoutPartitionColumns - val newDefaultExpressions: Seq[Expression] = - getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, project.projectList.size) - val newAliases: Seq[NamedExpression] = - newDefaultExpressions.zip(schema.fields).map { - case (expr, field) => Alias(expr, field.name)() - } - (project.copy(projectList = project.projectList ++ newAliases), - newDefaultExpressions.nonEmpty) - } - - /** - * This is a helper for the addMissingDefaultValuesForInsertFromInlineTable methods above. - */ - private def getNewDefaultExpressionsForInsert( - insertTableSchemaWithoutPartitionColumns: StructType, - numUserSpecifiedColumns: Int, - numProvidedValues: Int): Seq[UnresolvedAttribute] = { - val remainingFields: Seq[StructField] = if (numUserSpecifiedColumns > 0) { - insertTableSchemaWithoutPartitionColumns.fields.drop(numUserSpecifiedColumns) - } else { - Seq.empty - } - val numDefaultExpressionsToAdd = getStructFieldsForDefaultExpressions(remainingFields).size - // Limit the number of new DEFAULT expressions to the difference of the number of columns in - // the target table and the number of provided values in the source relation. This clamps the - // total final number of provided values to the number of columns in the target table. - .min(insertTableSchemaWithoutPartitionColumns.size - numProvidedValues) - Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME)) - } - - /** - * This is a helper for the getDefaultExpressionsForInsert methods above. - */ - private def getStructFieldsForDefaultExpressions(fields: Seq[StructField]): Seq[StructField] = { - if (SQLConf.get.useNullsForMissingDefaultColumnValues) { - fields - } else { - fields.takeWhile(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) - } - } - - /** - * Replaces unresolved DEFAULT column references with corresponding values in an INSERT INTO - * command from a logical plan. - */ - private def replaceExplicitDefaultValuesForInputOfInsertInto( - insertTableSchemaWithoutPartitionColumns: StructType, - input: LogicalPlan, - addedDefaults: Boolean): Option[LogicalPlan] = { - val schema = insertTableSchemaWithoutPartitionColumns - val defaultExpressions: Seq[Expression] = schema.fields.map { - case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "INSERT") - case _ => Literal(null) - } - // Check the type of `input` and replace its expressions accordingly. - // If necessary, return a more descriptive error message if the user tries to nest the DEFAULT - // column reference inside some other expression, such as DEFAULT + 1 (this is not allowed). - // - // Note that we don't need to check if "SQLConf.get.useNullsForMissingDefaultColumnValues" after - // this point because this method only takes responsibility to replace *existing* DEFAULT - // references. In contrast, the "getDefaultExpressionsForInsert" method will check that config - // and add new NULLs if needed. - input match { - case table: UnresolvedInlineTable => - replaceExplicitDefaultValuesForInlineTable(defaultExpressions, table) - case project: Project => - replaceExplicitDefaultValuesForProject(defaultExpressions, project) - case local: LocalRelation => - if (addedDefaults) { - Some(local) - } else { - None - } - } - } - - /** - * Replaces unresolved DEFAULT column references with corresponding values in an inline table. - */ - private def replaceExplicitDefaultValuesForInlineTable( - defaultExpressions: Seq[Expression], - table: UnresolvedInlineTable): Option[LogicalPlan] = { - var replaced = false - val updated: Seq[Seq[Expression]] = { - table.rows.map { row: Seq[Expression] => - for { - i <- row.indices - expr = row(i) - defaultExpr = if (i < defaultExpressions.size) defaultExpressions(i) else Literal(null) - } yield replaceExplicitDefaultReferenceInExpression( - expr, defaultExpr, CommandType.Insert, addAlias = false).map { e => - replaced = true - e - }.getOrElse(expr) - } - } - if (replaced) { - Some(table.copy(rows = updated)) - } else { - None - } - } - - /** - * Replaces unresolved DEFAULT column references with corresponding values in a projection. - */ - private def replaceExplicitDefaultValuesForProject( - defaultExpressions: Seq[Expression], - project: Project): Option[LogicalPlan] = { - var replaced = false - val updated: Seq[NamedExpression] = { - for { - i <- project.projectList.indices - projectExpr = project.projectList(i) - defaultExpr = if (i < defaultExpressions.size) defaultExpressions(i) else Literal(null) - } yield replaceExplicitDefaultReferenceInExpression( - projectExpr, defaultExpr, CommandType.Insert, addAlias = true).map { e => - replaced = true - e.asInstanceOf[NamedExpression] - }.getOrElse(projectExpr) - } - if (replaced) { - Some(project.copy(projectList = updated)) - } else { - None - } - } - - /** - * Represents a type of command we are currently processing. - */ - private object CommandType extends Enumeration { - val Insert, Update, Merge = Value - } - - /** - * Checks if a given input expression is an unresolved "DEFAULT" attribute reference. - * - * @param input the input expression to examine. - * @param defaultExpr the default to return if [[input]] is an unresolved "DEFAULT" reference. - * @param isInsert the type of command we are currently processing. - * @param addAlias if true, wraps the result with an alias of the original default column name. - * @return [[defaultExpr]] if [[input]] is an unresolved "DEFAULT" attribute reference. - */ - private def replaceExplicitDefaultReferenceInExpression( - input: Expression, - defaultExpr: Expression, - command: CommandType.Value, - addAlias: Boolean): Option[Expression] = { - input match { - case a@Alias(u: UnresolvedAttribute, _) - if isExplicitDefaultColumn(u) => - Some(Alias(defaultExpr, a.name)()) - case u: UnresolvedAttribute - if isExplicitDefaultColumn(u) => - if (addAlias) { - Some(Alias(defaultExpr, u.name)()) - } else { - Some(defaultExpr) - } - case expr@_ - if expr.find(isExplicitDefaultColumn).isDefined => - command match { - case CommandType.Insert => - throw QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList() - case CommandType.Update => - throw QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause() - case CommandType.Merge => - throw QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates() - } - case _ => - None - } - } - - /** - * Looks up the schema for the table object of an INSERT INTO statement from the catalog. - */ - private def getInsertTableSchemaWithoutPartitionColumns( - enclosingInsert: InsertIntoStatement): Option[StructType] = { - val target: StructType = getSchemaForTargetTable(enclosingInsert.table).getOrElse(return None) - val schema: StructType = StructType(target.fields.dropRight(enclosingInsert.partitionSpec.size)) - // Rearrange the columns in the result schema to match the order of the explicit column list, - // if any. - val userSpecifiedCols: Seq[String] = enclosingInsert.userSpecifiedCols - if (userSpecifiedCols.isEmpty) { - return Some(schema) - } - val colNamesToFields: Map[String, StructField] = mapStructFieldNamesToFields(schema) - val userSpecifiedFields: Seq[StructField] = - userSpecifiedCols.map { - name: String => colNamesToFields.getOrElse(normalizeFieldName(name), return None) - } - val userSpecifiedColNames: Set[String] = userSpecifiedCols.toSet - .map(normalizeFieldName) - val nonUserSpecifiedFields: Seq[StructField] = - schema.fields.filter { - field => !userSpecifiedColNames.contains( - normalizeFieldName( - field.name - ) - ) - } - Some(StructType(userSpecifiedFields ++ - getStructFieldsForDefaultExpressions(nonUserSpecifiedFields))) - } - - /** - * Returns a map of the names of fields in a schema to the fields themselves. - */ - private def mapStructFieldNamesToFields(schema: StructType): Map[String, StructField] = { - schema.fields.map { - field: StructField => normalizeFieldName(field.name) -> field - }.toMap - } - - /** - * Returns a map of the names of fields in a schema to corresponding expressions. - */ - private def mapStructFieldNamesToExpressions( - schema: StructType, - expressions: Seq[Expression]): Map[String, Expression] = { - schema.fields.zip(expressions).map { - case (field: StructField, expression: Expression) => - normalizeFieldName(field.name) -> expression - }.toMap - } - - /** - * Returns the schema for the target table of a DML command, looking into the catalog if needed. - */ - private def getSchemaForTargetTable(table: LogicalPlan): Option[StructType] = { - val resolved = table match { - case r: UnresolvedRelation if !r.skipSchemaResolution && !r.isStreaming => - resolveRelation(r) - case other => - other - } - resolved.collectFirst { - case r: UnresolvedCatalogRelation => - r.tableMeta.schema - case DataSourceV2Relation(table: SupportsCustomSchemaWrite, _, _, _, _) => - table.customSchemaForInserts - case r: NamedRelation if !r.skipSchemaResolution => - r.schema - case v: View if v.isTempViewStoringAnalyzedPlan => - v.schema - } - } - - /** - * Replaces unresolved DEFAULT column references with corresponding values in a series of - * assignments in an UPDATE assignment, either comprising an UPDATE command or as part of a MERGE. - */ - private def replaceExplicitDefaultValuesForUpdateAssignments( - assignments: Seq[Assignment], - command: CommandType.Value, - columnNamesToExpressions: Map[String, Expression]): Option[Seq[Assignment]] = { - var replaced = false - val newAssignments: Seq[Assignment] = - for (assignment <- assignments) yield { - val destColName = assignment.key match { - case a: AttributeReference => a.name - case u: UnresolvedAttribute => u.nameParts.last - case _ => "" - } - val adjusted: String = normalizeFieldName(destColName) - val lookup: Option[Expression] = columnNamesToExpressions.get(adjusted) - val newValue: Expression = lookup.map { defaultExpr => - val updated: Option[Expression] = - replaceExplicitDefaultReferenceInExpression( - assignment.value, - defaultExpr, - command, - addAlias = false) - updated.map { e => - replaced = true - e - }.getOrElse(assignment.value) - }.getOrElse(assignment.value) - assignment.copy(value = newValue) - } - if (replaced) { - Some(newAssignments) - } else { - None - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala new file mode 100644 index 00000000000..71d36867951 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.SchemaUtils + +abstract class ResolveInsertionBase extends Rule[LogicalPlan] { + def resolver: Resolver = conf.resolver + + /** Add a project to use the table column names for INSERT INTO BY NAME */ + protected def createProjectForByNameQuery(i: InsertIntoStatement): LogicalPlan = { + SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver) + + if (i.userSpecifiedCols.size != i.query.output.size) { + throw QueryCompilationErrors.writeTableWithMismatchedColumnsError( + i.userSpecifiedCols.size, i.query.output.size, i.query) + } + val projectByName = i.userSpecifiedCols.zip(i.query.output) + .map { case (userSpecifiedCol, queryOutputCol) => + val resolvedCol = i.table.resolve(Seq(userSpecifiedCol), resolver) + .getOrElse( + throw QueryCompilationErrors.unresolvedAttributeError( + "UNRESOLVED_COLUMN", userSpecifiedCol, i.table.output.map(_.name), i.origin)) + (queryOutputCol.dataType, resolvedCol.dataType) match { + case (input: StructType, expected: StructType) => + // Rename inner fields of the input column to pass the by-name INSERT analysis. + Alias(Cast(queryOutputCol, renameFieldsInStruct(input, expected)), resolvedCol.name)() + case _ => + Alias(queryOutputCol, resolvedCol.name)() + } + } + Project(projectByName, i.query) + } + + private def renameFieldsInStruct(input: StructType, expected: StructType): StructType = { + if (input.length == expected.length) { + val newFields = input.zip(expected).map { case (f1, f2) => + (f1.dataType, f2.dataType) match { + case (s1: StructType, s2: StructType) => + f1.copy(name = f2.name, dataType = renameFieldsInStruct(s1, s2)) + case _ => + f1.copy(name = f2.name) + } + } + StructType(newFields) + } else { + input + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala new file mode 100644 index 00000000000..cebc1e25f92 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.resolveColumnDefaultInAssignmentValue +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * A virtual rule to resolve [[UnresolvedAttribute]] in [[UpdateTable]]. It's only used by the real + * rule `ResolveReferences`. The column resolution order for [[UpdateTable]] is: + * 1. Resolves the column to `AttributeReference`` with the output of the child plan. This + * includes metadata columns as well. + * 2. Resolves the column to a literal function which is allowed to be invoked without braces, e.g. + * `SELECT col, current_date FROM t`. + * 3. Resolves the column to the default value expression, if the column is the assignment value + * and the corresponding assignment key is a top-level column. + */ +case object ResolveReferencesInUpdate extends SQLConfHelper with ColumnResolutionHelper { + + def apply(u: UpdateTable): UpdateTable = { + assert(u.table.resolved) + if (u.resolved) return u + + val newAssignments = u.assignments.map { assign => + val resolvedKey = assign.key match { + case c if !c.resolved => + resolveExprInAssignment(c, u) + case o => o + } + val resolvedValue = assign.value match { + case c if !c.resolved => + val resolved = resolveExprInAssignment(c, u) + if (conf.enableDefaultColumns) { + resolveColumnDefaultInAssignmentValue( + resolvedKey, + resolved, + QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause()) + } else { + resolved + } + case o => o + } + val resolved = Assignment(resolvedKey, resolvedValue) + resolved.copyTagsFrom(assign) + resolved + } + + val newUpdate = u.copy( + assignments = newAssignments, + condition = u.condition.map(resolveExpressionByPlanChildren(_, u))) + newUpdate.copyTagsFrom(u) + newUpdate + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index ae159cd349f..b9aca30c754 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -36,7 +37,9 @@ object TableOutputResolver { expected: Seq[Attribute], query: LogicalPlan, byName: Boolean, - conf: SQLConf): LogicalPlan = { + conf: SQLConf, + // TODO: Only DS v1 writing will set it to true. We should enable in for DS v2 as well. + supportColDefaultValue: Boolean = false): LogicalPlan = { val actualExpectedCols = expected.map { attr => attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)) @@ -49,14 +52,32 @@ object TableOutputResolver { val errors = new mutable.ArrayBuffer[String]() val resolved: Seq[NamedExpression] = if (byName) { - reorderColumnsByName(query.output, actualExpectedCols, conf, errors += _) + // If a top-level column does not have a corresponding value in the input query, fill with + // the column's default value. We need to pass `fillDefaultValue` as true here, if the + // `supportColDefaultValue` parameter is also true. + reorderColumnsByName( + query.output, + actualExpectedCols, + conf, + errors += _, + fillDefaultValue = supportColDefaultValue) } else { - if (actualExpectedCols.size > query.output.size) { + // If the target table needs more columns than the input query, fill them with + // the columns' default values, if the `supportColDefaultValue` parameter is true. + val fillDefaultValue = supportColDefaultValue && actualExpectedCols.size > query.output.size + val queryOutputCols = if (fillDefaultValue) { + query.output ++ actualExpectedCols.drop(query.output.size).flatMap { expectedCol => + getDefaultValueExprOrNullLit(expectedCol, conf) + } + } else { + query.output + } + if (actualExpectedCols.size > queryOutputCols.size) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( tableName, actualExpectedCols, query) } - resolveColumnsByPosition(query.output, actualExpectedCols, conf, errors += _) + resolveColumnsByPosition(queryOutputCols, actualExpectedCols, conf, errors += _) } if (errors.nonEmpty) { @@ -156,14 +177,22 @@ object TableOutputResolver { expectedCols: Seq[Attribute], conf: SQLConf, addError: String => Unit, - colPath: Seq[String] = Nil): Seq[NamedExpression] = { + colPath: Seq[String] = Nil, + fillDefaultValue: Boolean = false): Seq[NamedExpression] = { val matchedCols = mutable.HashSet.empty[String] val reordered = expectedCols.flatMap { expectedCol => val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name)) val newColPath = colPath :+ expectedCol.name if (matched.isEmpty) { - addError(s"Cannot find data for output column '${newColPath.quoted}'") - None + val defaultExpr = if (fillDefaultValue) { + getDefaultValueExprOrNullLit(expectedCol, conf) + } else { + None + } + if (defaultExpr.isEmpty) { + addError(s"Cannot find data for output column '${newColPath.quoted}'") + } + defaultExpr } else if (matched.length > 1) { addError(s"Ambiguous column name in the input data: '${newColPath.quoted}'") None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 0f5c413ed78..2169137685d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -155,6 +155,92 @@ object ResolveDefaultColumns { } } + /** + * Returns true if the unresolved column is an explicit DEFAULT column reference. + */ + def isExplicitDefaultColumn(col: UnresolvedAttribute): Boolean = { + col.name.equalsIgnoreCase(CURRENT_DEFAULT_COLUMN_NAME) + } + + /** + * Returns true if the given expression contains an explicit DEFAULT column reference. + */ + def containsExplicitDefaultColumn(expr: Expression): Boolean = { + expr.exists { + case u: UnresolvedAttribute => isExplicitDefaultColumn(u) + case _ => false + } + } + + /** + * Resolves the column "DEFAULT" in UPDATE/MERGE assignment value expression if the following + * conditions are met: + * 1. The assignment value expression is a single `UnresolvedAttribute` with name "DEFAULT". This + * means `key = DEFAULT` is allowed but `key = DEFAULT + 1` is not. + * 2. The assignment key expression is a top-level column. This means `col = DEFAULT` is allowed + * but `col.field = DEFAULT` is not. + * + * The column "DEFAULT" will be resolved to the default value expression defined for the column of + * the assignment key. + */ + def resolveColumnDefaultInAssignmentValue( + key: Expression, + value: Expression, + invalidColumnDefaultException: Throwable): Expression = { + key match { + case attr: AttributeReference => + value match { + case u: UnresolvedAttribute if isExplicitDefaultColumn(u) => + getDefaultValueExprOrNullLit(attr) + case other if containsExplicitDefaultColumn(other) => + throw invalidColumnDefaultException + case other => other + } + case _ => value + } + } + + private def getDefaultValueExprOpt(field: StructField): Option[Expression] = { + if (field.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) { + Some(analyze(field, "INSERT")) + } else { + None + } + } + + /** + * Generates the expression of the default value for the given field. If there is no + * user-specified default value for this field, returns null literal. + */ + def getDefaultValueExprOrNullLit(field: StructField): Expression = { + getDefaultValueExprOpt(field).getOrElse(Literal(null, field.dataType)) + } + + /** + * Generates the expression of the default value for the given column. If there is no + * user-specified default value for this field, returns null literal. + */ + def getDefaultValueExprOrNullLit(attr: Attribute): Expression = { + val field = StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + getDefaultValueExprOrNullLit(field) + } + + /** + * Generates the aliased expression of the default value for the given column. If there is no + * user-specified default value for this column, returns a null literal or None w.r.t. the config + * `USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES`. + */ + def getDefaultValueExprOrNullLit(attr: Attribute, conf: SQLConf): Option[NamedExpression] = { + val field = StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + getDefaultValueExprOpt(field).orElse { + if (conf.useNullsForMissingDefaultColumnValues) { + Some(Literal(null, attr.dataType)) + } else { + None + } + }.map(expr => Alias(expr, attr.name)()) + } + /** * Parses and analyzes the DEFAULT column text in `field`, returning an error upon failure. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 45a9a03df4d..1d80ca22550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, InsertIntoStatement, Join, LogicalPlan, SerdeInfo, Window} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Join, LogicalPlan, SerdeInfo, Window} import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.connector.catalog._ @@ -1731,17 +1731,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { "normalizedPartCols" -> normalizedPartCols.mkString(", "))) } - def mismatchedInsertedDataColumnNumberError( - tableName: String, insert: InsertIntoStatement, staticPartCols: Set[String]): Throwable = { - new AnalysisException( - errorClass = "INSERT_COLUMN_ARITY_MISMATCH", - messageParameters = Map( - "tableName" -> tableName, - "targetColumns" -> insert.table.output.size.toString, - "insertedColumns" -> (insert.query.output.length + staticPartCols.size).toString, - "staticPartCols" -> staticPartCols.size.toString)) - } - def requestedPartitionsMismatchTablePartitionsError( tableName: String, normalizedPartSpec: Map[String, Option[String]], @@ -1751,7 +1740,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { messageParameters = Map( "tableName" -> tableName, "normalizedPartSpec" -> normalizedPartSpec.keys.mkString(","), - "partColNames" -> partColNames.mkString(","))) + "partColNames" -> partColNames.map(_.name).mkString(","))) } def ddlWithoutHiveSupportEnabledError(detail: String): Throwable = { @@ -2066,11 +2055,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { } def cannotWriteTooManyColumnsToTableError( - tableName: String, expected: Seq[Attribute], query: LogicalPlan): Throwable = { + tableName: String, + expected: Seq[Attribute], + query: LogicalPlan): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1202", + errorClass = "INSERT_COLUMN_ARITY_MISMATCH", messageParameters = Map( "tableName" -> tableName, + "reason" -> "too many data columns", "tableColumns" -> expected.map(c => s"'${c.name}'").mkString(", "), "dataColumns" -> query.output.map(c => s"'${c.name}'").mkString(", "))) } @@ -2078,9 +2070,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { def cannotWriteNotEnoughColumnsToTableError( tableName: String, expected: Seq[Attribute], query: LogicalPlan): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1203", + errorClass = "INSERT_COLUMN_ARITY_MISMATCH", messageParameters = Map( "tableName" -> tableName, + "reason" -> "not enough data columns", "tableColumns" -> expected.map(c => s"'${c.name}'").mkString(", "), "dataColumns" -> query.output.map(c => s"'${c.name}'").mkString(", "))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 69c7624605b..dd79e9b26d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -164,7 +164,8 @@ object DataSourceAnalysis extends Rule[LogicalPlan] { InsertIntoDataSourceDirCommand(storage, provider.get, query, overwrite) case i @ InsertIntoStatement( - l @ LogicalRelation(t: HadoopFsRelation, _, table, _), parts, _, query, overwrite, _) => + l @ LogicalRelation(t: HadoopFsRelation, _, table, _), parts, _, query, overwrite, _) + if query.resolved => // If the InsertIntoTable command is for a partitioned HadoopFsRelation and // the user has specified static partitions, we add a Project operator on top of the query // to include those constant column values in the query result. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index dc9d0999d1a..b3fdfc76c7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -363,7 +363,7 @@ case class PreprocessTableCreation(catalog: SessionCatalog) extends Rule[Logical * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -object PreprocessTableInsertion extends Rule[LogicalPlan] { +object PreprocessTableInsertion extends ResolveInsertionBase { private def preprocess( insert: InsertIntoStatement, tblName: String, @@ -376,11 +376,6 @@ object PreprocessTableInsertion extends Rule[LogicalPlan] { val staticPartCols = normalizedPartSpec.filter(_._2.isDefined).keySet val expectedColumns = insert.table.output.filterNot(a => staticPartCols.contains(a.name)) - if (expectedColumns.length != insert.query.schema.length) { - throw QueryCompilationErrors.mismatchedInsertedDataColumnNumberError( - tblName, insert, staticPartCols) - } - val partitionsTrackedByCatalog = catalogTable.isDefined && catalogTable.get.partitionColumnNames.nonEmpty && catalogTable.get.tracksPartitionsInCatalog @@ -393,8 +388,28 @@ object PreprocessTableInsertion extends Rule[LogicalPlan] { } } - val newQuery = TableOutputResolver.resolveOutputColumns( - tblName, expectedColumns, insert.query, byName = false, conf) + // Create a project if this INSERT has a user-specified column list. + val isByName = insert.userSpecifiedCols.nonEmpty + val query = if (isByName) { + createProjectForByNameQuery(insert) + } else { + insert.query + } + val newQuery = try { + TableOutputResolver.resolveOutputColumns( + tblName, expectedColumns, query, byName = isByName, conf, supportColDefaultValue = true) + } catch { + case e: AnalysisException if staticPartCols.nonEmpty && + e.getErrorClass == "INSERT_COLUMN_ARITY_MISMATCH" => + val newException = e.copy( + errorClass = Some("INSERT_PARTITION_COLUMN_ARITY_MISMATCH"), + messageParameters = e.messageParameters ++ Map( + "tableColumns" -> insert.table.output.map(c => s"'${c.name}'").mkString(", "), + "staticPartCols" -> staticPartCols.toSeq.sorted.map(c => s"'$c'").mkString(", ") + )) + newException.setStackTrace(e.getStackTrace) + throw newException + } if (normalizedPartSpec.nonEmpty) { if (normalizedPartSpec.size != partColNames.length) { throw QueryCompilationErrors.requestedPartitionsMismatchTablePartitionsError( diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/numeric.sql.out index d32e2abe156..a6408f94579 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/numeric.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/numeric.sql.out @@ -3844,10 +3844,10 @@ org.apache.spark.sql.AnalysisException "errorClass" : "INSERT_COLUMN_ARITY_MISMATCH", "sqlState" : "21S01", "messageParameters" : { - "insertedColumns" : "5", - "staticPartCols" : "0", - "tableName" : "`spark_catalog`.`default`.`num_result`", - "targetColumns" : "3" + "dataColumns" : "'id', 'id', 'val', 'val', '(val * val)'", + "reason" : "too many data columns", + "tableColumns" : "'id1', 'id2', 'result'", + "tableName" : "`spark_catalog`.`default`.`num_result`" } } diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out index db81160bf03..5840e1164fa 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out @@ -3835,10 +3835,10 @@ org.apache.spark.sql.AnalysisException "errorClass" : "INSERT_COLUMN_ARITY_MISMATCH", "sqlState" : "21S01", "messageParameters" : { - "insertedColumns" : "5", - "staticPartCols" : "0", - "tableName" : "`spark_catalog`.`default`.`num_result`", - "targetColumns" : "3" + "dataColumns" : "'id', 'id', 'val', 'val', '(val * val)'", + "reason" : "too many data columns", + "tableColumns" : "'id1', 'id2', 'result'", + "tableName" : "`spark_catalog`.`default`.`num_result`" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala index 1997fce0f5c..904980d58d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala @@ -201,30 +201,6 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils { } } - test("insert with column list - missing columns") { - val v2Msg = "Cannot write incompatible data to table 'testcat.t1'" - val cols = Seq("c1", "c2", "c3", "c4") - - withTable("t1") { - createTable("t1", cols, Seq.fill(4)("int")) - val e1 = intercept[AnalysisException](sql(s"INSERT INTO t1 values(1)")) - assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 1") || - e1.getMessage.contains("expected 4 columns but found 1") || - e1.getMessage.contains("not enough data columns") || - e1.getMessage.contains(v2Msg)) - } - - withTable("t1") { - createTable("t1", cols, Seq.fill(4)("int"), cols.takeRight(2)) - val e1 = intercept[AnalysisException] { - sql(s"INSERT INTO t1 partition(c3=3, c4=4) values(1)") - } - assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 3") || - e1.getMessage.contains("not enough data columns") || - e1.getMessage.contains(v2Msg)) - } - } - test("SPARK-34223: static partition with null raise NPE") { withTable("t") { sql(s"CREATE TABLE t(i STRING, c string) USING PARQUET PARTITIONED BY (c)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala index b21b490ea08..30d32c8283c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala @@ -17,69 +17,84 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.connector.catalog.{Table, TableCapability} -import org.apache.spark.sql.connector.write.SupportsCustomSchemaWrite -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{MetadataBuilder, StringType, StructField, StructType, TimestampType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { - val rule = ResolveDefaultColumns(null) - // This is the internal storage for the timestamp 2020-12-31 00:00:00.0. - val literal = Literal(1609401600000000L, TimestampType) - val table = UnresolvedInlineTable( - names = Seq("attr1"), - rows = Seq(Seq(literal))) - val localRelation = ResolveInlineTables(table).asInstanceOf[LocalRelation] + test("column without default value defined (null as default)") { + withTable("t") { + sql("create table t(c1 timestamp, c2 timestamp) using parquet") - def asLocalRelation(result: LogicalPlan): LocalRelation = result match { - case r: LocalRelation => r - case _ => fail(s"invalid result operator type: $result") - } + // INSERT with user-defined columns + sql("insert into t (c2) values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select null, timestamp'2020-12-31'").collect().head) + sql("truncate table t") + sql("insert into t (c1) values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select timestamp'2020-12-31', null").collect().head) - test("SPARK-43018: Add DEFAULTs for INSERT from VALUES list with user-defined columns") { - // Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with one user-specified - // column. We add a default value of NULL to the row as a result. - val insertTableSchemaWithoutPartitionColumns = StructType(Seq( - StructField("c1", TimestampType), - StructField("c2", TimestampType))) - val (result: LogicalPlan, _: Boolean) = - rule.addMissingDefaultValuesForInsertFromInlineTable( - localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 1) - val relation = asLocalRelation(result) - assert(relation.output.map(_.name) == Seq("c1", "c2")) - val data: Seq[Seq[Any]] = relation.data.map { row => - row.toSeq(StructType(relation.output.map(col => StructField(col.name, col.dataType)))) + // INSERT without user-defined columns + sql("truncate table t") + sql("insert into t values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select timestamp'2020-12-31', null").collect().head) } - assert(data == Seq(Seq(literal.value, null))) - } - - test("SPARK-43018: Add no DEFAULTs for INSERT from VALUES list with no user-defined columns") { - // Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with zero user-specified - // columns. The table is unchanged because there are no default columns to add in this case. - val insertTableSchemaWithoutPartitionColumns = StructType(Seq( - StructField("c1", TimestampType), - StructField("c2", TimestampType))) - val (result: LogicalPlan, _: Boolean) = - rule.addMissingDefaultValuesForInsertFromInlineTable( - localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 0) - assert(asLocalRelation(result) == localRelation) } - test("SPARK-43018: INSERT timestamp values into a table with column DEFAULTs") { + test("column with default value defined") { withTable("t") { - sql("create table t(id int, ts timestamp) using parquet") - sql("insert into t (ts) values (timestamp'2020-12-31')") + sql("create table t(c1 timestamp DEFAULT timestamp'2020-01-01', " + + "c2 timestamp DEFAULT timestamp'2020-01-01') using parquet") + + // INSERT with user-defined columns + sql("insert into t (c1) values (timestamp'2020-12-31')") checkAnswer(spark.table("t"), - sql("select null, timestamp'2020-12-31'").collect().head) + sql("select timestamp'2020-12-31', timestamp'2020-01-01'").collect().head) + sql("truncate table t") + sql("insert into t (c2) values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select timestamp'2020-01-01', timestamp'2020-12-31'").collect().head) + + // INSERT without user-defined columns + sql("truncate table t") + sql("insert into t values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select timestamp'2020-12-31', timestamp'2020-01-01'").collect().head) } } + test("INSERT into partitioned tables") { + sql("create table t(c1 int, c2 int, c3 int, c4 int) using parquet partitioned by (c3, c4)") + + // INSERT without static partitions + sql("insert into t values (1, 2, 3)") + checkAnswer(spark.table("t"), Row(1, 2, 3, null)) + + // INSERT without static partitions but with column list + sql("truncate table t") + sql("insert into t (c2, c1, c4) values (1, 2, 3)") + checkAnswer(spark.table("t"), Row(2, 1, null, 3)) + + // INSERT with static partitions + sql("truncate table t") + sql("insert into t partition(c3=3, c4=4) values (1)") + checkAnswer(spark.table("t"), Row(1, null, 3, 4)) + + // INSERT with static partitions and with column list + sql("truncate table t") + sql("insert into t partition(c3=3, c4=4) (c2) values (1)") + checkAnswer(spark.table("t"), Row(null, 1, 3, 4)) + + // INSERT with partial static partitions + sql("truncate table t") + sql("insert into t partition(c3=3, c4) values (1, 2)") + checkAnswer(spark.table("t"), Row(1, 2, 3, null)) + + // INSERT with partial static partitions and with column list is not allowed + intercept[AnalysisException](sql("insert into t partition(c3=3, c4) (c1) values (1, 4)")) + } + test("SPARK-43085: Column DEFAULT assignment for target tables with multi-part names") { withDatabase("demos") { sql("create database demos") @@ -164,111 +179,4 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { } } } - - /** - * This is a new relation type that defines the 'customSchemaForInserts' method. - * Its implementation drops the last table column as it represents an internal pseudocolumn. - */ - case class TableWithCustomInsertSchema(output: Seq[Attribute], numMetadataColumns: Int) - extends Table with SupportsCustomSchemaWrite { - override def name: String = "t" - override def schema: StructType = StructType.fromAttributes(output) - override def capabilities(): java.util.Set[TableCapability] = - new java.util.HashSet[TableCapability]() - override def customSchemaForInserts: StructType = - StructType(schema.fields.dropRight(numMetadataColumns)) - } - - /** Helper method to generate a DSV2 relation using the above table type. */ - private def relationWithCustomInsertSchema( - output: Seq[AttributeReference], numMetadataColumns: Int): DataSourceV2Relation = { - DataSourceV2Relation( - TableWithCustomInsertSchema(output, numMetadataColumns), - output, - catalog = None, - identifier = None, - options = CaseInsensitiveStringMap.empty) - } - - test("SPARK-43313: Add missing default values for MERGE INSERT actions") { - val testRelation = SubqueryAlias( - "testRelation", - relationWithCustomInsertSchema(Seq( - AttributeReference( - "a", - StringType, - true, - new MetadataBuilder() - .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'a'") - .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'a'") - .build())(), - AttributeReference( - "b", - StringType, - true, - new MetadataBuilder() - .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'b'") - .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'b'") - .build())(), - AttributeReference( - "c", - StringType, - true, - new MetadataBuilder() - .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'c'") - .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'c'") - .build())(), - AttributeReference( - "pseudocolumn", - StringType, - true, - new MetadataBuilder() - .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'pseudocolumn'") - .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'pseudocolumn'") - .build())()), - numMetadataColumns = 1)) - val testRelation2 = - SubqueryAlias( - "testRelation2", - relationWithCustomInsertSchema(Seq( - AttributeReference("d", StringType)(), - AttributeReference("e", StringType)(), - AttributeReference("f", StringType)()), - numMetadataColumns = 0)) - val mergePlan = MergeIntoTable( - targetTable = testRelation, - sourceTable = testRelation2, - mergeCondition = EqualTo(testRelation.output.head, testRelation2.output.head), - matchedActions = Seq(DeleteAction(None)), - notMatchedActions = Seq( - InsertAction( - condition = None, - assignments = Seq( - Assignment( - key = UnresolvedAttribute("a"), - value = UnresolvedAttribute("DEFAULT")), - Assignment( - key = UnresolvedAttribute(Seq("testRelation", "b")), - value = Literal("xyz"))))), - notMatchedBySourceActions = Seq(DeleteAction(None))) - // Run the 'addMissingDefaultValuesForMergeAction' method of the 'ResolveDefaultColumns' rule - // on an MERGE INSERT action with two assignments, one to the target table's column 'a' and - // another to the target table's column 'b'. - val columnNamesWithDefaults = Seq("a", "b", "c") - val actualMergeAction = - rule.apply(mergePlan).asInstanceOf[MergeIntoTable].notMatchedActions.head - val expectedMergeAction = - InsertAction( - condition = None, - assignments = Seq( - Assignment(key = UnresolvedAttribute("a"), value = Literal("a")), - Assignment(key = UnresolvedAttribute(Seq("testRelation", "b")), value = Literal("xyz")), - Assignment(key = UnresolvedAttribute("c"), value = Literal("c")))) - assert(expectedMergeAction == actualMergeAction) - // Run the same method on another MERGE DELETE action. There is no change because this method - // only operates on MERGE INSERT actions. - assert(rule.addMissingDefaultValuesForMergeAction( - mergePlan.matchedActions.head, mergePlan, columnNamesWithDefaults) == - mergePlan.matchedActions.head) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index ff52239a1d9..7fdd049a977 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1117,15 +1117,8 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT INTO $t1(data, data) VALUES(5)") }, - errorClass = "_LEGACY_ERROR_TEMP_2305", - parameters = Map( - "numCols" -> "3", - "rowSize" -> "2", - "ri" -> "0"), - context = ExpectedContext( - fragment = s"INSERT INTO $t1(data, data)", - start = 0, - stop = 26)) + errorClass = "COLUMN_ALREADY_EXISTS", + parameters = Map("columnName" -> "`data`")) } } @@ -1151,15 +1144,8 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)") }, - errorClass = "_LEGACY_ERROR_TEMP_2305", - parameters = Map( - "numCols" -> "3", - "rowSize" -> "2", - "ri" -> "0"), - context = ExpectedContext( - fragment = s"INSERT OVERWRITE $t1(data, data)", - start = 0, - stop = 31)) + errorClass = "COLUMN_ALREADY_EXISTS", + parameters = Map("columnName" -> "`data`")) } } @@ -1186,15 +1172,8 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)") }, - errorClass = "_LEGACY_ERROR_TEMP_2305", - parameters = Map( - "numCols" -> "4", - "rowSize" -> "3", - "ri" -> "0"), - context = ExpectedContext( - fragment = s"INSERT OVERWRITE $t1(data, data)", - start = 0, - stop = 31)) + errorClass = "COLUMN_ALREADY_EXISTS", + parameters = Map("columnName" -> "`data`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala index e5691666339..a479e810e46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala @@ -602,14 +602,6 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuite { } test("invalid INSERT assignments") { - assertAnalysisException( - """MERGE INTO primitive_table t USING primitive_table src - |ON t.i = src.i - |WHEN NOT MATCHED THEN - | INSERT (i, txt) VALUES (src.i, src.txt) - |""".stripMargin, - "No assignment for 'l'") - assertAnalysisException( """MERGE INTO primitive_table t USING primitive_table src |ON t.i = src.i @@ -624,10 +616,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuite { |WHEN NOT MATCHED THEN | INSERT (s.n_i) VALUES (1) |""".stripMargin, - "INSERT assignment keys cannot be nested fields: t.s.`n_i` = 1", - "No assignment for 'i'", - "No assignment for 's'", - "No assignment for 'txt'") + "INSERT assignment keys cannot be nested fields: t.s.`n_i` = 1") } test("updates to nested structs in arrays") { @@ -866,6 +855,8 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuite { |ON t.b = s.b |WHEN MATCHED THEN | UPDATE SET t.i = DEFAULT + |WHEN NOT MATCHED AND (s.i = 1) THEN + | INSERT (b) VALUES (false) |WHEN NOT MATCHED THEN | INSERT (i, b) VALUES (DEFAULT, false) |WHEN NOT MATCHED BY SOURCE THEN @@ -889,8 +880,26 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuite { fail(s"Unexpected actions: $other") } - notMatchedActions match { - case Seq(InsertAction(None, assignments)) => + assert(notMatchedActions.length == 2) + notMatchedActions(0) match { + case InsertAction(Some(_), assignments) => + assignments match { + case Seq( + Assignment(b: AttributeReference, BooleanLiteral(false)), + Assignment(i: AttributeReference, IntegerLiteral(42))) => + + assert(b.name == "b") + assert(i.name == "i") + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + notMatchedActions(1) match { + case InsertAction(None, assignments) => assignments match { case Seq( Assignment(b: AttributeReference, BooleanLiteral(false)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 1b389d77142..7fa3873fc6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedIdentifier, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EqualTo, EvalMode, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, OverwriteByExpression, OverwritePartitionsDynamic, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} @@ -151,6 +151,7 @@ class PlanResolutionSuite extends AnalysisTest { case "defaultvalues" => defaultValues case "defaultvalues2" => defaultValues2 case "tablewithcolumnnameddefault" => tableWithColumnNamedDefault + case "v2TableWithAcceptAnySchemaCapability" => tableWithAcceptAnySchemaCapability case name => throw new NoSuchTableException(Seq(name)) } }) @@ -167,7 +168,6 @@ class PlanResolutionSuite extends AnalysisTest { case "v1HiveTable" => createV1TableMock(ident, provider = "hive") case "v2Table" => table case "v2Table1" => table1 - case "v2TableWithAcceptAnySchemaCapability" => tableWithAcceptAnySchemaCapability case "view" => createV1TableMock(ident, tableType = CatalogTableType.VIEW) case name => throw new NoSuchTableException(Seq(name)) } @@ -1023,12 +1023,12 @@ class PlanResolutionSuite extends AnalysisTest { val sql5 = s"UPDATE $tblName SET name=DEFAULT, age=DEFAULT" // Note: 'i' and 's' are the names of the columns in 'tblName'. val sql6 = s"UPDATE $tblName SET i=DEFAULT, s=DEFAULT" - val sql7 = s"UPDATE defaultvalues SET i=DEFAULT, s=DEFAULT" - val sql8 = s"UPDATE $tblName SET name='Robert', age=32 WHERE p=DEFAULT" - val sql9 = s"UPDATE defaultvalues2 SET i=DEFAULT" - // Note: 'i' is the correct column name, but since the table has ACCEPT_ANY_SCHEMA capability, - // DEFAULT column resolution should skip this table. - val sql10 = s"UPDATE v2TableWithAcceptAnySchemaCapability SET i=DEFAULT" + val sql7 = s"UPDATE testcat.defaultvalues SET i=DEFAULT, s=DEFAULT" + // UPDATE condition won't resolve column "DEFAULT" + val sql8 = s"UPDATE testcat.defaultvalues SET i=DEFAULT, s=DEFAULT WHERE i=DEFAULT" + val sql9 = s"UPDATE testcat.defaultvalues2 SET i=DEFAULT" + // Table with ACCEPT_ANY_SCHEMA can also resolve the column DEFAULT. + val sql10 = s"UPDATE testcat.v2TableWithAcceptAnySchemaCapability SET i=DEFAULT" val parsed1 = parseAndResolve(sql1) val parsed2 = parseAndResolve(sql2) @@ -1036,8 +1036,8 @@ class PlanResolutionSuite extends AnalysisTest { val parsed4 = parseAndResolve(sql4) val parsed5 = parseAndResolve(sql5) val parsed6 = parseAndResolve(sql6) - val parsed7 = parseAndResolve(sql7, true) - val parsed9 = parseAndResolve(sql9, true) + val parsed7 = parseAndResolve(sql7) + val parsed9 = parseAndResolve(sql9) val parsed10 = parseAndResolve(sql10) parsed1 match { @@ -1116,12 +1116,9 @@ class PlanResolutionSuite extends AnalysisTest { // Note that when resolving DEFAULT column references, the analyzer will insert literal // NULL values if the corresponding table does not define an explicit default value for // that column. This is intended. - Assignment(i: AttributeReference, - cast1 @ Cast(Literal(null, _), IntegerType, _, EvalMode.ANSI)), - Assignment(s: AttributeReference, - cast2 @ Cast(Literal(null, _), StringType, _, EvalMode.ANSI))), - None) if cast1.getTagValue(Cast.BY_TABLE_INSERTION).isDefined && - cast2.getTagValue(Cast.BY_TABLE_INSERTION).isDefined => + Assignment(i: AttributeReference, Literal(null, IntegerType)), + Assignment(s: AttributeReference, Literal(null, StringType))), + None) => assert(i.name == "i") assert(s.name == "s") @@ -1143,38 +1140,34 @@ class PlanResolutionSuite extends AnalysisTest { checkError( exception = intercept[AnalysisException] { - parseAndResolve(sql8) + parseAndResolve(sql8, checkAnalysis = true) }, - errorClass = "_LEGACY_ERROR_TEMP_1341", - parameters = Map.empty) + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`DEFAULT`", "proposal" -> "`i`, `s`"), + context = ExpectedContext( + fragment = "DEFAULT", + start = 62, + stop = 68)) parsed9 match { case UpdateTable( - _, - Seq(Assignment(i: AttributeReference, - cast @ Cast(Literal(null, _), StringType, _, EvalMode.ANSI))), - None) if cast.getTagValue(Cast.BY_TABLE_INSERTION).isDefined => + _, + Seq(Assignment(i: AttributeReference, Literal(null, StringType))), + None) => assert(i.name == "i") case _ => fail("Expect UpdateTable, but got:\n" + parsed9.treeString) } parsed10 match { - case u: UpdateTable => - assert(u.assignments.size == 1) - u.assignments(0).key match { - case i: AttributeReference => - assert(i.name == "i") - } - u.assignments(0).value match { - case d: UnresolvedAttribute => - assert(d.name == "DEFAULT") - } + case UpdateTable( + _, + Seq(Assignment(i: AttributeReference, Literal(null, IntegerType))), + None) => + assert(i.name == "i") - case _ => - fail("Expect UpdateTable, but got:\n" + parsed10.treeString) + case _ => fail("Expect UpdateTable, but got:\n" + parsed10.treeString) } - } val sql1 = "UPDATE non_existing SET id=1" @@ -1766,22 +1759,16 @@ class PlanResolutionSuite extends AnalysisTest { second match { case UpdateAction(Some(EqualTo(_: AttributeReference, StringLiteral("update"))), Seq( - Assignment(_: AttributeReference, - cast @ Cast(Literal(null, _), StringType, _, EvalMode.ANSI)), - Assignment(_: AttributeReference, _: AttributeReference))) - if cast.getTagValue(Cast.BY_TABLE_INSERTION).isDefined => + Assignment(_: AttributeReference, Literal(null, StringType)), + Assignment(_: AttributeReference, _: AttributeReference))) => case other => fail("unexpected second matched action " + other) } assert(m.notMatchedActions.length == 1) val negative = m.notMatchedActions(0) negative match { case InsertAction(Some(EqualTo(_: AttributeReference, StringLiteral("insert"))), - Seq(Assignment(i: AttributeReference, - cast1 @ Cast(Literal(null, _), IntegerType, _, EvalMode.ANSI)), - Assignment(s: AttributeReference, - cast2 @ Cast(Literal(null, _), StringType, _, EvalMode.ANSI)))) - if cast1.getTagValue(Cast.BY_TABLE_INSERTION).isDefined && - cast2.getTagValue(Cast.BY_TABLE_INSERTION).isDefined => + Seq(Assignment(i: AttributeReference, Literal(null, IntegerType)), + Assignment(s: AttributeReference, Literal(null, StringType)))) => assert(i.name == "i") assert(s.name == "s") case other => fail("unexpected not matched action " + other) @@ -1793,9 +1780,7 @@ class PlanResolutionSuite extends AnalysisTest { } m.notMatchedBySourceActions(1) match { case UpdateAction(Some(EqualTo(_: AttributeReference, StringLiteral("update"))), - Seq(Assignment(_: AttributeReference, - cast @ Cast(Literal(null, _), StringType, _, EvalMode.ANSI)))) - if cast.getTagValue(Cast.BY_TABLE_INSERTION).isDefined => + Seq(Assignment(_: AttributeReference, Literal(null, StringType)))) => case other => fail("unexpected second not matched by source action " + other) } @@ -1805,8 +1790,8 @@ class PlanResolutionSuite extends AnalysisTest { } // DEFAULT column reference in the merge condition: - // This MERGE INTO command includes an ON clause with a DEFAULT column reference. This is - // invalid and returns an error message. + // This MERGE INTO command includes an ON clause with a DEFAULT column reference. This + // DEFAULT column won't be resolved. val mergeWithDefaultReferenceInMergeCondition = s"""MERGE INTO testcat.tab AS target |USING testcat.tab1 AS source @@ -1821,14 +1806,19 @@ class PlanResolutionSuite extends AnalysisTest { | THEN UPDATE SET target.s = DEFAULT""".stripMargin checkError( exception = intercept[AnalysisException] { - parseAndResolve(mergeWithDefaultReferenceInMergeCondition) + parseAndResolve(mergeWithDefaultReferenceInMergeCondition, checkAnalysis = true) }, - errorClass = "_LEGACY_ERROR_TEMP_1342", - parameters = Map.empty) + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`DEFAULT`", + "proposal" -> "`target`.`i`, `source`.`i`, `target`.`s`, `source`.`s`"), + context = ExpectedContext( + fragment = "DEFAULT", + start = 76, + stop = 82)) // DEFAULT column reference within a complex expression: // This MERGE INTO command includes a WHEN MATCHED clause with a DEFAULT column reference as - // of a complex expression (DEFAULT + 1). This is invalid and returns an error message. + // of a complex expression (DEFAULT + 1). This is invalid and column won't be resolved. val mergeWithDefaultReferenceAsPartOfComplexExpression = s"""MERGE INTO testcat.tab AS target |USING testcat.tab1 AS source @@ -1890,7 +1880,7 @@ class PlanResolutionSuite extends AnalysisTest { // values. This test case covers that behavior. val mergeDefaultWithExplicitDefaultColumns = s""" - |MERGE INTO defaultvalues AS target + |MERGE INTO testcat.defaultvalues AS target |USING testcat.tab1 AS source |ON target.i = source.i |WHEN MATCHED AND (target.s = 31) THEN DELETE @@ -1902,7 +1892,7 @@ class PlanResolutionSuite extends AnalysisTest { |WHEN NOT MATCHED BY SOURCE AND (target.s = 31) | THEN UPDATE SET target.s = DEFAULT """.stripMargin - parseAndResolve(mergeDefaultWithExplicitDefaultColumns, true) match { + parseAndResolve(mergeDefaultWithExplicitDefaultColumns) match { case m: MergeIntoTable => val cond = m.mergeCondition cond match { @@ -2218,54 +2208,29 @@ class PlanResolutionSuite extends AnalysisTest { } } - test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") { + test("MERGE INTO TABLE - skip filling missing cols on v2 tables that accept any schema") { val sql = s""" - |MERGE INTO v2TableWithAcceptAnySchemaCapability AS target + |MERGE INTO testcat.v2TableWithAcceptAnySchemaCapability AS target |USING v2Table AS source |ON target.i = source.i - |WHEN MATCHED AND (target.s='delete')THEN DELETE - |WHEN MATCHED AND (target.s='update') THEN UPDATE SET target.s = source.s - |WHEN NOT MATCHED AND (target.s=DEFAULT) - | THEN INSERT (target.i, target.s) values (source.i, source.s) - |WHEN NOT MATCHED BY SOURCE AND (target.s='delete') THEN DELETE - |WHEN NOT MATCHED BY SOURCE AND (target.s='update') THEN UPDATE SET target.s = target.i + |WHEN MATCHED THEN DELETE + |WHEN NOT MATCHED THEN INSERT (target.i) values (DEFAULT) """.stripMargin parseAndResolve(sql) match { case MergeIntoTable( SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(_)), SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(_)), - EqualTo(l: UnresolvedAttribute, r: UnresolvedAttribute), - Seq( - DeleteAction(Some(EqualTo(dl: UnresolvedAttribute, StringLiteral("delete")))), - UpdateAction( - Some(EqualTo(ul: UnresolvedAttribute, StringLiteral("update"))), - firstUpdateAssigns)), - Seq( - InsertAction( - Some(EqualTo(il: UnresolvedAttribute, UnresolvedAttribute(Seq("DEFAULT")))), - insertAssigns)), - Seq( - DeleteAction(Some(EqualTo(ndl: UnresolvedAttribute, StringLiteral("delete")))), - UpdateAction( - Some(EqualTo(nul: UnresolvedAttribute, StringLiteral("update"))), - secondUpdateAssigns))) => - assert(l.name == "target.i" && r.name == "source.i") - assert(dl.name == "target.s") - assert(ul.name == "target.s") - assert(il.name == "target.s") - assert(ndl.name == "target.s") - assert(nul.name == "target.s") - assert(firstUpdateAssigns.size == 1) - assert(firstUpdateAssigns.head.key.asInstanceOf[UnresolvedAttribute].name == "target.s") - assert(firstUpdateAssigns.head.value.asInstanceOf[UnresolvedAttribute].name == "source.s") - assert(insertAssigns.size == 2) + _, + Seq(DeleteAction(None)), + Seq(InsertAction(None, insertAssigns)), + Nil) => + // There is only one assignment, the missing col is not filled with default value + assert(insertAssigns.size == 1) + // Special case: Spark does not resolve any columns in MERGE if table accepts any schema. assert(insertAssigns.head.key.asInstanceOf[UnresolvedAttribute].name == "target.i") - assert(insertAssigns.head.value.asInstanceOf[UnresolvedAttribute].name == "source.i") - assert(secondUpdateAssigns.size == 1) - assert(secondUpdateAssigns.head.key.asInstanceOf[UnresolvedAttribute].name == "target.s") - assert(secondUpdateAssigns.head.value.asInstanceOf[UnresolvedAttribute].name == "target.i") + assert(insertAssigns.head.value.asInstanceOf[UnresolvedAttribute].name == "DEFAULT") case l => fail("Expected unresolved MergeIntoTable, but got:\n" + l.treeString) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 08d9da3e9da..d8e0a05f262 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -142,17 +142,6 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { ) } - test("SELECT clause generating a different number of columns is not allowed.") { - val message = intercept[AnalysisException] { - sql( - s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt - """.stripMargin) - }.getMessage - assert(message.contains("target table has 2 column(s) but the inserted data has 1 column(s)") - ) - } - test("INSERT OVERWRITE a JSONRelation multiple times") { sql( s""" @@ -642,16 +631,8 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { msg = intercept[AnalysisException] { sql("insert into t select 1, 2.0D, 3") }.getMessage - assert(msg.contains("`t` requires that the data to be inserted have the same number of " + - "columns as the target table: target table has 2 column(s)" + - " but the inserted data has 3 column(s)")) - - msg = intercept[AnalysisException] { - sql("insert into t select 1") - }.getMessage - assert(msg.contains("`t` requires that the data to be inserted have the same number of " + - "columns as the target table: target table has 2 column(s)" + - " but the inserted data has 1 column(s)")) + assert(msg.contains( + "Cannot write to '`spark_catalog`.`default`.`t`', too many data columns")) // Insert into table successfully. sql("insert into t select 1, 2.0D") @@ -863,42 +844,45 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } test("Allow user to insert specified columns into insertable view") { - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - sql("INSERT OVERWRITE TABLE jsonTable SELECT a, DEFAULT FROM jt") - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, null)) - ) + sql("INSERT OVERWRITE TABLE jsonTable SELECT a, DEFAULT FROM jt") + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, null)) + ) - sql("INSERT OVERWRITE TABLE jsonTable(a) SELECT a FROM jt") - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, null)) - ) + sql("INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt") + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, null)) + ) - sql("INSERT OVERWRITE TABLE jsonTable(b) SELECT b FROM jt") - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(null, s"str$i")) - ) - } + sql("INSERT OVERWRITE TABLE jsonTable(a) SELECT a FROM jt") + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, null)) + ) - val message = intercept[AnalysisException] { - sql("INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt") - }.getMessage - assert(message.contains("target table has 2 column(s) but the inserted data has 1 column(s)")) + sql("INSERT OVERWRITE TABLE jsonTable(b) SELECT b FROM jt") + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(null, s"str$i")) + ) + + withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") { + val message = intercept[AnalysisException] { + sql("INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt") + }.getMessage + assert(message.contains("Cannot write to 'unknown', not enough data columns")) + } } test("SPARK-38336 INSERT INTO statements with tables with default columns: positive tests") { - // When the USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES configuration is enabled, and no - // explicit DEFAULT value is available when the INSERT INTO statement provides fewer - // values than expected, NULL values are appended in their place. - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - withTable("t") { - sql("create table t(i boolean, s bigint) using parquet") - sql("insert into t(i) values(true)") - checkAnswer(spark.table("t"), Row(true, null)) - } + // When the INSERT INTO statement provides fewer values than expected, NULL values are appended + // in their place. + withTable("t") { + sql("create table t(i boolean, s bigint) using parquet") + sql("insert into t(i) values(true)") + checkAnswer(spark.table("t"), Row(true, null)) } // The default value for the DEFAULT keyword is the NULL literal. withTable("t") { @@ -924,6 +908,11 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("insert into t(i) values(1)") checkAnswer(sql("select s + x from t where i = 1"), Seq(85L).map(i => Row(i))) } + withTable("t") { + sql("create table t(i int, s bigint default 42, x bigint) using parquet") + sql("insert into t values(1)") + checkAnswer(spark.table("t"), Row(1, 42L, null)) + } // The table has a partitioning column and a default value is injected. withTable("t") { sql("create table t(i boolean, s bigint, q int default 42) using parquet partitioned by (i)") @@ -998,45 +987,43 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { // There are three column types exercising various combinations of implicit and explicit // default column value references in the 'insert into' statements. Note these tests depend on // enabling the configuration to use NULLs for missing DEFAULT column values. - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - for (useDataFrames <- Seq(false, true)) { - withTable("t1", "t2") { - sql("create table t1(j int, s bigint default 42, x bigint default 43) using parquet") - if (useDataFrames) { - Seq((1, 42, 43)).toDF.write.insertInto("t1") - Seq((2, 42, 43)).toDF.write.insertInto("t1") - Seq((3, 42, 43)).toDF.write.insertInto("t1") - Seq((4, 44, 43)).toDF.write.insertInto("t1") - Seq((5, 44, 43)).toDF.write.insertInto("t1") - } else { - sql("insert into t1(j) values(1)") - sql("insert into t1(j, s) values(2, default)") - sql("insert into t1(j, s, x) values(3, default, default)") - sql("insert into t1(j, s) values(4, 44)") - sql("insert into t1(j, s, x) values(5, 44, 45)") - } - sql("create table t2(j int, s bigint default 42, x bigint default 43) using parquet") - if (useDataFrames) { - spark.table("t1").where("j = 1").write.insertInto("t2") - spark.table("t1").where("j = 2").write.insertInto("t2") - spark.table("t1").where("j = 3").write.insertInto("t2") - spark.table("t1").where("j = 4").write.insertInto("t2") - spark.table("t1").where("j = 5").write.insertInto("t2") - } else { - sql("insert into t2(j) select j from t1 where j = 1") - sql("insert into t2(j, s) select j, default from t1 where j = 2") - sql("insert into t2(j, s, x) select j, default, default from t1 where j = 3") - sql("insert into t2(j, s) select j, s from t1 where j = 4") - sql("insert into t2(j, s, x) select j, s, default from t1 where j = 5") - } - checkAnswer( - spark.table("t2"), - Row(1, 42L, 43L) :: - Row(2, 42L, 43L) :: - Row(3, 42L, 43L) :: - Row(4, 44L, 43L) :: - Row(5, 44L, 43L) :: Nil) + for (useDataFrames <- Seq(false, true)) { + withTable("t1", "t2") { + sql("create table t1(j int, s bigint default 42, x bigint default 43) using parquet") + if (useDataFrames) { + Seq((1, 42, 43)).toDF.write.insertInto("t1") + Seq((2, 42, 43)).toDF.write.insertInto("t1") + Seq((3, 42, 43)).toDF.write.insertInto("t1") + Seq((4, 44, 43)).toDF.write.insertInto("t1") + Seq((5, 44, 43)).toDF.write.insertInto("t1") + } else { + sql("insert into t1(j) values(1)") + sql("insert into t1(j, s) values(2, default)") + sql("insert into t1(j, s, x) values(3, default, default)") + sql("insert into t1(j, s) values(4, 44)") + sql("insert into t1(j, s, x) values(5, 44, 45)") + } + sql("create table t2(j int, s bigint default 42, x bigint default 43) using parquet") + if (useDataFrames) { + spark.table("t1").where("j = 1").write.insertInto("t2") + spark.table("t1").where("j = 2").write.insertInto("t2") + spark.table("t1").where("j = 3").write.insertInto("t2") + spark.table("t1").where("j = 4").write.insertInto("t2") + spark.table("t1").where("j = 5").write.insertInto("t2") + } else { + sql("insert into t2(j) select j from t1 where j = 1") + sql("insert into t2(j, s) select j, default from t1 where j = 2") + sql("insert into t2(j, s, x) select j, default, default from t1 where j = 3") + sql("insert into t2(j, s) select j, s from t1 where j = 4") + sql("insert into t2(j, s, x) select j, s, default from t1 where j = 5") } + checkAnswer( + spark.table("t2"), + Row(1, 42L, 43L) :: + Row(2, 42L, 43L) :: + Row(3, 42L, 43L) :: + Row(4, 44L, 43L) :: + Row(5, 44L, 43L) :: Nil) } } } @@ -1131,7 +1118,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("insert into t select t1.id, t2.id, t1.val, t2.val, t1.val * t2.val " + "from num_data t1, num_data t2") }.getMessage.contains( - "requires that the data to be inserted have the same number of columns as the target")) + "Cannot write to '`spark_catalog`.`default`.`t`', too many data columns")) } // The default value is disabled per configuration. withTable("t") { @@ -1141,13 +1128,6 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { }.getMessage.contains("Support for DEFAULT column values is not allowed")) } } - // There is one trailing default value referenced implicitly by the INSERT INTO statement. - withTable("t") { - sql("create table t(i int, s bigint default 42, x bigint) using parquet") - assert(intercept[AnalysisException] { - sql("insert into t values(1)") - }.getMessage.contains("target table has 3 column(s) but the inserted data has 1 column(s)")) - } // The table has a partitioning column with a default value; this is not allowed. withTable("t") { sql("create table t(i boolean default true, s bigint, q int default 42) " + @@ -1170,7 +1150,8 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("create table t(i boolean, s bigint) using parquet") assert(intercept[AnalysisException] { sql("insert into t values(true)") - }.getMessage.contains("target table has 2 column(s) but the inserted data has 1 column(s)")) + }.getMessage.contains( + "Cannot write to '`spark_catalog`.`default`.`t`', not enough data columns")) } } } @@ -1203,48 +1184,44 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { Row(4, 43, false), Row(4, 42, false))) } - // When the USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES configuration is enabled, and no - // explicit DEFAULT value is available when the INSERT INTO statement provides fewer + // If no explicit DEFAULT value is available when the INSERT INTO statement provides fewer // values than expected, NULL values are appended in their place. - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - withTable("t") { - sql("create table t(i boolean, s bigint) using parquet") - sql("insert into t (i) values (true)") - checkAnswer(spark.table("t"), Row(true, null)) - } - withTable("t") { - sql("create table t(i boolean default true, s bigint) using parquet") - sql("insert into t (i) values (default)") - checkAnswer(spark.table("t"), Row(true, null)) - } - withTable("t") { - sql("create table t(i boolean, s bigint default 42) using parquet") - sql("insert into t (s) values (default)") - checkAnswer(spark.table("t"), Row(null, 42L)) - } - withTable("t") { - sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") - sql("insert into t partition(i='true') (s) values(5)") - sql("insert into t partition(i='false') (q) select 43") - sql("insert into t partition(i='false') (q) select default") - checkAnswer(spark.table("t"), - Seq(Row(5, null, true), - Row(null, 43, false), - Row(null, null, false))) - } + withTable("t") { + sql("create table t(i boolean, s bigint) using parquet") + sql("insert into t (i) values (true)") + checkAnswer(spark.table("t"), Row(true, null)) + } + withTable("t") { + sql("create table t(i boolean default true, s bigint) using parquet") + sql("insert into t (i) values (default)") + checkAnswer(spark.table("t"), Row(true, null)) + } + withTable("t") { + sql("create table t(i boolean, s bigint default 42) using parquet") + sql("insert into t (s) values (default)") + checkAnswer(spark.table("t"), Row(null, 42L)) + } + withTable("t") { + sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") + sql("insert into t partition(i='true') (s) values(5)") + sql("insert into t partition(i='false') (q) select 43") + sql("insert into t partition(i='false') (q) select default") + checkAnswer(spark.table("t"), + Seq(Row(5, null, true), + Row(null, 43, false), + Row(null, null, false))) } } test("SPARK- 38795 INSERT INTO with user specified columns and defaults: negative tests") { - val addOneColButExpectedTwo = "target table has 2 column(s) but the inserted data has 1 col" - val addTwoColButExpectedThree = "target table has 3 column(s) but the inserted data has 2 col" + val missingColError = "Cannot find data for output column " // The missing columns in these INSERT INTO commands do not have explicit default values. withTable("t") { sql("create table t(i boolean, s bigint, q int default 43) using parquet") assert(intercept[AnalysisException] { sql("insert into t (i, q) select true from (select 1)") }.getMessage.contains("Cannot write to table due to mismatched user specified column " + - "size(3) and data column size(2)")) + "size(2) and data column size(1)")) } // When the USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES configuration is disabled, and no // explicit DEFAULT value is available when the INSERT INTO statement provides fewer @@ -1254,37 +1231,37 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("create table t(i boolean, s bigint) using parquet") assert(intercept[AnalysisException] { sql("insert into t (i) values (true)") - }.getMessage.contains(addOneColButExpectedTwo)) + }.getMessage.contains(missingColError + "'s'")) } withTable("t") { sql("create table t(i boolean default true, s bigint) using parquet") assert(intercept[AnalysisException] { sql("insert into t (i) values (default)") - }.getMessage.contains(addOneColButExpectedTwo)) + }.getMessage.contains(missingColError + "'s'")) } withTable("t") { sql("create table t(i boolean, s bigint default 42) using parquet") assert(intercept[AnalysisException] { sql("insert into t (s) values (default)") - }.getMessage.contains(addOneColButExpectedTwo)) + }.getMessage.contains(missingColError + "'i'")) } withTable("t") { sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") assert(intercept[AnalysisException] { sql("insert into t partition(i='true') (s) values(5)") - }.getMessage.contains(addTwoColButExpectedThree)) + }.getMessage.contains(missingColError + "'q'")) } withTable("t") { sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") assert(intercept[AnalysisException] { sql("insert into t partition(i='false') (q) select 43") - }.getMessage.contains(addTwoColButExpectedThree)) + }.getMessage.contains(missingColError + "'s'")) } withTable("t") { sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") assert(intercept[AnalysisException] { sql("insert into t partition(i='false') (q) select default") - }.getMessage.contains(addTwoColButExpectedThree)) + }.getMessage.contains(missingColError + "'s'")) } } // When the CASE_SENSITIVE configuration is enabled, then using different cases for the required @@ -1329,6 +1306,13 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("insert into t(i) values(1)") checkAnswer(spark.table("t"), Row(1, 42, 43)) } + withTable("t") { + sql(createTableIntCol) + sql("alter table t add column s bigint default 42") + sql("alter table t add column x bigint") + sql("insert into t values(1)") + checkAnswer(spark.table("t"), Row(1, 42, null)) + } // The table has a partitioning column and a default value is injected. withTable("t") { sql("create table t(i boolean, s bigint) using parquet partitioned by (i)") @@ -1386,31 +1370,29 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { // There are three column types exercising various combinations of implicit and explicit // default column value references in the 'insert into' statements. Note these tests depend on // enabling the configuration to use NULLs for missing DEFAULT column values. - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - withTable("t1", "t2") { - sql("create table t1(j int) using parquet") - sql("alter table t1 add column s bigint default 42") - sql("alter table t1 add column x bigint default 43") - sql("insert into t1(j) values(1)") - sql("insert into t1(j, s) values(2, default)") - sql("insert into t1(j, s, x) values(3, default, default)") - sql("insert into t1(j, s) values(4, 44)") - sql("insert into t1(j, s, x) values(5, 44, 45)") - sql("create table t2(j int) using parquet") - sql("alter table t2 add columns s bigint default 42, x bigint default 43") - sql("insert into t2(j) select j from t1 where j = 1") - sql("insert into t2(j, s) select j, default from t1 where j = 2") - sql("insert into t2(j, s, x) select j, default, default from t1 where j = 3") - sql("insert into t2(j, s) select j, s from t1 where j = 4") - sql("insert into t2(j, s, x) select j, s, default from t1 where j = 5") - checkAnswer( - spark.table("t2"), - Row(1, 42L, 43L) :: - Row(2, 42L, 43L) :: - Row(3, 42L, 43L) :: - Row(4, 44L, 43L) :: - Row(5, 44L, 43L) :: Nil) - } + withTable("t1", "t2") { + sql("create table t1(j int) using parquet") + sql("alter table t1 add column s bigint default 42") + sql("alter table t1 add column x bigint default 43") + sql("insert into t1(j) values(1)") + sql("insert into t1(j, s) values(2, default)") + sql("insert into t1(j, s, x) values(3, default, default)") + sql("insert into t1(j, s) values(4, 44)") + sql("insert into t1(j, s, x) values(5, 44, 45)") + sql("create table t2(j int) using parquet") + sql("alter table t2 add columns s bigint default 42, x bigint default 43") + sql("insert into t2(j) select j from t1 where j = 1") + sql("insert into t2(j, s) select j, default from t1 where j = 2") + sql("insert into t2(j, s, x) select j, default, default from t1 where j = 3") + sql("insert into t2(j, s) select j, s from t1 where j = 4") + sql("insert into t2(j, s, x) select j, s, default from t1 where j = 5") + checkAnswer( + spark.table("t2"), + Row(1, 42L, 43L) :: + Row(2, 42L, 43L) :: + Row(3, 42L, 43L) :: + Row(4, 44L, 43L) :: + Row(5, 44L, 43L) :: Nil) } } @@ -1476,15 +1458,6 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { }.getMessage.contains("Support for DEFAULT column values is not allowed")) } } - // There is one trailing default value referenced implicitly by the INSERT INTO statement. - withTable("t") { - sql("create table t(i int) using parquet") - sql("alter table t add column s bigint default 42") - sql("alter table t add column x bigint") - assert(intercept[AnalysisException] { - sql("insert into t values(1)") - }.getMessage.contains("target table has 3 column(s) but the inserted data has 1 column(s)")) - } } test("SPARK-38838 INSERT INTO with defaults set by ALTER TABLE ALTER COLUMN: positive tests") { @@ -2351,14 +2324,12 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { checkAnswer(spark.table("t1"), Row(1, "str1")) } - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - withTable("t1") { - sql("CREATE TABLE t1(c1 int, c2 string, c3 int) using parquet") - sql("INSERT INTO TABLE t1(c1, c2) select * from jt where a=1") - checkAnswer(spark.table("t1"), Row(1, "str1", null)) - sql("INSERT INTO TABLE t1 select *, 2 from jt where a=2") - checkAnswer(spark.table("t1"), Seq(Row(1, "str1", null), Row(2, "str2", 2))) - } + withTable("t1") { + sql("CREATE TABLE t1(c1 int, c2 string, c3 int) using parquet") + sql("INSERT INTO TABLE t1(c1, c2) select * from jt where a=1") + checkAnswer(spark.table("t1"), Row(1, "str1", null)) + sql("INSERT INTO TABLE t1 select *, 2 from jt where a=2") + checkAnswer(spark.table("t1"), Seq(Row(1, "str1", null), Row(2, "str2", 2))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index ea1e9a7e048..9cc26d894ba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -359,8 +359,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter val e = intercept[AnalysisException] { sql(s"INSERT INTO TABLE $tableName PARTITION(b=1, c=2) SELECT 1, 2, 3") } - assert(e.message.contains( - "target table has 4 column(s) but the inserted data has 5 column(s)")) + assert(e.message.contains("Cannot write to") && e.message.contains("too many data columns")) } testPartitionedTable("SPARK-16037: INSERT statement should match columns by position") { @@ -382,6 +381,9 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter sql(s"INSERT INTO TABLE $tableName PARTITION (c=11, b=10) SELECT 9, 12") + // The data is missing a column. The default value for the missing column is null. + sql(s"INSERT INTO TABLE $tableName PARTITION (c=15, b=16) SELECT 13") + // c is defined twice. Analyzer will complain. intercept[AnalysisException] { sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, c=16) SELECT 13") @@ -397,11 +399,6 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, d=16) SELECT 13") } - // The data is missing a column. - intercept[AnalysisException] { - sql(s"INSERT INTO TABLE $tableName PARTITION (c=15, b=16) SELECT 13") - } - // d is not a partitioning column. intercept[AnalysisException] { sql(s"INSERT INTO TABLE $tableName PARTITION (b=15, d=15) SELECT 13, 14") @@ -436,6 +433,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter Row(5, 6, 7, 8) :: Row(9, 10, 11, 12) :: Row(13, 14, 15, 16) :: + Row(13, 16, 15, null) :: Row(17, 18, 19, 20) :: Row(21, 22, 23, 24) :: Row(25, 26, 27, 28) :: Nil @@ -473,13 +471,14 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter } } - testPartitionedTable("insertInto() should reject missing columns") { + testPartitionedTable("insertInto() should reject missing columns if null default is disabled") { tableName => withTable("t") { sql("CREATE TABLE t (a INT, b INT)") - - intercept[AnalysisException] { - spark.table("t").write.insertInto(tableName) + withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") { + intercept[AnalysisException] { + spark.table("t").write.insertInto(tableName) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 10f18a9ef2e..4eae3933bf5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1258,13 +1258,11 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd """INSERT INTO TABLE dp_test PARTITION(dp) |SELECT key, value, key % 5 FROM src""".stripMargin) }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH", - sqlState = "21S01", + errorClass = "_LEGACY_ERROR_TEMP_1169", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`dp_test`", - "targetColumns" -> "4", - "insertedColumns" -> "3", - "staticPartCols" -> "0")) + "normalizedPartSpec" -> "dp", + "partColNames" -> "dp,sp")) sql("SET hive.exec.dynamic.partition.mode=nonstrict") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org