This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b71192cb08cb [SPARK-46586][SQL] Support `s.c.immutable.ArraySeq` as `customCollectionCls` in `MapObjects` b71192cb08cb is described below commit b71192cb08cb84b5329a63b7856442c00eb9c474 Author: panbingkun <panbing...@baidu.com> AuthorDate: Thu Jan 4 20:08:48 2024 -0800 [SPARK-46586][SQL] Support `s.c.immutable.ArraySeq` as `customCollectionCls` in `MapObjects` ### What changes were proposed in this pull request? The pr aims to support `s.c.immutable.ArraySeq` as `customCollectionCls` in `MapObjects`. ### Why are the changes needed? Because `s.c.immutable.ArraySeq` is a commonly used type in Scala 2.13, we should support it. ### Does this PR introduce _any_ user-facing change? Yes, We support `s.c.immutable.ArraySeq` in `MapObjects`. ### How was this patch tested? - Add new UT: Added a new test for `ArraySeq` in UDFSuite; Also updated `ObjectExpressionsSuite` for `MapObjects`. - Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44591 from panbingkun/SPARK-46586. Lead-authored-by: panbingkun <panbing...@baidu.com> Co-authored-by: panbingkun <pbk1...@gmail.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../sql/catalyst/expressions/objects/objects.scala | 24 +++++++++++++++++++++- .../expressions/ObjectExpressionsSuite.scala | 8 +++++++- .../test/scala/org/apache/spark/sql/UDFSuite.scala | 17 +++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index eb3568e43f70..bae2922cf921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.{Method, Modifier} +import scala.collection.immutable import scala.collection.mutable import scala.collection.mutable.Builder import scala.jdk.CollectionConverters._ @@ -938,6 +939,14 @@ case class MapObjects private( executeFuncOnCollection(input).foreach(builder += _) mutable.ArraySeq.make(builder.result()) } + case Some(cls) if classOf[immutable.ArraySeq[_]].isAssignableFrom(cls) => + implicit val tag: ClassTag[Any] = elementClassTag() + input => { + val builder = mutable.ArrayBuilder.make[Any] + builder.sizeHint(input.size) + executeFuncOnCollection(input).foreach(builder += _) + immutable.ArraySeq.unsafeWrapArray(builder.result()) + } case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) => // Scala sequence executeFuncOnCollection(_).toSeq @@ -1108,7 +1117,20 @@ case class MapObjects private( s"(${cls.getName}) ${classOf[mutable.ArraySeq[_]].getName}$$." + s"MODULE$$.make($builder.result());" ) - + case Some(cls) if classOf[immutable.ArraySeq[_]].isAssignableFrom(cls) => + val tag = ctx.addReferenceObj("tag", elementClassTag()) + val builderClassName = classOf[mutable.ArrayBuilder[_]].getName + val getBuilder = s"$builderClassName$$.MODULE$$.make($tag)" + val builder = ctx.freshName("collectionBuilder") + ( + s""" + ${classOf[Builder[_, _]].getName} $builder = $getBuilder; + $builder.sizeHint($dataLength); + """, + (genValue: String) => s"$builder.$$plus$$eq($genValue);", + s"(${cls.getName}) ${classOf[immutable.ArraySeq[_]].getName}$$." + + s"MODULE$$.unsafeWrapArray($builder.result());" + ) case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) || classOf[scala.collection.Set[_]].isAssignableFrom(cls) => // Scala sequence or set diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index a54f490dd146..538a7600b02a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import scala.collection.immutable import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -363,6 +364,9 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(result.asInstanceOf[java.util.List[_]].asScala == expected) case a if classOf[mutable.ArraySeq[Int]].isAssignableFrom(a) => assert(result == mutable.ArraySeq.make[Int](expected.toArray)) + case a if classOf[immutable.ArraySeq[Int]].isAssignableFrom(a) => + assert(result.isInstanceOf[immutable.ArraySeq[_]]) + assert(result == immutable.ArraySeq.unsafeWrapArray[Int](expected.toArray)) case s if classOf[Seq[_]].isAssignableFrom(s) => assert(result.asInstanceOf[Seq[_]] == expected) case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) => @@ -370,7 +374,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - val customCollectionClasses = Seq(classOf[mutable.ArraySeq[Int]], + val customCollectionClasses = Seq( + classOf[mutable.ArraySeq[Int]], classOf[immutable.ArraySeq[Int]], classOf[Seq[Int]], classOf[scala.collection.Set[Int]], classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]], classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]], @@ -392,6 +397,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Seq( (Seq(1, 2, 3), ObjectType(classOf[mutable.ArraySeq[Int]])), + (Seq(1, 2, 3), ObjectType(classOf[immutable.ArraySeq[Int]])), (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), (Array(1, 2, 3), ObjectType(classOf[Array[Int]])), (Seq(1, 2, 3), ObjectType(classOf[Object])), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 6a597d7ab4fc..2bd649ea85e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import java.time.{Instant, LocalDate} import java.time.format.DateTimeFormatter +import scala.collection.immutable import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -830,6 +831,22 @@ class UDFSuite extends QueryTest with SharedSparkSession { Row(ArrayBuffer(100))) } + test("SPARK-46586: UDF should not fail on immutable.ArraySeq") { + val myUdf1 = udf((a: immutable.ArraySeq[Int]) => + immutable.ArraySeq.unsafeWrapArray[Int](Array(a.head + 99))) + checkAnswer(Seq(Array(1)) + .toDF("col") + .select(myUdf1(Column("col"))), + Row(ArrayBuffer(100))) + + val myUdf2 = udf((a: immutable.ArraySeq[Int]) => + immutable.ArraySeq.unsafeWrapArray[Int](a.appended(5).appended(6).toArray)) + checkAnswer(Seq(Array(1, 2, 3)) + .toDF("col") + .select(myUdf2(Column("col"))), + Row(ArrayBuffer(1, 2, 3, 5, 6))) + } + test("SPARK-34388: UDF name is propagated with registration for ScalaUDF") { spark.udf.register("udf34388", udf((value: Int) => value > 2)) spark.sessionState.catalog.lookupFunction( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org