This is an automated email from the ASF dual-hosted git repository.

yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 729fc8ec95e0 [SPARK-46615][CONNECT] Support s.c.immutable.ArraySeq in 
ArrowDeserializers
729fc8ec95e0 is described below

commit 729fc8ec95e017bd6eead283c0b660b9c57a174d
Author: panbingkun <[email protected]>
AuthorDate: Thu Feb 8 14:57:13 2024 +0800

    [SPARK-46615][CONNECT] Support s.c.immutable.ArraySeq in ArrowDeserializers
    
    ### What changes were proposed in this pull request?
    The pr aims to support s.c.immutable.ArraySeq as customCollectionCls in 
ArrowDeserializers.
    
    ### Why are the changes needed?
    Because s.c.immutable.ArraySeq is a commonly used type in Scala 2.13, we 
should support it.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Update existed UT (SQLImplicitsTestSuite).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #44618 from panbingkun/SPARK-46615.
    
    Authored-by: panbingkun <[email protected]>
    Signed-off-by: yangjie01 <[email protected]>
---
 .../scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala    | 11 +++++++++++
 .../spark/sql/connect/client/arrow/ArrowDeserializer.scala    |  9 +++++++++
 .../spark/sql/connect/client/arrow/ArrowEncoderUtils.scala    |  2 ++
 3 files changed, 22 insertions(+)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
index b2c13850a13a..3e4704b6ab8e 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
@@ -52,6 +52,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
 
   test("test implicit encoder resolution") {
     val spark = session
+    import org.apache.spark.util.ArrayImplicits._
     import spark.implicits._
     def testImplicit[T: Encoder](expected: T): Unit = {
       val encoder = encoderFor[T]
@@ -84,6 +85,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     testImplicit(booleans)
     testImplicit(booleans.toSeq)
     testImplicit(booleans.toSeq)(newBooleanSeqEncoder)
+    testImplicit(booleans.toImmutableArraySeq)
 
     val bytes = Array(76.toByte, 59.toByte, 121.toByte)
     testImplicit(bytes.head)
@@ -91,6 +93,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     testImplicit(bytes)
     testImplicit(bytes.toSeq)
     testImplicit(bytes.toSeq)(newByteSeqEncoder)
+    testImplicit(bytes.toImmutableArraySeq)
 
     val shorts = Array(21.toShort, (-213).toShort, 14876.toShort)
     testImplicit(shorts.head)
@@ -98,6 +101,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     testImplicit(shorts)
     testImplicit(shorts.toSeq)
     testImplicit(shorts.toSeq)(newShortSeqEncoder)
+    testImplicit(shorts.toImmutableArraySeq)
 
     val ints = Array(4, 6, 5)
     testImplicit(ints.head)
@@ -105,6 +109,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     testImplicit(ints)
     testImplicit(ints.toSeq)
     testImplicit(ints.toSeq)(newIntSeqEncoder)
+    testImplicit(ints.toImmutableArraySeq)
 
     val longs = Array(System.nanoTime(), System.currentTimeMillis())
     testImplicit(longs.head)
@@ -112,6 +117,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     testImplicit(longs)
     testImplicit(longs.toSeq)
     testImplicit(longs.toSeq)(newLongSeqEncoder)
+    testImplicit(longs.toImmutableArraySeq)
 
     val floats = Array(3f, 10.9f)
     testImplicit(floats.head)
@@ -119,6 +125,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     testImplicit(floats)
     testImplicit(floats.toSeq)
     testImplicit(floats.toSeq)(newFloatSeqEncoder)
+    testImplicit(floats.toImmutableArraySeq)
 
     val doubles = Array(23.78d, -329.6d)
     testImplicit(doubles.head)
@@ -126,22 +133,26 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     testImplicit(doubles)
     testImplicit(doubles.toSeq)
     testImplicit(doubles.toSeq)(newDoubleSeqEncoder)
+    testImplicit(doubles.toImmutableArraySeq)
 
     val strings = Array("foo", "baz", "bar")
     testImplicit(strings.head)
     testImplicit(strings)
     testImplicit(strings.toSeq)
     testImplicit(strings.toSeq)(newStringSeqEncoder)
+    testImplicit(strings.toImmutableArraySeq)
 
     val myTypes = Array(MyType(12L, Math.E, Math.PI), MyType(0, 0, 0))
     testImplicit(myTypes.head)
     testImplicit(myTypes)
     testImplicit(myTypes.toSeq)
     testImplicit(myTypes.toSeq)(newProductSeqEncoder[MyType])
+    testImplicit(myTypes.toImmutableArraySeq)
 
     // Others.
     val decimal = java.math.BigDecimal.valueOf(3141527000000000000L, 18)
     testImplicit(decimal)
+    testImplicit(Array(decimal).toImmutableArraySeq)
     testImplicit(BigDecimal(decimal))
     testImplicit(Date.valueOf(LocalDate.now()))
     testImplicit(LocalDate.now())
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 52461d1ebaea..ac9619487f02 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -24,6 +24,7 @@ import java.time._
 import java.util
 import java.util.{List => JList, Locale, Map => JMap}
 
+import scala.collection.immutable
 import scala.collection.mutable
 import scala.reflect.ClassTag
 
@@ -46,6 +47,7 @@ import org.apache.spark.sql.types.Decimal
  */
 object ArrowDeserializers {
   import ArrowEncoderUtils._
+  import org.apache.spark.util.ArrayImplicits._
 
   /**
    * Create an Iterator of `T`. This iterator takes an Iterator of Arrow IPC 
Streams, and
@@ -222,6 +224,13 @@ object ArrowDeserializers {
               ScalaCollectionUtils.wrap(array)
             }
           }
+        } else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) {
+          new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) {
+            def value(i: Int): immutable.ArraySeq[Any] = {
+              val array = getArray(vector, i, deserializer)(element.clsTag)
+              array.asInstanceOf[Array[_]].toImmutableArraySeq
+            }
+          }
         } else if (isSubClass(Classes.ITERABLE, tag)) {
           val companion = ScalaCollectionUtils.getIterableCompanion(tag)
           new VectorFieldDeserializer[Iterable[Any], ListVector](v) {
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
index 6d1325b55d41..5b1539e39f4f 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.sql.connect.client.arrow
 
+import scala.collection.immutable
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 import scala.reflect.ClassTag
@@ -26,6 +27,7 @@ import org.apache.arrow.vector.complex.StructVector
 private[arrow] object ArrowEncoderUtils {
   object Classes {
     val MUTABLE_ARRAY_SEQ: Class[_] = classOf[mutable.ArraySeq[_]]
+    val IMMUTABLE_ARRAY_SEQ: Class[_] = classOf[immutable.ArraySeq[_]]
     val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]]
     val MAP: Class[_] = classOf[scala.collection.Map[_, _]]
     val JLIST: Class[_] = classOf[java.util.List[_]]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to