This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 6c4977d  [SPARK-30993][SQL] Use its sql type for UDT when checking the 
type of length (fixed/var) or mutable
6c4977d is described below

commit 6c4977d38f13628abfa24129ae6844146672d96d
Author: Jungtaek Lim (HeartSaVioR) <kabhwan.opensou...@gmail.com>
AuthorDate: Mon Mar 2 22:33:11 2020 +0800

    [SPARK-30993][SQL] Use its sql type for UDT when checking the type of 
length (fixed/var) or mutable
    
    ### What changes were proposed in this pull request?
    
    This patch fixes the bug of UnsafeRow which misses to handle the UDT 
specifically, in `isFixedLength` and `isMutable`. These methods don't check its 
SQL type for UDT, always treating UDT as variable-length, and non-mutable.
    
    It doesn't bring any issue if UDT is used to represent complicated type, 
but when UDT is used to represent some type which is matched with fixed length 
of SQL type, it exposes the chance of correctness issues, as these informations 
sometimes decide how the value should be handled.
    
    We got report from user mailing list which suspected as mapGroupsWithState 
looks like handling UDT incorrectly, but after some investigation it was from 
GenerateUnsafeRowJoiner in shuffle phase.
    
    
https://github.com/apache/spark/blob/0e2ca11d80c3921387d7b077cb64c3a0c06b08d7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala#L32-L43
    
    Here updating position should not happen on fixed-length column, but due to 
this bug, the value of UDT having fixed-length as sql type would be modified, 
which actually corrupts the value.
    
    ### Why are the changes needed?
    
    Misclassifying of the type of length for UDT can corrupt the value when the 
row is presented to the input of GenerateUnsafeRowJoiner, which brings 
correctness issue.
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New UT added.
    
    Closes #27747 from HeartSaVioR/SPARK-30993.
    
    Authored-by: Jungtaek Lim (HeartSaVioR) <kabhwan.opensou...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit f24a46011c8cba086193f697d653b6eccd029e8f)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/expressions/UnsafeRow.java  |  8 +++++
 .../codegen/GenerateUnsafeRowJoinerSuite.scala     | 41 +++++++++++++++++++++-
 .../apache/spark/sql/UserDefinedTypeSuite.scala    | 37 +++++++++++++++++++
 3 files changed, 85 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 23e7d1f..034894b 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -95,6 +95,10 @@ public final class UnsafeRow extends InternalRow implements 
Externalizable, Kryo
   }
 
   public static boolean isFixedLength(DataType dt) {
+    if (dt instanceof UserDefinedType) {
+      return isFixedLength(((UserDefinedType) dt).sqlType());
+    }
+
     if (dt instanceof DecimalType) {
       return ((DecimalType) dt).precision() <= Decimal.MAX_LONG_DIGITS();
     } else {
@@ -103,6 +107,10 @@ public final class UnsafeRow extends InternalRow 
implements Externalizable, Kryo
   }
 
   public static boolean isMutable(DataType dt) {
+    if (dt instanceof UserDefinedType) {
+      return isMutable(((UserDefinedType) dt).sqlType());
+    }
+
     return mutableFieldTypes.contains(dt) || dt instanceof DecimalType ||
       dt instanceof CalendarIntervalType;
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
index 81e2993..fb1ea7b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions.codegen
 
+import java.time.{LocalDateTime, ZoneOffset}
+
 import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.RandomDataGenerator
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, 
UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
JoinedRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -99,6 +101,23 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
     testConcatOnce(N, N, variable)
   }
 
+  test("SPARK-30993: UserDefinedType matched to fixed length SQL type 
shouldn't be corrupted") {
+    val schema1 = new StructType(Array(
+      StructField("date", new WrappedDateTimeUDT),
+      StructField("s", StringType),
+      StructField("i", IntegerType)))
+    val proj1 = UnsafeProjection.create(schema1.fields.map(_.dataType))
+    val intRow1 = new GenericInternalRow(Array[Any](
+      LocalDateTime.now().toEpochSecond(ZoneOffset.UTC),
+      UTF8String.fromString("hello"), 1))
+
+    val schema2 = new StructType(Array(StructField("i", IntegerType)))
+    val proj2 = UnsafeProjection.create(schema2.fields.map(_.dataType))
+    val intRow2 = new GenericInternalRow(Array[Any](2))
+
+    testConcat(schema1, proj1.apply(intRow1), schema2, proj2.apply(intRow2))
+  }
+
   private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: 
Seq[DataType]): Unit = {
     for (i <- 0 until 10) {
       testConcatOnce(numFields1, numFields2, candidateTypes)
@@ -204,3 +223,23 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
   }
 
 }
