This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bd6a3b4  [SPARK-38437][SQL] Lenient serialization of datetime from 
datasource
bd6a3b4 is described below

commit bd6a3b4a001d29255f36bab9e9969cd919306fc2
Author: Max Gekk <[email protected]>
AuthorDate: Wed Mar 9 11:36:57 2022 +0300

    [SPARK-38437][SQL] Lenient serialization of datetime from datasource
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to support the lenient mode by the row serializer used 
by datasources to converts rows received from scans. Spark SQL will be able to 
accept:
    - `java.time.Instant` and `java.sql.Timestamp` for the `TIMESTAMP` type, and
    - `java.time.LocalDate` and `java.sql.Date` for the `DATE` type
    
    independently from the current value of the SQL config 
`spark.sql.datetime.java8API.enabled`.
    
    ### Why are the changes needed?
    A datasource might not aware of the Spark SQL config 
`spark.sql.datetime.java8API.enabled` if this datasource was developed before 
the config was introduced by Spark version 3.0.0. In that case, it always 
return "legacy" timestamps/dates of the types 
`java.sql.Timestamp`/`java.sql.Date` even if an user enabled Java 8 API. As 
Spark expects `java.time.Instant` or `java.time.LocalDate` but gets 
`java.time.Timestamp` or `java.sql.Date`, the user observes the exception:
    ```java
    ERROR SparkExecuteStatementOperation: Error executing query with 
ac61b10a-486e-463b-8726-3b61da58582e, currentState RUNNING,
    org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 
in stage 2.0 failed 4 times, most recent failure: Lost task 0.3 in stage 2.0 
(TID 8) (10.157.1.194 executor 0): java.lang.RuntimeException: Error while 
encoding: java.lang.RuntimeException: java.sql.Timestamp is not a valid 
external type for schema of timestamp
    if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null 
else staticinvoke(class org.apache.spark.sql.catalyst.util.DateTimeUtils$, 
TimestampType, instantToMicros, 
validateexternaltype(getexternalrowfield(assertnotnull(input[0, 
org.apache.spark.sql.Row, true]), 0, loan_perf_date), TimestampType), true, 
false) AS loan_perf_date#1125
    at 
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:239)
    ```
    
    This PR fixes the issue above. And after the changes, users can use legacy 
datasource connecters with new Spark versions even when they need to enable 
Java 8 API.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    By running the affected test suites:
    ```
    $ build/sbt "test:testOnly *CodeGenerationSuite"
    $ build/sbt "test:testOnly *ObjectExpressionsSuite"
    ```
    and new tests:
    ```
    $ build/sbt "test:testOnly *RowEncoderSuite"
    $ build/sbt "test:testOnly *TableScanSuite"
    ```
    
    Closes #35756 from MaxGekk/dynamic-serializer-java-ts.
    
    Authored-by: Max Gekk <[email protected]>
    Signed-off-by: Max Gekk <[email protected]>
---
 .../spark/sql/catalyst/SerializerBuildHelper.scala | 18 +++++++++++
 .../spark/sql/catalyst/encoders/RowEncoder.scala   | 35 ++++++++++++++--------
 .../sql/catalyst/expressions/objects/objects.scala | 26 ++++++++++++----
 .../spark/sql/catalyst/util/DateTimeUtils.scala    | 22 ++++++++++++++
 .../sql/catalyst/encoders/RowEncoderSuite.scala    | 23 ++++++++++++++
 .../catalyst/expressions/CodeGenerationSuite.scala |  4 ++-
 .../expressions/ObjectExpressionsSuite.scala       |  8 +++--
 .../execution/datasources/DataSourceStrategy.scala |  2 +-
 .../apache/spark/sql/sources/TableScanSuite.scala  | 27 +++++++++++++++++
 9 files changed, 144 insertions(+), 21 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index 3c17575..8dec923 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -86,6 +86,15 @@ object SerializerBuildHelper {
       returnNullable = false)
   }
 
