This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 30ec6e358536 [SPARK-45742][CORE][CONNECT][MLLIB][PYTHON] Introduce an 
implicit function for Scala Array to wrap into `immutable.ArraySeq`
30ec6e358536 is described below

commit 30ec6e358536dfb695fcc1b8c3f084acb576d871
Author: yangjie01 <[email protected]>
AuthorDate: Wed Nov 1 21:08:04 2023 -0700

    [SPARK-45742][CORE][CONNECT][MLLIB][PYTHON] Introduce an implicit function 
for Scala Array to wrap into `immutable.ArraySeq`
    
    ### What changes were proposed in this pull request?
    Currently, we need to use `immutable.ArraySeq.unsafeWrapArray(array)` to 
wrap an Array into an `immutable.ArraySeq`, which makes the code look bloated.
    
    So this PR introduces an implicit function `toImmutableArraySeq` to make it 
easier for Scala Array to be wrapped into `immutable.ArraySeq`.
    
    After this pr, we can use the following way to wrap an array into an 
`immutable.ArraySeq`:
    
    ```scala
    import org.apache.spark.util.ArrayImplicits._
    
    val dataArray = ...
    val immutableArraySeq = dataArray.toImmutableArraySeq
    ```
    
    At the same time, this pr replaces the existing use of 
`immutable.ArraySeq.unsafeWrapArray(array)` with the new method.
    
    On the other hand, this implicit function will be conducive to the progress 
of work SPARK-45686 and SPARK-45687.
    
    ### Why are the changes needed?
    Makes the code for wrapping a Scala Array into an `immutable.ArraySeq` look 
less bloated.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Pass GitHub Actions
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43607 from LuciferYang/SPARK-45742.
    
    Authored-by: yangjie01 <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../org/apache/spark/util/ArrayImplicits.scala     | 36 ++++++++++++++++
 .../scala/org/apache/spark/sql/SparkSession.scala  |  4 +-
 .../connect/client/GrpcExceptionConverter.scala    |  4 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 27 ++++++------
 .../spark/sql/connect/utils/ErrorUtils.scala       | 32 +++++++-------
 .../apache/spark/util/ArrayImplicitsSuite.scala    | 50 ++++++++++++++++++++++
 .../api/python/GaussianMixtureModelWrapper.scala   |  4 +-
 .../spark/mllib/api/python/LDAModelWrapper.scala   |  8 ++--
 8 files changed, 126 insertions(+), 39 deletions(-)

