This is an automated email from the ASF dual-hosted git repository. brkyvz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 443904a [SPARK-27845][SQL] DataSourceV2: InsertTable 443904a is described below commit 443904a14044ff32421e577dc26d0d53112ceaba Author: Ryan Blue <b...@apache.org> AuthorDate: Thu Jul 25 15:05:51 2019 -0700 [SPARK-27845][SQL] DataSourceV2: InsertTable ## What changes were proposed in this pull request? Support multiple catalogs in the following InsertTable use cases: - INSERT INTO [TABLE] catalog.db.tbl - INSERT OVERWRITE TABLE catalog.db.tbl Support matrix: Overwrite|Partitioned Table|Partition Clause |Partition Overwrite Mode|Action ---------|-----------------|-----------------|------------------------|----- false|*|*|*|AppendData true|no|(empty)|*|OverwriteByExpression(true) true|yes|p1,p2 or p1 or p2 or (empty)|STATIC|OverwriteByExpression(true) true|yes|p2,p2 or p1 or p2 or (empty)|DYNAMIC|OverwritePartitionsDynamic true|yes|p1=23,p2=3|*|OverwriteByExpression(p1=23 and p2=3) true|yes|p1=23,p2 or p1=23|STATIC|OverwriteByExpression(p1=23) true|yes|p1=23,p2 or p1=23|DYNAMIC|OverwritePartitionsDynamic Notes: - Assume the partitioned table has 2 partitions: p1 and p2. - `STATIC` is the default Partition Overwrite Mode for data source tables. - DSv2 tables currently do not support `IfPartitionNotExists`. ## How was this patch tested? New tests. All existing catalyst and sql/core tests. Closes #24832 from jzhuge/SPARK-27845-pr. Lead-authored-by: Ryan Blue <b...@apache.org> Co-authored-by: John Zhuge <jzh...@apache.org> Signed-off-by: Burak Yavuz <brk...@gmail.com> --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 4 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 137 +++++++++- .../apache/spark/sql/catalyst/dsl/package.scala | 13 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 36 ++- .../plans/logical/sql/InsertIntoStatement.scala | 50 ++++ .../spark/sql/catalyst/parser/DDLParserSuite.scala | 113 +++++++- .../sql/catalyst/parser/PlanParserSuite.scala | 19 +- .../datasources/DataSourceResolution.scala | 1 + .../sql/sources/v2/DataSourceV2SQLSuite.scala | 301 ++++++++++++++++++++- .../sql/sources/v2/TestInMemoryTableCatalog.scala | 137 +++++++--- .../org/apache/spark/sql/hive/InsertSuite.scala | 3 +- 11 files changed, 738 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 0a142c2..517ef9d 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -294,8 +294,8 @@ query ; insertInto - : INSERT OVERWRITE TABLE tableIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable - | INSERT INTO TABLE? tableIdentifier partitionSpec? #insertIntoTable + : INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable + | INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? #insertIntoTable | INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir | INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e55cdfe..021fb26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,6 +25,8 @@ import scala.util.Random import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog, TableChange} +import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform} +import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util.loadTable import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -34,12 +36,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, InsertIntoStatement} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.sources.v2.Table import org.apache.spark.sql.types._ /** @@ -167,6 +171,7 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: ResolveAlterTable :: + ResolveInsertInto :: ResolveTables :: ResolveRelations :: ResolveReferences :: @@ -757,6 +762,136 @@ class Analyzer( } } + object ResolveInsertInto extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case i @ InsertIntoStatement( + UnresolvedRelation(CatalogObjectIdentifier(Some(tableCatalog), ident)), _, _, _, _) + if i.query.resolved => + loadTable(tableCatalog, ident) + .map(DataSourceV2Relation.create) + .map(relation => { + // ifPartitionNotExists is append with validation, but validation is not supported + if (i.ifPartitionNotExists) { + throw new AnalysisException( + s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}") + } + + val partCols = partitionColumnNames(relation.table) + validatePartitionSpec(partCols, i.partitionSpec) + + val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) + val query = addStaticPartitionColumns(relation, i.query, staticPartitions) + val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && + conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + + if (!i.overwrite) { + AppendData.byPosition(relation, query) + } else if (dynamicPartitionOverwrite) { + OverwritePartitionsDynamic.byPosition(relation, query) + } else { + OverwriteByExpression.byPosition( + relation, query, staticDeleteExpression(relation, staticPartitions)) + } + }) + .getOrElse(i) + + case i @ InsertIntoStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _) + if i.query.resolved => + InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists) + } + + private def partitionColumnNames(table: Table): Seq[String] = { + // get partition column names. in v2, partition columns are columns that are stored using an + // identity partition transform because the partition values and the column values are + // identical. otherwise, partition values are produced by transforming one or more source + // columns and cannot be set directly in a query's PARTITION clause. + table.partitioning.flatMap { + case IdentityTransform(FieldReference(Seq(name))) => Some(name) + case _ => None + } + } + + private def validatePartitionSpec( + partitionColumnNames: Seq[String], + partitionSpec: Map[String, Option[String]]): Unit = { + // check that each partition name is a partition column. otherwise, it is not valid + partitionSpec.keySet.foreach { partitionName => + partitionColumnNames.find(name => conf.resolver(name, partitionName)) match { + case Some(_) => + case None => + throw new AnalysisException( + s"PARTITION clause cannot contain a non-partition column name: $partitionName") + } + } + } + + private def addStaticPartitionColumns( + relation: DataSourceV2Relation, + query: LogicalPlan, + staticPartitions: Map[String, String]): LogicalPlan = { + + if (staticPartitions.isEmpty) { + query + + } else { + // add any static value as a literal column + val withStaticPartitionValues = { + // for each static name, find the column name it will replace and check for unknowns. + val outputNameToStaticName = staticPartitions.keySet.map(staticName => + relation.output.find(col => conf.resolver(col.name, staticName)) match { + case Some(attr) => + attr.name -> staticName + case _ => + throw new AnalysisException( + s"Cannot add static value for unknown column: $staticName") + }).toMap + + val queryColumns = query.output.iterator + + // for each output column, add the static value as a literal, or use the next input + // column. this does not fail if input columns are exhausted and adds remaining columns + // at the end. both cases will be caught by ResolveOutputRelation and will fail the + // query with a helpful error message. + relation.output.flatMap { col => + outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match { + case Some(staticValue) => + Some(Alias(Cast(Literal(staticValue), col.dataType), col.name)()) + case _ if queryColumns.hasNext => + Some(queryColumns.next) + case _ => + None + } + } ++ queryColumns + } + + Project(withStaticPartitionValues, query) + } + } + + private def staticDeleteExpression( + relation: DataSourceV2Relation, + staticPartitions: Map[String, String]): Expression = { + if (staticPartitions.isEmpty) { + Literal(true) + } else { + staticPartitions.map { case (name, value) => + relation.output.find(col => conf.resolver(col.name, name)) match { + case Some(attr) => + // the delete expression must reference the table's column names, but these attributes + // are not available when CheckAnalysis runs because the relation is not a child of + // the logical operation. instead, expressions are resolved after + // ResolveOutputRelation runs, using the query's column names that will match the + // table names at that point. because resolution happens after a future rule, create + // an UnresolvedAttribute. + EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType)) + case None => + throw new AnalysisException(s"Unknown static partition column: $name") + } + }.reduce(And) + } + } + } + /** * Resolve ALTER TABLE statements that use a DSv2 catalog. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 54fc1f9..796043f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.sql._ import org.apache.spark.sql.types._ /** @@ -379,10 +380,14 @@ package object dsl { Generate(generator, unrequiredChildIndex, outer, alias, outputNames.map(UnresolvedAttribute(_)), logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = - InsertIntoTable( - analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, ifPartitionNotExists = false) + def insertInto(tableName: String): LogicalPlan = insertInto(table(tableName)) + + def insertInto( + table: LogicalPlan, + partition: Map[String, Option[String]] = Map.empty, + overwrite: Boolean = false, + ifPartitionNotExists: Boolean = false): LogicalPlan = + InsertIntoStatement(table, partition, logicalPlan, overwrite, ifPartitionNotExists) def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a7a3b96..7d1ff15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement [...] +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, InsertIntoStatement, QualifiedColType, ReplaceT [...] import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -239,9 +239,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging /** * Parameters used for writing query to a table: - * (tableIdentifier, partitionKeys, exists). + * (multipartIdentifier, partitionKeys, ifPartitionNotExists). */ - type InsertTableParams = (TableIdentifier, Map[String, Option[String]], Boolean) + type InsertTableParams = (Seq[String], Map[String, Option[String]], Boolean) /** * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider). @@ -263,11 +263,21 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging query: LogicalPlan): LogicalPlan = withOrigin(ctx) { ctx match { case table: InsertIntoTableContext => - val (tableIdent, partitionKeys, exists) = visitInsertIntoTable(table) - InsertIntoTable(UnresolvedRelation(tableIdent), partitionKeys, query, false, exists) + val (tableIdent, partition, ifPartitionNotExists) = visitInsertIntoTable(table) + InsertIntoStatement( + UnresolvedRelation(tableIdent), + partition, + query, + overwrite = false, + ifPartitionNotExists) case table: InsertOverwriteTableContext => - val (tableIdent, partitionKeys, exists) = visitInsertOverwriteTable(table) - InsertIntoTable(UnresolvedRelation(tableIdent), partitionKeys, query, true, exists) + val (tableIdent, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table) + InsertIntoStatement( + UnresolvedRelation(tableIdent), + partition, + query, + overwrite = true, + ifPartitionNotExists) case dir: InsertOverwriteDirContext => val (isLocal, storage, provider) = visitInsertOverwriteDir(dir) InsertIntoDir(isLocal, storage, provider, query, overwrite = true) @@ -284,9 +294,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitInsertIntoTable( ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) { - val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier) val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + if (ctx.EXISTS != null) { + operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx) + } + (tableIdent, partitionKeys, false) } @@ -296,13 +310,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitInsertOverwriteTable( ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) { assert(ctx.OVERWRITE() != null) - val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier) val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { - throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + - "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) + operationNotAllowed("IF NOT EXISTS with dynamic partitions: " + + dynamicPartitionKeys.keys.mkString(", "), ctx) } (tableIdent, partitionKeys, ctx.EXISTS() != null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertIntoStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertIntoStatement.scala new file mode 100644 index 0000000..c4210ea --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertIntoStatement.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * An INSERT INTO statement, as parsed from SQL. + * + * @param table the logical plan representing the table. + * @param query the logical plan representing data to write to. + * @param overwrite overwrite existing table or partitions. + * @param partitionSpec a map from the partition key to the partition value (optional). + * If the value is missing, dynamic partition insert will be performed. + * As an example, `INSERT INTO tbl PARTITION (a=1, b=2) AS` would have + * Map('a' -> Some('1'), 'b' -> Some('2')), + * and `INSERT INTO tbl PARTITION (a=1, b) AS ...` + * would have Map('a' -> Some('1'), 'b' -> None). + * @param ifPartitionNotExists If true, only write if the partition does not exist. + * Only valid for static partitions. + */ +case class InsertIntoStatement( + table: LogicalPlan, + partitionSpec: Map[String, Option[String]], + query: LogicalPlan, + overwrite: Boolean, + ifPartitionNotExists: Boolean) extends ParsedStatement { + + require(overwrite || !ifPartitionNotExists, + "IF NOT EXISTS is only valid in INSERT OVERWRITE") + require(partitionSpec.values.forall(_.nonEmpty) || !ifPartitionNotExists, + "IF NOT EXISTS is only valid with static partitions") + + override def children: Seq[LogicalPlan] = query :: Nil +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index dd84170..0635f8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} -import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement [...] +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, InsertIntoStatement, QualifiedColType, ReplaceT [...] import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String @@ -616,6 +617,112 @@ class DDLParserSuite extends AnalysisTest { } } + test("insert table: basic append") { + Seq( + "INSERT INTO TABLE testcat.ns1.ns2.tbl SELECT * FROM source", + "INSERT INTO testcat.ns1.ns2.tbl SELECT * FROM source" + ).foreach { sql => + parseCompare(sql, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map.empty, + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = false, ifPartitionNotExists = false)) + } + } + + test("insert table: append from another catalog") { + parseCompare("INSERT INTO TABLE testcat.ns1.ns2.tbl SELECT * FROM testcat2.db.tbl", + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map.empty, + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("testcat2", "db", "tbl"))), + overwrite = false, ifPartitionNotExists = false)) + } + + test("insert table: append with partition") { + parseCompare( + """ + |INSERT INTO testcat.ns1.ns2.tbl + |PARTITION (p1 = 3, p2) + |SELECT * FROM source + """.stripMargin, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map("p1" -> Some("3"), "p2" -> None), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = false, ifPartitionNotExists = false)) + } + + test("insert table: overwrite") { + Seq( + "INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl SELECT * FROM source", + "INSERT OVERWRITE testcat.ns1.ns2.tbl SELECT * FROM source" + ).foreach { sql => + parseCompare(sql, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map.empty, + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = true, ifPartitionNotExists = false)) + } + } + + test("insert table: overwrite with partition") { + parseCompare( + """ + |INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl + |PARTITION (p1 = 3, p2) + |SELECT * FROM source + """.stripMargin, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map("p1" -> Some("3"), "p2" -> None), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = true, ifPartitionNotExists = false)) + } + + test("insert table: overwrite with partition if not exists") { + parseCompare( + """ + |INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl + |PARTITION (p1 = 3) IF NOT EXISTS + |SELECT * FROM source + """.stripMargin, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map("p1" -> Some("3")), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = true, ifPartitionNotExists = true)) + } + + test("insert table: if not exists with dynamic partition fails") { + val exc = intercept[AnalysisException] { + parsePlan( + """ + |INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl + |PARTITION (p1 = 3, p2) IF NOT EXISTS + |SELECT * FROM source + """.stripMargin) + } + + assert(exc.getMessage.contains("IF NOT EXISTS with dynamic partitions")) + assert(exc.getMessage.contains("p2")) + } + + test("insert table: if not exists without overwrite fails") { + val exc = intercept[AnalysisException] { + parsePlan( + """ + |INSERT INTO TABLE testcat.ns1.ns2.tbl + |PARTITION (p1 = 3) IF NOT EXISTS + |SELECT * FROM source + """.stripMargin) + } + + assert(exc.getMessage.contains("INSERT INTO ... IF NOT EXISTS")) + } + private case class TableSpec( name: Seq[String], schema: Option[StructType], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fb245ee..61f8c3b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, Un import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.sql.InsertIntoStatement import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType @@ -184,13 +185,15 @@ class PlanParserSuite extends AnalysisTest { } test("insert into") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ val sql = "select * from t" val plan = table("t").select(star()) def insert( partition: Map[String, Option[String]], overwrite: Boolean = false, ifPartitionNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifPartitionNotExists) + InsertIntoStatement(table("s"), partition, plan, overwrite, ifPartitionNotExists) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -205,17 +208,7 @@ class PlanParserSuite extends AnalysisTest { // Multi insert val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", - InsertIntoTable( - table("s"), Map.empty, plan.limit(1), false, ifPartitionNotExists = false).union( - InsertIntoTable( - table("u"), Map.empty, plan2, false, ifPartitionNotExists = false))) - } - - test ("insert with if not exists") { - val sql = "select * from t" - intercept(s"insert overwrite table s partition (e = 1, x) if not exists $sql", - "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [x]") - intercept[ParseException](parsePlan(s"insert overwrite table s if not exists $sql")) + plan.limit(1).insertInto("s").union(plan2.insertInto("u"))) } test("aggregation") { @@ -619,7 +612,7 @@ class PlanParserSuite extends AnalysisTest { comparePlans( parsePlan( "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"), - InsertIntoTable(table("s"), Map.empty, + InsertIntoStatement(table("s"), Map.empty, UnresolvedHint("REPARTITION", Seq(Literal(100)), UnresolvedHint("COALESCE", Seq(Literal(500)), UnresolvedHint("COALESCE", Seq(Literal(10)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 8685d2f..a51678d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -162,6 +162,7 @@ case class DataSourceResolution( case DataSourceV2Relation(CatalogTableAsV2(catalogTable), _, _) => UnresolvedCatalogRelation(catalogTable) + } object V1WriteProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index c173bdb..a3e029f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -22,12 +22,12 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalog.v2.Identifier import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 -import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType, LongType, MapType, StringType, StructField, StructType, TimestampType} @@ -1349,4 +1349,301 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn assert(updated.properties == Map("provider" -> "foo").asJava) } } + + test("InsertInto: append") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") + sql(s"INSERT INTO $t1 SELECT id, data FROM source") + checkAnswer(spark.table(t1), spark.table("source")) + } + } + + test("InsertInto: append - across catalog") { + val t1 = "testcat.ns1.ns2.tbl" + val t2 = "testcat2.db.tbl" + withTable(t1, t2) { + sql(s"CREATE TABLE $t1 USING foo AS SELECT * FROM source") + sql(s"CREATE TABLE $t2 (id bigint, data string) USING foo") + sql(s"INSERT INTO $t2 SELECT * FROM $t1") + checkAnswer(spark.table(t2), spark.table("source")) + } + } + + test("InsertInto: append to partitioned table - without PARTITION clause") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO TABLE $t1 SELECT * FROM source") + checkAnswer(spark.table(t1), spark.table("source")) + } + } + + test("InsertInto: append to partitioned table - with PARTITION clause") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO TABLE $t1 PARTITION (id) SELECT * FROM source") + checkAnswer(spark.table(t1), spark.table("source")) + } + } + + test("InsertInto: dynamic PARTITION clause fails with non-partition column") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + + val exc = intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $t1 PARTITION (data) SELECT * FROM source") + } + + assert(spark.table(t1).count === 0) + assert(exc.getMessage.contains("PARTITION clause cannot contain a non-partition column name")) + assert(exc.getMessage.contains("data")) + } + } + + test("InsertInto: static PARTITION clause fails with non-partition column") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (data)") + + val exc = intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $t1 PARTITION (id=1) SELECT data FROM source") + } + + assert(spark.table(t1).count === 0) + assert(exc.getMessage.contains("PARTITION clause cannot contain a non-partition column name")) + assert(exc.getMessage.contains("id")) + } + } + + test("InsertInto: fails when missing a column") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, missing string) USING foo") + val exc = intercept[AnalysisException] { + sql(s"INSERT INTO $t1 SELECT id, data FROM source") + } + + assert(spark.table(t1).count === 0) + assert(exc.getMessage.contains(s"Cannot write to '$t1', not enough data columns")) + } + } + + test("InsertInto: fails when an extra column is present") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") + val exc = intercept[AnalysisException] { + sql(s"INSERT INTO $t1 SELECT id, data, 'fruit' FROM source") + } + + assert(spark.table(t1).count === 0) + assert(exc.getMessage.contains(s"Cannot write to '$t1', too many data columns")) + } + } + + test("InsertInto: append to partitioned table - static clause") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 PARTITION (id = 23) SELECT data FROM source") + checkAnswer(spark.table(t1), sql("SELECT 23, data FROM source")) + } + } + + test("InsertInto: overwrite non-partitioned table") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 USING foo AS SELECT * FROM source") + sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM source2") + checkAnswer(spark.table(t1), spark.table("source2")) + } + } + + test("InsertInto: overwrite - dynamic clause - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'also-deleted')") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"))) + } + } + } + + test("InsertInto: overwrite - dynamic clause - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "keep"))) + } + } + } + + test("InsertInto: overwrite - missing clause - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'also-deleted')") + sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"))) + } + } + } + + test("InsertInto: overwrite - missing clause - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") + sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "keep"))) + } + } + } + + test("InsertInto: overwrite - static clause") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p1 int) USING foo PARTITIONED BY (p1)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 23), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p1 = 23) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 23), + Row(2, "b", 23), + Row(3, "c", 23), + Row(4, "keep", 2))) + } + } + + test("InsertInto: overwrite - mixed clause - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id, p = 2) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2))) + } + } + } + + test("InsertInto: overwrite - mixed clause reordered - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2))) + } + } + } + + test("InsertInto: overwrite - implicit dynamic partition - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2))) + } + } + } + + test("InsertInto: overwrite - mixed clause - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2), + Row(4, "keep", 2))) + } + } + } + + test("InsertInto: overwrite - mixed clause reordered - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id, p = 2) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2), + Row(4, "keep", 2))) + } + } + } + + test("InsertInto: overwrite - implicit dynamic partition - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2), + Row(4, "keep", 2))) + } + } + } + + test("InsertInto: overwrite - multiple static partitions - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM source") + checkAnswer(spark.table(t1), Seq( + Row(2, "a", 2), + Row(2, "b", 2), + Row(2, "c", 2), + Row(4, "keep", 2))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala index 9539808..19a41be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala @@ -24,12 +24,13 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.sql.catalog.v2.{CatalogV2Implicits, Identifier, StagingTableCatalog, TableCatalog, TableChange} -import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalog.v2.expressions.{IdentityTransform, Transform} import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.sources.{And, EqualTo, Filter} import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} -import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -70,12 +71,8 @@ class TestInMemoryTableCatalog extends TableCatalog { throw new TableAlreadyExistsException(ident) } TestInMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) - if (partitions.nonEmpty) { - throw new UnsupportedOperationException( - s"Catalog $name: Partitioned tables are not supported") - } - val table = new InMemoryTable(s"$name.${ident.quoted}", schema, properties) + val table = new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties) tables.put(ident, table) @@ -93,7 +90,8 @@ class TestInMemoryTableCatalog extends TableCatalog { throw new IllegalArgumentException(s"Cannot drop all fields") } - val newTable = new InMemoryTable(table.name, schema, properties, table.data) + val newTable = new InMemoryTable(table.name, schema, table.partitioning, properties) + .withData(table.data) tables.put(ident, newTable) @@ -118,28 +116,43 @@ class TestInMemoryTableCatalog extends TableCatalog { class InMemoryTable( val name: String, val schema: StructType, + override val partitioning: Array[Transform], override val properties: util.Map[String, String]) extends Table with SupportsRead with SupportsWrite { - def this( - name: String, - schema: StructType, - properties: util.Map[String, String], - data: Array[BufferedRows]) = { - this(name, schema, properties) - replaceData(data) + partitioning.foreach { t => + if (!t.isInstanceOf[IdentityTransform]) { + throw new IllegalArgumentException(s"Transform $t must be IdentityTransform") + } } - def rows: Seq[InternalRow] = data.flatMap(_.rows) + @volatile var dataMap: mutable.Map[Seq[Any], BufferedRows] = mutable.Map.empty + + def data: Array[BufferedRows] = dataMap.values.toArray + + def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq - @volatile var data: Array[BufferedRows] = Array.empty + private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) + private val partIndexes = partFieldNames.map(schema.fieldIndex(_)) - def replaceData(buffers: Array[BufferedRows]): Unit = synchronized { - data = buffers + private def getKey(row: InternalRow): Seq[Any] = partIndexes.map(row.toSeq(schema)(_)) + + def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { + data.foreach(_.rows.foreach { row => + val key = getKey(row) + dataMap += dataMap.get(key) + .map(key -> _.withRow(row)) + .getOrElse(key -> new BufferedRows().withRow(row)) + }) + this } override def capabilities: util.Set[TableCapability] = Set( - TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE).asJava + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.OVERWRITE_BY_FILTER, + TableCapability.OVERWRITE_DYNAMIC, + TableCapability.TRUNCATE).asJava override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { () => new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition])) @@ -157,43 +170,86 @@ class InMemoryTable( override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { TestInMemoryTableCatalog.maybeSimulateFailedTableWrite(options) - new WriteBuilder with SupportsTruncate { - private var shouldTruncate: Boolean = false + + new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite { + private var writer: BatchWrite = Append override def truncate(): WriteBuilder = { - shouldTruncate = true + assert(writer == Append) + writer = TruncateAndAppend this } - override def buildForBatch(): BatchWrite = { - if (shouldTruncate) TruncateAndAppend else Append + override def overwrite(filters: Array[Filter]): WriteBuilder = { + assert(writer == Append) + writer = new Overwrite(filters) + this } + + override def overwriteDynamicPartitions(): WriteBuilder = { + assert(writer == Append) + writer = DynamicOverwrite + this + } + + override def buildForBatch(): BatchWrite = writer } } - private object TruncateAndAppend extends BatchWrite { + private abstract class TestBatchWrite extends BatchWrite { override def createBatchWriterFactory(): DataWriterFactory = { BufferedRowsWriterFactory } - override def commit(messages: Array[WriterCommitMessage]): Unit = { - replaceData(messages.map(_.asInstanceOf[BufferedRows])) + override def abort(messages: Array[WriterCommitMessage]): Unit = { } + } - override def abort(messages: Array[WriterCommitMessage]): Unit = { + private object Append extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + withData(messages.map(_.asInstanceOf[BufferedRows])) } } - private object Append extends BatchWrite { - override def createBatchWriterFactory(): DataWriterFactory = { - BufferedRowsWriterFactory + private object DynamicOverwrite extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + dataMap --= newData.flatMap(_.rows.map(getKey)) + withData(newData) } + } - override def commit(messages: Array[WriterCommitMessage]): Unit = { - replaceData(data ++ messages.map(_.asInstanceOf[BufferedRows])) + private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val deleteKeys = dataMap.keys.filter { partValues => + filters.flatMap(splitAnd).forall { + case EqualTo(attr, value) => + partFieldNames.zipWithIndex.find(_._1 == attr) match { + case Some((_, partIndex)) => + value == partValues(partIndex) + case _ => + throw new IllegalArgumentException(s"Unknown filter attribute: $attr") + } + case f => + throw new IllegalArgumentException(s"Unsupported filter type: $f") + } + } + dataMap --= deleteKeys + withData(messages.map(_.asInstanceOf[BufferedRows])) } - override def abort(messages: Array[WriterCommitMessage]): Unit = { + private def splitAnd(filter: Filter): Seq[Filter] = { + filter match { + case And(left, right) => splitAnd(left) ++ splitAnd(right) + case _ => filter :: Nil + } + } + } + + private object TruncateAndAppend extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + dataMap.clear + withData(messages.map(_.asInstanceOf[BufferedRows])) } } } @@ -231,7 +287,7 @@ class TestStagingInMemoryCatalog validateStagedTable(partitions, properties) new TestStagedCreateTable( ident, - new InMemoryTable(s"$name.${ident.quoted}", schema, properties)) + new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)) } override def stageReplace( @@ -242,7 +298,7 @@ class TestStagingInMemoryCatalog validateStagedTable(partitions, properties) new TestStagedReplaceTable( ident, - new InMemoryTable(s"$name.${ident.quoted}", schema, properties)) + new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)) } override def stageCreateOrReplace( @@ -253,7 +309,7 @@ class TestStagingInMemoryCatalog validateStagedTable(partitions, properties) new TestStagedCreateOrReplaceTable( ident, - new InMemoryTable(s"$name.${ident.quoted}", schema, properties)) + new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)) } private def validateStagedTable( @@ -335,6 +391,11 @@ class TestStagingInMemoryCatalog class BufferedRows extends WriterCommitMessage with InputPartition with Serializable { val rows = new mutable.ArrayBuffer[InternalRow]() + + def withRow(row: InternalRow): BufferedRows = { + rows.append(row) + this + } } private object BufferedRowsReaderFactory extends PartitionReaderFactory { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index 70307ed..73f5bbd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -201,8 +201,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter |SELECT 7, 8, 3 """.stripMargin) } - assert(e.getMessage.contains( - "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [c]")) + assert(e.getMessage.contains("IF NOT EXISTS with dynamic partitions: c")) // If the partition already exists, the insert will overwrite the data // unless users specify IF NOT EXISTS --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org