This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new ef5f7c73a [SEDONA-669] Fix timestamp_nz for GeoParquet reader and
writer (#1661)
ef5f7c73a is described below
commit ef5f7c73aab0a66869522568d81647ff8dbde446
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Oct 30 15:14:27 2024 +0800
[SEDONA-669] Fix timestamp_nz for GeoParquet reader and writer (#1661)
* Fix timestamp_nz for geoparquet format
* Backport the fix to Spark 3.4
* Overwrite the geoparquet output directory to avoid test failures when
running with other tests
---
.../datasources/parquet/GeoParquetFileFormat.scala | 3 ++
.../parquet/GeoParquetRowConverter.scala | 32 ++++++++++++++++++++++
.../parquet/GeoParquetSchemaConverter.scala | 12 ++++++++
.../parquet/GeoParquetWriteSupport.scala | 5 ++++
.../org/apache/sedona/sql/geoparquetIOTests.scala | 32 ++++++++++++++++++++--
.../parquet/GeoParquetRowConverter.scala | 32 ++++++++++++++++++++++
.../parquet/GeoParquetSchemaConverter.scala | 12 ++++++++
.../parquet/GeoParquetWriteSupport.scala | 5 ++++
.../org/apache/sedona/sql/geoparquetIOTests.scala | 29 ++++++++++++++++++++
9 files changed, 159 insertions(+), 3 deletions(-)
diff --git
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
index 325a72098..cdb9834b8 100644
---
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
+++
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
@@ -202,6 +202,9 @@ class GeoParquetFileFormat(val spatialFilter:
Option[GeoParquetSpatialFilter])
hadoopConf.setBoolean(
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
sparkSession.sessionState.conf.isParquetINT96AsTimestamp)
+ hadoopConf.setBoolean(
+ SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key,
+ sparkSession.sessionState.conf.parquetInferTimestampNTZEnabled)
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new
SerializableConfiguration(hadoopConf))
diff --git
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
index 3e04a0a29..c50172874 100644
---
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
+++
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.parquet
import org.apache.parquet.column.Dictionary
import org.apache.parquet.io.api.{Binary, Converter, GroupConverter,
PrimitiveConverter}
+import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit
+import
org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
import org.apache.parquet.schema.OriginalType.LIST
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.parquet.schema.{GroupType, OriginalType, Type}
@@ -312,6 +314,25 @@ private[parquet] class GeoParquetRowConverter(
}
}
+ case TimestampNTZType
+ if canReadAsTimestampNTZ(parquetType) &&
+ parquetType.getLogicalTypeAnnotation
+ .asInstanceOf[TimestampLogicalTypeAnnotation]
+ .getUnit == TimeUnit.MICROS =>
+ new ParquetPrimitiveConverter(updater)
+
+ case TimestampNTZType
+ if canReadAsTimestampNTZ(parquetType) &&
+ parquetType.getLogicalTypeAnnotation
+ .asInstanceOf[TimestampLogicalTypeAnnotation]
+ .getUnit == TimeUnit.MILLIS =>
+ new ParquetPrimitiveConverter(updater) {
+ override def addLong(value: Long): Unit = {
+ val micros = DateTimeUtils.millisToMicros(value)
+ updater.setLong(micros)
+ }
+ }
+
case DateType =>
new ParquetPrimitiveConverter(updater) {
override def addInt(value: Int): Unit = {
@@ -379,6 +400,17 @@ private[parquet] class GeoParquetRowConverter(
}
}
+ // Only INT64 column with Timestamp logical annotation
`isAdjustedToUTC=false`
+ // can be read as Spark's TimestampNTZ type. This is to avoid mistakes in
reading the timestamp
+ // values.
+ private def canReadAsTimestampNTZ(parquetType: Type): Boolean =
+ schemaConverter.isTimestampNTZEnabled() &&
+ parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 &&
+
parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation]
&&
+ !parquetType.getLogicalTypeAnnotation
+ .asInstanceOf[TimestampLogicalTypeAnnotation]
+ .isAdjustedToUTC
+
/**
* Parquet converter for strings. A dictionary is used to minimize string
decoding cost.
*/
diff --git
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
index eab20875a..10dd9e01d 100644
---
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
+++
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
@@ -42,6 +42,8 @@ import org.apache.spark.sql.types._
* Whether unannotated BINARY fields should be assumed to be Spark SQL
[[StringType]] fields.
* @param assumeInt96IsTimestamp
* Whether unannotated INT96 fields should be assumed to be Spark SQL
[[TimestampType]] fields.
+ * @param inferTimestampNTZ
+ * Whether TimestampNTZType type is enabled.
* @param parameters
* Options for reading GeoParquet files.
*/
@@ -49,6 +51,7 @@ class GeoParquetToSparkSchemaConverter(
keyValueMetaData: java.util.Map[String, String],
assumeBinaryIsString: Boolean =
SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
assumeInt96IsTimestamp: Boolean =
SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
+ inferTimestampNTZ: Boolean =
SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get,
parameters: Map[String, String]) {
private val geoParquetMetaData: GeoParquetMetaData =
@@ -61,6 +64,7 @@ class GeoParquetToSparkSchemaConverter(
keyValueMetaData = keyValueMetaData,
assumeBinaryIsString = conf.isParquetBinaryAsString,
assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp,
+ inferTimestampNTZ = conf.parquetInferTimestampNTZEnabled,
parameters = parameters)
def this(
@@ -70,8 +74,16 @@ class GeoParquetToSparkSchemaConverter(
keyValueMetaData = keyValueMetaData,
assumeBinaryIsString =
conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
assumeInt96IsTimestamp =
conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
+ inferTimestampNTZ =
conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean,
parameters = parameters)
+ /**
+ * Returns true if TIMESTAMP_NTZ type is enabled in this
ParquetToSparkSchemaConverter.
+ */
+ def isTimestampNTZEnabled(): Boolean = {
+ inferTimestampNTZ
+ }
+
/**
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL
[[StructType]].
*/
diff --git
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
index 3a6a89773..9d6b36740 100644
---
a/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
+++
b/spark/spark-3.4/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
@@ -308,6 +308,11 @@ class GeoParquetWriteSupport extends
WriteSupport[InternalRow] with Logging {
recordConsumer.addLong(millis)
}
+ case TimestampNTZType =>
+ // For TimestampNTZType column, Spark always output as INT64 with
Timestamp annotation in
+ // MICROS time unit.
+ (row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addLong(row.getLong(ordinal))
+
case BinaryType =>
(row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal)))
diff --git
a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
index ccfd560c8..f5bd8b486 100644
---
a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
+++
b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
@@ -32,15 +32,15 @@ import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.expressions.st_constructors.{ST_Point,
ST_PolygonFromEnvelope}
import org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects
-import org.apache.spark.sql.types.IntegerType
-import org.apache.spark.sql.types.StructField
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType,
TimestampNTZType}
import org.json4s.jackson.parseJson
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.io.WKTReader
import org.scalatest.BeforeAndAfterAll
import java.io.File
+import java.time.LocalDateTime
+import java.time.format.DateTimeFormatter
import java.util.Collections
import java.util.concurrent.atomic.AtomicLong
import scala.collection.JavaConverters._
@@ -732,6 +732,32 @@ class geoparquetIOTests extends TestBaseScala with
BeforeAndAfterAll {
}
}
+ describe("Spark types tests") {
+ it("should support timestamp_ntz") {
+ // Write geoparquet files with a TimestampNTZ column
+ val schema = StructType(
+ Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("timestamp_ntz", TimestampNTZType, nullable = false)))
+ val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
+ val data = Seq(
+ Row(1, LocalDateTime.parse("2024-10-04 12:34:56", formatter)),
+ Row(2, LocalDateTime.parse("2024-10-04 15:30:00", formatter)))
+ val df = sparkSession
+ .createDataFrame(sparkSession.sparkContext.parallelize(data), schema)
+ .withColumn("geom", expr("ST_Point(id, id)"))
+
df.write.format("geoparquet").mode("overwrite").save(geoparquetoutputlocation)
+
+ // Read it back
+ val df2 =
+
sparkSession.read.format("geoparquet").load(geoparquetoutputlocation).sort(col("id"))
+ assert(df2.schema.fields(1).dataType == TimestampNTZType)
+ val data1 = df.sort(col("id")).collect()
+ val data2 = df2.collect()
+ assert(data1 sameElements data2)
+ }
+ }
+
def validateGeoParquetMetadata(path: String)(body: org.json4s.JValue =>
Unit): Unit = {
val parquetFiles = new
File(path).listFiles().filter(_.getName.endsWith(".parquet"))
parquetFiles.foreach { filePath =>
diff --git
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
index 07fc77e2c..44c65ab3e 100644
---
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
+++
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.parquet
import org.apache.parquet.column.Dictionary
import org.apache.parquet.io.api.{Binary, Converter, GroupConverter,
PrimitiveConverter}
+import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit
+import
org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation
import org.apache.parquet.schema.OriginalType.LIST
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.parquet.schema.{GroupType, OriginalType, Type}
@@ -315,6 +317,25 @@ private[parquet] class GeoParquetRowConverter(
}
}
+ case TimestampNTZType
+ if canReadAsTimestampNTZ(parquetType) &&
+ parquetType.getLogicalTypeAnnotation
+ .asInstanceOf[TimestampLogicalTypeAnnotation]
+ .getUnit == TimeUnit.MICROS =>
+ new ParquetPrimitiveConverter(updater)
+
+ case TimestampNTZType
+ if canReadAsTimestampNTZ(parquetType) &&
+ parquetType.getLogicalTypeAnnotation
+ .asInstanceOf[TimestampLogicalTypeAnnotation]
+ .getUnit == TimeUnit.MILLIS =>
+ new ParquetPrimitiveConverter(updater) {
+ override def addLong(value: Long): Unit = {
+ val micros = DateTimeUtils.millisToMicros(value)
+ updater.setLong(micros)
+ }
+ }
+
case DateType =>
new ParquetPrimitiveConverter(updater) {
override def addInt(value: Int): Unit = {
@@ -382,6 +403,17 @@ private[parquet] class GeoParquetRowConverter(
}
}
+ // Only INT64 column with Timestamp logical annotation
`isAdjustedToUTC=false`
+ // can be read as Spark's TimestampNTZ type. This is to avoid mistakes in
reading the timestamp
+ // values.
+ private def canReadAsTimestampNTZ(parquetType: Type): Boolean =
+ schemaConverter.isTimestampNTZEnabled() &&
+ parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 &&
+
parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation]
&&
+ !parquetType.getLogicalTypeAnnotation
+ .asInstanceOf[TimestampLogicalTypeAnnotation]
+ .isAdjustedToUTC
+
/**
* Parquet converter for strings. A dictionary is used to minimize string
decoding cost.
*/
diff --git
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
index eab20875a..10dd9e01d 100644
---
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
+++
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
@@ -42,6 +42,8 @@ import org.apache.spark.sql.types._
* Whether unannotated BINARY fields should be assumed to be Spark SQL
[[StringType]] fields.
* @param assumeInt96IsTimestamp
* Whether unannotated INT96 fields should be assumed to be Spark SQL
[[TimestampType]] fields.
+ * @param inferTimestampNTZ
+ * Whether TimestampNTZType type is enabled.
* @param parameters
* Options for reading GeoParquet files.
*/
@@ -49,6 +51,7 @@ class GeoParquetToSparkSchemaConverter(
keyValueMetaData: java.util.Map[String, String],
assumeBinaryIsString: Boolean =
SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
assumeInt96IsTimestamp: Boolean =
SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
+ inferTimestampNTZ: Boolean =
SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get,
parameters: Map[String, String]) {
private val geoParquetMetaData: GeoParquetMetaData =
@@ -61,6 +64,7 @@ class GeoParquetToSparkSchemaConverter(
keyValueMetaData = keyValueMetaData,
assumeBinaryIsString = conf.isParquetBinaryAsString,
assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp,
+ inferTimestampNTZ = conf.parquetInferTimestampNTZEnabled,
parameters = parameters)
def this(
@@ -70,8 +74,16 @@ class GeoParquetToSparkSchemaConverter(
keyValueMetaData = keyValueMetaData,
assumeBinaryIsString =
conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
assumeInt96IsTimestamp =
conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
+ inferTimestampNTZ =
conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean,
parameters = parameters)
+ /**
+ * Returns true if TIMESTAMP_NTZ type is enabled in this
ParquetToSparkSchemaConverter.
+ */
+ def isTimestampNTZEnabled(): Boolean = {
+ inferTimestampNTZ
+ }
+
/**
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL
[[StructType]].
*/
diff --git
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
index fb5c92163..18f9f4f5c 100644
---
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
+++
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetWriteSupport.scala
@@ -309,6 +309,11 @@ class GeoParquetWriteSupport extends
WriteSupport[InternalRow] with Logging {
recordConsumer.addLong(millis)
}
+ case TimestampNTZType =>
+ // For TimestampNTZType column, Spark always output as INT64 with
Timestamp annotation in
+ // MICROS time unit.
+ (row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addLong(row.getLong(ordinal))
+
case BinaryType =>
(row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addBinary(Binary.fromReusedByteArray(row.getBinary(ordinal)))
diff --git
a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
index ccfd560c8..a6e74730a 100644
---
a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
+++
b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
@@ -35,6 +35,7 @@ import
org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.TimestampNTZType
import org.json4s.jackson.parseJson
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.io.WKTReader
@@ -43,6 +44,8 @@ import org.scalatest.BeforeAndAfterAll
import java.io.File
import java.util.Collections
import java.util.concurrent.atomic.AtomicLong
+import java.time.LocalDateTime
+import java.time.format.DateTimeFormatter
import scala.collection.JavaConverters._
class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll {
@@ -732,6 +735,32 @@ class geoparquetIOTests extends TestBaseScala with
BeforeAndAfterAll {
}
}
+ describe("Spark types tests") {
+ it("should support timestamp_ntz") {
+ // Write geoparquet files with a TimestampNTZ column
+ val schema = StructType(
+ Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("timestamp_ntz", TimestampNTZType, nullable = false)))
+ val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
+ val data = Seq(
+ Row(1, LocalDateTime.parse("2024-10-04 12:34:56", formatter)),
+ Row(2, LocalDateTime.parse("2024-10-04 15:30:00", formatter)))
+ val df = sparkSession
+ .createDataFrame(sparkSession.sparkContext.parallelize(data), schema)
+ .withColumn("geom", expr("ST_Point(id, id)"))
+
df.write.format("geoparquet").mode("overwrite").save(geoparquetoutputlocation)
+
+ // Read it back
+ val df2 =
+
sparkSession.read.format("geoparquet").load(geoparquetoutputlocation).sort(col("id"))
+ assert(df2.schema.fields(1).dataType == TimestampNTZType)
+ val data1 = df.sort(col("id")).collect()
+ val data2 = df2.collect()
+ assert(data1 sameElements data2)
+ }
+ }
+
def validateGeoParquetMetadata(path: String)(body: org.json4s.JValue =>
Unit): Unit = {
val parquetFiles = new
File(path).listFiles().filter(_.getName.endsWith(".parquet"))
parquetFiles.foreach { filePath =>