This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new bd6a3b4 [SPARK-38437][SQL] Lenient serialization of datetime from
datasource
bd6a3b4 is described below
commit bd6a3b4a001d29255f36bab9e9969cd919306fc2
Author: Max Gekk <[email protected]>
AuthorDate: Wed Mar 9 11:36:57 2022 +0300
[SPARK-38437][SQL] Lenient serialization of datetime from datasource
### What changes were proposed in this pull request?
In the PR, I propose to support the lenient mode by the row serializer used
by datasources to converts rows received from scans. Spark SQL will be able to
accept:
- `java.time.Instant` and `java.sql.Timestamp` for the `TIMESTAMP` type, and
- `java.time.LocalDate` and `java.sql.Date` for the `DATE` type
independently from the current value of the SQL config
`spark.sql.datetime.java8API.enabled`.
### Why are the changes needed?
A datasource might not aware of the Spark SQL config
`spark.sql.datetime.java8API.enabled` if this datasource was developed before
the config was introduced by Spark version 3.0.0. In that case, it always
return "legacy" timestamps/dates of the types
`java.sql.Timestamp`/`java.sql.Date` even if an user enabled Java 8 API. As
Spark expects `java.time.Instant` or `java.time.LocalDate` but gets
`java.time.Timestamp` or `java.sql.Date`, the user observes the exception:
```java
ERROR SparkExecuteStatementOperation: Error executing query with
ac61b10a-486e-463b-8726-3b61da58582e, currentState RUNNING,
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0
in stage 2.0 failed 4 times, most recent failure: Lost task 0.3 in stage 2.0
(TID 8) (10.157.1.194 executor 0): java.lang.RuntimeException: Error while
encoding: java.lang.RuntimeException: java.sql.Timestamp is not a valid
external type for schema of timestamp
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null
else staticinvoke(class org.apache.spark.sql.catalyst.util.DateTimeUtils$,
TimestampType, instantToMicros,
validateexternaltype(getexternalrowfield(assertnotnull(input[0,
org.apache.spark.sql.Row, true]), 0, loan_perf_date), TimestampType), true,
false) AS loan_perf_date#1125
at
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:239)
```
This PR fixes the issue above. And after the changes, users can use legacy
datasource connecters with new Spark versions even when they need to enable
Java 8 API.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By running the affected test suites:
```
$ build/sbt "test:testOnly *CodeGenerationSuite"
$ build/sbt "test:testOnly *ObjectExpressionsSuite"
```
and new tests:
```
$ build/sbt "test:testOnly *RowEncoderSuite"
$ build/sbt "test:testOnly *TableScanSuite"
```
Closes #35756 from MaxGekk/dynamic-serializer-java-ts.
Authored-by: Max Gekk <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../spark/sql/catalyst/SerializerBuildHelper.scala | 18 +++++++++++
.../spark/sql/catalyst/encoders/RowEncoder.scala | 35 ++++++++++++++--------
.../sql/catalyst/expressions/objects/objects.scala | 26 ++++++++++++----
.../spark/sql/catalyst/util/DateTimeUtils.scala | 22 ++++++++++++++
.../sql/catalyst/encoders/RowEncoderSuite.scala | 23 ++++++++++++++
.../catalyst/expressions/CodeGenerationSuite.scala | 4 ++-
.../expressions/ObjectExpressionsSuite.scala | 8 +++--
.../execution/datasources/DataSourceStrategy.scala | 2 +-
.../apache/spark/sql/sources/TableScanSuite.scala | 27 +++++++++++++++++
9 files changed, 144 insertions(+), 21 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index 3c17575..8dec923 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -86,6 +86,15 @@ object SerializerBuildHelper {
returnNullable = false)
}
+ def createSerializerForAnyTimestamp(inputObject: Expression): Expression = {
+ StaticInvoke(
+ DateTimeUtils.getClass,
+ TimestampType,
+ "anyToMicros",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
def createSerializerForLocalDateTime(inputObject: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
@@ -113,6 +122,15 @@ object SerializerBuildHelper {
returnNullable = false)
}
+ def createSerializerForAnyDate(inputObject: Expression): Expression = {
+ StaticInvoke(
+ DateTimeUtils.getClass,
+ DateType,
+ "anyToDays",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
def createSerializerForJavaDuration(inputObject: Expression): Expression = {
StaticInvoke(
IntervalUtils.getClass,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index d34d953..d7e497f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -66,23 +66,27 @@ import org.apache.spark.sql.types._
* }}}
*/
object RowEncoder {
- def apply(schema: StructType): ExpressionEncoder[Row] = {
+ def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- val serializer = serializerFor(inputObject, schema)
+ val serializer = serializerFor(inputObject, schema, lenient)
val deserializer = deserializerFor(GetColumnByOrdinal(0,
serializer.dataType), schema)
new ExpressionEncoder[Row](
serializer,
deserializer,
ClassTag(cls))
}
+ def apply(schema: StructType): ExpressionEncoder[Row] = {
+ apply(schema, lenient = false)
+ }
private def serializerFor(
inputObject: Expression,
- inputType: DataType): Expression = inputType match {
+ inputType: DataType,
+ lenient: Boolean): Expression = inputType match {
case dt if ScalaReflection.isNativeType(dt) => inputObject
- case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType)
+ case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType,
lenient)
case udt: UserDefinedType[_] =>
val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
@@ -100,7 +104,9 @@ object RowEncoder {
Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
case TimestampType =>
- if (SQLConf.get.datetimeJava8ApiEnabled) {
+ if (lenient) {
+ createSerializerForAnyTimestamp(inputObject)
+ } else if (SQLConf.get.datetimeJava8ApiEnabled) {
createSerializerForJavaInstant(inputObject)
} else {
createSerializerForSqlTimestamp(inputObject)
@@ -109,7 +115,9 @@ object RowEncoder {
case TimestampNTZType => createSerializerForLocalDateTime(inputObject)
case DateType =>
- if (SQLConf.get.datetimeJava8ApiEnabled) {
+ if (lenient) {
+ createSerializerForAnyDate(inputObject)
+ } else if (SQLConf.get.datetimeJava8ApiEnabled) {
createSerializerForJavaLocalDate(inputObject)
} else {
createSerializerForSqlDate(inputObject)
@@ -144,7 +152,7 @@ object RowEncoder {
inputObject,
ObjectType(classOf[Object]),
element => {
- val value = serializerFor(ValidateExternalType(element, et), et)
+ val value = serializerFor(ValidateExternalType(element, et,
lenient), et, lenient)
expressionWithNullSafety(value, containsNull, WalkedTypePath())
})
}
@@ -156,7 +164,7 @@ object RowEncoder {
returnNullable = false),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
- val convertedKeys = serializerFor(keys, ArrayType(kt, false))
+ val convertedKeys = serializerFor(keys, ArrayType(kt, false), lenient)
val values =
Invoke(
@@ -164,7 +172,7 @@ object RowEncoder {
returnNullable = false),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
- val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
+ val convertedValues = serializerFor(values, ArrayType(vt,
valueNullable), lenient)
val nonNullOutput = NewInstance(
classOf[ArrayBasedMapData],
@@ -183,8 +191,10 @@ object RowEncoder {
val fieldValue = serializerFor(
ValidateExternalType(
GetExternalRowField(inputObject, index, field.name),
- field.dataType),
- field.dataType)
+ field.dataType,
+ lenient),
+ field.dataType,
+ lenient)
val convertedField = if (field.nullable) {
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(index) ::
Nil),
@@ -214,12 +224,13 @@ object RowEncoder {
* can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
* `org.apache.spark.sql.types.Decimal`.
*/
- def externalDataTypeForInput(dt: DataType): DataType = dt match {
+ def externalDataTypeForInput(dt: DataType, lenient: Boolean): DataType = dt
match {
// In order to support both Decimal and java/scala BigDecimal in external
row, we make this
// as java.lang.Object.
case _: DecimalType => ObjectType(classOf[java.lang.Object])
// In order to support both Array and Seq in external row, we make this as
java.lang.Object.
case _: ArrayType => ObjectType(classOf[java.lang.Object])
+ case _: DateType | _: TimestampType if lenient =>
ObjectType(classOf[java.lang.Object])
case _ => externalDataTypeFor(dt)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 4599c2a..6974ada 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -1875,14 +1875,14 @@ case class GetExternalRowField(
* Validates the actual data type of input expression at runtime. If it
doesn't match the
* expectation, throw an exception.
*/
-case class ValidateExternalType(child: Expression, expected: DataType)
+case class ValidateExternalType(child: Expression, expected: DataType,
lenient: Boolean)
extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] =
Seq(ObjectType(classOf[Object]))
override def nullable: Boolean = child.nullable
- override val dataType: DataType =
RowEncoder.externalDataTypeForInput(expected)
+ override val dataType: DataType =
RowEncoder.externalDataTypeForInput(expected, lenient)
private lazy val errMsg = s" is not a valid external type for schema of
${expected.simpleString}"
@@ -1896,6 +1896,14 @@ case class ValidateExternalType(child: Expression,
expected: DataType)
(value: Any) => {
value.getClass.isArray || value.isInstanceOf[Seq[_]]
}
+ case _: DateType =>
+ (value: Any) => {
+ value.isInstanceOf[java.sql.Date] ||
value.isInstanceOf[java.time.LocalDate]
+ }
+ case _: TimestampType =>
+ (value: Any) => {
+ value.isInstanceOf[java.sql.Timestamp] ||
value.isInstanceOf[java.time.Instant]
+ }
case _ =>
val dataTypeClazz = ScalaReflection.javaBoxedType(dataType)
(value: Any) => {
@@ -1918,13 +1926,21 @@ case class ValidateExternalType(child: Expression,
expected: DataType)
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val input = child.genCode(ctx)
val obj = input.value
-
+ def genCheckTypes(classes: Seq[Class[_]]): String = {
+ classes.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
+ }
val typeCheck = expected match {
case _: DecimalType =>
- Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal],
classOf[Decimal])
- .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
+ genCheckTypes(Seq(
+ classOf[java.math.BigDecimal],
+ classOf[scala.math.BigDecimal],
+ classOf[Decimal]))
case _: ArrayType =>
s"$obj.getClass().isArray() || $obj instanceof
${classOf[scala.collection.Seq[_]].getName}"
+ case _: DateType =>
+ genCheckTypes(Seq(classOf[java.sql.Date],
classOf[java.time.LocalDate]))
+ case _: TimestampType =>
+ genCheckTypes(Seq(classOf[java.sql.Timestamp],
classOf[java.time.Instant]))
case _ =>
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index c2ca436..7d2ead0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -108,6 +108,17 @@ object DateTimeUtils {
}
/**
+ * Converts an Java object to days.
+ *
+ * @param obj Either an object of `java.sql.Date` or `java.time.LocalDate`.
+ * @return The number of days since 1970-01-01.
+ */
+ def anyToDays(obj: Any): Int = obj match {
+ case d: Date => fromJavaDate(d)
+ case ld: LocalDate => localDateToDays(ld)
+ }
+
+ /**
* Converts days since the epoch 1970-01-01 in Proleptic Gregorian calendar
to a local date
* at the default JVM time zone in the hybrid calendar (Julian + Gregorian).
It rebases the given
* days from Proleptic Gregorian to the hybrid calendar at UTC time zone for
simplicity because
@@ -181,6 +192,17 @@ object DateTimeUtils {
}
/**
+ * Converts an Java object to microseconds.
+ *
+ * @param obj Either an object of `java.sql.Timestamp` or
`java.time.Instant`.
+ * @return The number of micros since the epoch.
+ */
+ def anyToMicros(obj: Any): Long = obj match {
+ case t: Timestamp => fromJavaTimestamp(t)
+ case i: Instant => instantToMicros(i)
+ }
+
+ /**
* Returns the number of microseconds since epoch from Julian day and
nanoseconds in a day.
*/
def fromJulianDay(days: Int, nanos: Long): Long = {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 44b06d9..c6bddfa 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -435,4 +435,27 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}
}
}
+
+ test("SPARK-38437: encoding TimestampType/DateType from any supported
datetime Java types") {
+ Seq(true, false).foreach { java8Api =>
+ withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) {
+ val schema = new StructType()
+ .add("t0", TimestampType)
+ .add("t1", TimestampType)
+ .add("d0", DateType)
+ .add("d1", DateType)
+ val encoder = RowEncoder(schema, lenient = true).resolveAndBind()
+ val instant = java.time.Instant.parse("2019-02-26T16:56:00Z")
+ val ld = java.time.LocalDate.parse("2022-03-08")
+ val row = encoder.createSerializer().apply(
+ Row(instant, java.sql.Timestamp.from(instant), ld,
java.sql.Date.valueOf(ld)))
+ val expectedMicros = DateTimeUtils.instantToMicros(instant)
+ assert(row.getLong(0) === expectedMicros)
+ assert(row.getLong(1) === expectedMicros)
+ val expectedDays = DateTimeUtils.localDateToDays(ld)
+ assert(row.getInt(2) === expectedDays)
+ assert(row.getInt(3) === expectedDays)
+ }
+ }
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 2b59d72..1e4499a 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -330,7 +330,9 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable =
true)
GenerateUnsafeProjection.generate(
ValidateExternalType(
- GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"),
IntegerType) :: Nil)
+ GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"),
+ IntegerType,
+ lenient = false) :: Nil)
}
test("SPARK-17160: field names are properly escaped by AssertTrue") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 8d98965..585191f 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -498,13 +498,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
(Array(3, 2, 1), ArrayType(IntegerType))
).foreach { case (input, dt) =>
val validateType = ValidateExternalType(
- GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt)
+ GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
+ dt,
+ lenient = false)
checkObjectExprEvaluation(validateType, input,
InternalRow.fromSeq(Seq(Row(input))))
}
checkExceptionInExpression[RuntimeException](
ValidateExternalType(
- GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
DoubleType),
+ GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
+ DoubleType,
+ lenient = false),
InternalRow.fromSeq(Seq(Row(1))),
"java.lang.Integer is not a valid external type for schema of double")
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index c386655..4e5014c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -802,7 +802,7 @@ object DataSourceStrategy
output: Seq[Attribute],
rdd: RDD[Row]): RDD[InternalRow] = {
if (relation.needConversion) {
- val toRow =
RowEncoder(StructType.fromAttributes(output)).createSerializer()
+ val toRow = RowEncoder(StructType.fromAttributes(output), lenient =
true).createSerializer()
rdd.mapPartitions { iterator =>
iterator.map(toRow)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 47bacde..8f263f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -109,6 +109,19 @@ case class AllDataTypesScan(
}
}
+class LegacyTimestampSource extends RelationProvider {
+ override def createRelation(ctx: SQLContext, parameters: Map[String,
String]): BaseRelation = {
+ new BaseRelation() with TableScan {
+ override val sqlContext: SQLContext = ctx
+ override val schema: StructType = StructType(StructField("col",
TimestampType) :: Nil)
+ override def buildScan(): RDD[Row] = {
+ sqlContext.sparkContext.parallelize(
+ Row(java.sql.Timestamp.valueOf("2022-03-08 12:13:14")) :: Nil)
+ }
+ }
+ }
+}
+
class TableScanSuite extends DataSourceTest with SharedSparkSession {
protected override lazy val sql = spark.sql _
@@ -420,4 +433,18 @@ class TableScanSuite extends DataSourceTest with
SharedSparkSession {
val comments =
planned.schema.fields.map(_.getComment().getOrElse("NO_COMMENT")).mkString(",")
assert(comments === "SN,SA,NO_COMMENT")
}
+
+ test("SPARK-38437: accept java.sql.Timestamp even when Java 8 API is
enabled") {
+ val tableName = "relationProviderWithLegacyTimestamps"
+ withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
+ withTable (tableName) {
+ sql(s"""
+ |CREATE TABLE $tableName (col TIMESTAMP)
+ |USING
org.apache.spark.sql.sources.LegacyTimestampSource""".stripMargin)
+ checkAnswer(
+ spark.table(tableName),
+ Row(java.sql.Timestamp.valueOf("2022-03-08 12:13:14").toInstant) ::
Nil)
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]