+  def createSerializerForAnyTimestamp(inputObject: Expression): Expression = {
+    StaticInvoke(
+      DateTimeUtils.getClass,
+      TimestampType,
+      "anyToMicros",
+      inputObject :: Nil,
+      returnNullable = false)
+  }
+
   def createSerializerForLocalDateTime(inputObject: Expression): Expression = {
     StaticInvoke(
       DateTimeUtils.getClass,
@@ -113,6 +122,15 @@ object SerializerBuildHelper {
       returnNullable = false)
   }
 
+  def createSerializerForAnyDate(inputObject: Expression): Expression = {
+    StaticInvoke(
+      DateTimeUtils.getClass,
+      DateType,
+      "anyToDays",
+      inputObject :: Nil,
+      returnNullable = false)
+  }
+
   def createSerializerForJavaDuration(inputObject: Expression): Expression = {
     StaticInvoke(
       IntervalUtils.getClass,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index d34d953..d7e497f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -66,23 +66,27 @@ import org.apache.spark.sql.types._
  * }}}
  */
 object RowEncoder {
-  def apply(schema: StructType): ExpressionEncoder[Row] = {
+  def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = {
     val cls = classOf[Row]
     val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
-    val serializer = serializerFor(inputObject, schema)
+    val serializer = serializerFor(inputObject, schema, lenient)
     val deserializer = deserializerFor(GetColumnByOrdinal(0, 
serializer.dataType), schema)
     new ExpressionEncoder[Row](
       serializer,
       deserializer,
       ClassTag(cls))
   }
+  def apply(schema: StructType): ExpressionEncoder[Row] = {
+    apply(schema, lenient = false)
+  }
 
   private def serializerFor(
       inputObject: Expression,
-      inputType: DataType): Expression = inputType match {
+      inputType: DataType,
+      lenient: Boolean): Expression = inputType match {
     case dt if ScalaReflection.isNativeType(dt) => inputObject
 
-    case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType)
+    case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType, 
lenient)
 
     case udt: UserDefinedType[_] =>
       val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
@@ -100,7 +104,9 @@ object RowEncoder {
       Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
 
     case TimestampType =>
-      if (SQLConf.get.datetimeJava8ApiEnabled) {
+      if (lenient) {
+        createSerializerForAnyTimestamp(inputObject)
+      } else if (SQLConf.get.datetimeJava8ApiEnabled) {
         createSerializerForJavaInstant(inputObject)
       } else {
         createSerializerForSqlTimestamp(inputObject)
@@ -109,7 +115,9 @@ object RowEncoder {
     case TimestampNTZType => createSerializerForLocalDateTime(inputObject)
 
     case DateType =>
-      if (SQLConf.get.datetimeJava8ApiEnabled) {
+      if (lenient) {
+        createSerializerForAnyDate(inputObject)
+      } else if (SQLConf.get.datetimeJava8ApiEnabled) {
         createSerializerForJavaLocalDate(inputObject)
       } else {
         createSerializerForSqlDate(inputObject)
@@ -144,7 +152,7 @@ object RowEncoder {
             inputObject,
             ObjectType(classOf[Object]),
             element => {
-              val value = serializerFor(ValidateExternalType(element, et), et)
+              val value = serializerFor(ValidateExternalType(element, et, 
lenient), et, lenient)
               expressionWithNullSafety(value, containsNull, WalkedTypePath())
             })
       }
@@ -156,7 +164,7 @@ object RowEncoder {
             returnNullable = false),
           "toSeq",
           ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
-      val convertedKeys = serializerFor(keys, ArrayType(kt, false))
+      val convertedKeys = serializerFor(keys, ArrayType(kt, false), lenient)
 
       val values =
         Invoke(
@@ -164,7 +172,7 @@ object RowEncoder {
             returnNullable = false),
           "toSeq",
           ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
-      val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
+      val convertedValues = serializerFor(values, ArrayType(vt, 
valueNullable), lenient)
 
       val nonNullOutput = NewInstance(
         classOf[ArrayBasedMapData],
@@ -183,8 +191,10 @@ object RowEncoder {
         val fieldValue = serializerFor(
           ValidateExternalType(
             GetExternalRowField(inputObject, index, field.name),
-            field.dataType),
-          field.dataType)
+            field.dataType,
+            lenient),
+          field.dataType,
+          lenient)
         val convertedField = if (field.nullable) {
           If(
             Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: 
Nil),
@@ -214,12 +224,13 @@ object RowEncoder {
    * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
    * `org.apache.spark.sql.types.Decimal`.
    */
-  def externalDataTypeForInput(dt: DataType): DataType = dt match {
+  def externalDataTypeForInput(dt: DataType, lenient: Boolean): DataType = dt 
match {
     // In order to support both Decimal and java/scala BigDecimal in external 
row, we make this
     // as java.lang.Object.
     case _: DecimalType => ObjectType(classOf[java.lang.Object])
     // In order to support both Array and Seq in external row, we make this as 
java.lang.Object.
     case _: ArrayType => ObjectType(classOf[java.lang.Object])
+    case _: DateType | _: TimestampType if lenient => 
ObjectType(classOf[java.lang.Object])
     case _ => externalDataTypeFor(dt)
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 4599c2a..6974ada 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -1875,14 +1875,14 @@ case class GetExternalRowField(
  * Validates the actual data type of input expression at runtime.  If it 
doesn't match the
  * expectation, throw an exception.
  */
-case class ValidateExternalType(child: Expression, expected: DataType)
+case class ValidateExternalType(child: Expression, expected: DataType, 
lenient: Boolean)
   extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
 
   override def inputTypes: Seq[AbstractDataType] = 
Seq(ObjectType(classOf[Object]))
 
   override def nullable: Boolean = child.nullable
 
-  override val dataType: DataType = 
RowEncoder.externalDataTypeForInput(expected)
+  override val dataType: DataType = 
RowEncoder.externalDataTypeForInput(expected, lenient)
 
   private lazy val errMsg = s" is not a valid external type for schema of 
${expected.simpleString}"
 
@@ -1896,6 +1896,14 @@ case class ValidateExternalType(child: Expression, 
expected: DataType)
       (value: Any) => {
         value.getClass.isArray || value.isInstanceOf[Seq[_]]
       }
+    case _: DateType =>
+      (value: Any) => {
+        value.isInstanceOf[java.sql.Date] || 
value.isInstanceOf[java.time.LocalDate]
+      }
+    case _: TimestampType =>
+      (value: Any) => {
+        value.isInstanceOf[java.sql.Timestamp] || 
value.isInstanceOf[java.time.Instant]
+      }
     case _ =>
       val dataTypeClazz = ScalaReflection.javaBoxedType(dataType)
       (value: Any) => {
@@ -1918,13 +1926,21 @@ case class ValidateExternalType(child: Expression, 
expected: DataType)
     val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
     val input = child.genCode(ctx)
     val obj = input.value
-
+    def genCheckTypes(classes: Seq[Class[_]]): String = {
+      classes.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
+    }
     val typeCheck = expected match {
       case _: DecimalType =>
-        Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], 
classOf[Decimal])
-          .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
+        genCheckTypes(Seq(
+          classOf[java.math.BigDecimal],
+          classOf[scala.math.BigDecimal],
+          classOf[Decimal]))
       case _: ArrayType =>
         s"$obj.getClass().isArray() || $obj instanceof 
${classOf[scala.collection.Seq[_]].getName}"
+      case _: DateType =>
+        genCheckTypes(Seq(classOf[java.sql.Date], 
classOf[java.time.LocalDate]))
+      case _: TimestampType =>
+        genCheckTypes(Seq(classOf[java.sql.Timestamp], 
classOf[java.time.Instant]))
       case _ =>
         s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index c2ca436..7d2ead0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -108,6 +108,17 @@ object DateTimeUtils {
   }
 
   /**
+   * Converts an Java object to days.
+   *
+   * @param obj Either an object of `java.sql.Date` or `java.time.LocalDate`.
+   * @return The number of days since 1970-01-01.
+   */
+  def anyToDays(obj: Any): Int = obj match {
+    case d: Date => fromJavaDate(d)
+    case ld: LocalDate => localDateToDays(ld)
+  }
+
+  /**
    * Converts days since the epoch 1970-01-01 in Proleptic Gregorian calendar 
to a local date
    * at the default JVM time zone in the hybrid calendar (Julian + Gregorian). 
It rebases the given
    * days from Proleptic Gregorian to the hybrid calendar at UTC time zone for 
simplicity because
@@ -181,6 +192,17 @@ object DateTimeUtils {
   }
 
   /**
+   * Converts an Java object to microseconds.
+   *
+   * @param obj Either an object of `java.sql.Timestamp` or 
`java.time.Instant`.
+   * @return The number of micros since the epoch.
+   */
+  def anyToMicros(obj: Any): Long = obj match {
+    case t: Timestamp => fromJavaTimestamp(t)
+    case i: Instant => instantToMicros(i)
+  }
+
+  /**
    * Returns the number of microseconds since epoch from Julian day and 
nanoseconds in a day.
    */
   def fromJulianDay(days: Int, nanos: Long): Long = {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 44b06d9..c6bddfa 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -435,4 +435,27 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
       }
     }
   }
+
+  test("SPARK-38437: encoding TimestampType/DateType from any supported 
datetime Java types") {
+    Seq(true, false).foreach { java8Api =>
+      withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) {
+        val schema = new StructType()
+          .add("t0", TimestampType)
+          .add("t1", TimestampType)
+          .add("d0", DateType)
+          .add("d1", DateType)
+        val encoder = RowEncoder(schema, lenient = true).resolveAndBind()
+        val instant = java.time.Instant.parse("2019-02-26T16:56:00Z")
+        val ld = java.time.LocalDate.parse("2022-03-08")
+        val row = encoder.createSerializer().apply(
+          Row(instant, java.sql.Timestamp.from(instant), ld, 
java.sql.Date.valueOf(ld)))
+        val expectedMicros = DateTimeUtils.instantToMicros(instant)
+        assert(row.getLong(0) === expectedMicros)
+        assert(row.getLong(1) === expectedMicros)
+        val expectedDays = DateTimeUtils.localDateToDays(ld)
+        assert(row.getInt(2) === expectedDays)
+        assert(row.getInt(3) === expectedDays)
+      }
+    }
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 2b59d72..1e4499a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -330,7 +330,9 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = 
true)
     GenerateUnsafeProjection.generate(
       ValidateExternalType(
-        GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"), 
IntegerType) :: Nil)
+        GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"),
+        IntegerType,
+        lenient = false) :: Nil)
   }
 
   test("SPARK-17160: field names are properly escaped by AssertTrue") {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 8d98965..585191f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -498,13 +498,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       (Array(3, 2, 1), ArrayType(IntegerType))
     ).foreach { case (input, dt) =>
       val validateType = ValidateExternalType(
-        GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt)
+        GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
+        dt,
+        lenient = false)
       checkObjectExprEvaluation(validateType, input, 
InternalRow.fromSeq(Seq(Row(input))))
     }
 
     checkExceptionInExpression[RuntimeException](
       ValidateExternalType(
-        GetExternalRowField(inputObject, index = 0, fieldName = "c0"), 
DoubleType),
+        GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
+        DoubleType,
+        lenient = false),
       InternalRow.fromSeq(Seq(Row(1))),
       "java.lang.Integer is not a valid external type for schema of double")
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index c386655..4e5014c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -802,7 +802,7 @@ object DataSourceStrategy
       output: Seq[Attribute],
       rdd: RDD[Row]): RDD[InternalRow] = {
     if (relation.needConversion) {
-      val toRow = 
RowEncoder(StructType.fromAttributes(output)).createSerializer()
+      val toRow = RowEncoder(StructType.fromAttributes(output), lenient = 
true).createSerializer()
       rdd.mapPartitions { iterator =>
         iterator.map(toRow)
       }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 47bacde..8f263f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -109,6 +109,19 @@ case class AllDataTypesScan(
   }
 }
 
+class LegacyTimestampSource extends RelationProvider {
+  override def createRelation(ctx: SQLContext, parameters: Map[String, 
String]): BaseRelation = {
+    new BaseRelation() with TableScan {
+      override val sqlContext: SQLContext = ctx
+      override val schema: StructType = StructType(StructField("col", 
TimestampType) :: Nil)
+      override def buildScan(): RDD[Row] = {
+        sqlContext.sparkContext.parallelize(
+          Row(java.sql.Timestamp.valueOf("2022-03-08 12:13:14")) :: Nil)
+      }
+    }
+  }
+}
+
 class TableScanSuite extends DataSourceTest with SharedSparkSession {
   protected override lazy val sql = spark.sql _
 
@@ -420,4 +433,18 @@ class TableScanSuite extends DataSourceTest with 
SharedSparkSession {
     val comments = 
planned.schema.fields.map(_.getComment().getOrElse("NO_COMMENT")).mkString(",")
     assert(comments === "SN,SA,NO_COMMENT")
   }
+
+  test("SPARK-38437: accept java.sql.Timestamp even when Java 8 API is 
enabled") {
+    val tableName = "relationProviderWithLegacyTimestamps"
+    withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
+      withTable (tableName) {
+        sql(s"""
+          |CREATE TABLE $tableName (col TIMESTAMP)
+          |USING 
org.apache.spark.sql.sources.LegacyTimestampSource""".stripMargin)
+        checkAnswer(
+          spark.table(tableName),
+          Row(java.sql.Timestamp.valueOf("2022-03-08 12:13:14").toInstant) :: 
Nil)
+      }
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to