This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a6ac63d14b5 [SPARK-44449][CONNECT] Upcasting for direct Arrow 
Deserialization
a6ac63d14b5 is described below

commit a6ac63d14b56d939dda1aa2a8e74308efc8e1b93
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Mon Jul 24 20:46:01 2023 -0400

    [SPARK-44449][CONNECT] Upcasting for direct Arrow Deserialization
    
    ### What changes were proposed in this pull request?
    This PR adds upcasting to direct Arrow deserialization for the Spark 
Connect Scala Client. This is implemented by decoupling leaf encoders from 
vector implementations, instead all leaf encoders are now tied to an 
`ArrowVectorReader` instance that will allow us to read the data we need from 
an arbitrary vector type, provided we can read this data without data loss 
(this is both checked at runtime and compile time).
    
    ### Why are the changes needed?
    Direct Arrow deserialization did not support upcasting yet. Not supporting 
this would be a regression for connect compared to SPARK 3.4.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, it brings arrow encoders up to par with the existing catalyst encoding 
framework.
    
    ### How was this patch tested?
    Added tests to `ArrowEncoderSuite`.
    Re-enabled tests that relied on upcasting.
    
    Closes #42076 from hvanhovell/SPARK-44449.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |   5 +-
 .../spark/sql/connect/client/SparkResult.scala     |   6 +-
 .../connect/client/arrow/ArrowDeserializer.scala   | 221 ++++++++++--------
 .../connect/client/arrow/ArrowVectorReader.scala   | 259 +++++++++++++++++++++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  10 +-
 .../sql/KeyValueGroupedDatasetE2ETestSuite.scala   |   5 +-
 .../connect/client/arrow/ArrowEncoderSuite.scala   | 217 +++++++++++++++--
 7 files changed, 595 insertions(+), 128 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index b37e3884038..161b5a0217e 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -126,7 +126,6 @@ class SparkSession private[sql] (
   private def createDataset[T](encoder: AgnosticEncoder[T], data: 
Iterator[T]): Dataset[T] = {
     newDataset(encoder) { builder =>
       if (data.nonEmpty) {
-        val timeZoneId = conf.get("spark.sql.session.timeZone")
         val arrowData = ArrowSerializer.serialize(data, encoder, allocator, 
timeZoneId)
         if (arrowData.size() <= 
conf.get("spark.sql.session.localRelationCacheThreshold").toInt) {
           builder.getLocalRelationBuilder
@@ -529,9 +528,11 @@ class SparkSession private[sql] (
     client.semanticHash(plan).getSemanticHash.getResult
   }
 
+  private[sql] def timeZoneId: String = conf.get("spark.sql.session.timeZone")
+
   private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): 
SparkResult[T] = {
     val value = client.execute(plan)
-    val result = new SparkResult(value, allocator, encoder)
+    val result = new SparkResult(value, allocator, encoder, timeZoneId)
     cleaner.register(result)
     result
   }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index eed8bd3f37d..e3055b2678f 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -36,7 +36,8 @@ import org.apache.spark.sql.util.ArrowUtils
 private[sql] class SparkResult[T](
     responses: java.util.Iterator[proto.ExecutePlanResponse],
     allocator: BufferAllocator,
-    encoder: AgnosticEncoder[T])
+    encoder: AgnosticEncoder[T],
+    timeZoneId: String)
     extends AutoCloseable
     with Cleanable { self =>
 
@@ -213,7 +214,8 @@ private[sql] class SparkResult[T](
             new ConcatenatingArrowStreamReader(
               allocator,
               Iterator.single(new ResultMessageIterator(destructive)),
-              destructive))
+              destructive),
+            timeZoneId)
         }
       }
 
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 154866d699a..01aba9cb0ce 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -28,7 +28,7 @@ import scala.collection.mutable
 import scala.reflect.ClassTag
 
 import org.apache.arrow.memory.BufferAllocator
-import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, 
DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, 
IntervalYearVector, IntVector, NullVector, SmallIntVector, 
TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, 
VarCharVector, VectorSchemaRoot}
+import org.apache.arrow.vector.{FieldVector, VarCharVector, VectorSchemaRoot}
 import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
 import org.apache.arrow.vector.ipc.ArrowReader
 import org.apache.arrow.vector.util.Text
@@ -37,7 +37,6 @@ import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.types.Decimal
 
