This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5939b75b5fe [SPARK-44396][CONNECT] Direct Arrow Deserialization
5939b75b5fe is described below

commit 5939b75b5fe701cb63fedc64f57c9f0a15ef9202
Author: Herman van Hovell <[email protected]>
AuthorDate: Wed Jul 19 09:26:26 2023 -0400

    [SPARK-44396][CONNECT] Direct Arrow Deserialization
    
    ### What changes were proposed in this pull request?
    This PR adds direct arrow to user object deserialization to the Spark 
Connect Scala Client.
    
    ### Why are the changes needed?
    We want to decouple the scala client from catalyst. We need a way to encode 
user object from and to arrrow.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added tests to `ArrowEncoderSuite`.
    
    Closes #42011 from hvanhovell/SPARK-44396.
    
    Authored-by: Herman van Hovell <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 connector/connect/client/jvm/pom.xml               |  19 +
 .../client/arrow/ScalaCollectionUtils.scala        |  38 ++
 .../client/arrow/ScalaCollectionUtils.scala        |  37 ++
 .../spark/sql/connect/client/SparkResult.scala     | 230 +++++----
 .../connect/client/arrow/ArrowDeserializer.scala   | 533 +++++++++++++++++++++
 .../connect/client/arrow/ArrowEncoderUtils.scala   |   3 +
 .../arrow/ConcatenatingArrowStreamReader.scala     | 185 +++++++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  49 +-
 .../sql/KeyValueGroupedDatasetE2ETestSuite.scala   |  36 +-
 .../spark/sql/application/ReplE2ESuite.scala       |   6 +-
 .../connect/client/arrow/ArrowEncoderSuite.scala   | 127 +++--
 11 files changed, 1085 insertions(+), 178 deletions(-)

diff --git a/connector/connect/client/jvm/pom.xml 
b/connector/connect/client/jvm/pom.xml
index 93cc782ab13..60ed0f3ba46 100644
--- a/connector/connect/client/jvm/pom.xml
+++ b/connector/connect/client/jvm/pom.xml
@@ -140,6 +140,7 @@
     </dependency>
   </dependencies>
   <build>
+    
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
     
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
     <plugins>
       <!-- Shade all Guava / Protobuf / Netty dependencies of this build -->
@@ -224,6 +225,24 @@
           </execution>
         </executions>
       </plugin>
+      <plugin>
+        <groupId>org.codehaus.mojo</groupId>
+        <artifactId>build-helper-maven-plugin</artifactId>
+        <executions>
+          <execution>
+            <id>add-sources</id>
+            <phase>generate-sources</phase>
+            <goals>
+              <goal>add-source</goal>
+            </goals>
+            <configuration>
+              <sources>
+                <source>src/main/scala-${scala.binary.version}</source>
+              </sources>
+            </configuration>
+          </execution>
+        </executions>
+      </plugin>
     </plugins>
   </build>
 </project>
\ No newline at end of file
diff --git 
a/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala
 
b/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala
new file mode 100644
index 00000000000..c2e01d974e0
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.arrow
+
+import scala.collection.generic.{GenericCompanion, GenMapFactory}
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+import 
org.apache.spark.sql.connect.client.arrow.ArrowDeserializers.resolveCompanion
+
+/**
+ * A couple of scala version specific collection utility functions.
+ */
+private[arrow] object ScalaCollectionUtils {
+  def getIterableCompanion(tag: ClassTag[_]): GenericCompanion[Iterable] = {
+    ArrowDeserializers.resolveCompanion[GenericCompanion[Iterable]](tag)
+  }
+  def getMapCompanion(tag: ClassTag[_]): GenMapFactory[Map] = {
+    resolveCompanion[GenMapFactory[Map]](tag)
+  }
+  def wrap[T](array: AnyRef): mutable.WrappedArray[T] = {
+    mutable.WrappedArray.make(array)
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala
 
b/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala
new file mode 100644
index 00000000000..8a80e341622
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.arrow
+
+import scala.collection.{mutable, IterableFactory, MapFactory}
+import scala.reflect.ClassTag
+
+import 
org.apache.spark.sql.connect.client.arrow.ArrowDeserializers.resolveCompanion
+
+/**
+ * A couple of scala version specific collection utility functions.
+ */
+private[arrow] object ScalaCollectionUtils {
+  def getIterableCompanion(tag: ClassTag[_]): IterableFactory[Iterable] = {
+    ArrowDeserializers.resolveCompanion[IterableFactory[Iterable]](tag)
+  }
+  def getMapCompanion(tag: ClassTag[_]): MapFactory[Map] = {
+    resolveCompanion[MapFactory[Map]](tag)
+  }
+  def wrap[T](array: AnyRef): mutable.WrappedArray[T] = {
+    mutable.WrappedArray.make(array.asInstanceOf[Array[T]])
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index a727c86f70f..1cdc2035de6 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -16,53 +16,48 @@
  */
 package org.apache.spark.sql.connect.client
 
-import java.util.Collections
+import java.util.Objects
 
-import scala.collection.JavaConverters._
 import scala.collection.mutable
 
 import org.apache.arrow.memory.BufferAllocator
-import org.apache.arrow.vector.FieldVector
-import org.apache.arrow.vector.ipc.ArrowStreamReader
+import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch}
+import org.apache.arrow.vector.types.pojo
 
 import org.apache.spark.connect.proto
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, 
UnboundRowEncoder}
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
-import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable}
+import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, 
ArrowDeserializingIterator, CloseableIterator, ConcatenatingArrowStreamReader, 
MessageIterator}
+import org.apache.spark.sql.connect.client.util.Cleanable
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
 import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.sql.util.ArrowUtils
-import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
 
 private[sql] class SparkResult[T](
     responses: java.util.Iterator[proto.ExecutePlanResponse],
     allocator: BufferAllocator,
     encoder: AgnosticEncoder[T])
     extends AutoCloseable
-    with Cleanable {
+    with Cleanable { self =>
 
   private[this] var numRecords: Int = 0
   private[this] var structType: StructType = _
-  private[this] var boundEncoder: ExpressionEncoder[T] = _
-  private[this] var nextBatchIndex: Int = 0
-  private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch]
-
-  private def createEncoder(schema: StructType): ExpressionEncoder[T] = {
-    val agnosticEncoder = createEncoder(encoder, 
schema).asInstanceOf[AgnosticEncoder[T]]
-    ExpressionEncoder(agnosticEncoder)
-  }
+  private[this] var arrowSchema: pojo.Schema = _
+  private[this] var nextResultIndex: Int = 0
+  private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])]
 
   /**
    * Update RowEncoder and recursively update the fields of the ProductEncoder 
if found.
    */
