This is an automated email from the ASF dual-hosted git repository.
yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 729fc8ec95e0 [SPARK-46615][CONNECT] Support s.c.immutable.ArraySeq in
ArrowDeserializers
729fc8ec95e0 is described below
commit 729fc8ec95e017bd6eead283c0b660b9c57a174d
Author: panbingkun <[email protected]>
AuthorDate: Thu Feb 8 14:57:13 2024 +0800
[SPARK-46615][CONNECT] Support s.c.immutable.ArraySeq in ArrowDeserializers
### What changes were proposed in this pull request?
The pr aims to support s.c.immutable.ArraySeq as customCollectionCls in
ArrowDeserializers.
### Why are the changes needed?
Because s.c.immutable.ArraySeq is a commonly used type in Scala 2.13, we
should support it.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Update existed UT (SQLImplicitsTestSuite).
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44618 from panbingkun/SPARK-46615.
Authored-by: panbingkun <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
---
.../scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala | 11 +++++++++++
.../spark/sql/connect/client/arrow/ArrowDeserializer.scala | 9 +++++++++
.../spark/sql/connect/client/arrow/ArrowEncoderUtils.scala | 2 ++
3 files changed, 22 insertions(+)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
index b2c13850a13a..3e4704b6ab8e 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
@@ -52,6 +52,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
test("test implicit encoder resolution") {
val spark = session
+ import org.apache.spark.util.ArrayImplicits._
import spark.implicits._
def testImplicit[T: Encoder](expected: T): Unit = {
val encoder = encoderFor[T]
@@ -84,6 +85,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
testImplicit(booleans)
testImplicit(booleans.toSeq)
testImplicit(booleans.toSeq)(newBooleanSeqEncoder)
+ testImplicit(booleans.toImmutableArraySeq)
val bytes = Array(76.toByte, 59.toByte, 121.toByte)
testImplicit(bytes.head)
@@ -91,6 +93,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
testImplicit(bytes)
testImplicit(bytes.toSeq)
testImplicit(bytes.toSeq)(newByteSeqEncoder)
+ testImplicit(bytes.toImmutableArraySeq)
val shorts = Array(21.toShort, (-213).toShort, 14876.toShort)
testImplicit(shorts.head)
@@ -98,6 +101,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
testImplicit(shorts)
testImplicit(shorts.toSeq)
testImplicit(shorts.toSeq)(newShortSeqEncoder)
+ testImplicit(shorts.toImmutableArraySeq)
val ints = Array(4, 6, 5)
testImplicit(ints.head)
@@ -105,6 +109,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
testImplicit(ints)
testImplicit(ints.toSeq)
testImplicit(ints.toSeq)(newIntSeqEncoder)
+ testImplicit(ints.toImmutableArraySeq)
val longs = Array(System.nanoTime(), System.currentTimeMillis())
testImplicit(longs.head)
@@ -112,6 +117,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
testImplicit(longs)
testImplicit(longs.toSeq)
testImplicit(longs.toSeq)(newLongSeqEncoder)
+ testImplicit(longs.toImmutableArraySeq)
val floats = Array(3f, 10.9f)
testImplicit(floats.head)
@@ -119,6 +125,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
testImplicit(floats)
testImplicit(floats.toSeq)
testImplicit(floats.toSeq)(newFloatSeqEncoder)
+ testImplicit(floats.toImmutableArraySeq)
val doubles = Array(23.78d, -329.6d)
testImplicit(doubles.head)
@@ -126,22 +133,26 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with
BeforeAndAfterAll {
testImplicit(doubles)
testImplicit(doubles.toSeq)
testImplicit(doubles.toSeq)(newDoubleSeqEncoder)
+ testImplicit(doubles.toImmutableArraySeq)
val strings = Array("foo", "baz", "bar")
testImplicit(strings.head)
testImplicit(strings)
testImplicit(strings.toSeq)
testImplicit(strings.toSeq)(newStringSeqEncoder)
+ testImplicit(strings.toImmutableArraySeq)
val myTypes = Array(MyType(12L, Math.E, Math.PI), MyType(0, 0, 0))
testImplicit(myTypes.head)
testImplicit(myTypes)
testImplicit(myTypes.toSeq)
testImplicit(myTypes.toSeq)(newProductSeqEncoder[MyType])
+ testImplicit(myTypes.toImmutableArraySeq)
// Others.
val decimal = java.math.BigDecimal.valueOf(3141527000000000000L, 18)
testImplicit(decimal)
+ testImplicit(Array(decimal).toImmutableArraySeq)
testImplicit(BigDecimal(decimal))
testImplicit(Date.valueOf(LocalDate.now()))
testImplicit(LocalDate.now())
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 52461d1ebaea..ac9619487f02 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -24,6 +24,7 @@ import java.time._
import java.util
import java.util.{List => JList, Locale, Map => JMap}
+import scala.collection.immutable
import scala.collection.mutable
import scala.reflect.ClassTag
@@ -46,6 +47,7 @@ import org.apache.spark.sql.types.Decimal
*/
object ArrowDeserializers {
import ArrowEncoderUtils._
+ import org.apache.spark.util.ArrayImplicits._
/**
* Create an Iterator of `T`. This iterator takes an Iterator of Arrow IPC
Streams, and
@@ -222,6 +224,13 @@ object ArrowDeserializers {
ScalaCollectionUtils.wrap(array)
}
}
+ } else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) {
+ new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) {
+ def value(i: Int): immutable.ArraySeq[Any] = {
+ val array = getArray(vector, i, deserializer)(element.clsTag)
+ array.asInstanceOf[Array[_]].toImmutableArraySeq
+ }
+ }
} else if (isSubClass(Classes.ITERABLE, tag)) {
val companion = ScalaCollectionUtils.getIterableCompanion(tag)
new VectorFieldDeserializer[Iterable[Any], ListVector](v) {
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
index 6d1325b55d41..5b1539e39f4f 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.connect.client.arrow
+import scala.collection.immutable
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
@@ -26,6 +27,7 @@ import org.apache.arrow.vector.complex.StructVector
private[arrow] object ArrowEncoderUtils {
object Classes {
val MUTABLE_ARRAY_SEQ: Class[_] = classOf[mutable.ArraySeq[_]]
+ val IMMUTABLE_ARRAY_SEQ: Class[_] = classOf[immutable.ArraySeq[_]]
val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]]
val MAP: Class[_] = classOf[scala.collection.Map[_, _]]
val JLIST: Class[_] = classOf[java.util.List[_]]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]