@@ -54,13 +53,14 @@ object ArrowDeserializers {
   def deserializeFromArrow[T](
       input: Iterator[Array[Byte]],
       encoder: AgnosticEncoder[T],
-      allocator: BufferAllocator): CloseableIterator[T] = {
+      allocator: BufferAllocator,
+      timeZoneId: String): CloseableIterator[T] = {
     try {
       val reader = new ConcatenatingArrowStreamReader(
         allocator,
         input.map(bytes => new MessageIterator(new 
ByteArrayInputStream(bytes), allocator)),
         destructive = true)
-      new ArrowDeserializingIterator(encoder, reader)
+      new ArrowDeserializingIterator(encoder, reader, timeZoneId)
     } catch {
       case _: IOException =>
         new EmptyDeserializingIterator(encoder)
@@ -72,7 +72,8 @@ object ArrowDeserializers {
    */
   private[arrow] def deserializerFor[T](
       encoder: AgnosticEncoder[T],
-      root: VectorSchemaRoot): Deserializer[T] = {
+      root: VectorSchemaRoot,
+      timeZoneId: String): Deserializer[T] = {
     val data: AnyRef = if (encoder.isStruct) {
       root
     } else {
@@ -80,138 +81,141 @@ object ArrowDeserializers {
       // by convention we bind to the first one.
       root.getVector(0)
     }
-    deserializerFor(encoder, data).asInstanceOf[Deserializer[T]]
+    deserializerFor(encoder, data, timeZoneId).asInstanceOf[Deserializer[T]]
   }
 
   private[arrow] def deserializerFor(
       encoder: AgnosticEncoder[_],
-      data: AnyRef): Deserializer[Any] = {
+      data: AnyRef,
+      timeZoneId: String): Deserializer[Any] = {
     (encoder, data) match {
-      case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) =>
-        new FieldDeserializer[Boolean, BitVector](v) {
-          def value(i: Int): Boolean = vector.get(i) != 0
+      case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[Boolean](encoder, v, timeZoneId) {
+          override def value(i: Int): Boolean = reader.getBoolean(i)
         }
-      case (PrimitiveByteEncoder | BoxedByteEncoder, v: TinyIntVector) =>
-        new FieldDeserializer[Byte, TinyIntVector](v) {
-          def value(i: Int): Byte = vector.get(i)
+      case (PrimitiveByteEncoder | BoxedByteEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[Byte](encoder, v, timeZoneId) {
+          override def value(i: Int): Byte = reader.getByte(i)
         }
-      case (PrimitiveShortEncoder | BoxedShortEncoder, v: SmallIntVector) =>
-        new FieldDeserializer[Short, SmallIntVector](v) {
-          def value(i: Int): Short = vector.get(i)
+      case (PrimitiveShortEncoder | BoxedShortEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[Short](encoder, v, timeZoneId) {
+          override def value(i: Int): Short = reader.getShort(i)
         }
-      case (PrimitiveIntEncoder | BoxedIntEncoder, v: IntVector) =>
-        new FieldDeserializer[Int, IntVector](v) {
-          def value(i: Int): Int = vector.get(i)
+      case (PrimitiveIntEncoder | BoxedIntEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[Int](encoder, v, timeZoneId) {
+          override def value(i: Int): Int = reader.getInt(i)
         }
-      case (PrimitiveLongEncoder | BoxedLongEncoder, v: BigIntVector) =>
-        new FieldDeserializer[Long, BigIntVector](v) {
-          def value(i: Int): Long = vector.get(i)
+      case (PrimitiveLongEncoder | BoxedLongEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[Long](encoder, v, timeZoneId) {
+          override def value(i: Int): Long = reader.getLong(i)
         }
-      case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: Float4Vector) =>
-        new FieldDeserializer[Float, Float4Vector](v) {
-          def value(i: Int): Float = vector.get(i)
+      case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[Float](encoder, v, timeZoneId) {
+          override def value(i: Int): Float = reader.getFloat(i)
         }
-      case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: Float8Vector) =>
-        new FieldDeserializer[Double, Float8Vector](v) {
-          def value(i: Int): Double = vector.get(i)
+      case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[Double](encoder, v, timeZoneId) {
+          override def value(i: Int): Double = reader.getDouble(i)
         }
-      case (NullEncoder, v: NullVector) =>
-        new FieldDeserializer[Any, NullVector](v) {
-          def value(i: Int): Any = null
+      case (NullEncoder, _: FieldVector) =>
+        new Deserializer[Any] {
+          def get(i: Int): Any = null
         }
-      case (StringEncoder, v: VarCharVector) =>
-        new FieldDeserializer[String, VarCharVector](v) {
-          def value(i: Int): String = getString(vector, i)
+      case (StringEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[String](encoder, v, timeZoneId) {
+          override def value(i: Int): String = reader.getString(i)
         }
-      case (JavaEnumEncoder(tag), v: VarCharVector) =>
+      case (JavaEnumEncoder(tag), v: FieldVector) =>
         // It would be nice if we can get Enum.valueOf working...
         val valueOf = methodLookup.findStatic(
           tag.runtimeClass,
           "valueOf",
           MethodType.methodType(tag.runtimeClass, classOf[String]))
-        new FieldDeserializer[Enum[_], VarCharVector](v) {
-          def value(i: Int): Enum[_] = {
-            valueOf.invoke(getString(vector, i)).asInstanceOf[Enum[_]]
+        new LeafFieldDeserializer[Enum[_]](encoder, v, timeZoneId) {
+          override def value(i: Int): Enum[_] = {
+            valueOf.invoke(reader.getString(i)).asInstanceOf[Enum[_]]
           }
         }
-      case (ScalaEnumEncoder(parent, _), v: VarCharVector) =>
+      case (ScalaEnumEncoder(parent, _), v: FieldVector) =>
         val mirror = scala.reflect.runtime.currentMirror
         val module = mirror.classSymbol(parent).module.asModule
         val enumeration = 
mirror.reflectModule(module).instance.asInstanceOf[Enumeration]
-        new FieldDeserializer[Enumeration#Value, VarCharVector](v) {
-          def value(i: Int): Enumeration#Value = 
enumeration.withName(getString(vector, i))
+        new LeafFieldDeserializer[Enumeration#Value](encoder, v, timeZoneId) {
+          override def value(i: Int): Enumeration#Value = {
+            enumeration.withName(reader.getString(i))
+          }
         }
-      case (BinaryEncoder, v: VarBinaryVector) =>
-        new FieldDeserializer[Array[Byte], VarBinaryVector](v) {
-          def value(i: Int): Array[Byte] = vector.get(i)
+      case (BinaryEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[Array[Byte]](encoder, v, timeZoneId) {
+          override def value(i: Int): Array[Byte] = reader.getBytes(i)
         }
-      case (SparkDecimalEncoder(_), v: DecimalVector) =>
-        new FieldDeserializer[Decimal, DecimalVector](v) {
-          def value(i: Int): Decimal = Decimal(vector.getObject(i))
+      case (SparkDecimalEncoder(_), v: FieldVector) =>
+        new LeafFieldDeserializer[Decimal](encoder, v, timeZoneId) {
+          override def value(i: Int): Decimal = reader.getDecimal(i)
         }
-      case (ScalaDecimalEncoder(_), v: DecimalVector) =>
-        new FieldDeserializer[BigDecimal, DecimalVector](v) {
-          def value(i: Int): BigDecimal = BigDecimal(vector.getObject(i))
+      case (ScalaDecimalEncoder(_), v: FieldVector) =>
+        new LeafFieldDeserializer[BigDecimal](encoder, v, timeZoneId) {
+          override def value(i: Int): BigDecimal = reader.getScalaDecimal(i)
         }
-      case (JavaDecimalEncoder(_, _), v: DecimalVector) =>
-        new FieldDeserializer[JBigDecimal, DecimalVector](v) {
-          def value(i: Int): JBigDecimal = vector.getObject(i)
+      case (JavaDecimalEncoder(_, _), v: FieldVector) =>
+        new LeafFieldDeserializer[JBigDecimal](encoder, v, timeZoneId) {
+          override def value(i: Int): JBigDecimal = reader.getJavaDecimal(i)
         }
-      case (ScalaBigIntEncoder, v: DecimalVector) =>
-        new FieldDeserializer[BigInt, DecimalVector](v) {
-          def value(i: Int): BigInt = new 
BigInt(vector.getObject(i).toBigInteger)
+      case (ScalaBigIntEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[BigInt](encoder, v, timeZoneId) {
+          override def value(i: Int): BigInt = reader.getScalaBigInt(i)
         }
-      case (JavaBigIntEncoder, v: DecimalVector) =>
-        new FieldDeserializer[JBigInteger, DecimalVector](v) {
-          def value(i: Int): JBigInteger = vector.getObject(i).toBigInteger
+      case (JavaBigIntEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[JBigInteger](encoder, v, timeZoneId) {
+          override def value(i: Int): JBigInteger = reader.getJavaBigInt(i)
         }
-      case (DayTimeIntervalEncoder, v: DurationVector) =>
-        new FieldDeserializer[Duration, DurationVector](v) {
-          def value(i: Int): Duration = vector.getObject(i)
+      case (DayTimeIntervalEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[Duration](encoder, v, timeZoneId) {
+          override def value(i: Int): Duration = reader.getDuration(i)
         }
-      case (YearMonthIntervalEncoder, v: IntervalYearVector) =>
-        new FieldDeserializer[Period, IntervalYearVector](v) {
-          def value(i: Int): Period = vector.getObject(i).normalized()
+      case (YearMonthIntervalEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[Period](encoder, v, timeZoneId) {
+          override def value(i: Int): Period = reader.getPeriod(i)
         }
-      case (DateEncoder(_), v: DateDayVector) =>
-        new FieldDeserializer[java.sql.Date, DateDayVector](v) {
-          def value(i: Int): java.sql.Date = 
DateTimeUtils.toJavaDate(vector.get(i))
+      case (DateEncoder(_), v: FieldVector) =>
+        new LeafFieldDeserializer[java.sql.Date](encoder, v, timeZoneId) {
+          override def value(i: Int): java.sql.Date = reader.getDate(i)
         }
-      case (LocalDateEncoder(_), v: DateDayVector) =>
-        new FieldDeserializer[LocalDate, DateDayVector](v) {
-          def value(i: Int): LocalDate = 
DateTimeUtils.daysToLocalDate(vector.get(i))
+      case (LocalDateEncoder(_), v: FieldVector) =>
+        new LeafFieldDeserializer[LocalDate](encoder, v, timeZoneId) {
+          override def value(i: Int): LocalDate = reader.getLocalDate(i)
         }
-      case (TimestampEncoder(_), v: TimeStampMicroTZVector) =>
-        new FieldDeserializer[java.sql.Timestamp, TimeStampMicroTZVector](v) {
-          def value(i: Int): java.sql.Timestamp = 
DateTimeUtils.toJavaTimestamp(vector.get(i))
+      case (TimestampEncoder(_), v: FieldVector) =>
+        new LeafFieldDeserializer[java.sql.Timestamp](encoder, v, timeZoneId) {
+          override def value(i: Int): java.sql.Timestamp = 
reader.getTimestamp(i)
         }
-      case (InstantEncoder(_), v: TimeStampMicroTZVector) =>
-        new FieldDeserializer[Instant, TimeStampMicroTZVector](v) {
-          def value(i: Int): Instant = 
DateTimeUtils.microsToInstant(vector.get(i))
+      case (InstantEncoder(_), v: FieldVector) =>
+        new LeafFieldDeserializer[Instant](encoder, v, timeZoneId) {
+          override def value(i: Int): Instant = reader.getInstant(i)
         }
-      case (LocalDateTimeEncoder, v: TimeStampMicroVector) =>
-        new FieldDeserializer[LocalDateTime, TimeStampMicroVector](v) {
-          def value(i: Int): LocalDateTime = 
DateTimeUtils.microsToLocalDateTime(vector.get(i))
+      case (LocalDateTimeEncoder, v: FieldVector) =>
+        new LeafFieldDeserializer[LocalDateTime](encoder, v, timeZoneId) {
+          override def value(i: Int): LocalDateTime = 
reader.getLocalDateTime(i)
         }
 
       case (OptionEncoder(value), v) =>
-        val deserializer = deserializerFor(value, v)
+        val deserializer = deserializerFor(value, v, timeZoneId)
         new Deserializer[Any] {
           override def get(i: Int): Any = Option(deserializer.get(i))
         }
 
       case (ArrayEncoder(element, _), v: ListVector) =>
-        val deserializer = deserializerFor(element, v.getDataVector)
-        new FieldDeserializer[AnyRef, ListVector](v) {
+        val deserializer = deserializerFor(element, v.getDataVector, 
timeZoneId)
+        new VectorFieldDeserializer[AnyRef, ListVector](v) {
           def value(i: Int): AnyRef = getArray(vector, i, 
deserializer)(element.clsTag)
         }
 
       case (IterableEncoder(tag, element, _, _), v: ListVector) =>
-        val deserializer = deserializerFor(element, v.getDataVector)
+        val deserializer = deserializerFor(element, v.getDataVector, 
timeZoneId)
         if (isSubClass(Classes.WRAPPED_ARRAY, tag)) {
           // Wrapped array is a bit special because we need to use an array of 
the element type.
           // Some parts of our codebase (unfortunately) rely on this for type 
inference on results.
-          new FieldDeserializer[mutable.WrappedArray[Any], ListVector](v) {
+          new VectorFieldDeserializer[mutable.WrappedArray[Any], 
ListVector](v) {
             def value(i: Int): mutable.WrappedArray[Any] = {
               val array = getArray(vector, i, deserializer)(element.clsTag)
               ScalaCollectionUtils.wrap(array)
@@ -219,7 +223,7 @@ object ArrowDeserializers {
           }
         } else if (isSubClass(Classes.ITERABLE, tag)) {
           val companion = ScalaCollectionUtils.getIterableCompanion(tag)
-          new FieldDeserializer[Iterable[Any], ListVector](v) {
+          new VectorFieldDeserializer[Iterable[Any], ListVector](v) {
             def value(i: Int): Iterable[Any] = {
               val builder = companion.newBuilder[Any]
               loadListIntoBuilder(vector, i, deserializer, builder)
@@ -228,7 +232,7 @@ object ArrowDeserializers {
           }
         } else if (isSubClass(Classes.JLIST, tag)) {
           val newInstance = resolveJavaListCreator(tag)
-          new FieldDeserializer[JList[Any], ListVector](v) {
+          new VectorFieldDeserializer[JList[Any], ListVector](v) {
             def value(i: Int): JList[Any] = {
               var index = v.getElementStartIndex(i)
               val end = v.getElementEndIndex(i)
@@ -246,12 +250,13 @@ object ArrowDeserializers {
 
       case (MapEncoder(tag, key, value, _), v: MapVector) =>
         val structVector = v.getDataVector.asInstanceOf[StructVector]
-        val keyDeserializer = deserializerFor(key, 
structVector.getChild(MapVector.KEY_NAME))
+        val keyDeserializer =
+          deserializerFor(key, structVector.getChild(MapVector.KEY_NAME), 
timeZoneId)
         val valueDeserializer =
-          deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME))
+          deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME), 
timeZoneId)
         if (isSubClass(Classes.MAP, tag)) {
           val companion = ScalaCollectionUtils.getMapCompanion(tag)
-          new FieldDeserializer[Map[Any, Any], MapVector](v) {
+          new VectorFieldDeserializer[Map[Any, Any], MapVector](v) {
             def value(i: Int): Map[Any, Any] = {
               val builder = companion.newBuilder[Any, Any]
               var index = v.getElementStartIndex(i)
@@ -266,7 +271,7 @@ object ArrowDeserializers {
           }
         } else if (isSubClass(Classes.JMAP, tag)) {
           val newInstance = resolveJavaMapCreator(tag)
-          new FieldDeserializer[JMap[Any, Any], MapVector](v) {
+          new VectorFieldDeserializer[JMap[Any, Any], MapVector](v) {
             def value(i: Int): JMap[Any, Any] = {
               val map = newInstance()
               var index = v.getElementStartIndex(i)
@@ -288,12 +293,12 @@ object ArrowDeserializers {
           ScalaReflection.findConstructor(tag.runtimeClass, 
fields.map(_.enc.clsTag.runtimeClass))
         val deserializers = if (isTuple(tag.runtimeClass)) {
           fields.zip(vectors).map { case (field, vector) =>
-            deserializerFor(field.enc, vector)
+            deserializerFor(field.enc, vector, timeZoneId)
           }
         } else {
           val lookup = createFieldLookup(vectors)
           fields.map { field =>
-            deserializerFor(field.enc, lookup(field.name))
+            deserializerFor(field.enc, lookup(field.name), timeZoneId)
           }
         }
         new StructFieldSerializer[Any](struct) {
@@ -305,7 +310,7 @@ object ArrowDeserializers {
       case (r @ RowEncoder(fields), StructVectors(struct, vectors)) =>
         val lookup = createFieldLookup(vectors)
         val deserializers = fields.toArray.map { field =>
-          deserializerFor(field.enc, lookup(field.name))
+          deserializerFor(field.enc, lookup(field.name), timeZoneId)
         }
         new StructFieldSerializer[Any](struct) {
           def value(i: Int): Any = {
@@ -320,7 +325,7 @@ object ArrowDeserializers {
         val lookup = createFieldLookup(vectors)
         val setters = fields.map { field =>
           val vector = lookup(field.name)
-          val deserializer = deserializerFor(field.enc, vector)
+          val deserializer = deserializerFor(field.enc, vector, timeZoneId)
           val setter = methodLookup.findVirtual(
             tag.runtimeClass,
             field.writeMethod.get,
@@ -478,9 +483,9 @@ object ArrowDeserializers {
     def get(i: Int): E
   }
 
-  abstract class FieldDeserializer[E, V <: FieldVector](val vector: V) extends 
Deserializer[E] {
+  abstract class FieldDeserializer[E] extends Deserializer[E] {
     def value(i: Int): E
-    def isNull(i: Int): Boolean = vector.isNull(i)
+    def isNull(i: Int): Boolean
     override def get(i: Int): E = {
       if (!isNull(i)) {
         value(i)
@@ -490,8 +495,23 @@ object ArrowDeserializers {
     }
   }
 
+  abstract class LeafFieldDeserializer[E](val reader: ArrowVectorReader)
+      extends FieldDeserializer[E] {
+    def this(encoder: AgnosticEncoder[_], vector: FieldVector, timeZoneId: 
String) = {
+      this(ArrowVectorReader(encoder.dataType, vector, timeZoneId))
+    }
+    def value(i: Int): E
+    def isNull(i: Int): Boolean = reader.isNull(i)
+  }
+
+  abstract class VectorFieldDeserializer[E, V <: FieldVector](val vector: V)
+      extends FieldDeserializer[E] {
+    def value(i: Int): E
+    def isNull(i: Int): Boolean = vector.isNull(i)
+  }
+
   abstract class StructFieldSerializer[E](v: StructVector)
-      extends FieldDeserializer[E, StructVector](v) {
+      extends VectorFieldDeserializer[E, StructVector](v) {
     override def isNull(i: Int): Boolean = vector != null && vector.isNull(i)
   }
 }
@@ -505,11 +525,12 @@ class EmptyDeserializingIterator[E](val encoder: 
AgnosticEncoder[E])
 
 class ArrowDeserializingIterator[E](
     val encoder: AgnosticEncoder[E],
-    private[this] val reader: ArrowReader)
+    private[this] val reader: ArrowReader,
+    timeZoneId: String)
     extends CloseableIterator[E] {
   private[this] var index = 0
   private[this] val root = reader.getVectorSchemaRoot
-  private[this] val deserializer = ArrowDeserializers.deserializerFor(encoder, 
root)
+  private[this] val deserializer = ArrowDeserializers.deserializerFor(encoder, 
root, timeZoneId)
 
   override def hasNext: Boolean = {
     if (index >= root.getRowCount) {
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
new file mode 100644
index 00000000000..9111b3b9ccf
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.arrow
+
+import java.math.{BigDecimal => JBigDecimal}
+import java.sql.{Date, Timestamp}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, 
ZoneOffset}
+
+import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, 
DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, 
IntervalYearVector, IntVector, NullVector, SmallIntVector, 
TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, 
VarCharVector}
+import org.apache.arrow.vector.util.Text
+
+import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.util.{DateFormatter, IntervalUtils, 
StringUtils, TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
+import org.apache.spark.sql.catalyst.util.DateTimeUtils._
+import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
+import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, Decimal, 
YearMonthIntervalType}
+import org.apache.spark.sql.util.ArrowUtils
+
+/**
+ * Base class for reading leaf values from an arrow vector. This reader has 
read methods for all
+ * leaf data types supported by the encoder framework. A subclass should 
always implement one of
+ * the read methods. If upcasting is allowed for the given vector, then all 
allowed read methods
+ * must be implemented.
+ */
+private[arrow] abstract class ArrowVectorReader {
+  def isNull(i: Int): Boolean
+  def getBoolean(i: Int): Boolean = unsupported()
+  def getByte(i: Int): Byte = unsupported()
+  def getShort(i: Int): Short = unsupported()
+  def getInt(i: Int): Int = unsupported()
+  def getLong(i: Int): Long = unsupported()
+  def getFloat(i: Int): Float = unsupported()
+  def getDouble(i: Int): Double = unsupported()
+  def getString(i: Int): String = unsupported()
+  def getBytes(i: Int): Array[Byte] = unsupported()
+  def getJavaDecimal(i: Int): JBigDecimal = unsupported()
+  def getJavaBigInt(i: Int): java.math.BigInteger = 
getJavaDecimal(i).toBigInteger
+  def getScalaDecimal(i: Int): BigDecimal = BigDecimal(getJavaDecimal(i))
+  def getScalaBigInt(i: Int): BigInt = BigInt(getJavaBigInt(i))
+  def getDecimal(i: Int): Decimal = Decimal(getJavaDecimal(i))
+  def getPeriod(i: Int): java.time.Period = unsupported()
+  def getDuration(i: Int): java.time.Duration = unsupported()
+  def getDate(i: Int): java.sql.Date = unsupported()
+  def getTimestamp(i: Int): java.sql.Timestamp = unsupported()
+  def getInstant(i: Int): java.time.Instant = unsupported()
+  def getLocalDate(i: Int): java.time.LocalDate = unsupported()
+  def getLocalDateTime(i: Int): java.time.LocalDateTime = unsupported()
+  private def unsupported(): Nothing = throw new 
UnsupportedOperationException()
+}
+
+object ArrowVectorReader {
+  def apply(
+      targetDataType: DataType,
+      vector: FieldVector,
+      timeZoneId: String): ArrowVectorReader = {
+    val vectorDataType = ArrowUtils.fromArrowType(vector.getField.getType)
+    if (!Cast.canUpCast(vectorDataType, targetDataType)) {
+      throw new RuntimeException(
+        s"Reading '$targetDataType' values from a ${vector.getClass} instance 
is not supported.")
+    }
+    vector match {
+      case v: BitVector => new BitVectorReader(v)
+      case v: TinyIntVector => new TinyIntVectorReader(v)
+      case v: SmallIntVector => new SmallIntVectorReader(v)
+      case v: IntVector => new IntVectorReader(v)
+      case v: BigIntVector => new BigIntVectorReader(v)
+      case v: Float4Vector => new Float4VectorReader(v)
+      case v: Float8Vector => new Float8VectorReader(v)
+      case v: DecimalVector => new DecimalVectorReader(v)
+      case v: VarCharVector => new VarCharVectorReader(v)
+      case v: VarBinaryVector => new VarBinaryVectorReader(v)
+      case v: DurationVector => new DurationVectorReader(v)
+      case v: IntervalYearVector => new IntervalYearVectorReader(v)
+      case v: DateDayVector => new DateDayVectorReader(v, timeZoneId)
+      case v: TimeStampMicroTZVector => new TimeStampMicroTZVectorReader(v)
+      case v: TimeStampMicroVector => new TimeStampMicroVectorReader(v, 
timeZoneId)
+      case _: NullVector => NullVectorReader
+      case _ => throw new RuntimeException("Unsupported Vector Type: " + 
vector.getClass)
+    }
+  }
+}
+
+private[arrow] object NullVectorReader extends ArrowVectorReader {
+  override def isNull(i: Int): Boolean = true
+}
+
+private[arrow] abstract class TypedArrowVectorReader[E <: FieldVector](val 
vector: E)
+    extends ArrowVectorReader {
+  override def isNull(i: Int): Boolean = vector.isNull(i)
+}
+
+private[arrow] class BitVectorReader(v: BitVector) extends 
TypedArrowVectorReader[BitVector](v) {
+  override def getBoolean(i: Int): Boolean = vector.get(i) > 0
+  override def getString(i: Int): String = String.valueOf(getBoolean(i))
+}
+
+private[arrow] class TinyIntVectorReader(v: TinyIntVector)
+    extends TypedArrowVectorReader[TinyIntVector](v) {
+  override def getByte(i: Int): Byte = vector.get(i)
+  override def getShort(i: Int): Short = getByte(i)
+  override def getInt(i: Int): Int = getByte(i)
+  override def getLong(i: Int): Long = getByte(i)
+  override def getFloat(i: Int): Float = getByte(i)
+  override def getDouble(i: Int): Double = getByte(i)
+  override def getString(i: Int): String = String.valueOf(getByte(i))
+  override def getJavaDecimal(i: Int): JBigDecimal = 
JBigDecimal.valueOf(getByte(i))
+}
+
+private[arrow] class SmallIntVectorReader(v: SmallIntVector)
+    extends TypedArrowVectorReader[SmallIntVector](v) {
+  override def getShort(i: Int): Short = vector.get(i)
+  override def getInt(i: Int): Int = getShort(i)
+  override def getLong(i: Int): Long = getShort(i)
+  override def getFloat(i: Int): Float = getShort(i)
+  override def getDouble(i: Int): Double = getShort(i)
+  override def getString(i: Int): String = String.valueOf(getShort(i))
+  override def getJavaDecimal(i: Int): JBigDecimal = 
JBigDecimal.valueOf(getShort(i))
+}
+
+private[arrow] class IntVectorReader(v: IntVector) extends 
TypedArrowVectorReader[IntVector](v) {
+  override def getInt(i: Int): Int = vector.get(i)
+  override def getLong(i: Int): Long = getInt(i)
+  override def getFloat(i: Int): Float = getInt(i)
+  override def getDouble(i: Int): Double = getInt(i)
+  override def getString(i: Int): String = String.valueOf(getInt(i))
+  override def getJavaDecimal(i: Int): JBigDecimal = 
JBigDecimal.valueOf(getInt(i))
+}
+
+private[arrow] class BigIntVectorReader(v: BigIntVector)
+    extends TypedArrowVectorReader[BigIntVector](v) {
+  override def getLong(i: Int): Long = vector.get(i)
+  override def getFloat(i: Int): Float = getLong(i)
+  override def getDouble(i: Int): Double = getLong(i)
+  override def getString(i: Int): String = String.valueOf(getLong(i))
+  override def getJavaDecimal(i: Int): JBigDecimal = 
JBigDecimal.valueOf(getLong(i))
+  override def getTimestamp(i: Int): Timestamp = toJavaTimestamp(getLong(i) * 
MICROS_PER_SECOND)
+  override def getInstant(i: Int): Instant = microsToInstant(getLong(i))
+}
+
+private[arrow] class Float4VectorReader(v: Float4Vector)
+    extends TypedArrowVectorReader[Float4Vector](v) {
+  override def getFloat(i: Int): Float = vector.get(i)
+  override def getDouble(i: Int): Double = getFloat(i)
+  override def getString(i: Int): String = String.valueOf(getFloat(i))
+  override def getJavaDecimal(i: Int): JBigDecimal = 
JBigDecimal.valueOf(getFloat(i))
+}
+
+private[arrow] class Float8VectorReader(v: Float8Vector)
+    extends TypedArrowVectorReader[Float8Vector](v) {
+  override def getDouble(i: Int): Double = vector.get(i)
+  override def getString(i: Int): String = String.valueOf(getDouble(i))
+  override def getJavaDecimal(i: Int): JBigDecimal = 
JBigDecimal.valueOf(getDouble(i))
+}
+
+private[arrow] class DecimalVectorReader(v: DecimalVector)
+    extends TypedArrowVectorReader[DecimalVector](v) {
+  override def getByte(i: Int): Byte = getJavaDecimal(i).byteValueExact()
+  override def getShort(i: Int): Short = getJavaDecimal(i).shortValueExact()
+  override def getInt(i: Int): Int = getJavaDecimal(i).intValueExact()
+  override def getLong(i: Int): Long = getJavaDecimal(i).longValueExact()
+  override def getFloat(i: Int): Float = getJavaDecimal(i).floatValue()
+  override def getDouble(i: Int): Double = getJavaDecimal(i).doubleValue()
+  override def getJavaDecimal(i: Int): JBigDecimal = vector.getObject(i)
+  override def getString(i: Int): String = getJavaDecimal(i).toPlainString
+}
+
+private[arrow] class VarCharVectorReader(v: VarCharVector)
+    extends TypedArrowVectorReader[VarCharVector](v) {
+  // This is currently a bit heavy on allocations:
+  // - byte array created in VarCharVector.get
+  // - CharBuffer created CharSetEncoder
+  // - char array in String
+  // By using direct buffers and reusing the char buffer
+  // we could get rid of the first two allocations.
+  override def getString(i: Int): String = Text.decode(vector.get(i))
+}
+
+private[arrow] class VarBinaryVectorReader(v: VarBinaryVector)
+    extends TypedArrowVectorReader[VarBinaryVector](v) {
+  override def getBytes(i: Int): Array[Byte] = vector.get(i)
+  override def getString(i: Int): String = 
StringUtils.getHexString(getBytes(i))
+}
+
+private[arrow] class DurationVectorReader(v: DurationVector)
+    extends TypedArrowVectorReader[DurationVector](v) {
+  override def getDuration(i: Int): Duration = vector.getObject(i)
+  override def getString(i: Int): String = {
+    IntervalUtils.toDayTimeIntervalString(
+      IntervalUtils.durationToMicros(getDuration(i)),
+      ANSI_STYLE,
+      DayTimeIntervalType.DEFAULT.startField,
+      DayTimeIntervalType.DEFAULT.endField)
+  }
+}
+
+private[arrow] class IntervalYearVectorReader(v: IntervalYearVector)
+    extends TypedArrowVectorReader[IntervalYearVector](v) {
+  override def getPeriod(i: Int): Period = vector.getObject(i).normalized()
+  override def getString(i: Int): String = {
+    IntervalUtils.toYearMonthIntervalString(
+      vector.get(i),
+      ANSI_STYLE,
+      YearMonthIntervalType.DEFAULT.startField,
+      YearMonthIntervalType.DEFAULT.endField)
+  }
+}
+
+private[arrow] class DateDayVectorReader(v: DateDayVector, timeZoneId: String)
+    extends TypedArrowVectorReader[DateDayVector](v) {
+  private val zone = getZoneId(timeZoneId)
+  private lazy val formatter = DateFormatter()
+  private def days(i: Int): Int = vector.get(i)
+  private def micros(i: Int): Long = daysToMicros(days(i), zone)
+  override def getDate(i: Int): Date = toJavaDate(days(i))
+  override def getLocalDate(i: Int): LocalDate = daysToLocalDate(days(i))
+  override def getTimestamp(i: Int): Timestamp = toJavaTimestamp(micros(i))
+  override def getInstant(i: Int): Instant = microsToInstant(micros(i))
+  override def getLocalDateTime(i: Int): LocalDateTime = 
microsToLocalDateTime(micros(i))
+  override def getString(i: Int): String = formatter.format(getLocalDate(i))
+}
+
+private[arrow] class TimeStampMicroTZVectorReader(v: TimeStampMicroTZVector)
+    extends TypedArrowVectorReader[TimeStampMicroTZVector](v) {
+  private val zone = getZoneId(v.getTimeZone)
+  private lazy val formatter = TimestampFormatter.getFractionFormatter(zone)
+  private def utcMicros(i: Int): Long = convertTz(vector.get(i), zone, 
ZoneOffset.UTC)
+  override def getLong(i: Int): Long = Math.floorDiv(vector.get(i), 
MICROS_PER_SECOND)
+  override def getTimestamp(i: Int): Timestamp = toJavaTimestamp(vector.get(i))
+  override def getInstant(i: Int): Instant = microsToInstant(vector.get(i))
+  override def getLocalDateTime(i: Int): LocalDateTime = 
microsToLocalDateTime(utcMicros(i))
+  override def getString(i: Int): String = formatter.format(vector.get(i))
+}
+
+private[arrow] class TimeStampMicroVectorReader(v: TimeStampMicroVector, 
timeZoneId: String)
+    extends TypedArrowVectorReader[TimeStampMicroVector](v) {
+  private val zone = getZoneId(timeZoneId)
+  private lazy val formatter = 
TimestampFormatter.getFractionFormatter(ZoneOffset.UTC)
+  private def tzMicros(i: Int): Long = convertTz(utcMicros(i), ZoneOffset.UTC, 
zone)
+  private def utcMicros(i: Int): Long = vector.get(i)
+  override def getTimestamp(i: Int): Timestamp = toJavaTimestamp(tzMicros(i))
+  override def getInstant(i: Int): Instant = microsToInstant(tzMicros(i))
+  override def getLocalDateTime(i: Int): LocalDateTime = 
microsToLocalDateTime(utcMicros(i))
+  override def getString(i: Int): String = formatter.format(utcMicros(i))
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 07dd2a96bd8..b69151f75be 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -570,8 +570,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
     (col("id") / lit(10.0d)).as("b"),
     col("id"),
     lit("world").as("d"),
-    // TODO SPARK-44449 make this int again when upcasting is in.
-    (col("id") % 2).cast("double").as("a"))
+    (col("id") % 2).as("a"))
 
   private def validateMyTypeResult(result: Array[MyType]): Unit = {
     result.zipWithIndex.foreach { case (MyType(id, a, b), i) =>
@@ -818,11 +817,10 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
   }
 
   test("toJSON") {
-    // TODO SPARK-44449 make this int again when upcasting is in.
     val expected = Array(
-      """{"b":0.0,"id":0,"d":"world","a":0.0}""",
-      """{"b":0.1,"id":1,"d":"world","a":1.0}""",
-      """{"b":0.2,"id":2,"d":"world","a":0.0}""")
+      """{"b":0.0,"id":0,"d":"world","a":0}""",
+      """{"b":0.1,"id":1,"d":"world","a":1}""",
+      """{"b":0.2,"id":2,"d":"world","a":0}""")
     val result = spark
       .range(3)
       .select(generateMyTypeColumns: _*)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index ab3e13da531..ad75887a7e2 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -483,9 +483,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
     val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
       .toDF("key", "seq", "value")
     val grouped = ds.groupBy($"value").as[String, (String, Int, Int)]
-    // TODO SPARK-44449 make this string again when upcasting is in.
-    val keys = grouped.keyAs[Int].keys.sort($"value")
-    checkDataset(keys, 1, 2, 10, 20)
+    val keys = grouped.keyAs[String].keys.sort($"value")
+    checkDataset(keys, "1", "2", "10", "20")
   }
 
   test("flatMapGroupsWithState") {
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index 16eec3eee31..3f8ac1cb8d1 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql.connect.client.arrow
 
 import java.math.BigInteger
+import java.time.{Duration, Period, ZoneOffset}
 import java.util
 import java.util.{Collections, Objects}
 
@@ -32,11 +33,16 @@ import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, 
JavaTypeInference, ScalaReflection}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, 
CalendarIntervalEncoder, DateEncoder, EncoderField, InstantEncoder, 
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, PrimitiveDoubleEncoder, 
PrimitiveFloatEncoder, RowEncoder, StringEncoder, TimestampEncoder, UDTEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, 
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, 
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, 
DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, 
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, 
NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, 
PrimitiveDoubleEncoder, PrimitiveFloatEncoder, Primi [...]
 import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => 
toRowEncoder}
+import org.apache.spark.sql.catalyst.util.{DateFormatter, StringUtils, 
TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
+import org.apache.spark.sql.catalyst.util.DateTimeUtils._
+import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
+import org.apache.spark.sql.catalyst.util.IntervalUtils._
 import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
 import org.apache.spark.sql.connect.client.util.ConnectFunSuite
-import org.apache.spark.sql.types.{ArrayType, DataType, Decimal, DecimalType, 
IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType}
+import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, 
Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, 
UserDefinedType, YearMonthIntervalType}
 
 /**
  * Tests for encoding external data to and from arrow.
@@ -68,13 +74,31 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
       maxBatchSize: Long = 16 * 1024,
       batchSizeCheckInterval: Int = 128,
       inspectBatch: Array[Byte] => Unit = null): CloseableIterator[T] = {
+    roundTripWithDifferentIOEncoders(
+      encoder,
+      encoder,
+      iterator,
+      maxRecordsPerBatch,
+      maxBatchSize,
+      batchSizeCheckInterval,
+      inspectBatch)
+  }
+
+  private def roundTripWithDifferentIOEncoders[I, O](
+      inputEncoder: AgnosticEncoder[I],
+      outputEncoder: AgnosticEncoder[O],
+      iterator: Iterator[I],
+      maxRecordsPerBatch: Int = 4 * 1024,
+      maxBatchSize: Long = 16 * 1024,
+      batchSizeCheckInterval: Int = 128,
+      inspectBatch: Array[Byte] => Unit = null): CloseableIterator[O] = {
     // Use different allocators so we can pinpoint memory leaks better.
     val serializerAllocator = newAllocator("serialization")
     val deserializerAllocator = newAllocator("deserialization")
 
     val arrowIterator = ArrowSerializer.serialize(
       input = iterator,
-      enc = encoder,
+      enc = inputEncoder,
       allocator = serializerAllocator,
       maxRecordsPerBatch = maxRecordsPerBatch,
       maxBatchSize = maxBatchSize,
@@ -91,8 +115,12 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     }
 
     val resultIterator =
-      ArrowDeserializers.deserializeFromArrow(inspectedIterator, encoder, 
deserializerAllocator)
-    new CloseableIterator[T] {
+      ArrowDeserializers.deserializeFromArrow(
+        inspectedIterator,
+        outputEncoder,
+        deserializerAllocator,
+        timeZoneId = "UTC")
+    new CloseableIterator[O] {
       override def close(): Unit = {
         arrowIterator.close()
         resultIterator.close()
@@ -100,7 +128,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
         deserializerAllocator.close()
       }
       override def hasNext: Boolean = resultIterator.hasNext
-      override def next(): T = resultIterator.next()
+      override def next(): O = resultIterator.next()
     }
   }
 
@@ -156,11 +184,11 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
   }
 
   private def compareIterators[T](expected: Iterator[T], actual: Iterator[T]): 
Unit = {
-    expected.zipAll(actual, null, null).foreach { case (expected, actual) =>
-      assert(expected != null)
-      assert(actual != null)
-      assert(actual == expected)
+    while (expected.hasNext && actual.hasNext) {
+      assert(expected.next() == actual.next())
     }
+    assert(!expected.hasNext, "Less results produced than expected.")
+    assert(!actual.hasNext, "More results produced than expected.")
   }
 
   private class CountingBatchInspector extends (Array[Byte] => Unit) {
@@ -216,8 +244,11 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
 
   test("deserializing empty iterator") {
     withAllocator { allocator =>
-      val iterator =
-        ArrowDeserializers.deserializeFromArrow(Iterator.empty, 
singleIntEncoder, allocator)
+      val iterator = ArrowDeserializers.deserializeFromArrow(
+        Iterator.empty,
+        singleIntEncoder,
+        allocator,
+        timeZoneId = "UTC")
       assert(iterator.isEmpty)
       assert(allocator.getAllocatedMemory == 0)
     }
@@ -674,7 +705,11 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
         Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte)))
       val arrowBatches = serializeToArrow(Iterator.single(input), 
wideSchemaEncoder, allocator)
       val result =
-        ArrowDeserializers.deserializeFromArrow(arrowBatches, 
narrowSchemaEncoder, allocator)
+        ArrowDeserializers.deserializeFromArrow(
+          arrowBatches,
+          narrowSchemaEncoder,
+          allocator,
+          timeZoneId = "UTC")
       val actual = result.next()
       assert(result.isEmpty)
       assert(expected === actual)
@@ -687,7 +722,11 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     withAllocator { allocator =>
       val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, 
allocator)
       intercept[AnalysisException] {
-        ArrowDeserializers.deserializeFromArrow(arrowBatches, 
wideSchemaEncoder, allocator)
+        ArrowDeserializers.deserializeFromArrow(
+          arrowBatches,
+          wideSchemaEncoder,
+          allocator,
+          timeZoneId = "UTC")
       }
       arrowBatches.close()
     }
@@ -704,12 +743,160 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     withAllocator { allocator =>
       val arrowBatches = serializeToArrow(Iterator.empty, 
duplicateSchemaEncoder, allocator)
       intercept[AnalysisException] {
-        ArrowDeserializers.deserializeFromArrow(arrowBatches, 
fooSchemaEncoder, allocator)
+        ArrowDeserializers.deserializeFromArrow(
+          arrowBatches,
+          fooSchemaEncoder,
+          allocator,
+          timeZoneId = "UTC")
       }
       arrowBatches.close()
     }
   }
 
+  /* ******************************************************************** *
+   * Arrow deserialization upcasting
+   * ******************************************************************** */
+  // Not supported: UDT, CalendarInterval
+  // Not tested: Char/Varchar.
+  private case class UpCastTestCase[I](input: AgnosticEncoder[I], generator: 
Int => I) {
+    def test[O](output: AgnosticEncoder[O], convert: I => O): this.type = {
+      val name = "upcast " + input.dataType.catalogString + " to " + 
output.dataType.catalogString
+      ArrowEncoderSuite.this.test(name) {
+        def data: Iterator[I] = Iterator.tabulate(5)(generator)
+        val result = roundTripWithDifferentIOEncoders(input, output, data)
+        try {
+          compareIterators(data.map(convert), result)
+        } finally {
+          result.close()
+        }
+      }
+      this
+    }
+
+    def nullTest[O](e: AgnosticEncoder[O]): this.type = {
+      test(e, _.asInstanceOf[O])
+    }
+  }
+
+  private val timestampFormatter = 
TimestampFormatter.getFractionFormatter(ZoneOffset.UTC)
+  private val dateFormatter = DateFormatter()
+
+  private def scalaDecimalEncoder(precision: Int, scale: Int = 0): 
ScalaDecimalEncoder = {
+    ScalaDecimalEncoder(DecimalType(precision, scale))
+  }
+
+  UpCastTestCase(NullEncoder, _ => null)
+    .nullTest(BoxedBooleanEncoder)
+    .nullTest(BoxedByteEncoder)
+    .nullTest(BoxedShortEncoder)
+    .nullTest(BoxedIntEncoder)
+    .nullTest(BoxedLongEncoder)
+    .nullTest(BoxedFloatEncoder)
+    .nullTest(BoxedDoubleEncoder)
+    .nullTest(StringEncoder)
+    .nullTest(DateEncoder(false))
+    .nullTest(TimestampEncoder(false))
+  UpCastTestCase(PrimitiveBooleanEncoder, _ % 2 == 0)
+    .test(StringEncoder, _.toString)
+  UpCastTestCase(PrimitiveByteEncoder, i => i.toByte)
+    .test(PrimitiveShortEncoder, _.toShort)
+    .test(PrimitiveIntEncoder, _.toInt)
+    .test(PrimitiveLongEncoder, _.toLong)
+    .test(PrimitiveFloatEncoder, _.toFloat)
+    .test(PrimitiveDoubleEncoder, _.toDouble)
+    .test(scalaDecimalEncoder(3), BigDecimal(_))
+    .test(scalaDecimalEncoder(5, 2), BigDecimal(_))
+    .test(StringEncoder, _.toString)
+  UpCastTestCase(PrimitiveShortEncoder, i => i.toShort)
+    .test(PrimitiveIntEncoder, _.toInt)
+    .test(PrimitiveLongEncoder, _.toLong)
+    .test(PrimitiveFloatEncoder, _.toFloat)
+    .test(PrimitiveDoubleEncoder, _.toDouble)
+    .test(scalaDecimalEncoder(5), BigDecimal(_))
+    .test(scalaDecimalEncoder(10, 5), BigDecimal(_))
+    .test(StringEncoder, _.toString)
+  UpCastTestCase(PrimitiveIntEncoder, i => i)
+    .test(PrimitiveLongEncoder, _.toLong)
+    .test(PrimitiveFloatEncoder, _.toFloat)
+    .test(PrimitiveDoubleEncoder, _.toDouble)
+    .test(scalaDecimalEncoder(10), BigDecimal(_))
+    .test(scalaDecimalEncoder(13, 3), BigDecimal(_))
+    .test(StringEncoder, _.toString)
+  UpCastTestCase(PrimitiveLongEncoder, i => i.toLong)
+    .test(PrimitiveFloatEncoder, _.toFloat)
+    .test(PrimitiveDoubleEncoder, _.toDouble)
+    .test(scalaDecimalEncoder(20), BigDecimal(_))
+    .test(scalaDecimalEncoder(25, 5), BigDecimal(_))
+    .test(TimestampEncoder(false), s => toJavaTimestamp(s * MICROS_PER_SECOND))
+    .test(StringEncoder, _.toString)
+  UpCastTestCase(PrimitiveFloatEncoder, i => i.toFloat)
+    .test(PrimitiveDoubleEncoder, _.toDouble)
+    .test(StringEncoder, _.toString)
+  UpCastTestCase(PrimitiveDoubleEncoder, i => i.toDouble)
+    .test(StringEncoder, _.toString)
+  UpCastTestCase(scalaDecimalEncoder(2), BigDecimal(_))
+    .test(PrimitiveByteEncoder, _.toByte)
+    .test(PrimitiveShortEncoder, _.toShort)
+    .test(PrimitiveIntEncoder, _.toInt)
+    .test(PrimitiveLongEncoder, _.toLong)
+    .test(scalaDecimalEncoder(7, 5), identity)
+    .test(StringEncoder, _.toString())
+  UpCastTestCase(scalaDecimalEncoder(4), BigDecimal(_))
+    .test(PrimitiveShortEncoder, _.toShort)
+    .test(PrimitiveIntEncoder, _.toInt)
+    .test(PrimitiveLongEncoder, _.toLong)
+    .test(scalaDecimalEncoder(10, 1), identity)
+    .test(StringEncoder, _.toString())
+  UpCastTestCase(scalaDecimalEncoder(9), BigDecimal(_))
+    .test(PrimitiveIntEncoder, _.toInt)
+    .test(PrimitiveLongEncoder, _.toLong)
+    .test(scalaDecimalEncoder(13, 4), identity)
+    .test(StringEncoder, _.toString())
+  UpCastTestCase(scalaDecimalEncoder(19), BigDecimal(_))
+    .test(PrimitiveLongEncoder, _.toLong)
+    .test(scalaDecimalEncoder(23, 1), identity)
+    .test(StringEncoder, _.toString())
+  UpCastTestCase(scalaDecimalEncoder(7, 3), BigDecimal(_))
+    .test(scalaDecimalEncoder(9, 5), identity)
+    .test(scalaDecimalEncoder(23, 3), identity)
+  UpCastTestCase(DateEncoder(false), i => toJavaDate(i))
+    .test(
+      TimestampEncoder(false),
+      date => toJavaTimestamp(daysToMicros(fromJavaDate(date), 
ZoneOffset.UTC)))
+    .test(
+      LocalDateTimeEncoder,
+      date => microsToLocalDateTime(daysToMicros(fromJavaDate(date), 
ZoneOffset.UTC)))
+    .test(StringEncoder, date => dateFormatter.format(date))
+  UpCastTestCase(TimestampEncoder(false), i => toJavaTimestamp(i))
+    .test(PrimitiveLongEncoder, ts => Math.floorDiv(fromJavaTimestamp(ts), 
MICROS_PER_SECOND))
+    .test(LocalDateTimeEncoder, ts => 
microsToLocalDateTime(fromJavaTimestamp(ts)))
+    .test(StringEncoder, ts => timestampFormatter.format(ts))
+  UpCastTestCase(LocalDateTimeEncoder, i => microsToLocalDateTime(i))
+    .test(TimestampEncoder(false), ldt => 
toJavaTimestamp(localDateTimeToMicros(ldt)))
+    .test(StringEncoder, ldt => timestampFormatter.format(ldt))
+  UpCastTestCase(DayTimeIntervalEncoder, i => Duration.ofDays(i))
+    .test(
+      StringEncoder,
+      { i =>
+        toDayTimeIntervalString(
+          durationToMicros(i),
+          ANSI_STYLE,
+          DayTimeIntervalType.DEFAULT.startField,
+          DayTimeIntervalType.DEFAULT.endField)
+      })
+  UpCastTestCase(YearMonthIntervalEncoder, i => Period.ofMonths(i))
+    .test(
+      StringEncoder,
+      { i =>
+        toYearMonthIntervalString(
+          periodToMonths(i),
+          ANSI_STYLE,
+          YearMonthIntervalType.DEFAULT.startField,
+          YearMonthIntervalType.DEFAULT.endField)
+      })
+  UpCastTestCase(BinaryEncoder, i => Array.tabulate(10)(j => (64 + j + 
i).toByte))
+    .test(StringEncoder, bytes => StringUtils.getHexString(bytes))
+
   /* ******************************************************************** *
    * Arrow serialization/deserialization specific errors
    * ******************************************************************** */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to