diff --git 
a/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala 
b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala
new file mode 100644
index 000000000000..08997a800c95
--- /dev/null
+++ b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.collection.immutable
+
+/**
+ * Implicit methods related to Scala Array.
+ */
+private[spark] object ArrayImplicits {
+
+  implicit class SparkArrayOps[T](xs: Array[T]) {
+
+    /**
+     * Wraps an Array[T] as an immutable.ArraySeq[T] without copying.
+     */
+    def toImmutableArraySeq: immutable.ArraySeq[T] =
+      if (xs eq null) null
+      else immutable.ArraySeq.unsafeWrapArray(xs)
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 1cc1c8400fa8..34756f9a440b 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -21,7 +21,6 @@ import java.net.URI
 import java.util.concurrent.TimeUnit._
 import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
 
-import scala.collection.immutable
 import scala.jdk.CollectionConverters._
 import scala.reflect.runtime.universe.TypeTag
 
@@ -45,6 +44,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf}
 import org.apache.spark.sql.streaming.DataStreamReader
 import org.apache.spark.sql.streaming.StreamingQueryManager
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ArrayImplicits._
 
 /**
  * The entry point to programming Spark with the Dataset and DataFrame API.
@@ -248,7 +248,7 @@ class SparkSession private[sql] (
         proto.SqlCommand
           .newBuilder()
           .setSql(sqlText)
-          
.addAllPosArguments(immutable.ArraySeq.unsafeWrapArray(args.map(lit(_).expr)).asJava)))
+          
.addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava)))
     val plan = proto.Plan.newBuilder().setCommand(cmd)
     // .toBuffer forces that the iterator is consumed and closed
     val responseSeq = client.execute(plan.build()).toBuffer.toSeq
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
index 3e53722caeb0..652797bc2e40 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.client
 
 import java.time.DateTimeException
 
-import scala.collection.immutable
 import scala.jdk.CollectionConverters._
 import scala.reflect.ClassTag
 
@@ -37,6 +36,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
 import org.apache.spark.sql.catalyst.parser.ParseException
 import org.apache.spark.sql.catalyst.trees.Origin
 import org.apache.spark.sql.streaming.StreamingQueryException
+import org.apache.spark.util.ArrayImplicits._
 
 /**
  * GrpcExceptionConverter handles the conversion of StatusRuntimeExceptions 
into Spark exceptions.
@@ -375,7 +375,7 @@ private[client] object GrpcExceptionConverter {
         FetchErrorDetailsResponse.Error
           .newBuilder()
           .setMessage(message)
-          
.addAllErrorTypeHierarchy(immutable.ArraySeq.unsafeWrapArray(classes).asJava)
+          .addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava)
           .build()))
   }
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ec57909ad144..018e293795e9 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.connect.planner
 
-import scala.collection.immutable
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 import scala.util.Try
@@ -80,6 +79,7 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout, 
OutputMode, StreamingQ
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 import org.apache.spark.storage.CacheId
+import org.apache.spark.util.ArrayImplicits._
 import org.apache.spark.util.Utils
 
 final case class InvalidCommandInput(
@@ -3184,9 +3184,9 @@ class SparkConnectPlanner(
       case StreamingQueryManagerCommand.CommandCase.ACTIVE =>
         val active_queries = session.streams.active
         respBuilder.getActiveBuilder.addAllActiveQueries(
-          immutable.ArraySeq
-            .unsafeWrapArray(active_queries
-              .map(query => buildStreamingQueryInstance(query)))
+          active_queries
+            .map(query => buildStreamingQueryInstance(query))
+            .toImmutableArraySeq
             .asJava)
 
       case StreamingQueryManagerCommand.CommandCase.GET_QUERY =>
@@ -3265,15 +3265,16 @@ class SparkConnectPlanner(
         .setGetResourcesCommandResult(
           proto.GetResourcesCommandResult
             .newBuilder()
-            .putAllResources(session.sparkContext.resources.view
-              .mapValues(resource =>
-                proto.ResourceInformation
-                  .newBuilder()
-                  .setName(resource.name)
-                  
.addAllAddresses(immutable.ArraySeq.unsafeWrapArray(resource.addresses).asJava)
-                  .build())
-              .toMap
-              .asJava)
+            .putAllResources(
+              session.sparkContext.resources.view
+                .mapValues(resource =>
+                  proto.ResourceInformation
+                    .newBuilder()
+                    .setName(resource.name)
+                    
.addAllAddresses(resource.addresses.toImmutableArraySeq.asJava)
+                    .build())
+                .toMap
+                .asJava)
             .build())
         .build())
   }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 837ee5a00227..744fa3c8aa1a 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.connect.utils
 import java.util.UUID
 
 import scala.annotation.tailrec
-import scala.collection.immutable
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
@@ -43,6 +42,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connect.config.Connect
 import org.apache.spark.sql.connect.service.{ExecuteEventsManager, 
SessionHolder, SessionKey, SparkConnectService}
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.ArrayImplicits._
 
 private[connect] object ErrorUtils extends Logging {
 
@@ -91,21 +91,21 @@ private[connect] object ErrorUtils extends Logging {
 
       if (serverStackTraceEnabled) {
         builder.addAllStackTrace(
-          immutable.ArraySeq
-            .unsafeWrapArray(currentError.getStackTrace
-              .map { stackTraceElement =>
-                val stackTraceBuilder = 
FetchErrorDetailsResponse.StackTraceElement
-                  .newBuilder()
-                  .setDeclaringClass(stackTraceElement.getClassName)
-                  .setMethodName(stackTraceElement.getMethodName)
-                  .setLineNumber(stackTraceElement.getLineNumber)
-
-                if (stackTraceElement.getFileName != null) {
-                  stackTraceBuilder.setFileName(stackTraceElement.getFileName)
-                }
-
-                stackTraceBuilder.build()
-              })
+          currentError.getStackTrace
+            .map { stackTraceElement =>
+              val stackTraceBuilder = 
FetchErrorDetailsResponse.StackTraceElement
+                .newBuilder()
+                .setDeclaringClass(stackTraceElement.getClassName)
+                .setMethodName(stackTraceElement.getMethodName)
+                .setLineNumber(stackTraceElement.getLineNumber)
+
+              if (stackTraceElement.getFileName != null) {
+                stackTraceBuilder.setFileName(stackTraceElement.getFileName)
+              }
+
+              stackTraceBuilder.build()
+            }
+            .toImmutableArraySeq
             .asJava)
       }
 
diff --git 
a/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala 
b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala
new file mode 100644
index 000000000000..135af550c4b3
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.collection.immutable
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.util.ArrayImplicits._
+
+class ArrayImplicitsSuite extends SparkFunSuite {
+
+  test("Int Array") {
+    val data = Array(1, 2, 3)
+    val arraySeq = data.toImmutableArraySeq
+    assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofInt])
+    assert(arraySeq.length === 3)
+    assert(arraySeq.unsafeArray.sameElements(data))
+  }
+
+  test("TestClass Array") {
+    val data = Array(TestClass(1), TestClass(2), TestClass(3))
+    val arraySeq = data.toImmutableArraySeq
+    assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofRef[TestClass]])
+    assert(arraySeq.length === 3)
+    assert(arraySeq.unsafeArray.sameElements(data))
+  }
+
+  test("Null Array") {
+    val data: Array[Int] = null
+    val arraySeq = data.toImmutableArraySeq
+    assert(arraySeq == null)
+  }
+
+  case class TestClass(i: Int)
+}
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
index 1eed97a8d4f6..2f3f396730be 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -17,12 +17,12 @@
 
 package org.apache.spark.mllib.api.python
 
-import scala.collection.immutable
 import scala.jdk.CollectionConverters._
 
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.clustering.GaussianMixtureModel
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.util.ArrayImplicits._
 
 /**
  * Wrapper around GaussianMixtureModel to provide helper methods in Python
@@ -38,7 +38,7 @@ private[python] class GaussianMixtureModelWrapper(model: 
GaussianMixtureModel) {
     val modelGaussians = model.gaussians.map { gaussian =>
       Array[Any](gaussian.mu, gaussian.sigma)
     }
-    SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(modelGaussians).asJava)
+    SerDe.dumps(modelGaussians.toImmutableArraySeq.asJava)
   }
 
   def predictSoft(point: Vector): Vector = {
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala
index b919b0a8c3f2..6a6c6cf6bcfb 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala
@@ -16,12 +16,12 @@
  */
 package org.apache.spark.mllib.api.python
 
-import scala.collection.immutable
 import scala.jdk.CollectionConverters._
 
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.clustering.LDAModel
 import org.apache.spark.mllib.linalg.Matrix
+import org.apache.spark.util.ArrayImplicits._
 
 /**
  * Wrapper around LDAModel to provide helper methods in Python
@@ -36,11 +36,11 @@ private[python] class LDAModelWrapper(model: LDAModel) {
 
   def describeTopics(maxTermsPerTopic: Int): Array[Byte] = {
     val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, 
termWeights) =>
-      val jTerms = immutable.ArraySeq.unsafeWrapArray(terms).asJava
-      val jTermWeights = immutable.ArraySeq.unsafeWrapArray(termWeights).asJava
+      val jTerms = terms.toImmutableArraySeq.asJava
+      val jTermWeights = termWeights.toImmutableArraySeq.asJava
       Array[Any](jTerms, jTermWeights)
     }
-    SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(topics).asJava)
+    SerDe.dumps(topics.toImmutableArraySeq.asJava)
   }
 
   def save(sc: SparkContext, path: String): Unit = model.save(sc, path)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to