+
+private[sql] case class WrappedDateTime(dt: LocalDateTime)
+
+private[sql] class WrappedDateTimeUDT extends UserDefinedType[WrappedDateTime] 
{
+  override def sqlType: DataType = LongType
+
+  override def serialize(obj: WrappedDateTime): Long = {
+    obj.dt.toEpochSecond(ZoneOffset.UTC)
+  }
+
+  def deserialize(datum: Any): WrappedDateTime = datum match {
+    case value: Long =>
+      val v = LocalDateTime.ofEpochSecond(value, 0, ZoneOffset.UTC)
+      WrappedDateTime(v)
+  }
+
+  override def userClass: Class[WrappedDateTime] = classOf[WrappedDateTime]
+
+  private[spark] override def asNullable: WrappedDateTimeUDT = this
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index ffc2018d..157610f5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import java.time.{LocalDateTime, ZoneOffset}
 import java.util.Arrays
 
 import org.apache.spark.rdd.RDD
@@ -103,6 +104,24 @@ private[spark] class ExampleSubTypeUDT extends 
UserDefinedType[IExampleSubType]
   override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
 }
 
+private[sql] case class FooWithDate(date: LocalDateTime, s: String, i: Int)
+
+private[sql] class LocalDateTimeUDT extends UserDefinedType[LocalDateTime] {
+  override def sqlType: DataType = LongType
+
+  override def serialize(obj: LocalDateTime): Long = {
+    obj.toEpochSecond(ZoneOffset.UTC)
+  }
+
+  def deserialize(datum: Any): LocalDateTime = datum match {
+    case value: Long => LocalDateTime.ofEpochSecond(value, 0, ZoneOffset.UTC)
+  }
+
+  override def userClass: Class[LocalDateTime] = classOf[LocalDateTime]
+
+  private[spark] override def asNullable: LocalDateTimeUDT = this
+}
+
 class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with 
ParquetTest
     with ExpressionEvalHelper {
   import testImplicits._
@@ -287,4 +306,22 @@ class UserDefinedTypeSuite extends QueryTest with 
SharedSparkSession with Parque
     checkAnswer(spark.createDataFrame(data, schema).selectExpr("typeof(a)"),
       Seq(Row("array<double>")))
   }
+
+  test("SPARK-30993: UserDefinedType matched to fixed length SQL type 
shouldn't be corrupted") {
+    def concatFoo(a: FooWithDate, b: FooWithDate): FooWithDate = {
+      FooWithDate(b.date, a.s + b.s, a.i)
+    }
+
+    UDTRegistration.register(classOf[LocalDateTime].getName, 
classOf[LocalDateTimeUDT].getName)
+
+    // remove sub-millisecond part as we only use millis based timestamp while 
serde
+    val date = 
LocalDateTime.ofEpochSecond(LocalDateTime.now().toEpochSecond(ZoneOffset.UTC),
+      0, ZoneOffset.UTC)
+    val inputDS = List(FooWithDate(date, "Foo", 1), FooWithDate(date, "Foo", 
3),
+      FooWithDate(date, "Foo", 3)).toDS()
+    val agg = inputDS.groupByKey(x => x.i).mapGroups((_, iter) => 
iter.reduce(concatFoo))
+    val result = agg.collect()
+
+    assert(result.toSet === Set(FooWithDate(date, "FooFoo", 3), 
FooWithDate(date, "Foo", 1)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to