This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new af33722d563 [SPARK-41404][SQL] Refactor `ColumnVectorUtils#toBatch`  
to make `ColumnarBatchSuite#testRandomRows` test more primitive  dataType
af33722d563 is described below

commit af33722d5633a9a37c4ecd31ca0c2630b070ef90
Author: yangjie01 <yangji...@baidu.com>
AuthorDate: Sat Dec 10 19:25:32 2022 -0800

    [SPARK-41404][SQL] Refactor `ColumnVectorUtils#toBatch`  to make 
`ColumnarBatchSuite#testRandomRows` test more primitive  dataType
    
    ### What changes were proposed in this pull request?
    This pr refactor `ColumnVectorUtils#toBatch` to make 
`ColumnarBatchSuite#testRandomRows` to test more primitive dataType.
    
    ### Why are the changes needed?
    Support `ColumnarBatchSuite#testRandomRows` to test more primitive dataType
    
    ### Does this PR introduce _any_ user-facing change?
    No, just for test
    
    ### How was this patch tested?
    Pass GitHub Actions
    
    Closes #38933 from LuciferYang/toBatch-bugfix.
    
    Lead-authored-by: yangjie01 <yangji...@baidu.com>
    Co-authored-by: YangJie <yangji...@baidu.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../execution/vectorized/ColumnVectorUtils.java    | 11 +++-
 .../execution/vectorized/ColumnarBatchSuite.scala  | 64 +++++++++++++++++++++-
 2 files changed, 71 insertions(+), 4 deletions(-)

diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 4efa06a781a..f89c10155a7 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -20,6 +20,8 @@ import java.math.BigDecimal;
 import java.math.BigInteger;
 import java.nio.charset.StandardCharsets;
 import java.sql.Date;
+import java.sql.Timestamp;
+import java.time.LocalDateTime;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
@@ -146,6 +148,9 @@ public class ColumnVectorUtils {
       } else if (t == DataTypes.StringType) {
         byte[] b =((String)o).getBytes(StandardCharsets.UTF_8);
         dst.appendByteArray(b, 0, b.length);
+      } else if (t == DataTypes.BinaryType) {
+        byte[] b = (byte[]) o;
+        dst.appendByteArray(b, 0, b.length);
       } else if (t instanceof DecimalType) {
         DecimalType dt = (DecimalType) t;
         Decimal d = Decimal.apply((BigDecimal) o, dt.precision(), dt.scale());
@@ -165,7 +170,11 @@ public class ColumnVectorUtils {
         dst.getChild(1).appendInt(c.days);
         dst.getChild(2).appendLong(c.microseconds);
       } else if (t instanceof DateType) {
-        dst.appendInt(DateTimeUtils.fromJavaDate((Date)o));
+        dst.appendInt(DateTimeUtils.fromJavaDate((Date) o));
+      } else if (t instanceof TimestampType) {
+        dst.appendLong(DateTimeUtils.fromJavaTimestamp((Timestamp) o));
+      } else if (t instanceof TimestampNTZType) {
+        dst.appendLong(DateTimeUtils.localDateTimeToMicros((LocalDateTime) o));
       } else {
         throw new UnsupportedOperationException("Type " + t);
       }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 0395798d9e7..8ee6da10b5e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql.execution.vectorized
 import java.nio.ByteBuffer
 import java.nio.ByteOrder
 import java.nio.charset.StandardCharsets
+import java.sql.{Date, Timestamp}
 import java.time.LocalDateTime
 import java.util
-import java.util.NoSuchElementException
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -1379,12 +1379,26 @@ class ColumnarBatchSuite extends SparkFunSuite {
             "Seed = " + seed)
           case DoubleType => assert(doubleEquals(r1.getDouble(ordinal), 
r2.getDouble(ordinal)),
             "Seed = " + seed)
+          case DateType =>
+            assert(r1.getInt(ordinal) == 
DateTimeUtils.fromJavaDate(r2.getDate(ordinal)),
+              "Seed = " + seed)
+          case TimestampType =>
+            assert(r1.getLong(ordinal) ==
+              DateTimeUtils.fromJavaTimestamp(r2.getTimestamp(ordinal)),
+              "Seed = " + seed)
+          case TimestampNTZType =>
+            assert(r1.getLong(ordinal) ==
+              
DateTimeUtils.localDateTimeToMicros(r2.getAs[LocalDateTime](ordinal)),
+              "Seed = " + seed)
           case t: DecimalType =>
             val d1 = r1.getDecimal(ordinal, t.precision, t.scale).toBigDecimal
             val d2 = r2.getDecimal(ordinal)
             assert(d1.compare(d2) == 0, "Seed = " + seed)
           case StringType =>
             assert(r1.getString(ordinal) == r2.getString(ordinal), "Seed = " + 
seed)
+          case BinaryType =>
+            assert(r1.getBinary(ordinal) sameElements 
r2.getAs[Array[Byte]](ordinal),
+              "Seed = " + seed)
           case CalendarIntervalType =>
             assert(r1.getInterval(ordinal) === 
r2.get(ordinal).asInstanceOf[CalendarInterval])
           case ArrayType(childType, n) =>
@@ -1406,6 +1420,50 @@ class ColumnarBatchSuite extends SparkFunSuite {
                     "Seed = " + seed)
                   i += 1
                 }
