This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0c4ac71080f [SPARK-43040][SQL] Improve TimestampNTZ type support in 
JDBC data source
0c4ac71080f is described below

commit 0c4ac71080fc480b527e06aceeaf7d52a5161f31
Author: tianhanhu <adrianh...@gmail.com>
AuthorDate: Fri May 5 10:06:40 2023 +0800

    [SPARK-43040][SQL] Improve TimestampNTZ type support in JDBC data source
    
    ### What changes were proposed in this pull request?
    
    https://github.com/apache/spark/pull/36726 supports TimestampNTZ type in 
JDBC data source and https://github.com/apache/spark/pull/37013 applies a fix 
to pass more test cases with H2.
    
    The problem is that Java Timestamp is a poorly defined class and different 
JDBC drivers implement "getTimestamp" and "setTimestamp" with different 
expected behaviors in mind. The general conversion implementation would work 
with some JDBC dialects and their drivers but not others. This issue is 
discovered when testing with PostgreSQL database.
    
    This PR adds a `dialect` parameter to `makeGetter` for applying dialect 
specific conversions when reading a Java Timestamp into TimestampNTZType. 
`makeSetter` already has a `dialect` field and we will use that for converting 
back to Java Timestamp.
    
    ### Why are the changes needed?
    
    Fix TimestampNTZ support for PostgreSQL. Allows other JDBC dialects to 
provide dialect specific implementation for
    converting between Java Timestamp and Spark TimestampNTZType.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing unit test.
    I added new test cases for `PostgresIntegrationSuite` to cover TimestampNTZ 
read and writes.
    
    Closes #40678 from tianhanhu/SPARK-43040_jdbc_timestamp_ntz.
    
    Authored-by: tianhanhu <adrianh...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/jdbc/PostgresIntegrationSuite.scala  | 35 ++++++++++++++++++++
 .../sql/execution/datasources/jdbc/JDBCRDD.scala   |  3 +-
 .../sql/execution/datasources/jdbc/JdbcUtils.scala | 38 +++++++++++++++-------
 .../org/apache/spark/sql/jdbc/JdbcDialects.scala   | 30 +++++++++++++++--
 .../apache/spark/sql/jdbc/PostgresDialect.scala    | 11 ++++++-
 5 files changed, 102 insertions(+), 15 deletions(-)

diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index ff5127ce350..f840876fc5d 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc
 import java.math.{BigDecimal => JBigDecimal}
 import java.sql.{Connection, Date, Timestamp}
 import java.text.SimpleDateFormat
+import java.time.LocalDateTime
 import java.util.Properties
 
 import org.apache.spark.sql.Column
@@ -140,6 +141,12 @@ class PostgresIntegrationSuite extends 
DockerJDBCIntegrationSuite {
       "c0 money)").executeUpdate()
     conn.prepareStatement("INSERT INTO money_types VALUES " +
       "('$1,000.00')").executeUpdate()
