This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0c4ac71080f [SPARK-43040][SQL] Improve TimestampNTZ type support in JDBC data source 0c4ac71080f is described below commit 0c4ac71080fc480b527e06aceeaf7d52a5161f31 Author: tianhanhu <adrianh...@gmail.com> AuthorDate: Fri May 5 10:06:40 2023 +0800 [SPARK-43040][SQL] Improve TimestampNTZ type support in JDBC data source ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/36726 supports TimestampNTZ type in JDBC data source and https://github.com/apache/spark/pull/37013 applies a fix to pass more test cases with H2. The problem is that Java Timestamp is a poorly defined class and different JDBC drivers implement "getTimestamp" and "setTimestamp" with different expected behaviors in mind. The general conversion implementation would work with some JDBC dialects and their drivers but not others. This issue is discovered when testing with PostgreSQL database. This PR adds a `dialect` parameter to `makeGetter` for applying dialect specific conversions when reading a Java Timestamp into TimestampNTZType. `makeSetter` already has a `dialect` field and we will use that for converting back to Java Timestamp. ### Why are the changes needed? Fix TimestampNTZ support for PostgreSQL. Allows other JDBC dialects to provide dialect specific implementation for converting between Java Timestamp and Spark TimestampNTZType. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing unit test. I added new test cases for `PostgresIntegrationSuite` to cover TimestampNTZ read and writes. Closes #40678 from tianhanhu/SPARK-43040_jdbc_timestamp_ntz. Authored-by: tianhanhu <adrianh...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/jdbc/PostgresIntegrationSuite.scala | 35 ++++++++++++++++++++ .../sql/execution/datasources/jdbc/JDBCRDD.scala | 3 +- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 38 +++++++++++++++------- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 30 +++++++++++++++-- .../apache/spark/sql/jdbc/PostgresDialect.scala | 11 ++++++- 5 files changed, 102 insertions(+), 15 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index ff5127ce350..f840876fc5d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.math.{BigDecimal => JBigDecimal} import java.sql.{Connection, Date, Timestamp} import java.text.SimpleDateFormat +import java.time.LocalDateTime import java.util.Properties import org.apache.spark.sql.Column @@ -140,6 +141,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { "c0 money)").executeUpdate() conn.prepareStatement("INSERT INTO money_types VALUES " + "('$1,000.00')").executeUpdate() + + conn.prepareStatement(s"CREATE TABLE timestamp_ntz(v timestamp)").executeUpdate() + conn.prepareStatement(s"""INSERT INTO timestamp_ntz VALUES + |('2013-04-05 12:01:02'), + |('2013-04-05 18:01:02.123'), + |('2013-04-05 18:01:02.123456')""".stripMargin).executeUpdate() } test("Type mapping for various types") { @@ -381,4 +388,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(row(0).length === 1) assert(row(0).getString(0) === "$1,000.00") } + + test("SPARK-43040: timestamp_ntz read test") { + val prop = new Properties + prop.setProperty("preferTimestampNTZ", "true") + val df = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz", prop) + val row = df.collect() + assert(row.length === 3) + assert(row(0).length === 1) + assert(row(0) === Row(LocalDateTime.of(2013, 4, 5, 12, 1, 2))) + assert(row(1) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123000000))) + assert(row(2) === Row(LocalDateTime.of(2013, 4, 5, 18, 1, 2, 123456000))) + } + + test("SPARK-43040: timestamp_ntz roundtrip test") { + val prop = new Properties + prop.setProperty("preferTimestampNTZ", "true") + + val sparkQuery = """ + |select + | timestamp_ntz'2020-12-10 11:22:33' as col0 + """.stripMargin + + val df_expected = sqlContext.sql(sparkQuery) + df_expected.write.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop) + + val df_actual = sqlContext.read.jdbc(jdbcUrl, "timestamp_ntz_roundtrip", prop) + assert(df_actual.collect()(0) == df_expected.collect()(0)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 70e29f5d719..e241951abe3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -273,7 +273,8 @@ private[jdbc] class JDBCRDD( stmt.setFetchSize(options.fetchSize) stmt.setQueryTimeout(options.queryTimeout) rs = stmt.executeQuery() - val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) + val rowsIterator = + JdbcUtils.resultSetToSparkInternalRows(rs, dialect, schema, inputMetrics) CompletionIterator[InternalRow, Iterator[InternalRow]]( new InterruptibleIterator(context, rowsIterator), close()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index fe53ba91d95..d907ce6b100 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -38,12 +38,12 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData} -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros, localDateToDays, toJavaDate, toJavaTimestamp, toJavaTimestampNoRebase} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType, NoopDialect} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String @@ -316,21 +316,31 @@ object JdbcUtils extends Logging with SQLConfHelper { /** * Convert a [[ResultSet]] into an iterator of Catalyst Rows. */ - def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = { + def resultSetToRows( + resultSet: ResultSet, + schema: StructType): Iterator[Row] = { + resultSetToRows(resultSet, schema, NoopDialect) + } + + def resultSetToRows( + resultSet: ResultSet, + schema: StructType, + dialect: JdbcDialect): Iterator[Row] = { val inputMetrics = Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics) val fromRow = RowEncoder(schema).resolveAndBind().createDeserializer() - val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics) + val internalRows = resultSetToSparkInternalRows(resultSet, dialect, schema, inputMetrics) internalRows.map(fromRow) } private[spark] def resultSetToSparkInternalRows( resultSet: ResultSet, + dialect: JdbcDialect, schema: StructType, inputMetrics: InputMetrics): Iterator[InternalRow] = { new NextIterator[InternalRow] { private[this] val rs = resultSet - private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema) + private[this] val getters: Array[JDBCValueGetter] = makeGetters(dialect, schema) private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) override protected def close(): Unit = { @@ -368,12 +378,17 @@ object JdbcUtils extends Logging with SQLConfHelper { * Creates `JDBCValueGetter`s according to [[StructType]], which can set * each value from `ResultSet` to each field of [[InternalRow]] correctly. */ - private def makeGetters(schema: StructType): Array[JDBCValueGetter] = { + private def makeGetters( + dialect: JdbcDialect, + schema: StructType): Array[JDBCValueGetter] = { val replaced = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) - replaced.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) + replaced.fields.map(sf => makeGetter(sf.dataType, dialect, sf.metadata)) } - private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { + private def makeGetter( + dt: DataType, + dialect: JdbcDialect, + metadata: Metadata): JDBCValueGetter = dt match { case BooleanType => (rs: ResultSet, row: InternalRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1)) @@ -478,7 +493,8 @@ object JdbcUtils extends Logging with SQLConfHelper { (rs: ResultSet, row: InternalRow, pos: Int) => val t = rs.getTimestamp(pos + 1) if (t != null) { - row.setLong(pos, DateTimeUtils.fromJavaTimestampNoRebase(t)) + row.setLong(pos, + DateTimeUtils.localDateTimeToMicros(dialect.convertJavaTimestampToTimestampNTZ(t))) } else { row.update(pos, null) } @@ -596,8 +612,8 @@ object JdbcUtils extends Logging with SQLConfHelper { case TimestampNTZType => (stmt: PreparedStatement, row: Row, pos: Int) => - val micros = localDateTimeToMicros(row.getAs[java.time.LocalDateTime](pos)) - stmt.setTimestamp(pos + 1, toJavaTimestampNoRebase(micros)) + stmt.setTimestamp(pos + 1, + dialect.convertTimestampNTZToJavaTimestamp(row.getAs[java.time.LocalDateTime](pos))) case DateType => if (conf.datetimeJava8ApiEnabled) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index e7a74ee3aa9..93a311be2f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Date, Driver, Statement, Timestamp} -import java.time.{Instant, LocalDate} +import java.time.{Instant, LocalDate, LocalDateTime} import java.util import scala.collection.mutable.ArrayBuilder @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateTimeToMicros, toJavaTimestampNoRebase} import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction @@ -104,6 +105,31 @@ abstract class JdbcDialect extends Serializable with Logging { */ def getJDBCType(dt: DataType): Option[JdbcType] = None + /** + * Convert java.sql.Timestamp to a LocalDateTime representing the same wall-clock time as the + * value stored in a remote database. + * JDBC dialects should override this function to provide implementations that suite their + * JDBC drivers. + * @param t Timestamp returned from JDBC driver getTimestamp method. + * @return A LocalDateTime representing the same wall clock time as the timestamp in database. + */ + @Since("3.5.0") + def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = { + DateTimeUtils.microsToLocalDateTime(DateTimeUtils.fromJavaTimestampNoRebase(t)) + } + + /** + * Converts a LocalDateTime representing a TimestampNTZ type to an + * instance of `java.sql.Timestamp`. + * @param ldt representing a TimestampNTZType. + * @return A Java Timestamp representing this LocalDateTime. + */ + @Since("3.5.0") + def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = { + val micros = localDateTimeToMicros(ldt) + toJavaTimestampNoRebase(micros) + } + /** * Returns a factory for creating connections to the given JDBC URL. * In general, creating a connection has nothing to do with JDBC partition id. @@ -682,6 +708,6 @@ object JdbcDialects { /** * NOOP dialect object, always returning the neutral element. */ -private object NoopDialect extends JdbcDialect { +object NoopDialect extends JdbcDialect { override def canHandle(url : String): Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index b53a0e66ba7..b42d575ae2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, SQLException, Types} +import java.sql.{Connection, SQLException, Timestamp, Types} +import java.time.LocalDateTime import java.util import java.util.Locale @@ -102,6 +103,14 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { case _ => None } + override def convertJavaTimestampToTimestampNTZ(t: Timestamp): LocalDateTime = { + t.toLocalDateTime + } + + override def convertTimestampNTZToJavaTimestamp(ldt: LocalDateTime): Timestamp = { + Timestamp.valueOf(ldt) + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("TEXT", Types.VARCHAR)) case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org