+              case StringType =>
+                var i = 0
+                while (i < a1.length) {
+                  assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed)
+                  if (a1(i) != null) {
+                    val s1 = a1(i).asInstanceOf[UTF8String].toString
+                    val s2 = a2(i).asInstanceOf[String]
+                    assert(s1 === s2, "Seed = " + seed)
+                  }
+                  i += 1
+                }
+              case DateType =>
+                var i = 0
+                while (i < a1.length) {
+                  assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed)
+                  if (a1(i) != null) {
+                    val i1 = a1(i).asInstanceOf[Int]
+                    val i2 = 
DateTimeUtils.fromJavaDate(a2(i).asInstanceOf[Date])
+                    assert(i1 === i2, "Seed = " + seed)
+                  }
+                  i += 1
+                }
+              case TimestampType =>
+                var i = 0
+                while (i < a1.length) {
+                  assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed)
+                  if (a1(i) != null) {
+                    val i1 = a1(i).asInstanceOf[Long]
+                    val i2 = 
DateTimeUtils.fromJavaTimestamp(a2(i).asInstanceOf[Timestamp])
+                    assert(i1 === i2, "Seed = " + seed)
+                  }
+                  i += 1
+                }
+              case TimestampNTZType =>
+                var i = 0
+                while (i < a1.length) {
+                  assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed)
+                  if (a1(i) != null) {
+                    val i1 = a1(i).asInstanceOf[Long]
+                    val i2 = 
DateTimeUtils.localDateTimeToMicros(a2(i).asInstanceOf[LocalDateTime])
+                    assert(i1 === i2, "Seed = " + seed)
+                  }
+                  i += 1
+                }
               case t: DecimalType =>
                 var i = 0
                 while (i < a1.length) {
@@ -1457,12 +1515,12 @@ class ColumnarBatchSuite extends SparkFunSuite {
    * results.
    */
   def testRandomRows(flatSchema: Boolean, numFields: Int): Unit = {
-    // TODO: Figure out why StringType doesn't work on jenkins.
     val types = Array(
       BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, 
ShortType,
       DecimalType.ShortDecimal, DecimalType.IntDecimal, 
DecimalType.ByteDecimal,
       DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2),
-      new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType)
+      new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType,
+      DateType, StringType, BinaryType, TimestampType, TimestampNTZType)
     val seed = System.nanoTime()
     val NUM_ROWS = 200
     val NUM_ITERS = 1000


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to