-  private def createEncoder(enc: AgnosticEncoder[_], dataType: DataType): 
AgnosticEncoder[_] = {
+  private def createEncoder[E](
+      enc: AgnosticEncoder[E],
+      dataType: DataType): AgnosticEncoder[E] = {
     enc match {
       case UnboundRowEncoder =>
         // Replace the row encoder with the encoder inferred from the schema.
-        RowEncoder.encoderFor(dataType.asInstanceOf[StructType])
+        RowEncoder
+          .encoderFor(dataType.asInstanceOf[StructType])
+          .asInstanceOf[AgnosticEncoder[E]]
       case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) =>
         // Recursively continue updating the tuple product encoder
         val schema = dataType.asInstanceOf[StructType]
@@ -76,53 +71,61 @@ private[sql] class SparkResult[T](
     }
   }
 
-  private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean 
= {
-    while (responses.hasNext) {
+  private def processResponses(
+      stopOnSchema: Boolean = false,
+      stopOnArrowSchema: Boolean = false,
+      stopOnFirstNonEmptyResponse: Boolean = false): Boolean = {
+    var nonEmpty = false
+    var stop = false
+    while (!stop && responses.hasNext) {
       val response = responses.next()
       if (response.hasSchema) {
         // The original schema should arrive before ArrowBatches.
         structType =
           
DataTypeProtoConverter.toCatalystType(response.getSchema).asInstanceOf[StructType]
-      } else if (response.hasArrowBatch) {
+        stop |= stopOnSchema
+      }
+      if (response.hasArrowBatch) {
         val ipcStreamBytes = response.getArrowBatch.getData
-        val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), 
allocator)
-        try {
-          val root = reader.getVectorSchemaRoot
-          if (structType == null) {
-            // If the schema is not available yet, fallback to the schema from 
Arrow.
-            structType = ArrowUtils.fromArrowSchema(root.getSchema)
-          }
-          // TODO: create encoders that directly operate on arrow vectors.
-          if (boundEncoder == null) {
-            boundEncoder = createEncoder(structType)
-              .resolveAndBind(DataTypeUtils.toAttributes(structType))
-          }
-          while (reader.loadNextBatch()) {
-            val rowCount = root.getRowCount
-            if (rowCount > 0) {
-              val vectors = root.getFieldVectors.asScala
-                .map(v => new ArrowColumnVector(transferToNewVector(v)))
-                .toArray[ColumnVector]
-              idxToBatches.put(nextBatchIndex, new ColumnarBatch(vectors, 
rowCount))
-              nextBatchIndex += 1
-              numRecords += rowCount
-              if (stopOnFirstNonEmptyResponse) {
-                return true
-              }
-            }
+        val reader = new MessageIterator(ipcStreamBytes.newInput(), allocator)
+        if (arrowSchema == null) {
+          arrowSchema = reader.schema
+          stop |= stopOnArrowSchema
+        } else if (arrowSchema != reader.schema) {
+          throw new IllegalStateException(
+            s"""Schema Mismatch between expected and received schema:
+               |=== Expected Schema ===
+               |$arrowSchema
+               |=== Received Schema ===
+               |${reader.schema}
+               |""".stripMargin)
+        }
+        if (structType == null) {
+          // If the schema is not available yet, fallback to the arrow schema.
+          structType = ArrowUtils.fromArrowSchema(reader.schema)
+        }
+        var numRecordsInBatch = 0
+        val messages = Seq.newBuilder[ArrowMessage]
+        while (reader.hasNext) {
+          val message = reader.next()
+          message match {
+            case batch: ArrowRecordBatch =>
+              numRecordsInBatch += batch.getLength
+            case _ =>
           }
-        } finally {
-          reader.close()
+          messages += message
+        }
+        // Skip the entire result if it is empty.
+        if (numRecordsInBatch > 0) {
+          numRecords += numRecordsInBatch
+          resultMap.put(nextResultIndex, (reader.bytesRead, messages.result()))
+          nextResultIndex += 1
+          nonEmpty |= true
+          stop |= stopOnFirstNonEmptyResponse
         }
       }
     }
-    false
-  }
-
-  private def transferToNewVector(in: FieldVector): FieldVector = {
-    val pair = in.getTransferPair(allocator)
-    pair.transfer()
-    pair.getTo.asInstanceOf[FieldVector]
+    nonEmpty
   }
 
   /**
@@ -130,7 +133,7 @@ private[sql] class SparkResult[T](
    */
   def length: Int = {
     // We need to process all responses to make sure numRecords is correct.
-    processResponses(stopOnFirstNonEmptyResponse = false)
+    processResponses()
     numRecords
   }
 
@@ -139,7 +142,9 @@ private[sql] class SparkResult[T](
    *   the schema of the result.
    */
   def schema: StructType = {
-    processResponses(stopOnFirstNonEmptyResponse = true)
+    if (structType == null) {
+      processResponses(stopOnSchema = true)
+    }
     structType
   }
 
@@ -172,52 +177,93 @@ private[sql] class SparkResult[T](
 
   private def buildIterator(destructive: Boolean): java.util.Iterator[T] with 
AutoCloseable = {
     new java.util.Iterator[T] with AutoCloseable {
-      private[this] var batchIndex: Int = -1
-      private[this] var iterator: java.util.Iterator[InternalRow] = 
Collections.emptyIterator()
-      private[this] var deserializer: Deserializer[T] = _
+      private[this] var iterator: CloseableIterator[T] = _
 
-      override def hasNext: Boolean = {
-        if (iterator.hasNext) {
-          return true
-        }
-
-        val nextBatchIndex = batchIndex + 1
-        if (destructive) {
-          idxToBatches.remove(batchIndex).foreach(_.close())
+      private def initialize(): Unit = {
+        if (iterator == null) {
+          iterator = new ArrowDeserializingIterator(
+            createEncoder(encoder, schema),
+            new ConcatenatingArrowStreamReader(
+              allocator,
+              Iterator.single(new ResultMessageIterator(destructive)),
+              destructive))
         }
+      }
 
-        val hasNextBatch = if (!idxToBatches.contains(nextBatchIndex)) {
-          processResponses(stopOnFirstNonEmptyResponse = true)
-        } else {
-          true
-        }
-        if (hasNextBatch) {
-          batchIndex = nextBatchIndex
-          iterator = idxToBatches(nextBatchIndex).rowIterator()
-          if (deserializer == null) {
-            deserializer = boundEncoder.createDeserializer()
-          }
-        }
-        hasNextBatch
+      override def hasNext: Boolean = {
+        initialize()
+        iterator.hasNext
       }
 
       override def next(): T = {
-        if (!hasNext) {
-          throw new NoSuchElementException
-        }
-        deserializer(iterator.next())
+        initialize()
+        iterator.next()
       }
 
-      override def close(): Unit = SparkResult.this.close()
+      override def close(): Unit = {
+        if (iterator != null) {
+          iterator.close()
+        }
+      }
     }
   }
 
   /**
    * Close this result, freeing any underlying resources.
    */
-  override def close(): Unit = {
-    idxToBatches.values.foreach(_.close())
+  override def close(): Unit = cleaner.close()
+
+  override val cleaner: AutoCloseable = new SparkResultCloseable(resultMap)
+
+  private class ResultMessageIterator(destructive: Boolean) extends 
AbstractMessageIterator {
+    private[this] var totalBytesRead = 0L
+    private[this] var nextResultIndex = 0
+    private[this] var current: Iterator[ArrowMessage] = Iterator.empty
+
+    override def bytesRead: Long = totalBytesRead
+
+    override def schema: pojo.Schema = {
+      if (arrowSchema == null) {
+        // We need a schema to proceed. Spark Connect will always
+        // return a result (with a schema) even if the result is empty.
+        processResponses(stopOnArrowSchema = true)
+        Objects.requireNonNull(arrowSchema)
+      }
+      arrowSchema
+    }
+
+    override def hasNext: Boolean = {
+      if (current.hasNext) {
+        return true
+      }
+      val hasNextResult = if (!resultMap.contains(nextResultIndex)) {
+        self.processResponses(stopOnFirstNonEmptyResponse = true)
+      } else {
+        true
+      }
+      if (hasNextResult) {
+        val Some((sizeInBytes, messages)) = if (destructive) {
+          resultMap.remove(nextResultIndex)
+        } else {
+          resultMap.get(nextResultIndex)
+        }
+        totalBytesRead += sizeInBytes
+        current = messages.iterator
+        nextResultIndex += 1
+      }
+      hasNextResult
+    }
+
+    override def next(): ArrowMessage = {
+      if (!hasNext) {
+        throw new NoSuchElementException()
+      }
+      current.next()
+    }
   }
+}
 
-  override def cleaner: AutoCloseable = 
AutoCloseables(idxToBatches.values.toSeq)
+private[client] class SparkResultCloseable(resultMap: mutable.Map[Int, (Long, 
Seq[ArrowMessage])])
+    extends AutoCloseable {
+  override def close(): Unit = 
resultMap.values.foreach(_._2.foreach(_.close()))
 }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
new file mode 100644
index 00000000000..154866d699a
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -0,0 +1,533 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.arrow
+
+import java.io.{ByteArrayInputStream, IOException}
+import java.lang.invoke.{MethodHandles, MethodType}
+import java.lang.reflect.Modifier
+import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger}
+import java.time._
+import java.util
+import java.util.{List => JList, Locale, Map => JMap}
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+import org.apache.arrow.memory.BufferAllocator
+import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, 
DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, 
IntervalYearVector, IntVector, NullVector, SmallIntVector, 
TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, 
VarCharVector, VectorSchemaRoot}
+import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
+import org.apache.arrow.vector.ipc.ArrowReader
+import org.apache.arrow.vector.util.Text
+
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.types.Decimal
+
+/**
+ * Helper class for converting arrow batches into user objects.
+ */
+object ArrowDeserializers {
+  import ArrowEncoderUtils._
+
+  /**
+   * Create an Iterator of `T`. This iterator takes an Iterator of Arrow IPC 
Streams, and
+   * deserializes these streams into one or more instances of `T`
+   */
+  def deserializeFromArrow[T](
+      input: Iterator[Array[Byte]],
+      encoder: AgnosticEncoder[T],
+      allocator: BufferAllocator): CloseableIterator[T] = {
+    try {
+      val reader = new ConcatenatingArrowStreamReader(
+        allocator,
+        input.map(bytes => new MessageIterator(new 
ByteArrayInputStream(bytes), allocator)),
+        destructive = true)
+      new ArrowDeserializingIterator(encoder, reader)
+    } catch {
+      case _: IOException =>
+        new EmptyDeserializingIterator(encoder)
+    }
+  }
+
+  /**
+   * Create a deserializer of `T` on top of the given `root`.
+   */
+  private[arrow] def deserializerFor[T](
+      encoder: AgnosticEncoder[T],
+      root: VectorSchemaRoot): Deserializer[T] = {
+    val data: AnyRef = if (encoder.isStruct) {
+      root
+    } else {
+      // The input schema is allowed to have multiple columns,
+      // by convention we bind to the first one.
+      root.getVector(0)
+    }
+    deserializerFor(encoder, data).asInstanceOf[Deserializer[T]]
+  }
+
+  private[arrow] def deserializerFor(
+      encoder: AgnosticEncoder[_],
+      data: AnyRef): Deserializer[Any] = {
+    (encoder, data) match {
+      case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) =>
+        new FieldDeserializer[Boolean, BitVector](v) {
+          def value(i: Int): Boolean = vector.get(i) != 0
+        }
+      case (PrimitiveByteEncoder | BoxedByteEncoder, v: TinyIntVector) =>
+        new FieldDeserializer[Byte, TinyIntVector](v) {
+          def value(i: Int): Byte = vector.get(i)
+        }
+      case (PrimitiveShortEncoder | BoxedShortEncoder, v: SmallIntVector) =>
+        new FieldDeserializer[Short, SmallIntVector](v) {
+          def value(i: Int): Short = vector.get(i)
+        }
+      case (PrimitiveIntEncoder | BoxedIntEncoder, v: IntVector) =>
+        new FieldDeserializer[Int, IntVector](v) {
+          def value(i: Int): Int = vector.get(i)
+        }
+      case (PrimitiveLongEncoder | BoxedLongEncoder, v: BigIntVector) =>
+        new FieldDeserializer[Long, BigIntVector](v) {
+          def value(i: Int): Long = vector.get(i)
+        }
+      case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: Float4Vector) =>
+        new FieldDeserializer[Float, Float4Vector](v) {
+          def value(i: Int): Float = vector.get(i)
+        }
+      case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: Float8Vector) =>
+        new FieldDeserializer[Double, Float8Vector](v) {
+          def value(i: Int): Double = vector.get(i)
+        }
+      case (NullEncoder, v: NullVector) =>
+        new FieldDeserializer[Any, NullVector](v) {
+          def value(i: Int): Any = null
+        }
+      case (StringEncoder, v: VarCharVector) =>
+        new FieldDeserializer[String, VarCharVector](v) {
+          def value(i: Int): String = getString(vector, i)
+        }
+      case (JavaEnumEncoder(tag), v: VarCharVector) =>
+        // It would be nice if we can get Enum.valueOf working...
+        val valueOf = methodLookup.findStatic(
+          tag.runtimeClass,
+          "valueOf",
+          MethodType.methodType(tag.runtimeClass, classOf[String]))
+        new FieldDeserializer[Enum[_], VarCharVector](v) {
+          def value(i: Int): Enum[_] = {
+            valueOf.invoke(getString(vector, i)).asInstanceOf[Enum[_]]
+          }
+        }
+      case (ScalaEnumEncoder(parent, _), v: VarCharVector) =>
+        val mirror = scala.reflect.runtime.currentMirror
+        val module = mirror.classSymbol(parent).module.asModule
+        val enumeration = 
mirror.reflectModule(module).instance.asInstanceOf[Enumeration]
+        new FieldDeserializer[Enumeration#Value, VarCharVector](v) {
+          def value(i: Int): Enumeration#Value = 
enumeration.withName(getString(vector, i))
+        }
+      case (BinaryEncoder, v: VarBinaryVector) =>
+        new FieldDeserializer[Array[Byte], VarBinaryVector](v) {
+          def value(i: Int): Array[Byte] = vector.get(i)
+        }
+      case (SparkDecimalEncoder(_), v: DecimalVector) =>
+        new FieldDeserializer[Decimal, DecimalVector](v) {
+          def value(i: Int): Decimal = Decimal(vector.getObject(i))
+        }
+      case (ScalaDecimalEncoder(_), v: DecimalVector) =>
+        new FieldDeserializer[BigDecimal, DecimalVector](v) {
+          def value(i: Int): BigDecimal = BigDecimal(vector.getObject(i))
+        }
+      case (JavaDecimalEncoder(_, _), v: DecimalVector) =>
+        new FieldDeserializer[JBigDecimal, DecimalVector](v) {
+          def value(i: Int): JBigDecimal = vector.getObject(i)
+        }
+      case (ScalaBigIntEncoder, v: DecimalVector) =>
+        new FieldDeserializer[BigInt, DecimalVector](v) {
+          def value(i: Int): BigInt = new 
BigInt(vector.getObject(i).toBigInteger)
+        }
+      case (JavaBigIntEncoder, v: DecimalVector) =>
+        new FieldDeserializer[JBigInteger, DecimalVector](v) {
+          def value(i: Int): JBigInteger = vector.getObject(i).toBigInteger
+        }
+      case (DayTimeIntervalEncoder, v: DurationVector) =>
+        new FieldDeserializer[Duration, DurationVector](v) {
+          def value(i: Int): Duration = vector.getObject(i)
+        }
+      case (YearMonthIntervalEncoder, v: IntervalYearVector) =>
+        new FieldDeserializer[Period, IntervalYearVector](v) {
+          def value(i: Int): Period = vector.getObject(i).normalized()
+        }
+      case (DateEncoder(_), v: DateDayVector) =>
+        new FieldDeserializer[java.sql.Date, DateDayVector](v) {
+          def value(i: Int): java.sql.Date = 
DateTimeUtils.toJavaDate(vector.get(i))
+        }
+      case (LocalDateEncoder(_), v: DateDayVector) =>
+        new FieldDeserializer[LocalDate, DateDayVector](v) {
+          def value(i: Int): LocalDate = 
DateTimeUtils.daysToLocalDate(vector.get(i))
+        }
+      case (TimestampEncoder(_), v: TimeStampMicroTZVector) =>
+        new FieldDeserializer[java.sql.Timestamp, TimeStampMicroTZVector](v) {
+          def value(i: Int): java.sql.Timestamp = 
DateTimeUtils.toJavaTimestamp(vector.get(i))
+        }
+      case (InstantEncoder(_), v: TimeStampMicroTZVector) =>
+        new FieldDeserializer[Instant, TimeStampMicroTZVector](v) {
+          def value(i: Int): Instant = 
DateTimeUtils.microsToInstant(vector.get(i))
+        }
+      case (LocalDateTimeEncoder, v: TimeStampMicroVector) =>
+        new FieldDeserializer[LocalDateTime, TimeStampMicroVector](v) {
+          def value(i: Int): LocalDateTime = 
DateTimeUtils.microsToLocalDateTime(vector.get(i))
+        }
+
+      case (OptionEncoder(value), v) =>
+        val deserializer = deserializerFor(value, v)
+        new Deserializer[Any] {
+          override def get(i: Int): Any = Option(deserializer.get(i))
+        }
+
+      case (ArrayEncoder(element, _), v: ListVector) =>
+        val deserializer = deserializerFor(element, v.getDataVector)
+        new FieldDeserializer[AnyRef, ListVector](v) {
+          def value(i: Int): AnyRef = getArray(vector, i, 
deserializer)(element.clsTag)
+        }
+
+      case (IterableEncoder(tag, element, _, _), v: ListVector) =>
+        val deserializer = deserializerFor(element, v.getDataVector)
+        if (isSubClass(Classes.WRAPPED_ARRAY, tag)) {
+          // Wrapped array is a bit special because we need to use an array of 
the element type.
+          // Some parts of our codebase (unfortunately) rely on this for type 
inference on results.
+          new FieldDeserializer[mutable.WrappedArray[Any], ListVector](v) {
+            def value(i: Int): mutable.WrappedArray[Any] = {
+              val array = getArray(vector, i, deserializer)(element.clsTag)
+              ScalaCollectionUtils.wrap(array)
+            }
+          }
+        } else if (isSubClass(Classes.ITERABLE, tag)) {
+          val companion = ScalaCollectionUtils.getIterableCompanion(tag)
+          new FieldDeserializer[Iterable[Any], ListVector](v) {
+            def value(i: Int): Iterable[Any] = {
+              val builder = companion.newBuilder[Any]
+              loadListIntoBuilder(vector, i, deserializer, builder)
+              builder.result()
+            }
+          }
+        } else if (isSubClass(Classes.JLIST, tag)) {
+          val newInstance = resolveJavaListCreator(tag)
+          new FieldDeserializer[JList[Any], ListVector](v) {
+            def value(i: Int): JList[Any] = {
+              var index = v.getElementStartIndex(i)
+              val end = v.getElementEndIndex(i)
+              val list = newInstance(end - index)
+              while (index < end) {
+                list.add(deserializer.get(index))
+                index += 1
+              }
+              list
+            }
+          }
+        } else {
+          throw unsupportedCollectionType(tag.runtimeClass)
+        }
+
+      case (MapEncoder(tag, key, value, _), v: MapVector) =>
+        val structVector = v.getDataVector.asInstanceOf[StructVector]
+        val keyDeserializer = deserializerFor(key, 
structVector.getChild(MapVector.KEY_NAME))
+        val valueDeserializer =
+          deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME))
+        if (isSubClass(Classes.MAP, tag)) {
+          val companion = ScalaCollectionUtils.getMapCompanion(tag)
+          new FieldDeserializer[Map[Any, Any], MapVector](v) {
+            def value(i: Int): Map[Any, Any] = {
+              val builder = companion.newBuilder[Any, Any]
+              var index = v.getElementStartIndex(i)
+              val end = v.getElementEndIndex(i)
+              builder.sizeHint(end - index)
+              while (index < end) {
+                builder += (keyDeserializer.get(index) -> 
valueDeserializer.get(index))
+                index += 1
+              }
+              builder.result()
+            }
+          }
+        } else if (isSubClass(Classes.JMAP, tag)) {
+          val newInstance = resolveJavaMapCreator(tag)
+          new FieldDeserializer[JMap[Any, Any], MapVector](v) {
+            def value(i: Int): JMap[Any, Any] = {
+              val map = newInstance()
+              var index = v.getElementStartIndex(i)
+              val end = v.getElementEndIndex(i)
+              while (index < end) {
+                map.put(keyDeserializer.get(index), 
valueDeserializer.get(index))
+                index += 1
+              }
+              map
+            }
+          }
+        } else {
+          throw unsupportedCollectionType(tag.runtimeClass)
+        }
+
+      case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) =>
+        // We should try to make this work with MethodHandles.
+        val Some(constructor) =
+          ScalaReflection.findConstructor(tag.runtimeClass, 
fields.map(_.enc.clsTag.runtimeClass))
+        val deserializers = if (isTuple(tag.runtimeClass)) {
+          fields.zip(vectors).map { case (field, vector) =>
+            deserializerFor(field.enc, vector)
+          }
+        } else {
+          val lookup = createFieldLookup(vectors)
+          fields.map { field =>
+            deserializerFor(field.enc, lookup(field.name))
+          }
+        }
+        new StructFieldSerializer[Any](struct) {
+          def value(i: Int): Any = {
+            constructor(deserializers.map(_.get(i).asInstanceOf[AnyRef]))
+          }
+        }
+
+      case (r @ RowEncoder(fields), StructVectors(struct, vectors)) =>
+        val lookup = createFieldLookup(vectors)
+        val deserializers = fields.toArray.map { field =>
+          deserializerFor(field.enc, lookup(field.name))
+        }
+        new StructFieldSerializer[Any](struct) {
+          def value(i: Int): Any = {
+            val values = deserializers.map(_.get(i))
+            new GenericRowWithSchema(values, r.schema)
+          }
+        }
+
+      case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) =>
+        val constructor =
+          methodLookup.findConstructor(tag.runtimeClass, 
MethodType.methodType(classOf[Unit]))
+        val lookup = createFieldLookup(vectors)
+        val setters = fields.map { field =>
+          val vector = lookup(field.name)
+          val deserializer = deserializerFor(field.enc, vector)
+          val setter = methodLookup.findVirtual(
+            tag.runtimeClass,
+            field.writeMethod.get,
+            MethodType.methodType(classOf[Unit], 
field.enc.clsTag.runtimeClass))
+          (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i))
+        }
+        new StructFieldSerializer[Any](struct) {
+          def value(i: Int): Any = {
+            val instance = constructor.invoke()
+            setters.foreach(_(instance, i))
+            instance
+          }
+        }
+
+      case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
+        throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType)
+
+      case _ =>
+        throw new RuntimeException(
+          s"Unsupported Encoder($encoder)/Vector(${data.getClass}) 
combination.")
+    }
+  }
+
+  private val methodLookup = MethodHandles.lookup()
+
+  /**
+   * Resolve the companion object for a scala class. In our particular case 
the class we pass in
+   * is a Scala collection. We use the companion to create a builder for that 
collection.
+   */
+  private[arrow] def resolveCompanion[T](tag: ClassTag[_]): T = {
+    val mirror = scala.reflect.runtime.currentMirror
+    val module = mirror.classSymbol(tag.runtimeClass).companion.asModule
+    mirror.reflectModule(module).instance.asInstanceOf[T]
+  }
+
+  /**
+   * Create a function that creates a [[util.List]] instance. The int 
parameter of the creator
+   * function is a size hint.
+   *
+   * If the [[ClassTag]] `tag` points to an interface instead of a concrete 
class we try to use
+   * [[util.ArrayList]]. For concrete classes we try to use a constructor that 
takes a single
+   * [[Int]] argument, it is assumed this is a size hint. If no such 
constructor exists we
+   * fallback to a no-args constructor.
+   */
+  private def resolveJavaListCreator(tag: ClassTag[_]): Int => JList[Any] = {
+    val cls = tag.runtimeClass
+    val modifiers = cls.getModifiers
+    if (Modifier.isInterface(modifiers) || Modifier.isAbstract(modifiers)) {
+      // Abstract class or interface; we try to use ArrayList.
+      if (!cls.isAssignableFrom(classOf[util.ArrayList[_]])) {
+        unsupportedCollectionType(cls)
+      }
+      (size: Int) => new util.ArrayList[Any](size)
+    } else {
+      try {
+        // Try to use a constructor that (hopefully) takes a size argument.
+        val ctor = methodLookup.findConstructor(
+          tag.runtimeClass,
+          MethodType.methodType(classOf[Unit], Integer.TYPE))
+        size => ctor.invoke(size).asInstanceOf[JList[Any]]
+      } catch {
+        case _: java.lang.NoSuchMethodException =>
+          // Use a no-args constructor.
+          val ctor =
+            methodLookup.findConstructor(tag.runtimeClass, 
MethodType.methodType(classOf[Unit]))
+          _ => ctor.invoke().asInstanceOf[JList[Any]]
+      }
+    }
+  }
+
+  /**
+   * Create a function that creates a [[util.Map]] instance.
+   *
+   * If the [[ClassTag]] `tag` points to an interface instead of a concrete 
class we try to use
+   * [[util.HashMap]]. For concrete classes we try to use a no-args 
constructor.
+   */
+  private def resolveJavaMapCreator(tag: ClassTag[_]): () => JMap[Any, Any] = {
+    val cls = tag.runtimeClass
+    val modifiers = cls.getModifiers
+    if (Modifier.isInterface(modifiers) || Modifier.isAbstract(modifiers)) {
+      // Abstract class or interface; we try to use HashMap.
+      if (!cls.isAssignableFrom(classOf[java.util.HashMap[_, _]])) {
+        unsupportedCollectionType(cls)
+      }
+      () => new util.HashMap[Any, Any]()
+    } else {
+      // Use a no-args constructor.
+      val ctor =
+        methodLookup.findConstructor(tag.runtimeClass, 
MethodType.methodType(classOf[Unit]))
+      () => ctor.invoke().asInstanceOf[JMap[Any, Any]]
+    }
+  }
+
+  /**
+   * Create a function that can lookup one [[FieldVector vectors]] in `fields` 
by name. This
+   * lookup is case insensitive. If the schema contains fields with duplicate 
(with
+   * case-insensitive resolution) names an exception is thrown. The returned 
function will throw
+   * an exception when no column can be found for a name.
+   *
+   * A small note on the binding process in general. Over complete schemas are 
currently allowed,
+   * meaning that the data can have more column than the encoder. In this the 
over complete
+   * (unbound) columns are ignored.
+   */
+  private def createFieldLookup(fields: Seq[FieldVector]): String => 
FieldVector = {
+    def toKey(k: String): String = k.toLowerCase(Locale.ROOT)
+    val lookup = mutable.Map.empty[String, FieldVector]
+    fields.foreach { field =>
+      val key = toKey(field.getName)
+      val old = lookup.put(key, field)
+      if (old.isDefined) {
+        throw QueryCompilationErrors.ambiguousColumnOrFieldError(
+          field.getName :: Nil,
+          fields.count(f => toKey(f.getName) == key))
+      }
+    }
+    name => {
+      lookup.getOrElse(toKey(name), throw 
QueryCompilationErrors.columnNotFoundError(name))
+    }
+  }
+
+  private def isTuple(cls: Class[_]): Boolean = 
cls.getName.startsWith("scala.Tuple")
+
+  private def getString(v: VarCharVector, i: Int): String = {
+    // This is currently a bit heavy on allocations:
+    // - byte array created in VarCharVector.get
+    // - CharBuffer created CharSetEncoder
+    // - char array in String
+    // By using direct buffers and reusing the char buffer
+    // we could get rid of the first two allocations.
+    Text.decode(v.get(i))
+  }
+
+  private def loadListIntoBuilder(
+      v: ListVector,
+      i: Int,
+      deserializer: Deserializer[Any],
+      builder: mutable.Builder[Any, _]): Unit = {
+    var index = v.getElementStartIndex(i)
+    val end = v.getElementEndIndex(i)
+    builder.sizeHint(end - index)
+    while (index < end) {
+      builder += deserializer.get(index)
+      index += 1
+    }
+  }
+
+  private def getArray(v: ListVector, i: Int, deserializer: 
Deserializer[Any])(implicit
+      tag: ClassTag[Any]): AnyRef = {
+    val builder = mutable.ArrayBuilder.make[Any]
+    loadListIntoBuilder(v, i, deserializer, builder)
+    builder.result()
+  }
+
+  abstract class Deserializer[+E] {
+    def get(i: Int): E
+  }
+
+  abstract class FieldDeserializer[E, V <: FieldVector](val vector: V) extends 
Deserializer[E] {
+    def value(i: Int): E
+    def isNull(i: Int): Boolean = vector.isNull(i)
+    override def get(i: Int): E = {
+      if (!isNull(i)) {
+        value(i)
+      } else {
+        null.asInstanceOf[E]
+      }
+    }
+  }
+
+  abstract class StructFieldSerializer[E](v: StructVector)
+      extends FieldDeserializer[E, StructVector](v) {
+    override def isNull(i: Int): Boolean = vector != null && vector.isNull(i)
+  }
+}
+
+class EmptyDeserializingIterator[E](val encoder: AgnosticEncoder[E])
+    extends CloseableIterator[E] {
+  override def close(): Unit = ()
+  override def hasNext: Boolean = false
+  override def next(): E = throw new NoSuchElementException()
+}
+
+class ArrowDeserializingIterator[E](
+    val encoder: AgnosticEncoder[E],
+    private[this] val reader: ArrowReader)
+    extends CloseableIterator[E] {
+  private[this] var index = 0
+  private[this] val root = reader.getVectorSchemaRoot
+  private[this] val deserializer = ArrowDeserializers.deserializerFor(encoder, 
root)
+
+  override def hasNext: Boolean = {
+    if (index >= root.getRowCount) {
+      if (reader.loadNextBatch()) {
+        index = 0
+      }
+    }
+    index < root.getRowCount
+  }
+
+  override def next(): E = {
+    if (!hasNext) {
+      throw new NoSuchElementException()
+    }
+    val result = deserializer.get(index)
+    index += 1
+    result
+  }
+
+  override def close(): Unit = reader.close()
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
index f6b140bae55..ed273369854 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
@@ -24,8 +24,11 @@ import org.apache.arrow.vector.complex.StructVector
 
 private[arrow] object ArrowEncoderUtils {
   object Classes {
+    val WRAPPED_ARRAY: Class[_] = 
classOf[scala.collection.mutable.WrappedArray[_]]
     val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]]
+    val MAP: Class[_] = classOf[scala.collection.Map[_, _]]
     val JLIST: Class[_] = classOf[java.util.List[_]]
+    val JMAP: Class[_] = classOf[java.util.Map[_, _]]
   }
 
   def isSubClass(cls: Class[_], tag: ClassTag[_]): Boolean = {
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
new file mode 100644
index 00000000000..90963c831c2
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.arrow
+
+import java.io.{InputStream, IOException}
+import java.nio.channels.Channels
+
+import org.apache.arrow.flatbuf.MessageHeader
+import org.apache.arrow.memory.{ArrowBuf, BufferAllocator}
+import org.apache.arrow.vector.ipc.{ArrowReader, ReadChannel}
+import org.apache.arrow.vector.ipc.message.{ArrowDictionaryBatch, 
ArrowMessage, ArrowRecordBatch, MessageChannelReader, MessageResult, 
MessageSerializer}
+import org.apache.arrow.vector.types.pojo.Schema
+
+/**
+ * An [[ArrowReader]] that concatenates multiple [[MessageIterator]]s into a 
single stream. Each
+ * iterator represents a single IPC stream. The concatenated streams all must 
have the same
+ * schema. If the schema is different an exception is thrown.
+ *
+ * In some cases we want to retain the messages (see `SparkResult`). Normally 
a stream reader
+ * closes its messages when it consumes them. In order to prevent that from 
happening in
+ * non-destructive mode we clone the messages before passing them to the 
reading logic.
+ */
+class ConcatenatingArrowStreamReader(
+    allocator: BufferAllocator,
+    input: Iterator[AbstractMessageIterator],
+    destructive: Boolean)
+    extends ArrowReader(allocator) {
+
+  private[this] var totalBytesRead: Long = 0
+  private[this] var current: AbstractMessageIterator = _
+
+  override protected def readSchema(): Schema = {
+    // readSchema() should only be called once during initialization.
+    assert(current == null)
+    if (!input.hasNext) {
+      // ArrowStreamReader throws the same exception.
+      throw new IOException("Unexpected end of input. Missing schema.")
+    }
+    current = input.next()
+    current.schema
+  }
+
+  private def nextMessage(): ArrowMessage = {
+    // readSchema() should have been invoked at this point so 'current' should 
be initialized.
+    assert(current != null)
+    // Try to find a non-empty message iterator.
+    while (!current.hasNext && input.hasNext) {
+      totalBytesRead += current.bytesRead
+      current = input.next()
+      if (current.schema != getVectorSchemaRoot.getSchema) {
+        throw new IllegalStateException()
+      }
+    }
+    if (current.hasNext) {
+      current.next()
+    } else {
+      null
+    }
+  }
+
+  override def loadNextBatch(): Boolean = {
+    // Keep looping until we load a non-empty batch or until we exhaust the 
input.
+    var message = nextMessage()
+    while (message != null) {
+      message match {
+        case rb: ArrowRecordBatch =>
+          loadRecordBatch(cloneIfNonDestructive(rb))
+          if (getVectorSchemaRoot.getRowCount > 0) {
+            return true
+          }
+        case db: ArrowDictionaryBatch =>
+          loadDictionary(cloneIfNonDestructive(db))
+      }
+      message = nextMessage()
+    }
+    false
+  }
+
+  private def cloneIfNonDestructive(batch: ArrowRecordBatch): ArrowRecordBatch 
= {
+    if (destructive) {
+      return batch
+    }
+    cloneRecordBatch(batch)
+  }
+
+  private def cloneIfNonDestructive(batch: ArrowDictionaryBatch): 
ArrowDictionaryBatch = {
+    if (destructive) {
+      return batch
+    }
+    new ArrowDictionaryBatch(
+      batch.getDictionaryId,
+      cloneRecordBatch(batch.getDictionary),
+      batch.isDelta)
+  }
+
+  private def cloneRecordBatch(batch: ArrowRecordBatch): ArrowRecordBatch = {
+    new ArrowRecordBatch(
+      batch.getLength,
+      batch.getNodes,
+      batch.getBuffers,
+      batch.getBodyCompression,
+      true,
+      true)
+  }
+
+  override def bytesRead(): Long = {
+    if (current != null) {
+      totalBytesRead + current.bytesRead
+    } else {
+      0
+    }
+  }
+
+  override def closeReadSource(): Unit = ()
+}
+
+trait AbstractMessageIterator extends Iterator[ArrowMessage] {
+  def schema: Schema
+  def bytesRead: Long
+}
+
+/**
+ * Decode an Arrow IPC stream into individual messages. Please note that this 
iterator MUST have a
+ * valid IPC stream as its input, otherwise construction will fail.
+ */
+class MessageIterator(input: InputStream, allocator: BufferAllocator)
+    extends AbstractMessageIterator {
+  private[this] val in = new ReadChannel(Channels.newChannel(input))
+  private[this] val reader = new MessageChannelReader(in, allocator)
+  private[this] var result: MessageResult = _
+
+  // Eagerly read the schema.
+  val schema: Schema = {
+    val result = reader.readNext()
+    if (result == null) {
+      throw new IOException("Unexpected end of input. Missing schema.")
+    }
+    MessageSerializer.deserializeSchema(result.getMessage)
+  }
+
+  override def bytesRead: Long = reader.bytesRead()
+
+  override def hasNext: Boolean = {
+    if (result == null) {
+      result = reader.readNext()
+    }
+    result != null
+  }
+
+  override def next(): ArrowMessage = {
+    if (!hasNext) {
+      throw new NoSuchElementException()
+    }
+    val message = result.getMessage.headerType() match {
+      case MessageHeader.RecordBatch =>
+        MessageSerializer.deserializeRecordBatch(result.getMessage, 
bodyBuffer(result))
+      case MessageHeader.DictionaryBatch =>
+        MessageSerializer.deserializeDictionaryBatch(result.getMessage, 
bodyBuffer(result))
+    }
+    result = null
+    message
+  }
+
+  private def bodyBuffer(result: MessageResult): ArrowBuf = {
+    var buffer = result.getBodyBuffer
+    if (buffer == null) {
+      buffer = allocator.getEmpty
+    }
+    buffer
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 73c04389c05..07dd2a96bd8 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -39,7 +39,6 @@ import 
org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.vectorized.ColumnarBatch
 
 class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with 
PrivateMethodTester {
 
@@ -571,7 +570,8 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
     (col("id") / lit(10.0d)).as("b"),
     col("id"),
     lit("world").as("d"),
-    (col("id") % 2).cast("int").as("a"))
+    // TODO SPARK-44449 make this int again when upcasting is in.
+    (col("id") % 2).cast("double").as("a"))
 
   private def validateMyTypeResult(result: Array[MyType]): Unit = {
     result.zipWithIndex.foreach { case (MyType(id, a, b), i) =>
@@ -818,10 +818,11 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
   }
 
   test("toJSON") {
+    // TODO SPARK-44449 make this int again when upcasting is in.
     val expected = Array(
-      """{"b":0.0,"id":0,"d":"world","a":0}""",
-      """{"b":0.1,"id":1,"d":"world","a":1}""",
-      """{"b":0.2,"id":2,"d":"world","a":0}""")
+      """{"b":0.0,"id":0,"d":"world","a":0.0}""",
+      """{"b":0.1,"id":1,"d":"world","a":1.0}""",
+      """{"b":0.2,"id":2,"d":"world","a":0.0}""")
     val result = spark
       .range(3)
       .select(generateMyTypeColumns: _*)
@@ -893,14 +894,12 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
 
   test("Dataset result destructive iterator") {
     // Helper methods for accessing private field `idxToBatches` from 
SparkResult
-    val _idxToBatches =
-      PrivateMethod[mutable.Map[Int, ColumnarBatch]](Symbol("idxToBatches"))
+    val getResultMap =
+      PrivateMethod[mutable.Map[Int, Any]](Symbol("resultMap"))
 
-    def getColumnarBatches(result: SparkResult[_]): Seq[ColumnarBatch] = {
-      val idxToBatches = result invokePrivate _idxToBatches()
-
-      // Sort by key to get stable results.
-      idxToBatches.toSeq.sortBy(_._1).map(_._2)
+    def assertResultsMapEmpty(result: SparkResult[_]): Unit = {
+      val resultMap = result invokePrivate getResultMap()
+      assert(resultMap.isEmpty)
     }
 
     val df = spark
@@ -911,25 +910,19 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
       try {
         // build and verify the destructive iterator
         val iterator = result.destructiveIterator
-        // batches is empty before traversing the result iterator
-        assert(getColumnarBatches(result).isEmpty)
-        var previousBatch: ColumnarBatch = null
-        val buffer = mutable.Buffer.empty[Long]
+        // resultMap Map is empty before traversing the result iterator
+        assertResultsMapEmpty(result)
+        val buffer = mutable.Set.empty[Long]
         while (iterator.hasNext) {
-          // always having 1 batch, since a columnar batch will be removed and 
closed after
-          // its data got consumed.
-          val batches = getColumnarBatches(result)
-          assert(batches.size === 1)
-          assert(batches.head != previousBatch)
-          previousBatch = batches.head
-
-          buffer.append(iterator.next())
+          // resultMap is empty during iteration because results get removed 
immediately on access.
+          assertResultsMapEmpty(result)
+          buffer += iterator.next()
         }
-        // Batches should be closed and removed after traversing all the 
records.
-        assert(getColumnarBatches(result).isEmpty)
+        // resultMap Map is empty afterward because all results have been 
removed.
+        assertResultsMapEmpty(result)
 
-        val expectedResult = Seq(6L, 7L, 8L)
-        assert(buffer.size === 3 && expectedResult.forall(buffer.contains))
+        val expectedResult = Set(6L, 7L, 8L)
+        assert(buffer.size === 3 && expectedResult == buffer)
       } finally {
         result.close()
       }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index e15069f2d9e..ab3e13da531 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -68,10 +68,11 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
   }
 
   test("keyAs - keys") {
+    // TODO SPARK-44449 make this long again when upcasting is in.
     // It is okay to cast from Long to Double, but not Long to Int.
     val values = spark
       .range(10)
-      .groupByKey(v => v % 2)
+      .groupByKey(v => (v % 2).toDouble)
       .keyAs[Double]
       .keys
       .collectAsList()
@@ -232,9 +233,10 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
   }
 
   test("agg, keyAs") {
+    // TODO SPARK-44449 make this long again when upcasting is in.
     val ds = spark
       .range(10)
-      .groupByKey(v => v % 2)
+      .groupByKey(v => (v % 2).toDouble)
       .keyAs[Double]
       .agg(count("*"))
 
@@ -244,7 +246,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
   test("typed aggregation: expr") {
     val session: SparkSession = spark
     import session.implicits._
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    // TODO SPARK-44449 make this int again when upcasting is in.
+    val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 
1L)).toDS()
 
     checkDatasetUnorderly(
       ds.groupByKey(_._1).agg(sum("_2").as[Long]),
@@ -254,7 +257,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
   }
 
   test("typed aggregation: expr, expr") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    // TODO SPARK-44449 make this int again when upcasting is in.
+    val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 
1L)).toDS()
 
     checkDatasetUnorderly(
       ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
@@ -264,7 +268,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
   }
 
   test("typed aggregation: expr, expr, expr") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    // TODO SPARK-44449 make this int again when upcasting is in.
+    val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 
1L)).toDS()
 
     checkDatasetUnorderly(
       ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], 
