This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 30e3fcb  [SPARK-32168][SQL] Fix hidden partitioning correctness bug in 
SQL overwrite
30e3fcb is described below

commit 30e3fcbb081b6b07a12316c9dd63c66f77b8fec7
Author: Ryan Blue <[email protected]>
AuthorDate: Wed Jul 8 16:06:40 2020 -0700

    [SPARK-32168][SQL] Fix hidden partitioning correctness bug in SQL overwrite
    
    ### What changes were proposed in this pull request?
    
    When converting an `INSERT OVERWRITE` query to a v2 overwrite plan, Spark 
attempts to detect when a dynamic overwrite and a static overwrite will produce 
the same result so it can use the static overwrite. Spark incorrectly detects 
when dynamic and static overwrites are equivalent when there are hidden 
partitions, such as `days(ts)`.
    
    This updates the analyzer rule `ResolveInsertInto` to always use a dynamic 
overwrite when the mode is dynamic, and static when the mode is static. This 
avoids the problem by not trying to determine whether the two plans are 
equivalent and always using the one that corresponds to the partition overwrite 
mode.
    
    ### Why are the changes needed?
    
    This is a correctness bug. If a table has hidden partitions, all of the 
values for those partitions are dropped instead of dynamically overwriting 
changed partitions.
    
    This only affects SQL commands (not `DataFrameWriter`) writing to tables 
that have hidden partitions. It is also only a problem when the partition 
overwrite mode is dynamic.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it fixes the correctness bug detailed above.
    
    ### How was this patch tested?
    
    * This updates the in-memory table to support a hidden partition transform, 
`days`, and adds a test case to `DataSourceV2SQLSuite` in which the table uses 
this hidden partition function. This test fails without the fix to 
`ResolveInsertInto`.
    * This updates the test case `InsertInto: overwrite - multiple static 
partitions - dynamic mode` in `InsertIntoTests`. The result of the SQL command 
is unchanged, but the SQL command will now use a dynamic overwrite so the test 
now uses `dynamicOverwriteTest`.
    
    Closes #28993 from rdblue/fix-insert-overwrite-v2-conversion.
    
    Authored-by: Ryan Blue <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
    (cherry picked from commit 3bb1ac597a6603e8224cb99349419d950ad7318e)
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  4 +-
 .../apache/spark/sql/connector/InMemoryTable.scala | 66 +++++++++++++++++++---
 .../execution/datasources/v2/BatchScanExec.scala   |  2 +-
 .../apache/spark/sql/DataFrameWriterV2Suite.scala  |  7 ---
 .../spark/sql/connector/DataSourceV2SQLSuite.scala | 38 ++++++++++++-
 .../spark/sql/connector/InsertIntoTests.scala      | 27 ++++-----
 6 files changed, 107 insertions(+), 37 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6fb103e..243b555 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1041,12 +1041,10 @@ class Analyzer(
 
         val staticPartitions = 
i.partitionSpec.filter(_._2.isDefined).mapValues(_.get)
         val query = addStaticPartitionColumns(r, i.query, staticPartitions)
-        val dynamicPartitionOverwrite = partCols.size > staticPartitions.size 
&&
-          conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
 
         if (!i.overwrite) {
           AppendData.byPosition(r, query)
-        } else if (dynamicPartitionOverwrite) {
+        } else if (conf.partitionOverwriteMode == 
PartitionOverwriteMode.DYNAMIC) {
           OverwritePartitionsDynamic.byPosition(r, query)
         } else {
           OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, 
staticPartitions))
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
index 3d7026e..616fc72 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.connector
 
+import java.time.{Instant, ZoneId}
+import java.time.temporal.ChronoUnit
 import java.util
 
 import scala.collection.JavaConverters._
@@ -25,12 +27,13 @@ import scala.collection.mutable
 import org.scalatest.Assertions._
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.connector.catalog._
-import org.apache.spark.sql.connector.expressions.{IdentityTransform, 
NamedReference, Transform}
+import org.apache.spark.sql.connector.expressions.{BucketTransform, 
DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, 
YearsTransform}
 import org.apache.spark.sql.connector.read._
 import org.apache.spark.sql.connector.write._
 import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, DateType, StructType, 
