This is an automated email from the ASF dual-hosted git repository. dbtsai pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 2824fec9 [SPARK-30289][SQL] Partitioned by Nested Column for `InMemoryTable` 2824fec9 is described below commit 2824fec9fa57444b7c64edb8226cf75bb87a2e5d Author: DB Tsai <d_t...@apple.com> AuthorDate: Fri Feb 14 21:46:01 2020 +0000 [SPARK-30289][SQL] Partitioned by Nested Column for `InMemoryTable` ### What changes were proposed in this pull request? 1. `InMemoryTable` was flatting the nested columns, and then the flatten columns was used to look up the indices which is not correct. This PR implements partitioned by nested column for `InMemoryTable`. ### Why are the changes needed? This PR implements partitioned by nested column for `InMemoryTable`, so we can test this features in DSv2 ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing unit tests and new tests. Closes #26929 from dbtsai/addTests. Authored-by: DB Tsai <d_t...@apple.com> Signed-off-by: DB Tsai <d_t...@apple.com> (cherry picked from commit d0f961476031b62bda0d4d41f7248295d651ea92) Signed-off-by: DB Tsai <d_t...@apple.com> --- .../apache/spark/sql/connector/InMemoryTable.scala | 35 +++++++-- .../apache/spark/sql/DataFrameWriterV2Suite.scala | 86 +++++++++++++++++++++- 2 files changed, 114 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index c9e4e0a..0187ae3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -26,7 +26,7 @@ import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.expressions.{IdentityTransform, Transform} +import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull} @@ -59,10 +59,30 @@ class InMemoryTable( def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq - private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) - private val partIndexes = partFieldNames.map(schema.fieldIndex) + private val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref => + schema.findNestedField(ref.fieldNames(), includeCollections = false) match { + case Some(_) => ref.fieldNames() + case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") + } + } - private def getKey(row: InternalRow): Seq[Any] = partIndexes.map(row.toSeq(schema)(_)) + private def getKey(row: InternalRow): Seq[Any] = { + def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = { + val index = schema.fieldIndex(fieldNames(0)) + val value = row.toSeq(schema).apply(index) + if (fieldNames.length > 1) { + (value, schema(index).dataType) match { + case (row: InternalRow, nestedSchema: StructType) => + extractor(fieldNames.drop(1), nestedSchema, row) + case (_, dataType) => + throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") + } + } else { + value + } + } + partCols.map(fieldNames => extractor(fieldNames, schema, row)) + } def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { data.foreach(_.rows.foreach { row => @@ -146,8 +166,10 @@ class InMemoryTable( } private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - val deleteKeys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + val deleteKeys = InMemoryTable.filtersToKeys( + dataMap.keys, partCols.map(_.toSeq.quoted), filters) dataMap --= deleteKeys withData(messages.map(_.asInstanceOf[BufferedRows])) } @@ -161,7 +183,8 @@ class InMemoryTable( } override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { - dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index d49dc58..cd15708 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -17,20 +17,24 @@ package org.apache.spark.sql +import java.sql.Timestamp + import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} -import org.apache.spark.sql.connector.InMemoryTableCatalog +import org.apache.spark.sql.connector.{InMemoryTable, InMemoryTableCatalog} import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.types.TimestampType import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -550,4 +554,84 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo assert(replaced.partitioning.isEmpty) assert(replaced.properties === defaultOwnership.asJava) } + + test("SPARK-30289 Create: partitioned by nested column") { + val schema = new StructType().add("ts", new StructType() + .add("created", TimestampType) + .add("modified", TimestampType) + .add("timezone", StringType)) + + val data = Seq( + Row(Row(Timestamp.valueOf("2019-06-01 10:00:00"), Timestamp.valueOf("2019-09-02 07:00:00"), + "America/Los_Angeles")), + Row(Row(Timestamp.valueOf("2019-08-26 18:00:00"), Timestamp.valueOf("2019-09-26 18:00:00"), + "America/Los_Angeles")), + Row(Row(Timestamp.valueOf("2018-11-23 18:00:00"), Timestamp.valueOf("2018-12-22 18:00:00"), + "America/New_York"))) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) + + df.writeTo("testcat.table_name") + .partitionedBy($"ts.timezone") + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + .asInstanceOf[InMemoryTable] + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(IdentityTransform(FieldReference(Array("ts", "timezone"))))) + checkAnswer(spark.table(table.name), data) + assert(table.dataMap.toArray.length == 2) + assert(table.dataMap(Seq(UTF8String.fromString("America/Los_Angeles"))).rows.size == 2) + assert(table.dataMap(Seq(UTF8String.fromString("America/New_York"))).rows.size == 1) + + // TODO: `DataSourceV2Strategy` can not translate nested fields into source filter yet + // so the following sql will fail. + // sql("DELETE FROM testcat.table_name WHERE ts.timezone = \"America/Los_Angeles\"") + } + + test("SPARK-30289 Create: partitioned by multiple transforms on nested columns") { + spark.table("source") + .withColumn("ts", struct( + lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created", + lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", + lit("America/Los_Angeles") as "timezone")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy( + years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"), + years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified") + ) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq( + YearsTransform(FieldReference(Array("ts", "created"))), + MonthsTransform(FieldReference(Array("ts", "created"))), + DaysTransform(FieldReference(Array("ts", "created"))), + HoursTransform(FieldReference(Array("ts", "created"))), + YearsTransform(FieldReference(Array("ts", "modified"))), + MonthsTransform(FieldReference(Array("ts", "modified"))), + DaysTransform(FieldReference(Array("ts", "modified"))), + HoursTransform(FieldReference(Array("ts", "modified"))))) + } + + test("SPARK-30289 Create: partitioned by bucket(4, ts.timezone)") { + spark.table("source") + .withColumn("ts", struct( + lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created", + lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", + lit("America/Los_Angeles") as "timezone")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(bucket(4, $"ts.timezone")) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(BucketTransform(LiteralValue(4, IntegerType), + Seq(FieldReference(Seq("ts", "timezone")))))) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org