count("*")),
@@ -274,7 +279,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
   }
 
   test("typed aggregation: expr, expr, expr, expr") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    // TODO SPARK-44449 make this int again when upcasting is in.
+    val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 
1L)).toDS()
 
     checkDatasetUnorderly(
       ds.groupByKey(_._1)
@@ -289,7 +295,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
   }
 
   test("typed aggregation: expr, expr, expr, expr, expr") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    // TODO SPARK-44449 make this int again when upcasting is in.
+    val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 
1L)).toDS()
 
     checkDatasetUnorderly(
       ds.groupByKey(_._1)
@@ -305,7 +312,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
   }
 
   test("typed aggregation: expr, expr, expr, expr, expr, expr") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    // TODO SPARK-44449 make this int again when upcasting is in.
+    val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 
1L)).toDS()
 
     checkDatasetUnorderly(
       ds.groupByKey(_._1)
@@ -322,7 +330,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
   }
 
   test("typed aggregation: expr, expr, expr, expr, expr, expr, expr") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    // TODO SPARK-44449 make this int again when upcasting is in.
+    val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 
1L)).toDS()
 
     checkDatasetUnorderly(
       ds.groupByKey(_._1)
@@ -340,7 +349,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
   }
 
   test("typed aggregation: expr, expr, expr, expr, expr, expr, expr, expr") {
-    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    // TODO SPARK-44449 make this int again when upcasting is in.
+    val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 
1L)).toDS()
 
     checkDatasetUnorderly(
       ds.groupByKey(_._1)
@@ -473,9 +483,9 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with SQLHelper {
     val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
       .toDF("key", "seq", "value")
     val grouped = ds.groupBy($"value").as[String, (String, Int, Int)]
-    val keys = grouped.keyAs[String].keys.sort($"value")
-
-    checkDataset(keys, "1", "2", "10", "20")
+    // TODO SPARK-44449 make this string again when upcasting is in.
+    val keys = grouped.keyAs[Int].keys.sort($"value")
+    checkDataset(keys, 1, 2, 10, 20)
   }
 
   test("flatMapGroupsWithState") {
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 58758a13840..800ce43a60d 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -208,8 +208,9 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
   }
 
   test("UDF Registration") {
+    // TODO SPARK-44449 make this long again when upcasting is in.
     val input = """
-        |class A(x: Int) { def get = x * 100 }
+        |class A(x: Int) { def get: Long = x * 100 }
         |val myUdf = udf((x: Int) => new A(x).get)
         |spark.udf.register("dummyUdf", myUdf)
         |spark.sql("select dummyUdf(id) from range(5)").as[Long].collect()
@@ -219,8 +220,9 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
   }
 
   test("UDF closure registration") {
+    // TODO SPARK-44449 make this int again when upcasting is in.
     val input = """
-        |class A(x: Int) { def get = x * 15 }
+        |class A(x: Int) { def get: Long = x * 15 }
         |spark.udf.register("directUdf", (x: Int) => new A(x).get)
         |spark.sql("select directUdf(id) from range(5)").as[Long].collect()
       """.stripMargin
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index 0c327484e47..16eec3eee31 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -21,24 +21,19 @@ import java.util
 import java.util.{Collections, Objects}
 
 import scala.beans.BeanProperty
-import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.reflect.classTag
-import scala.util.control.NonFatal
 
-import com.google.protobuf.ByteString
 import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
 import org.apache.arrow.vector.VarBinaryVector
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.SparkUnsupportedOperationException
-import org.apache.spark.connect.proto
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, 
JavaTypeInference, ScalaReflection}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, 
CalendarIntervalEncoder, DateEncoder, EncoderField, InstantEncoder, 
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, PrimitiveDoubleEncoder, 
PrimitiveFloatEncoder, RowEncoder, StringEncoder, TimestampEncoder, UDTEncoder}
 import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => 
toRowEncoder}
-import org.apache.spark.sql.connect.client.SparkResult
 import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
 import org.apache.spark.sql.connect.client.util.ConnectFunSuite
 import org.apache.spark.sql.types.{ArrayType, DataType, Decimal, DecimalType, 
IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType}
@@ -96,15 +91,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     }
 
     val resultIterator =
-      try {
-        deserializeFromArrow(inspectedIterator, encoder, deserializerAllocator)
-      } catch {
-        case NonFatal(e) =>
-          arrowIterator.close()
-          serializerAllocator.close()
-          deserializerAllocator.close()
-          throw e
-      }
+      ArrowDeserializers.deserializeFromArrow(inspectedIterator, encoder, 
deserializerAllocator)
     new CloseableIterator[T] {
       override def close(): Unit = {
         arrowIterator.close()
@@ -117,25 +104,6 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     }
   }
 
-  // Temporary hack until we merge the deserializer.
-  private def deserializeFromArrow[E](
-      batches: Iterator[Array[Byte]],
-      encoder: AgnosticEncoder[E],
-      allocator: BufferAllocator): CloseableIterator[E] = {
-    val responses = batches.map { batch =>
-      val builder = proto.ExecutePlanResponse.newBuilder()
-      builder.getArrowBatchBuilder.setData(ByteString.copyFrom(batch))
-      builder.build()
-    }
-    val result = new SparkResult[E](responses.asJava, allocator, encoder)
-    new CloseableIterator[E] {
-      private val itr = result.iterator
-      override def close(): Unit = itr.close()
-      override def hasNext: Boolean = itr.hasNext
-      override def next(): E = itr.next()
-    }
-  }
-
   private def roundTripAndCheck[T](
       encoder: AgnosticEncoder[T],
       toInputIterator: () => Iterator[Any],
@@ -246,6 +214,15 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     assert(inspector.sizeInBytes > 0)
   }
 
+  test("deserializing empty iterator") {
+    withAllocator { allocator =>
+      val iterator =
+        ArrowDeserializers.deserializeFromArrow(Iterator.empty, 
singleIntEncoder, allocator)
+      assert(iterator.isEmpty)
+      assert(allocator.getAllocatedMemory == 0)
+    }
+  }
+
   test("single batch") {
     val inspector = new CountingBatchInspector
     roundTripAndCheckIdentical(singleIntEncoder, inspectBatch = inspector) { 
() =>
@@ -533,15 +510,22 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
       val maybeNull = MaybeNull(11)
       Iterator.tabulate(100) { i =>
         val bean = new JavaMapData
-        bean.setDummyToDoubleListMap(maybeNull {
-          val map = new util.HashMap[DummyBean, 
java.util.List[java.lang.Double]]
-          (0 until (i % 5)).foreach { j =>
-            val dummy = new DummyBean
-            dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j)))
+        bean.setMetricMap(maybeNull {
+          val map = new util.HashMap[String, util.List[java.lang.Double]]
+          (0 until (i % 20)).foreach { i =>
             val values = Array.tabulate(i % 40) { j =>
               Double.box(j.toDouble)
             }
-            map.put(dummy, maybeNull(util.Arrays.asList(values: _*)))
+            map.put("k" + i, maybeNull(util.Arrays.asList(values: _*)))
+          }
+          map
+        })
+        bean.setDummyToStringMap(maybeNull {
+          val map = new util.HashMap[DummyBean, String]
+          (0 until (i % 5)).foreach { j =>
+            val dummy = new DummyBean
+            dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j)))
+            map.put(dummy, maybeNull("s" + i + "v" + j))
           }
           map
         })
@@ -675,6 +659,57 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
           .add("Ca", "array<int>")
           .add("Cb", "binary")))
 
+  test("bind to schema") {
+    // Binds to a wider schema. The narrow schema has fewer (nested) fields, 
has a slightly
+    // different field order, and uses different cased names in a couple of 
places.
+    withAllocator { allocator =>
+      val input = Row(
+        887,
+        "foo",
+        Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte), 5f),
+        Seq(Row(null, "a", false), Row(javaBigDecimal(57853, 10), "b", false)))
+      val expected = Row(
+        "foo",
+        Seq(Row(null, false), Row(javaBigDecimal(57853, 10), false)),
+        Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte)))
+      val arrowBatches = serializeToArrow(Iterator.single(input), 
wideSchemaEncoder, allocator)
+      val result =
+        ArrowDeserializers.deserializeFromArrow(arrowBatches, 
narrowSchemaEncoder, allocator)
+      val actual = result.next()
+      assert(result.isEmpty)
+      assert(expected === actual)
+      result.close()
+      arrowBatches.close()
+    }
+  }
+
+  test("unknown field") {
+    withAllocator { allocator =>
+      val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, 
allocator)
+      intercept[AnalysisException] {
+        ArrowDeserializers.deserializeFromArrow(arrowBatches, 
wideSchemaEncoder, allocator)
+      }
+      arrowBatches.close()
+    }
+  }
+
+  test("duplicate fields") {
+    val duplicateSchemaEncoder = toRowEncoder(
+      new StructType()
+        .add("foO", "string")
+        .add("Foo", "string"))
+    val fooSchemaEncoder = toRowEncoder(
+      new StructType()
+        .add("foo", "string"))
+    withAllocator { allocator =>
+      val arrowBatches = serializeToArrow(Iterator.empty, 
duplicateSchemaEncoder, allocator)
+      intercept[AnalysisException] {
+        ArrowDeserializers.deserializeFromArrow(arrowBatches, 
fooSchemaEncoder, allocator)
+      }
+      arrowBatches.close()
+    }
+  }
+
   /* ******************************************************************** *
    * Arrow serialization/deserialization specific errors
    * ******************************************************************** */
@@ -833,17 +868,23 @@ case class MapData(intStringMap: Map[Int, String], 
metricMap: Map[String, Array[
 
 class JavaMapData {
   @scala.beans.BeanProperty
-  var dummyToDoubleListMap: java.util.Map[DummyBean, 
java.util.List[java.lang.Double]] = _
+  var dummyToStringMap: java.util.Map[DummyBean, String] = _
+
+  @scala.beans.BeanProperty
+  var metricMap: java.util.HashMap[String, java.util.List[java.lang.Double]] = 
_
 
   def canEqual(other: Any): Boolean = other.isInstanceOf[JavaMapData]
 
   override def equals(other: Any): Boolean = other match {
     case that: JavaMapData if that canEqual this =>
-      dummyToDoubleListMap == that.dummyToDoubleListMap
+      dummyToStringMap == that.dummyToStringMap &&
+      metricMap == that.metricMap
     case _ => false
   }
 
-  override def hashCode(): Int = Objects.hashCode(dummyToDoubleListMap)
+  override def hashCode(): Int = {
+    java.util.Arrays.deepHashCode(Array(dummyToStringMap, metricMap))
+  }
 }
 
 class DummyBean {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to