TimestampType}
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 /**
@@ -46,10 +49,15 @@ class InMemoryTable(
   private val allowUnsupportedTransforms =
     properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean
 
-  partitioning.foreach { t =>
-    if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) {
-      throw new IllegalArgumentException(s"Transform $t must be 
IdentityTransform")
-    }
+  partitioning.foreach {
+    case _: IdentityTransform =>
+    case _: YearsTransform =>
+    case _: MonthsTransform =>
+    case _: DaysTransform =>
+    case _: HoursTransform =>
+    case _: BucketTransform =>
+    case t if !allowUnsupportedTransforms =>
+      throw new IllegalArgumentException(s"Transform $t is not a supported 
transform")
   }
 
   // The key `Seq[Any]` is the partition values.
@@ -66,8 +74,14 @@ class InMemoryTable(
     }
   }
 
+  private val UTC = ZoneId.of("UTC")
+  private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate
+
   private def getKey(row: InternalRow): Seq[Any] = {
-    def extractor(fieldNames: Array[String], schema: StructType, row: 
InternalRow): Any = {
+    def extractor(
+        fieldNames: Array[String],
+        schema: StructType,
+        row: InternalRow): (Any, DataType) = {
       val index = schema.fieldIndex(fieldNames(0))
       val value = row.toSeq(schema).apply(index)
       if (fieldNames.length > 1) {
@@ -78,10 +92,44 @@ class InMemoryTable(
             throw new IllegalArgumentException(s"Unsupported type, 
${dataType.simpleString}")
         }
       } else {
-        value
+        (value, schema(index).dataType)
       }
     }
-    partCols.map(fieldNames => extractor(fieldNames, schema, row))
+
+    partitioning.map {
+      case IdentityTransform(ref) =>
+        extractor(ref.fieldNames, schema, row)._1
+      case YearsTransform(ref) =>
+        extractor(ref.fieldNames, schema, row) match {
+          case (days: Int, DateType) =>
+            ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, 
DateTimeUtils.daysToLocalDate(days))
+          case (micros: Long, TimestampType) =>
+            val localDate = 
DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
+            ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
+        }
+      case MonthsTransform(ref) =>
+        extractor(ref.fieldNames, schema, row) match {
+          case (days: Int, DateType) =>
+            ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, 
DateTimeUtils.daysToLocalDate(days))
+          case (micros: Long, TimestampType) =>
+            val localDate = 
DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
+            ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate)
+        }
+      case DaysTransform(ref) =>
+        extractor(ref.fieldNames, schema, row) match {
+          case (days, DateType) =>
+            days
+          case (micros: Long, TimestampType) =>
+            ChronoUnit.DAYS.between(Instant.EPOCH, 
DateTimeUtils.microsToInstant(micros))
+        }
+      case HoursTransform(ref) =>
+        extractor(ref.fieldNames, schema, row) match {
+          case (micros: Long, TimestampType) =>
+            ChronoUnit.HOURS.between(Instant.EPOCH, 
DateTimeUtils.microsToInstant(micros))
+        }
+      case BucketTransform(numBuckets, ref) =>
+        (extractor(ref.fieldNames, schema, row).hashCode() & 
Integer.MAX_VALUE) % numBuckets
+    }
   }
 
   def withData(data: Array[BufferedRows]): InMemoryTable = 
dataMap.synchronized {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index e4e7887..c199df6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -40,7 +40,7 @@ case class BatchScanExec(
 
   override def hashCode(): Int = batch.hashCode()
 
-  override lazy val partitions: Seq[InputPartition] = 
batch.planInputPartitions()
+  @transient override lazy val partitions: Seq[InputPartition] = 
batch.planInputPartitions()
 
   override lazy val readerFactory: PartitionReaderFactory = 
batch.createReaderFactory()
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
index ac2ebd8..508eefa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
@@ -336,7 +336,6 @@ class DataFrameWriterV2Suite extends QueryTest with 
SharedSparkSession with Befo
     spark.table("source")
         .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
         .writeTo("testcat.table_name")
-        .tableProperty("allow-unsupported-transforms", "true")
         .partitionedBy(years($"ts"))
         .create()
 
@@ -350,7 +349,6 @@ class DataFrameWriterV2Suite extends QueryTest with 
SharedSparkSession with Befo
     spark.table("source")
         .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
         .writeTo("testcat.table_name")
-        .tableProperty("allow-unsupported-transforms", "true")
         .partitionedBy(months($"ts"))
         .create()
 
@@ -364,7 +362,6 @@ class DataFrameWriterV2Suite extends QueryTest with 
SharedSparkSession with Befo
     spark.table("source")
         .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
         .writeTo("testcat.table_name")
-        .tableProperty("allow-unsupported-transforms", "true")
         .partitionedBy(days($"ts"))
         .create()
 
@@ -378,7 +375,6 @@ class DataFrameWriterV2Suite extends QueryTest with 
SharedSparkSession with Befo
     spark.table("source")
         .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
         .writeTo("testcat.table_name")
-        .tableProperty("allow-unsupported-transforms", "true")
         .partitionedBy(hours($"ts"))
         .create()
 
@@ -391,7 +387,6 @@ class DataFrameWriterV2Suite extends QueryTest with 
SharedSparkSession with Befo
   test("Create: partitioned by bucket(4, id)") {
     spark.table("source")
         .writeTo("testcat.table_name")
-        .tableProperty("allow-unsupported-transforms", "true")
         .partitionedBy(bucket(4, $"id"))
         .create()
 
@@ -596,7 +591,6 @@ class DataFrameWriterV2Suite extends QueryTest with 
SharedSparkSession with Befo
         lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
         lit("America/Los_Angeles") as "timezone"))
       .writeTo("testcat.table_name")
-      .tableProperty("allow-unsupported-transforms", "true")
       .partitionedBy(
         years($"ts.created"), months($"ts.created"), days($"ts.created"), 
hours($"ts.created"),
         years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), 
hours($"ts.modified")
@@ -624,7 +618,6 @@ class DataFrameWriterV2Suite extends QueryTest with 
SharedSparkSession with Befo
         lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
         lit("America/Los_Angeles") as "timezone"))
       .writeTo("testcat.table_name")
-      .tableProperty("allow-unsupported-transforms", "true")
       .partitionedBy(bucket(4, $"ts.timezone"))
       .create()
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 8462ce5..ef5558a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.sql.connector
 
+import java.sql.Timestamp
+import java.time.LocalDate
+
 import scala.collection.JavaConverters._
 
 import org.apache.spark.SparkException
@@ -27,7 +30,7 @@ import org.apache.spark.sql.connector.catalog._
 import 
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
 import 
org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
 import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
-import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, 
PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION}
 import org.apache.spark.sql.internal.connector.SimpleTableProvider
 import org.apache.spark.sql.sources.SimpleScanSource
 import org.apache.spark.sql.types.{BooleanType, LongType, StringType, 
StructField, StructType}
@@ -1630,7 +1633,6 @@ class DataSourceV2SQLSuite
         """
           |CREATE TABLE testcat.t (id int, `a.b` string) USING foo
           |CLUSTERED BY (`a.b`) INTO 4 BUCKETS
-          |OPTIONS ('allow-unsupported-transforms'=true)
         """.stripMargin)
 
       val testCatalog = 
catalog("testcat").asTableCatalog.asInstanceOf[InMemoryTableCatalog]
@@ -2476,6 +2478,38 @@ class DataSourceV2SQLSuite
     }
   }
 
+  test("SPARK-32168: INSERT OVERWRITE - hidden days partition - dynamic mode") 
{
+    def testTimestamp(daysOffset: Int): Timestamp = {
+      Timestamp.valueOf(LocalDate.of(2020, 1, 1 + daysOffset).atStartOfDay())
+    }
+
+    withSQLConf(PARTITION_OVERWRITE_MODE.key -> 
PartitionOverwriteMode.DYNAMIC.toString) {
+      val t1 = s"${catalogAndNamespace}tbl"
+      withTable(t1) {
+        val df = spark.createDataFrame(Seq(
+          (testTimestamp(1), "a"),
+          (testTimestamp(2), "b"),
+          (testTimestamp(3), "c"))).toDF("ts", "data")
+        df.createOrReplaceTempView("source_view")
+
+        sql(s"CREATE TABLE $t1 (ts timestamp, data string) " +
+            s"USING $v2Format PARTITIONED BY (days(ts))")
+        sql(s"INSERT INTO $t1 VALUES " +
+            s"(CAST(date_add('2020-01-01', 2) AS timestamp), 'dummy'), " +
+            s"(CAST(date_add('2020-01-01', 4) AS timestamp), 'keep')")
+        sql(s"INSERT OVERWRITE TABLE $t1 SELECT ts, data FROM source_view")
+
+        val expected = spark.createDataFrame(Seq(
+          (testTimestamp(1), "a"),
+          (testTimestamp(2), "b"),
+          (testTimestamp(3), "c"),
+          (testTimestamp(4), "keep"))).toDF("ts", "data")
+
+        verifyTable(t1, expected)
+      }
+    }
+  }
+
   private def testV1Command(sqlCommand: String, sqlParams: String): Unit = {
     val e = intercept[AnalysisException] {
       sql(s"$sqlCommand $sqlParams")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala
index b88ad52..2cc7a1f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala
@@ -446,21 +446,18 @@ trait InsertIntoSQLOnlyTests
       }
     }
 
-    test("InsertInto: overwrite - multiple static partitions - dynamic mode") {
-      // Since all partitions are provided statically, this should be 
supported by everyone
-      withSQLConf(PARTITION_OVERWRITE_MODE.key -> 
PartitionOverwriteMode.DYNAMIC.toString) {
-        val t1 = s"${catalogAndNamespace}tbl"
-        withTableAndData(t1) { view =>
-          sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " +
-            s"USING $v2Format PARTITIONED BY (id, p)")
-          sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)")
-          sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT 
data FROM $view")
-          verifyTable(t1, Seq(
-            (2, "a", 2),
-            (2, "b", 2),
-            (2, "c", 2),
-            (4, "keep", 2)).toDF("id", "data", "p"))
-        }
+    dynamicOverwriteTest("InsertInto: overwrite - multiple static partitions - 
dynamic mode") {
+      val t1 = s"${catalogAndNamespace}tbl"
+      withTableAndData(t1) { view =>
+        sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " +
+          s"USING $v2Format PARTITIONED BY (id, p)")
+        sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)")
+        sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data 
FROM $view")
+        verifyTable(t1, Seq(
+          (2, "a", 2),
+          (2, "b", 2),
+          (2, "c", 2),
+          (4, "keep", 2)).toDF("id", "data", "p"))
       }
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to