This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b71192cb08cb [SPARK-46586][SQL] Support `s.c.immutable.ArraySeq` as 
`customCollectionCls` in `MapObjects`
b71192cb08cb is described below

commit b71192cb08cb84b5329a63b7856442c00eb9c474
Author: panbingkun <panbing...@baidu.com>
AuthorDate: Thu Jan 4 20:08:48 2024 -0800

    [SPARK-46586][SQL] Support `s.c.immutable.ArraySeq` as 
`customCollectionCls` in `MapObjects`
    
    ### What changes were proposed in this pull request?
    The pr aims to support `s.c.immutable.ArraySeq` as `customCollectionCls` in 
`MapObjects`.
    
    ### Why are the changes needed?
    Because `s.c.immutable.ArraySeq` is a commonly used type in Scala 2.13, we 
should support it.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, We support `s.c.immutable.ArraySeq` in `MapObjects`.
    
    ### How was this patch tested?
    - Add new UT: Added a new test for `ArraySeq` in UDFSuite; Also updated 
`ObjectExpressionsSuite` for `MapObjects`.
    - Pass GA.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #44591 from panbingkun/SPARK-46586.
    
    Lead-authored-by: panbingkun <panbing...@baidu.com>
    Co-authored-by: panbingkun <pbk1...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/catalyst/expressions/objects/objects.scala | 24 +++++++++++++++++++++-
 .../expressions/ObjectExpressionsSuite.scala       |  8 +++++++-
 .../test/scala/org/apache/spark/sql/UDFSuite.scala | 17 +++++++++++++++
 3 files changed, 47 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index eb3568e43f70..bae2922cf921 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
 
 import java.lang.reflect.{Method, Modifier}
 
+import scala.collection.immutable
 import scala.collection.mutable
 import scala.collection.mutable.Builder
 import scala.jdk.CollectionConverters._
@@ -938,6 +939,14 @@ case class MapObjects private(
         executeFuncOnCollection(input).foreach(builder += _)
         mutable.ArraySeq.make(builder.result())
       }
+    case Some(cls) if classOf[immutable.ArraySeq[_]].isAssignableFrom(cls) =>
+      implicit val tag: ClassTag[Any] = elementClassTag()
+      input => {
+        val builder = mutable.ArrayBuilder.make[Any]
+        builder.sizeHint(input.size)
+        executeFuncOnCollection(input).foreach(builder += _)
+        immutable.ArraySeq.unsafeWrapArray(builder.result())
+      }
     case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) =>
       // Scala sequence
       executeFuncOnCollection(_).toSeq
@@ -1108,7 +1117,20 @@ case class MapObjects private(
             s"(${cls.getName}) ${classOf[mutable.ArraySeq[_]].getName}$$." +
               s"MODULE$$.make($builder.result());"
           )
-
+        case Some(cls) if classOf[immutable.ArraySeq[_]].isAssignableFrom(cls) 
=>
+          val tag = ctx.addReferenceObj("tag", elementClassTag())
+          val builderClassName = classOf[mutable.ArrayBuilder[_]].getName
+          val getBuilder = s"$builderClassName$$.MODULE$$.make($tag)"
+          val builder = ctx.freshName("collectionBuilder")
+          (
+            s"""
+               ${classOf[Builder[_, _]].getName} $builder = $getBuilder;
+               $builder.sizeHint($dataLength);
+             """,
+            (genValue: String) => s"$builder.$$plus$$eq($genValue);",
+            s"(${cls.getName}) ${classOf[immutable.ArraySeq[_]].getName}$$." +
+              s"MODULE$$.unsafeWrapArray($builder.result());"
+          )
         case Some(cls) if 
classOf[scala.collection.Seq[_]].isAssignableFrom(cls) ||
           classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
           // Scala sequence or set
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index a54f490dd146..538a7600b02a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.sql.{Date, Timestamp}
 
+import scala.collection.immutable
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 import scala.reflect.ClassTag
@@ -363,6 +364,9 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
           assert(result.asInstanceOf[java.util.List[_]].asScala == expected)
         case a if classOf[mutable.ArraySeq[Int]].isAssignableFrom(a) =>
           assert(result == mutable.ArraySeq.make[Int](expected.toArray))
+        case a if classOf[immutable.ArraySeq[Int]].isAssignableFrom(a) =>
+          assert(result.isInstanceOf[immutable.ArraySeq[_]])
+          assert(result == 
immutable.ArraySeq.unsafeWrapArray[Int](expected.toArray))
         case s if classOf[Seq[_]].isAssignableFrom(s) =>
           assert(result.asInstanceOf[Seq[_]] == expected)
         case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) =>
@@ -370,7 +374,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       }
     }
 
-    val customCollectionClasses = Seq(classOf[mutable.ArraySeq[Int]],
+    val customCollectionClasses = Seq(
+      classOf[mutable.ArraySeq[Int]], classOf[immutable.ArraySeq[Int]],
       classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
       classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
       classOf[java.util.AbstractSequentialList[Int]], 
classOf[java.util.Vector[Int]],
@@ -392,6 +397,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
 
     Seq(
       (Seq(1, 2, 3), ObjectType(classOf[mutable.ArraySeq[Int]])),
+      (Seq(1, 2, 3), ObjectType(classOf[immutable.ArraySeq[Int]])),
       (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])),
       (Array(1, 2, 3), ObjectType(classOf[Array[Int]])),
       (Seq(1, 2, 3), ObjectType(classOf[Object])),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 6a597d7ab4fc..2bd649ea85e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -22,6 +22,7 @@ import java.sql.Timestamp
 import java.time.{Instant, LocalDate}
 import java.time.format.DateTimeFormatter
 
+import scala.collection.immutable
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
@@ -830,6 +831,22 @@ class UDFSuite extends QueryTest with SharedSparkSession {
       Row(ArrayBuffer(100)))
   }
 
+  test("SPARK-46586: UDF should not fail on immutable.ArraySeq") {
+    val myUdf1 = udf((a: immutable.ArraySeq[Int]) =>
+      immutable.ArraySeq.unsafeWrapArray[Int](Array(a.head + 99)))
+    checkAnswer(Seq(Array(1))
+      .toDF("col")
+      .select(myUdf1(Column("col"))),
+    Row(ArrayBuffer(100)))
+
+     val myUdf2 = udf((a: immutable.ArraySeq[Int]) =>
+      
immutable.ArraySeq.unsafeWrapArray[Int](a.appended(5).appended(6).toArray))
+    checkAnswer(Seq(Array(1, 2, 3))
+      .toDF("col")
+      .select(myUdf2(Column("col"))),
+    Row(ArrayBuffer(1, 2, 3, 5, 6)))
+  }
+
   test("SPARK-34388: UDF name is propagated with registration for ScalaUDF") {
     spark.udf.register("udf34388", udf((value: Int) => value > 2))
     spark.sessionState.catalog.lookupFunction(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to