This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4ed2dab  [SPARK-36608][SQL] Support TimestampNTZ in Arrow
4ed2dab is described below

commit 4ed2dab5ee46226f29919eba05fe52e446ab2449
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Wed Sep 1 10:23:42 2021 +0900

    [SPARK-36608][SQL] Support TimestampNTZ in Arrow
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add the support of `TimestampNTZType` in Arrow APIs.
    Now, Arrow can write `TimestampNTZType` as Timestamp with `null` timezone 
in Arrow.
    
    ### Why are the changes needed?
    
    To complete the support of `TimestampNTZType` in Apache Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the Arrow APIs (`ArrowColumnVector`) can now write `TimestampNTZType`
    
    ### How was this patch tested?
    
    Unittests were added.
    
    Closes #33875 from HyukjinKwon/SPARK-36608-arrow.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../apache/spark/sql/vectorized/ArrowColumnVector.java  | 17 +++++++++++++++++
 .../scala/org/apache/spark/sql/util/ArrowUtils.scala    |  4 ++++
 .../apache/spark/sql/execution/arrow/ArrowWriter.scala  | 13 +++++++++++++
 .../spark/sql/execution/arrow/ArrowWriterSuite.scala    |  8 ++++++--
 4 files changed, 40 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index fe48670..0813701 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -160,6 +160,8 @@ public final class ArrowColumnVector extends ColumnVector {
       accessor = new DateAccessor((DateDayVector) vector);
     } else if (vector instanceof TimeStampMicroTZVector) {
       accessor = new TimestampAccessor((TimeStampMicroTZVector) vector);
+    } else if (vector instanceof TimeStampMicroVector) {
+      accessor = new TimestampNTZAccessor((TimeStampMicroVector) vector);
     } else if (vector instanceof MapVector) {
       MapVector mapVector = (MapVector) vector;
       accessor = new MapAccessor(mapVector);
@@ -444,6 +446,21 @@ public final class ArrowColumnVector extends ColumnVector {
     }
   }
 
+  private static class TimestampNTZAccessor extends ArrowVectorAccessor {
+
+    private final TimeStampMicroVector accessor;
+
+    TimestampNTZAccessor(TimeStampMicroVector vector) {
+      super(vector);
+      this.accessor = vector;
+    }
+
+    @Override
+    final long getLong(int rowId) {
+      return accessor.get(rowId);
+    }
+  }
+
   private static class ArrayAccessor extends ArrowVectorAccessor {
 
     private final ListVector accessor;
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index d09d83d..4065d23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -53,6 +53,8 @@ private[sql] object ArrowUtils {
       } else {
         new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
       }
+    case TimestampNTZType =>
+      new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
     case NullType => ArrowType.Null.INSTANCE
     case _: YearMonthIntervalType => new 
ArrowType.Interval(IntervalUnit.YEAR_MONTH)
     case _: DayTimeIntervalType => new 
ArrowType.Interval(IntervalUnit.DAY_TIME)
@@ -74,6 +76,8 @@ private[sql] object ArrowUtils {
     case ArrowType.Binary.INSTANCE => BinaryType
     case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
     case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
+    case ts: ArrowType.Timestamp
+      if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => 
TimestampNTZType
     case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => 
TimestampType
     case ArrowType.Null.INSTANCE => NullType
     case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => 
YearMonthIntervalType()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 887b0f8..c216d92 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -61,6 +61,7 @@ object ArrowWriter {
       case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
       case (DateType, vector: DateDayVector) => new DateWriter(vector)
       case (TimestampType, vector: TimeStampMicroTZVector) => new 
TimestampWriter(vector)
+      case (TimestampNTZType, vector: TimeStampMicroVector) => new 
TimestampNTZWriter(vector)
       case (ArrayType(_, _), vector: ListVector) =>
         val elementVector = createFieldWriter(vector.getDataVector())
         new ArrayWriter(vector, elementVector)
@@ -288,6 +289,18 @@ private[arrow] class TimestampWriter(
   }
 }
 
+private[arrow] class TimestampNTZWriter(
+    val valueVector: TimeStampMicroVector) extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getLong(ordinal))
+  }
+}
+
 private[arrow] class ArrayWriter(
     val valueVector: ListVector,
     val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
index 146d9fc..f980a84 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala
@@ -30,12 +30,12 @@ class ArrowWriterSuite extends SparkFunSuite {
 
   test("simple") {
     def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = 
{
-      val avroDatatype = dt match {
+      val datatype = dt match {
         case _: DayTimeIntervalType => DayTimeIntervalType()
         case _: YearMonthIntervalType => YearMonthIntervalType()
         case tpe => tpe
       }
-      val schema = new StructType().add("value", avroDatatype, nullable = true)
+      val schema = new StructType().add("value", datatype, nullable = true)
       val writer = ArrowWriter.create(schema, timeZoneId)
       assert(writer.schema === schema)
 
@@ -61,6 +61,7 @@ class ArrowWriterSuite extends SparkFunSuite {
             case BinaryType => reader.getBinary(rowId)
             case DateType => reader.getInt(rowId)
             case TimestampType => reader.getLong(rowId)
+            case TimestampNTZType => reader.getLong(rowId)
             case _: YearMonthIntervalType => reader.getInt(rowId)
             case _: DayTimeIntervalType => reader.getLong(rowId)
           }
@@ -81,6 +82,7 @@ class ArrowWriterSuite extends SparkFunSuite {
     check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, 
"d".getBytes()))
     check(DateType, Seq(0, 1, 2, null, 4))
     check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), 
"America/Los_Angeles")
+    check(TimestampNTZType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong))
     check(NullType, Seq(null, null, null))
     DataTypeTestUtils.yearMonthIntervalTypes
       .foreach(check(_, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue)))
@@ -139,6 +141,7 @@ class ArrowWriterSuite extends SparkFunSuite {
         case DoubleType => reader.getDoubles(0, data.size)
         case DateType => reader.getInts(0, data.size)
         case TimestampType => reader.getLongs(0, data.size)
+        case TimestampNTZType => reader.getLongs(0, data.size)
         case _: YearMonthIntervalType => reader.getInts(0, data.size)
         case _: DayTimeIntervalType => reader.getLongs(0, data.size)
       }
@@ -155,6 +158,7 @@ class ArrowWriterSuite extends SparkFunSuite {
     check(DoubleType, (0 until 10).map(_.toDouble))
     check(DateType, (0 until 10))
     check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), 
"America/Los_Angeles")
+    check(TimestampNTZType, (0 until 10).map(_ * 4.32e10.toLong))
     DataTypeTestUtils.yearMonthIntervalTypes.foreach(check(_, (0 until 14)))
     DataTypeTestUtils.dayTimeIntervalTypes.foreach(check(_, (-10 until 
10).map(_ * 1000.toLong)))
   }

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to