This is an automated email from the ASF dual-hosted git repository. brkyvz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3821d75 [SPARK-28612][SQL] Add DataFrameWriterV2 API 3821d75 is described below commit 3821d75b836afae55a2a92c14b379bf4ec8a5362 Author: Ryan Blue <b...@apache.org> AuthorDate: Sat Aug 31 21:28:20 2019 -0700 [SPARK-28612][SQL] Add DataFrameWriterV2 API ## What changes were proposed in this pull request? This adds a new write API as proposed in the [SPIP to standardize logical plans](https://issues.apache.org/jira/browse/SPARK-23521). This new API: * Uses clear verbs to execute writes, like `append`, `overwrite`, `create`, and `replace` that correspond to the new logical plans. * Only creates v2 logical plans so the behavior is always consistent. * Does not allow table configuration options for operations that cannot change table configuration. For example, `partitionedBy` can only be called when the writer executes `create` or `replace`. Here are a few example uses of the new API: ```scala df.writeTo("catalog.db.table").append() df.writeTo("catalog.db.table").overwrite($"date" === "2019-06-01") df.writeTo("catalog.db.table").overwritePartitions() df.writeTo("catalog.db.table").asParquet.create() df.writeTo("catalog.db.table").partitionedBy(days($"ts")).createOrReplace() df.writeTo("catalog.db.table").using("abc").replace() ``` ## How was this patch tested? Added `DataFrameWriterV2Suite` that tests the new write API. Existing tests for v2 plans. Closes #25354 from rdblue/SPARK-28612-add-data-frame-writer-v2. Authored-by: Ryan Blue <b...@apache.org> Signed-off-by: Burak Yavuz <brk...@gmail.com> --- .../catalyst/expressions/PartitionTransforms.scala | 77 ++++ .../spark/sql/catalyst/analysis/Analyzer.scala | 6 +- .../plans/logical/basicLogicalOperators.scala | 47 +- .../datasources/v2/DataSourceV2Implicits.scala | 9 + .../apache/spark/sql/connector/InMemoryTable.scala | 5 +- .../org/apache/spark/sql/DataFrameWriter.scala | 11 +- .../org/apache/spark/sql/DataFrameWriterV2.scala | 365 +++++++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 28 ++ .../datasources/v2/DataSourceV2Strategy.scala | 20 +- .../datasources/v2/V2WriteSupportCheck.scala | 6 +- .../scala/org/apache/spark/sql/functions.scala | 64 +++ .../sql/sources/v2/DataFrameWriterV2Suite.scala | 508 +++++++++++++++++++++ 12 files changed, 1110 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala new file mode 100644 index 0000000..e48fd8a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types.{DataType, IntegerType} + +/** + * Base class for expressions that are converted to v2 partition transforms. + * + * Subclasses represent abstract transform functions with concrete implementations that are + * determined by data source implementations. Because the concrete implementation is not known, + * these expressions are [[Unevaluable]]. + * + * These expressions are used to pass transformations from the DataFrame API: + * + * {{{ + * df.writeTo("catalog.db.table").partitionedBy($"category", days($"timestamp")).create() + * }}} + */ +abstract class PartitionTransformExpression extends Expression with Unevaluable { + override def nullable: Boolean = true +} + +/** + * Expression for the v2 partition transform years. + */ +case class Years(child: Expression) extends PartitionTransformExpression { + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Seq(child) +} + +/** + * Expression for the v2 partition transform months. + */ +case class Months(child: Expression) extends PartitionTransformExpression { + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Seq(child) +} + +/** + * Expression for the v2 partition transform days. + */ +case class Days(child: Expression) extends PartitionTransformExpression { + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Seq(child) +} + +/** + * Expression for the v2 partition transform hours. + */ +case class Hours(child: Expression) extends PartitionTransformExpression { + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Seq(child) +} + +/** + * Expression for the v2 partition transform bucket. + */ +case class Bucket(numBuckets: Literal, child: Expression) extends PartitionTransformExpression { + override def dataType: DataType = IntegerType + override def children: Seq[Expression] = Seq(numBuckets, child) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dcb6af6..0cb5941 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2506,7 +2506,7 @@ class Analyzer( */ object ResolveOutputRelation extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { - case append @ AppendData(table, query, isByName) + case append @ AppendData(table, query, _, isByName) if table.resolved && query.resolved && !append.outputResolved => val projection = TableOutputResolver.resolveOutputColumns( @@ -2518,7 +2518,7 @@ class Analyzer( append } - case overwrite @ OverwriteByExpression(table, _, query, isByName) + case overwrite @ OverwriteByExpression(table, _, query, _, isByName) if table.resolved && query.resolved && !overwrite.outputResolved => val projection = TableOutputResolver.resolveOutputColumns( @@ -2530,7 +2530,7 @@ class Analyzer( overwrite } - case overwrite @ OverwritePartitionsDynamic(table, query, isByName) + case overwrite @ OverwritePartitionsDynamic(table, query, _, isByName) if table.resolved && query.resolved && !overwrite.outputResolved => val projection = TableOutputResolver.resolveOutputColumns( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 0be61cf..6e1825e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -489,7 +489,7 @@ case class ReplaceTableAsSelect( override def tableSchema: StructType = query.schema override def children: Seq[LogicalPlan] = Seq(query) - override lazy val resolved: Boolean = { + override lazy val resolved: Boolean = childrenResolved && { // the table schema is created from the query schema, so the only resolution needed is to check // that the columns referenced by the table's partitioning exist in the query schema val references = partitioning.flatMap(_.references).toSet @@ -507,15 +507,22 @@ case class ReplaceTableAsSelect( case class AppendData( table: NamedRelation, query: LogicalPlan, + writeOptions: Map[String, String], isByName: Boolean) extends V2WriteCommand object AppendData { - def byName(table: NamedRelation, df: LogicalPlan): AppendData = { - new AppendData(table, df, isByName = true) + def byName( + table: NamedRelation, + df: LogicalPlan, + writeOptions: Map[String, String] = Map.empty): AppendData = { + new AppendData(table, df, writeOptions, isByName = true) } - def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = { - new AppendData(table, query, isByName = false) + def byPosition( + table: NamedRelation, + query: LogicalPlan, + writeOptions: Map[String, String] = Map.empty): AppendData = { + new AppendData(table, query, writeOptions, isByName = false) } } @@ -526,19 +533,26 @@ case class OverwriteByExpression( table: NamedRelation, deleteExpr: Expression, query: LogicalPlan, + writeOptions: Map[String, String], isByName: Boolean) extends V2WriteCommand { override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved } object OverwriteByExpression { def byName( - table: NamedRelation, df: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { - OverwriteByExpression(table, deleteExpr, df, isByName = true) + table: NamedRelation, + df: LogicalPlan, + deleteExpr: Expression, + writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, df, writeOptions, isByName = true) } def byPosition( - table: NamedRelation, query: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { - OverwriteByExpression(table, deleteExpr, query, isByName = false) + table: NamedRelation, + query: LogicalPlan, + deleteExpr: Expression, + writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, query, writeOptions, isByName = false) } } @@ -548,15 +562,22 @@ object OverwriteByExpression { case class OverwritePartitionsDynamic( table: NamedRelation, query: LogicalPlan, + writeOptions: Map[String, String], isByName: Boolean) extends V2WriteCommand object OverwritePartitionsDynamic { - def byName(table: NamedRelation, df: LogicalPlan): OverwritePartitionsDynamic = { - OverwritePartitionsDynamic(table, df, isByName = true) + def byName( + table: NamedRelation, + df: LogicalPlan, + writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, df, writeOptions, isByName = true) } - def byPosition(table: NamedRelation, query: LogicalPlan): OverwritePartitionsDynamic = { - OverwritePartitionsDynamic(table, query, isByName = false) + def byPosition( + table: NamedRelation, + query: LogicalPlan, + writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, query, writeOptions, isByName = false) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala index 2d59c42..ab33e8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.sources.v2.{SupportsDelete, SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.util.CaseInsensitiveStringMap object DataSourceV2Implicits { implicit class TableHelper(table: Table) { @@ -53,4 +56,10 @@ object DataSourceV2Implicits { def supportsAny(capabilities: TableCapability*): Boolean = capabilities.exists(supports) } + + implicit class OptionsHelper(options: Map[String, String]) { + def asOptions: CaseInsensitiveStringMap = { + new CaseInsensitiveStringMap(options.asJava) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 0dea1e3..2dc4f8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -41,8 +41,11 @@ class InMemoryTable( override val properties: util.Map[String, String]) extends Table with SupportsRead with SupportsWrite with SupportsDelete { + private val allowUnsupportedTransforms = + properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean + partitioning.foreach { t => - if (!t.isInstanceOf[IdentityTransform]) { + if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) { throw new IllegalArgumentException(s"Transform $t must be IdentityTransform") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index d0a1d41..13d38d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -271,13 +271,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { modeForDSV2 match { case SaveMode.Append => runCommand(df.sparkSession, "save") { - AppendData.byName(relation, df.logicalPlan) + AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) } case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => // truncate the table runCommand(df.sparkSession, "save") { - OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) + OverwriteByExpression.byName( + relation, df.logicalPlan, Literal(true), extraOptions.toMap) } case other => @@ -383,7 +384,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val command = modeForDSV2 match { case SaveMode.Append => - AppendData.byPosition(table, df.logicalPlan) + AppendData.byPosition(table, df.logicalPlan, extraOptions.toMap) case SaveMode.Overwrite => val conf = df.sparkSession.sessionState.conf @@ -391,9 +392,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC if (dynamicPartitionOverwrite) { - OverwritePartitionsDynamic.byPosition(table, df.logicalPlan) + OverwritePartitionsDynamic.byPosition(table, df.logicalPlan, extraOptions.toMap) } else { - OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true)) + OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true), extraOptions.toMap) } case other => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala new file mode 100644 index 0000000..57b212e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalog.v2.expressions.{LogicalExpressions, Transform} +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.IntegerType + +/** + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API. + * + * @since 3.0.0 + */ +@Experimental +final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) + extends CreateTableWriter[T] { + + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._ + import df.sparkSession.sessionState.analyzer.CatalogObjectIdentifier + + private val df: DataFrame = ds.toDF() + + private val sparkSession = ds.sparkSession + + private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) + + private val (catalog, identifier) = { + val CatalogObjectIdentifier(maybeCatalog, identifier) = tableName + val catalog = maybeCatalog.orElse(sparkSession.sessionState.analyzer.sessionCatalog) + .getOrElse(throw new AnalysisException( + s"No catalog specified for table ${identifier.quoted} and no default v2 catalog is set")) + .asTableCatalog + + (catalog, identifier) + } + + private val logicalPlan = df.queryExecution.logical + + private var provider: Option[String] = None + + private val options = new mutable.HashMap[String, String]() + + private val properties = new mutable.HashMap[String, String]() + + private var partitioning: Option[Seq[Transform]] = None + + override def using(provider: String): CreateTableWriter[T] = { + this.provider = Some(provider) + this + } + + override def option(key: String, value: String): DataFrameWriterV2[T] = { + this.options.put(key, value) + this + } + + override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = { + options.foreach { + case (key, value) => + this.options.put(key, value) + } + this + } + + override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = { + this.options(options.asScala) + this + } + + override def tableProperty(property: String, value: String): DataFrameWriterV2[T] = { + this.properties.put(property, value) + this + } + + @scala.annotation.varargs + override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = { + val asTransforms = (column +: columns).map(_.expr).map { + case Years(attr: Attribute) => + LogicalExpressions.years(attr.name) + case Months(attr: Attribute) => + LogicalExpressions.months(attr.name) + case Days(attr: Attribute) => + LogicalExpressions.days(attr.name) + case Hours(attr: Attribute) => + LogicalExpressions.hours(attr.name) + case Bucket(Literal(numBuckets: Int, IntegerType), attr: Attribute) => + LogicalExpressions.bucket(numBuckets, attr.name) + case attr: Attribute => + LogicalExpressions.identity(attr.name) + case expr => + throw new AnalysisException(s"Invalid partition transformation: ${expr.sql}") + } + + this.partitioning = Some(asTransforms) + this + } + + override def create(): Unit = { + // create and replace could alternatively create ParsedPlan statements, like + // `CreateTableFromDataFrameStatement(UnresolvedRelation(tableName), ...)`, to keep the catalog + // resolution logic in the analyzer. + runCommand("create") { + CreateTableAsSelect( + catalog, + identifier, + partitioning.getOrElse(Seq.empty), + logicalPlan, + properties = provider.map(p => properties + ("provider" -> p)).getOrElse(properties).toMap, + writeOptions = options.toMap, + ignoreIfExists = false) + } + } + + override def replace(): Unit = { + internalReplace(orCreate = false) + } + + override def createOrReplace(): Unit = { + internalReplace(orCreate = true) + } + + + /** + * Append the contents of the data frame to the output table. + * + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be + * validated to ensure it is compatible with the existing table. + * + * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist + */ + @throws(classOf[NoSuchTableException]) + def append(): Unit = { + val append = loadTable(catalog, identifier) match { + case Some(t) => + AppendData.byName(DataSourceV2Relation.create(t), logicalPlan, options.toMap) + case _ => + throw new NoSuchTableException(identifier) + } + + runCommand("append")(append) + } + + /** + * Overwrite rows matching the given filter condition with the contents of the data frame in + * the output table. + * + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. + * The data frame will be validated to ensure it is compatible with the existing table. + * + * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist + */ + @throws(classOf[NoSuchTableException]) + def overwrite(condition: Column): Unit = { + val overwrite = loadTable(catalog, identifier) match { + case Some(t) => + OverwriteByExpression.byName( + DataSourceV2Relation.create(t), logicalPlan, condition.expr, options.toMap) + case _ => + throw new NoSuchTableException(identifier) + } + + runCommand("overwrite")(overwrite) + } + + /** + * Overwrite all partition for which the data frame contains at least one row with the contents + * of the data frame in the output table. + * + * This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces + * partitions dynamically depending on the contents of the data frame. + * + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be + * validated to ensure it is compatible with the existing table. + * + * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist + */ + @throws(classOf[NoSuchTableException]) + def overwritePartitions(): Unit = { + val dynamicOverwrite = loadTable(catalog, identifier) match { + case Some(t) => + OverwritePartitionsDynamic.byName( + DataSourceV2Relation.create(t), logicalPlan, options.toMap) + case _ => + throw new NoSuchTableException(identifier) + } + + runCommand("overwritePartitions")(dynamicOverwrite) + } + + /** + * Wrap an action to track the QueryExecution and time cost, then report to the user-registered + * callback functions. + */ + private def runCommand(name: String)(command: LogicalPlan): Unit = { + val qe = sparkSession.sessionState.executePlan(command) + // call `QueryExecution.toRDD` to trigger the execution of commands. + SQLExecution.withNewExecutionId(sparkSession, qe, Some(name))(qe.toRdd) + } + + private def internalReplace(orCreate: Boolean): Unit = { + runCommand("replace") { + ReplaceTableAsSelect( + catalog, + identifier, + partitioning.getOrElse(Seq.empty), + logicalPlan, + properties = provider.map(p => properties + ("provider" -> p)).getOrElse(properties).toMap, + writeOptions = options.toMap, + orCreate = orCreate) + } + } +} + +/** + * Configuration methods common to create/replace operations and insert/overwrite operations. + * @tparam R builder type to return + */ +trait WriteConfigMethods[R] { + /** + * Add a write option. + * + * @since 3.0.0 + */ + def option(key: String, value: String): R + + /** + * Add a boolean output option. + * + * @since 3.0.0 + */ + def option(key: String, value: Boolean): R = option(key, value.toString) + + /** + * Add a long output option. + * + * @since 3.0.0 + */ + def option(key: String, value: Long): R = option(key, value.toString) + + /** + * Add a double output option. + * + * @since 3.0.0 + */ + def option(key: String, value: Double): R = option(key, value.toString) + + /** + * Add write options from a Scala Map. + * + * @since 3.0.0 + */ + def options(options: scala.collection.Map[String, String]): R + + /** + * Add write options from a Java Map. + * + * @since 3.0.0 + */ + def options(options: java.util.Map[String, String]): R +} + +/** + * Trait to restrict calls to create and replace operations. + */ +trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { + /** + * Create a new table from the contents of the data frame. + * + * The new table's schema, partition layout, properties, and other configuration will be + * based on the configuration set on this writer. + * + * If the output table exists, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException]]. + * + * @throws org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException + * If the table already exists + */ + @throws(classOf[TableAlreadyExistsException]) + def create(): Unit + + /** + * Replace an existing table with the contents of the data frame. + * + * The existing table's schema, partition layout, properties, and other configuration will be + * replaced with the contents of the data frame and the configuration set on this writer. + * + * If the output table does not exist, this operation will fail with + * [[org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException]]. + * + * @throws org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException + * If the table already exists + */ + @throws(classOf[CannotReplaceMissingTableException]) + def replace(): Unit + + /** + * Create a new table or replace an existing table with the contents of the data frame. + * + * The output table's schema, partition layout, properties, and other configuration will be based + * on the contents of the data frame and the configuration set on this writer. If the table + * exists, its configuration and data will be replaced. + */ + def createOrReplace(): Unit + + /** + * Partition the output table created by `create`, `createOrReplace`, or `replace` using + * the given columns or transforms. + * + * When specified, the table data will be stored by these values for efficient reads. + * + * For example, when a table is partitioned by day, it may be stored in a directory layout like: + * <ul> + * <li>`table/day=2019-06-01/`</li> + * <li>`table/day=2019-06-02/`</li> + * </ul> + * + * Partitioning is one of the most widely used techniques to optimize physical data layout. + * It provides a coarse-grained index for skipping unnecessary data reads when queries have + * predicates on the partitioned columns. In order for partitioning to work well, the number + * of distinct values in each column should typically be less than tens of thousands. + * + * @since 3.0.0 + */ + def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] + + /** + * Specifies a provider for the underlying output data source. Spark's default catalog supports + * "parquet", "json", etc. + * + * @since 3.0.0 + */ + def using(provider: String): CreateTableWriter[T] + + /** + * Add a table property. + */ + def tableProperty(property: String, value: String): CreateTableWriter[T] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7c25397..23360df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3191,6 +3191,34 @@ class Dataset[T] private[sql]( } /** + * Create a write configuration builder for v2 sources. + * + * This builder is used to configure and execute write operations. For example, to append to an + * existing table, run: + * + * {{{ + * df.writeTo("catalog.db.table").append() + * }}} + * + * This can also be used to create or replace existing tables: + * + * {{{ + * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace() + * }}} + * + * @group basic + * @since 3.0.0 + */ + def writeTo(table: String): DataFrameWriterV2[T] = { + // TODO: streaming could be adapted to use this interface + if (isStreaming) { + logicalPlan.failAnalysis( + "'writeTo' can not be called on streaming Dataset/DataFrame") + } + new DataFrameWriterV2[T](table, this) + } + + /** * Interface for saving the content of the streaming Dataset out into external storage. * * @group basic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index a934c09..b5a573c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2 -import java.util.UUID - import scala.collection.JavaConverters._ import scala.collection.mutable @@ -34,7 +32,6 @@ import org.apache.spark.sql.sources import org.apache.spark.sql.sources.v2.TableCapability import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.V1WriteBuilder import org.apache.spark.sql.util.CaseInsensitiveStringMap object DataSourceV2Strategy extends Strategy with PredicateHelper { @@ -212,15 +209,15 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { orCreate = orCreate) :: Nil } - case AppendData(r: DataSourceV2Relation, query, _) => + case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - AppendDataExecV1(v1, r.options, query) :: Nil + AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil case v2 => - AppendDataExec(v2, r.options, planLater(query)) :: Nil + AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil } - case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) => + case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => // fail if any filter cannot be converted. correctness depends on removing all matching data. val filters = splitConjunctivePredicates(deleteExpr).map { filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( @@ -228,13 +225,14 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { }.toArray r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - OverwriteByExpressionExecV1(v1, filters, r.options, query) :: Nil + OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query) :: Nil case v2 => - OverwriteByExpressionExec(v2, filters, r.options, planLater(query)) :: Nil + OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil } - case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) => - OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil + case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => + OverwritePartitionsDynamicExec( + r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil case DeleteFromTable(r: DataSourceV2Relation, condition) => // fail if any filter cannot be converted. correctness depends on removing all matching data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala index 5648d54..5a093ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala @@ -29,14 +29,14 @@ object V2WriteSupportCheck extends (LogicalPlan => Unit) { def failAnalysis(msg: String): Unit = throw new AnalysisException(msg) override def apply(plan: LogicalPlan): Unit = plan foreach { - case AppendData(rel: DataSourceV2Relation, _, _) if !rel.table.supports(BATCH_WRITE) => + case AppendData(rel: DataSourceV2Relation, _, _, _) if !rel.table.supports(BATCH_WRITE) => failAnalysis(s"Table does not support append in batch mode: ${rel.table}") - case OverwritePartitionsDynamic(rel: DataSourceV2Relation, _, _) + case OverwritePartitionsDynamic(rel: DataSourceV2Relation, _, _, _) if !rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_DYNAMIC) => failAnalysis(s"Table does not support dynamic overwrite in batch mode: ${rel.table}") - case OverwriteByExpression(rel: DataSourceV2Relation, expr, _, _) => + case OverwriteByExpression(rel: DataSourceV2Relation, expr, _, _, _) => expr match { case Literal(true, BooleanType) => if (!rel.table.supports(BATCH_WRITE) || diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6b8127b..0ece755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -69,6 +69,7 @@ import org.apache.spark.util.Utils * @groupname window_funcs Window functions * @groupname string_funcs String functions * @groupname collection_funcs Collection functions + * @groupname partition_transforms Partition transform functions * @groupname Ungrouped Support functions for DataFrames * @since 1.3.0 */ @@ -3942,6 +3943,69 @@ object functions { */ def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String].asJava) + // turn off style check that object names must start with a capital letter + // scalastyle:off + object partitioning { + // scalastyle:on + + /** + * A transform for timestamps and dates to partition data into years. + * + * @group partition_transforms + * @since 3.0.0 + */ + def years(e: Column): Column = withExpr { Years(e.expr) } + + /** + * A transform for timestamps and dates to partition data into months. + * + * @group partition_transforms + * @since 3.0.0 + */ + def months(e: Column): Column = withExpr { Months(e.expr) } + + /** + * A transform for timestamps and dates to partition data into days. + * + * @group partition_transforms + * @since 3.0.0 + */ + def days(e: Column): Column = withExpr { Days(e.expr) } + + /** + * A transform for timestamps to partition data into hours. + * + * @group partition_transforms + * @since 3.0.0 + */ + def hours(e: Column): Column = withExpr { Hours(e.expr) } + + /** + * A transform for any type that partitions by a hash of the input column. + * + * @group partition_transforms + * @since 3.0.0 + */ + def bucket(numBuckets: Column, e: Column): Column = withExpr { + numBuckets.expr match { + case lit @ Literal(_, IntegerType) => + Bucket(lit, e.expr) + case _ => + throw new AnalysisException(s"Invalid number of buckets: bucket($numBuckets, $e)") + } + } + + /** + * A transform for any type that partitions by a hash of the input column. + * + * @group partition_transforms + * @since 3.0.0 + */ + def bucket(numBuckets: Int, e: Column): Column = withExpr { + Bucket(Literal(numBuckets), e.expr) + } + } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala new file mode 100644 index 0000000..810a192 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala @@ -0,0 +1,508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalog.v2.{ Identifier, TableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.InMemoryTableCatalog +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} + +class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + import org.apache.spark.sql.functions._ + import testImplicits._ + + private def catalog(name: String): TableCatalog = { + spark.sessionState.catalogManager.catalog(name).asTableCatalog + } + + before { + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + + val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + df.createOrReplaceTempView("source") + val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L, "f"))).toDF("id", "data") + df2.createOrReplaceTempView("source2") + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.clear() + } + + test("Append: basic append") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + spark.table("source").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + spark.table("source2").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"), Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + } + + test("Append: by name not position") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + val exc = intercept[AnalysisException] { + spark.table("source").withColumnRenamed("data", "d").writeTo("testcat.table_name").append() + } + + assert(exc.getMessage.contains("Cannot find data for output column")) + assert(exc.getMessage.contains("'data'")) + + checkAnswer( + spark.table("testcat.table_name"), + Seq()) + } + + test("Append: fail if table does not exist") { + val exc = intercept[NoSuchTableException] { + spark.table("source").writeTo("testcat.table_name").append() + } + + assert(exc.getMessage.contains("table_name")) + } + + test("Overwrite: overwrite by expression: true") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + spark.table("source").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + spark.table("source2").writeTo("testcat.table_name").overwrite(lit(true)) + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + } + + test("Overwrite: overwrite by expression: id = 3") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + spark.table("source").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + spark.table("source2").writeTo("testcat.table_name").overwrite($"id" === 3) + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + } + + test("Overwrite: by name not position") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + val exc = intercept[AnalysisException] { + spark.table("source").withColumnRenamed("data", "d") + .writeTo("testcat.table_name").overwrite(lit(true)) + } + + assert(exc.getMessage.contains("Cannot find data for output column")) + assert(exc.getMessage.contains("'data'")) + + checkAnswer( + spark.table("testcat.table_name"), + Seq()) + } + + test("Overwrite: fail if table does not exist") { + val exc = intercept[NoSuchTableException] { + spark.table("source").writeTo("testcat.table_name").overwrite(lit(true)) + } + + assert(exc.getMessage.contains("table_name")) + } + + test("OverwritePartitions: overwrite conflicting partitions") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + spark.table("source").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + spark.table("source2").withColumn("id", $"id" - 2) + .writeTo("testcat.table_name").overwritePartitions() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "d"), Row(3L, "e"), Row(4L, "f"))) + } + + test("OverwritePartitions: overwrite all rows if not partitioned") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + spark.table("source").writeTo("testcat.table_name").append() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + spark.table("source2").writeTo("testcat.table_name").overwritePartitions() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + } + + test("OverwritePartitions: by name not position") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + checkAnswer(spark.table("testcat.table_name"), Seq.empty) + + val exc = intercept[AnalysisException] { + spark.table("source").withColumnRenamed("data", "d") + .writeTo("testcat.table_name").overwritePartitions() + } + + assert(exc.getMessage.contains("Cannot find data for output column")) + assert(exc.getMessage.contains("'data'")) + + checkAnswer( + spark.table("testcat.table_name"), + Seq()) + } + + test("OverwritePartitions: fail if table does not exist") { + val exc = intercept[NoSuchTableException] { + spark.table("source").writeTo("testcat.table_name").overwritePartitions() + } + + assert(exc.getMessage.contains("table_name")) + } + + test("Create: basic behavior") { + spark.table("source").writeTo("testcat.table_name").create() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning.isEmpty) + assert(table.properties.isEmpty) + } + + test("Create: with using") { + spark.table("source").writeTo("testcat.table_name").using("foo").create() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning.isEmpty) + assert(table.properties === Map("provider" -> "foo").asJava) + } + + test("Create: with property") { + spark.table("source").writeTo("testcat.table_name").tableProperty("prop", "value").create() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning.isEmpty) + assert(table.properties === Map("prop" -> "value").asJava) + } + + test("Create: identity partitioned table") { + spark.table("source").writeTo("testcat.table_name").partitionedBy($"id").create() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning === Seq(IdentityTransform(FieldReference("id")))) + assert(table.properties.isEmpty) + } + + test("Create: partitioned by years(ts)") { + spark.table("source") + .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(partitioning.years($"ts")) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(YearsTransform(FieldReference("ts")))) + } + + test("Create: partitioned by months(ts)") { + spark.table("source") + .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(partitioning.months($"ts")) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(MonthsTransform(FieldReference("ts")))) + } + + test("Create: partitioned by days(ts)") { + spark.table("source") + .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(partitioning.days($"ts")) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(DaysTransform(FieldReference("ts")))) + } + + test("Create: partitioned by hours(ts)") { + spark.table("source") + .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(partitioning.hours($"ts")) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(HoursTransform(FieldReference("ts")))) + } + + test("Create: partitioned by bucket(4, id)") { + spark.table("source") + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(partitioning.bucket(4, $"id")) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === + Seq(BucketTransform(LiteralValue(4, IntegerType), Seq(FieldReference("id"))))) + } + + test("Create: fail if table already exists") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + + val exc = intercept[TableAlreadyExistsException] { + spark.table("source").writeTo("testcat.table_name").create() + } + + assert(exc.getMessage.contains("table_name")) + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + // table should not have been changed + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning === Seq(IdentityTransform(FieldReference("id")))) + assert(table.properties === Map("provider" -> "foo").asJava) + } + + test("Replace: basic behavior") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + spark.sql("INSERT INTO TABLE testcat.table_name SELECT * FROM source") + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + // validate the initial table + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning === Seq(IdentityTransform(FieldReference("id")))) + assert(table.properties === Map("provider" -> "foo").asJava) + + spark.table("source2") + .withColumn("even_or_odd", when(($"id" % 2) === 0, "even").otherwise("odd")) + .writeTo("testcat.table_name").replace() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) + + val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + // validate the replacement table + assert(replaced.name === "testcat.table_name") + assert(replaced.schema === new StructType() + .add("id", LongType) + .add("data", StringType) + .add("even_or_odd", StringType)) + assert(replaced.partitioning.isEmpty) + assert(replaced.properties.isEmpty) + } + + test("Replace: partitioned table") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + spark.sql("INSERT INTO TABLE testcat.table_name SELECT * FROM source") + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + // validate the initial table + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning.isEmpty) + assert(table.properties === Map("provider" -> "foo").asJava) + + spark.table("source2") + .withColumn("even_or_odd", when(($"id" % 2) === 0, "even").otherwise("odd")) + .writeTo("testcat.table_name").partitionedBy($"id").replace() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) + + val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + // validate the replacement table + assert(replaced.name === "testcat.table_name") + assert(replaced.schema === new StructType() + .add("id", LongType) + .add("data", StringType) + .add("even_or_odd", StringType)) + assert(replaced.partitioning === Seq(IdentityTransform(FieldReference("id")))) + assert(replaced.properties.isEmpty) + } + + test("Replace: fail if table does not exist") { + val exc = intercept[CannotReplaceMissingTableException] { + spark.table("source").writeTo("testcat.table_name").replace() + } + + assert(exc.getMessage.contains("table_name")) + } + + test("CreateOrReplace: table does not exist") { + spark.table("source2").writeTo("testcat.table_name").createOrReplace() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + + val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + // validate the replacement table + assert(replaced.name === "testcat.table_name") + assert(replaced.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(replaced.partitioning.isEmpty) + assert(replaced.properties.isEmpty) + } + + test("CreateOrReplace: table exists") { + spark.sql( + "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") + spark.sql("INSERT INTO TABLE testcat.table_name SELECT * FROM source") + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + // validate the initial table + assert(table.name === "testcat.table_name") + assert(table.schema === new StructType().add("id", LongType).add("data", StringType)) + assert(table.partitioning === Seq(IdentityTransform(FieldReference("id")))) + assert(table.properties === Map("provider" -> "foo").asJava) + + spark.table("source2") + .withColumn("even_or_odd", when(($"id" % 2) === 0, "even").otherwise("odd")) + .writeTo("testcat.table_name").createOrReplace() + + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even"))) + + val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + // validate the replacement table + assert(replaced.name === "testcat.table_name") + assert(replaced.schema === new StructType() + .add("id", LongType) + .add("data", StringType) + .add("even_or_odd", StringType)) + assert(replaced.partitioning.isEmpty) + assert(replaced.properties.isEmpty) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org