This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4ed2dab [SPARK-36608][SQL] Support TimestampNTZ in Arrow
4ed2dab is described below
commit 4ed2dab5ee46226f29919eba05fe52e446ab2449
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Wed Sep 1 10:23:42 2021 +0900
[SPARK-36608][SQL] Support TimestampNTZ in Arrow
### What changes were proposed in this pull request?
This PR proposes to add the support of `TimestampNTZType` in Arrow APIs.
Now, Arrow can write `TimestampNTZType` as Timestamp with `null` timezone
in Arrow.
### Why are the changes needed?
To complete the support of `TimestampNTZType` in Apache Spark.
### Does this PR introduce _any_ user-facing change?
Yes, the Arrow APIs (`ArrowColumnVector`) can now write `TimestampNTZType`
### How was this patch tested?
Unittests were added.
Closes #33875 from HyukjinKwon/SPARK-36608-arrow.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../apache/spark/sql/vectorized/ArrowColumnVector.java | 17 +++++++++++++++++
.../scala/org/apache/spark/sql/util/ArrowUtils.scala | 4 ++++
.../apache/spark/sql/execution/arrow/ArrowWriter.scala | 13 +++++++++++++
.../spark/sql/execution/arrow/ArrowWriterSuite.scala | 8 ++++++--
4 files changed, 40 insertions(+), 2 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index fe48670..0813701 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -160,6 +160,8 @@ public final class ArrowColumnVector extends ColumnVector {
accessor = new DateAccessor((DateDayVector) vector);
} else if (vector instanceof TimeStampMicroTZVector) {
accessor = new TimestampAccessor((TimeStampMicroTZVector) vector);
+ } else if (vector instanceof TimeStampMicroVector) {
+ accessor = new TimestampNTZAccessor((TimeStampMicroVector) vector);
} else if (vector instanceof MapVector) {
MapVector mapVector = (MapVector) vector;
accessor = new MapAccessor(mapVector);
@@ -444,6 +446,21 @@ public final class ArrowColumnVector extends ColumnVector {
}
}
+ private static class TimestampNTZAccessor extends ArrowVectorAccessor {
+
+ private final TimeStampMicroVector accessor;
+
+ TimestampNTZAccessor(TimeStampMicroVector vector) {
+ super(vector);
+ this.accessor = vector;
+ }
+
+ @Override
+ final long getLong(int rowId) {
+ return accessor.get(rowId);
+ }
+ }
+
private static class ArrayAccessor extends ArrowVectorAccessor {
private final ListVector accessor;
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index d09d83d..4065d23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -53,6 +53,8 @@ private[sql] object ArrowUtils {
} else {
new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
}
+ case TimestampNTZType =>
+ new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
case NullType => ArrowType.Null.INSTANCE
case _: YearMonthIntervalType => new
ArrowType.Interval(IntervalUnit.YEAR_MONTH)
case _: DayTimeIntervalType => new
ArrowType.Interval(IntervalUnit.DAY_TIME)
@@ -74,6 +76,8 @@ private[sql] object ArrowUtils {
case ArrowType.Binary.INSTANCE => BinaryType
case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
+ case ts: ArrowType.Timestamp
+ if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
TimestampNTZType
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND =>
TimestampType
case ArrowType.Null.INSTANCE => NullType
case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
YearMonthIntervalType()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 887b0f8..c216d92 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -61,6 +61,7 @@ object ArrowWriter {
case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
case (DateType, vector: DateDayVector) => new DateWriter(vector)
case (TimestampType, vector: TimeStampMicroTZVector) => new
TimestampWriter(vector)
+ case (TimestampNTZType, vector: TimeStampMicroVector) => new
TimestampNTZWriter(vector)
case (ArrayType(_, _), vector: ListVector) =>
val elementVector = createFieldWriter(vector.getDataVector())
new ArrayWriter(vector, elementVector)
@@ -288,6 +289,18 @@ private[arrow] class TimestampWriter(
}
}
+private[arrow] class TimestampNTZWriter(
+ val valueVector: TimeStampMicroVector) extends ArrowFieldWriter {
+
+ override def setNull(): Unit = {
+ valueVector.setNull(count)
+ }
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setSafe(count, input.getLong(ordinal))
+ }
+}
+
private[arrow] class ArrayWriter(
val valueVector: ListVector,
val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index 146d9fc..f980a84 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -30,12 +30,12 @@ class ArrowWriterSuite extends SparkFunSuite {
test("simple") {
def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit =
{
- val avroDatatype = dt match {
+ val datatype = dt match {
case _: DayTimeIntervalType => DayTimeIntervalType()
case _: YearMonthIntervalType => YearMonthIntervalType()
case tpe => tpe
}
- val schema = new StructType().add("value", avroDatatype, nullable = true)
+ val schema = new StructType().add("value", datatype, nullable = true)
val writer = ArrowWriter.create(schema, timeZoneId)
assert(writer.schema === schema)
@@ -61,6 +61,7 @@ class ArrowWriterSuite extends SparkFunSuite {
case BinaryType => reader.getBinary(rowId)
case DateType => reader.getInt(rowId)
case TimestampType => reader.getLong(rowId)
+ case TimestampNTZType => reader.getLong(rowId)
case _: YearMonthIntervalType => reader.getInt(rowId)
case _: DayTimeIntervalType => reader.getLong(rowId)
}
@@ -81,6 +82,7 @@ class ArrowWriterSuite extends SparkFunSuite {
check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null,
"d".getBytes()))
check(DateType, Seq(0, 1, 2, null, 4))
check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong),
"America/Los_Angeles")
+ check(TimestampNTZType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong))
check(NullType, Seq(null, null, null))
DataTypeTestUtils.yearMonthIntervalTypes
.foreach(check(_, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue)))
@@ -139,6 +141,7 @@ class ArrowWriterSuite extends SparkFunSuite {
case DoubleType => reader.getDoubles(0, data.size)
case DateType => reader.getInts(0, data.size)
case TimestampType => reader.getLongs(0, data.size)
+ case TimestampNTZType => reader.getLongs(0, data.size)
case _: YearMonthIntervalType => reader.getInts(0, data.size)
case _: DayTimeIntervalType => reader.getLongs(0, data.size)
}
@@ -155,6 +158,7 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DoubleType, (0 until 10).map(_.toDouble))
check(DateType, (0 until 10))
check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong),
"America/Los_Angeles")
+ check(TimestampNTZType, (0 until 10).map(_ * 4.32e10.toLong))
DataTypeTestUtils.yearMonthIntervalTypes.foreach(check(_, (0 until 14)))
DataTypeTestUtils.dayTimeIntervalTypes.foreach(check(_, (-10 until
10).map(_ * 1000.toLong)))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]