+
+    conn.prepareStatement(s"CREATE TABLE timestamp_ntz(v 
timestamp)").executeUpdate()
+    conn.prepareStatement(s"""INSERT INTO timestamp_ntz VALUES
+      |('2013-04-05 12:01:02'),
+      |('2013-04-05 18:01:02.123'),
+      |('2013-04-05 18:01:02.123456')""".stripMargin).executeUpdate()
   }
 
   test("Type mapping for various types") {
@@ -381,4 +388,32 @@ class PostgresIntegrationSuite extends 
DockerJDBCIntegrationSuite {
     assert(row(0).length === 1)
     assert(row(0).getString(0) === "$1,000.00")
   }
+
+  test("SPARK-43040: timestamp_ntz read test") {
+    val prop = new Properties
+    prop.setProperty("preferTimestampNTZ", "true")
+    val df = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz", prop)
+    val row = df.collect()
+    assert(row.length === 3)
+    assert(row(0).length === 1)
+    assert(row(0) === Row(LocalDateTime.of(2013, 4, 5, 12, 1, 2)))
+    assert(row(1) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123000000)))
+    assert(row(2) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123456000)))
+  }
+
+  test("SPARK-43040: timestamp_ntz roundtrip test") {
+    val prop = new Properties
+    prop.setProperty("preferTimestampNTZ", "true")
+
+    val sparkQuery = """
+      |select
+      |  timestamp_ntz'2020-12-10 11:22:33' as col0
+      """.stripMargin
+
+    val df_expected = sqlContext.sql(sparkQuery)
+    df_expected.write.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop)
+
+    val df_actual = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", 
prop)
+    assert(df_actual.collect()(0) == df_expected.collect()(0))
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 70e29f5d719..e241951abe3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -273,7 +273,8 @@ private[jdbc] class JDBCRDD(
     stmt.setFetchSize(options.fetchSize)
     stmt.setQueryTimeout(options.queryTimeout)
     rs = stmt.executeQuery()
-    val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, 
inputMetrics)
+    val rowsIterator =
+      JdbcUtils.resultSetToSparkInternalRows(rs, dialect, schema, inputMetrics)
 
     CompletionIterator[InternalRow, Iterator[InternalRow]](
       new InterruptibleIterator(context, rowsIterator), close())
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index fe53ba91d95..d907ce6b100 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -38,12 +38,12 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
CharVarcharUtils, DateTimeUtils, GenericArrayData}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, 
localDateTimeToMicros, localDateToDays, toJavaDate, toJavaTimestamp, 
toJavaTimestampNoRebase}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, 
localDateToDays, toJavaDate, toJavaTimestamp}
 import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
 import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex}
 import org.apache.spark.sql.connector.expressions.NamedReference
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
-import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
+import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, 
NoopDialect}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
 import org.apache.spark.unsafe.types.UTF8String
@@ -316,21 +316,31 @@ object JdbcUtils extends Logging with SQLConfHelper {
   /**
    * Convert a [[ResultSet]] into an iterator of Catalyst Rows.
    */
-  def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] 
= {
+  def resultSetToRows(
+      resultSet: ResultSet,
+      schema: StructType): Iterator[Row] = {
+    resultSetToRows(resultSet, schema, NoopDialect)
+  }
+
+  def resultSetToRows(
+      resultSet: ResultSet,
+      schema: StructType,
+      dialect: JdbcDialect): Iterator[Row] = {
     val inputMetrics =
       
Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new 
InputMetrics)
     val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer()
-    val internalRows = resultSetToSparkInternalRows(resultSet, schema, 
inputMetrics)
+    val internalRows = resultSetToSparkInternalRows(resultSet, dialect, 
schema, inputMetrics)
     internalRows.map(fromRow)
   }
 
   private[spark] def resultSetToSparkInternalRows(
       resultSet: ResultSet,
+      dialect: JdbcDialect,
       schema: StructType,
       inputMetrics: InputMetrics): Iterator[InternalRow] = {
     new NextIterator[InternalRow] {
       private[this] val rs = resultSet
-      private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
+      private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, 
schema)
       private[this] val mutableRow = new 
SpecificInternalRow(schema.fields.map(x => x.dataType))
 
       override protected def close(): Unit = {
@@ -368,12 +378,17 @@ object JdbcUtils extends Logging with SQLConfHelper {
    * Creates `JDBCValueGetter`s according to [[StructType]], which can set
    * each value from `ResultSet` to each field of [[InternalRow]] correctly.
    */
-  private def makeGetters(schema: StructType): Array[JDBCValueGetter] = {
+  private def makeGetters(
+      dialect: JdbcDialect,
+      schema: StructType): Array[JDBCValueGetter] = {
     val replaced = 
CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)
-    replaced.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
+    replaced.fields.map(sf => makeGetter(sf.dataType, dialect, sf.metadata))
   }
 
-  private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = 
dt match {
+  private def makeGetter(
+      dt: DataType,
+      dialect: JdbcDialect,
+      metadata: Metadata): JDBCValueGetter = dt match {
     case BooleanType =>
       (rs: ResultSet, row: InternalRow, pos: Int) =>
         row.setBoolean(pos, rs.getBoolean(pos + 1))
@@ -478,7 +493,8 @@ object JdbcUtils extends Logging with SQLConfHelper {
       (rs: ResultSet, row: InternalRow, pos: Int) =>
         val t = rs.getTimestamp(pos + 1)
         if (t != null) {
-          row.setLong(pos, DateTimeUtils.fromJavaTimestampNoRebase(t))
+          row.setLong(pos,
+            
DateTimeUtils.localDateTimeToMicros(dialect.convertJavaTimestampToTimestampNTZ(t)))
         } else {
           row.update(pos, null)
         }
@@ -596,8 +612,8 @@ object JdbcUtils extends Logging with SQLConfHelper {
 
     case TimestampNTZType =>
       (stmt: PreparedStatement, row: Row, pos: Int) =>
-        val micros = 
localDateTimeToMicros(row.getAs[java.time.LocalDateTime](pos))
-        stmt.setTimestamp(pos + 1, toJavaTimestampNoRebase(micros))
+        stmt.setTimestamp(pos + 1,
+          
dialect.convertTimestampNTZToJavaTimestamp(row.getAs[java.time.LocalDateTime](pos)))
 
     case DateType =>
       if (conf.datetimeJava8ApiEnabled) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index e7a74ee3aa9..93a311be2f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.jdbc
 
 import java.sql.{Connection, Date, Driver, Statement, Timestamp}
-import java.time.{Instant, LocalDate}
+import java.time.{Instant, LocalDate, LocalDateTime}
 import java.util
 
 import scala.collection.mutable.ArrayBuilder
@@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, 
TimestampFormatter}
+import 
org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateTimeToMicros, 
toJavaTimestampNoRebase}
 import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
 import org.apache.spark.sql.connector.catalog.TableChange._
 import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
@@ -104,6 +105,31 @@ abstract class JdbcDialect extends Serializable with 
Logging {
    */
   def getJDBCType(dt: DataType): Option[JdbcType] = None
 
+  /**
+   * Convert java.sql.Timestamp to a LocalDateTime representing the same 
wall-clock time as the
+   * value stored in a remote database.
+   * JDBC dialects should override this function to provide implementations 
that suite their
+   * JDBC drivers.
+   * @param t Timestamp returned from JDBC driver getTimestamp method.
+   * @return A LocalDateTime representing the same wall clock time as the 
timestamp in database.
+   */
+  @Since("3.5.0")
+  def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = {
+    
DateTimeUtils.microsToLocalDateTime(DateTimeUtils.fromJavaTimestampNoRebase(t))
+  }
+
+  /**
+   * Converts a LocalDateTime representing a TimestampNTZ type to an
+   * instance of `java.sql.Timestamp`.
+   * @param ldt representing a TimestampNTZType.
+   * @return A Java Timestamp representing this LocalDateTime.
+   */
+  @Since("3.5.0")
+  def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = {
+    val micros = localDateTimeToMicros(ldt)
+    toJavaTimestampNoRebase(micros)
+  }
+
   /**
    * Returns a factory for creating connections to the given JDBC URL.
    * In general, creating a connection has nothing to do with JDBC partition 
id.
@@ -682,6 +708,6 @@ object JdbcDialects {
 /**
  * NOOP dialect object, always returning the neutral element.
  */
-private object NoopDialect extends JdbcDialect {
+object NoopDialect extends JdbcDialect {
   override def canHandle(url : String): Boolean = true
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index b53a0e66ba7..b42d575ae2d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.jdbc
 
-import java.sql.{Connection, SQLException, Types}
+import java.sql.{Connection, SQLException, Timestamp, Types}
+import java.time.LocalDateTime
 import java.util
 import java.util.Locale
 
@@ -102,6 +103,14 @@ private object PostgresDialect extends JdbcDialect with 
SQLConfHelper {
     case _ => None
   }
 
+  override def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime 
= {
+    t.toLocalDateTime
+  }
+
+  override def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): 
Timestamp = {
+    Timestamp.valueOf(ldt)
+  }
+
   override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
     case StringType => Some(JdbcType("TEXT", Types.VARCHAR))
     case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to