This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 30ec6e358536 [SPARK-45742][CORE][CONNECT][MLLIB][PYTHON] Introduce an
implicit function for Scala Array to wrap into `immutable.ArraySeq`
30ec6e358536 is described below
commit 30ec6e358536dfb695fcc1b8c3f084acb576d871
Author: yangjie01 <[email protected]>
AuthorDate: Wed Nov 1 21:08:04 2023 -0700
[SPARK-45742][CORE][CONNECT][MLLIB][PYTHON] Introduce an implicit function
for Scala Array to wrap into `immutable.ArraySeq`
### What changes were proposed in this pull request?
Currently, we need to use `immutable.ArraySeq.unsafeWrapArray(array)` to
wrap an Array into an `immutable.ArraySeq`, which makes the code look bloated.
So this PR introduces an implicit function `toImmutableArraySeq` to make it
easier for Scala Array to be wrapped into `immutable.ArraySeq`.
After this pr, we can use the following way to wrap an array into an
`immutable.ArraySeq`:
```scala
import org.apache.spark.util.ArrayImplicits._
val dataArray = ...
val immutableArraySeq = dataArray.toImmutableArraySeq
```
At the same time, this pr replaces the existing use of
`immutable.ArraySeq.unsafeWrapArray(array)` with the new method.
On the other hand, this implicit function will be conducive to the progress
of work SPARK-45686 and SPARK-45687.
### Why are the changes needed?
Makes the code for wrapping a Scala Array into an `immutable.ArraySeq` look
less bloated.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Pass GitHub Actions
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43607 from LuciferYang/SPARK-45742.
Authored-by: yangjie01 <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../org/apache/spark/util/ArrayImplicits.scala | 36 ++++++++++++++++
.../scala/org/apache/spark/sql/SparkSession.scala | 4 +-
.../connect/client/GrpcExceptionConverter.scala | 4 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 27 ++++++------
.../spark/sql/connect/utils/ErrorUtils.scala | 32 +++++++-------
.../apache/spark/util/ArrayImplicitsSuite.scala | 50 ++++++++++++++++++++++
.../api/python/GaussianMixtureModelWrapper.scala | 4 +-
.../spark/mllib/api/python/LDAModelWrapper.scala | 8 ++--
8 files changed, 126 insertions(+), 39 deletions(-)
diff --git
a/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala
b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala
new file mode 100644
index 000000000000..08997a800c95
--- /dev/null
+++ b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.collection.immutable
+
+/**
+ * Implicit methods related to Scala Array.
+ */
+private[spark] object ArrayImplicits {
+
+ implicit class SparkArrayOps[T](xs: Array[T]) {
+
+ /**
+ * Wraps an Array[T] as an immutable.ArraySeq[T] without copying.
+ */
+ def toImmutableArraySeq: immutable.ArraySeq[T] =
+ if (xs eq null) null
+ else immutable.ArraySeq.unsafeWrapArray(xs)
+ }
+}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 1cc1c8400fa8..34756f9a440b 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -21,7 +21,6 @@ import java.net.URI
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
-import scala.collection.immutable
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
@@ -45,6 +44,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf}
import org.apache.spark.sql.streaming.DataStreamReader
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ArrayImplicits._
/**
* The entry point to programming Spark with the Dataset and DataFrame API.
@@ -248,7 +248,7 @@ class SparkSession private[sql] (
proto.SqlCommand
.newBuilder()
.setSql(sqlText)
-
.addAllPosArguments(immutable.ArraySeq.unsafeWrapArray(args.map(lit(_).expr)).asJava)))
+
.addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava)))
val plan = proto.Plan.newBuilder().setCommand(cmd)
// .toBuffer forces that the iterator is consumed and closed
val responseSeq = client.execute(plan.build()).toBuffer.toSeq
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
index 3e53722caeb0..652797bc2e40 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.client
import java.time.DateTimeException
-import scala.collection.immutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
@@ -37,6 +36,7 @@ import
org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.streaming.StreamingQueryException
+import org.apache.spark.util.ArrayImplicits._
/**
* GrpcExceptionConverter handles the conversion of StatusRuntimeExceptions
into Spark exceptions.
@@ -375,7 +375,7 @@ private[client] object GrpcExceptionConverter {
FetchErrorDetailsResponse.Error
.newBuilder()
.setMessage(message)
-
.addAllErrorTypeHierarchy(immutable.ArraySeq.unsafeWrapArray(classes).asJava)
+ .addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava)
.build()))
}
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ec57909ad144..018e293795e9 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.connect.planner
-import scala.collection.immutable
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Try
@@ -80,6 +79,7 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout,
OutputMode, StreamingQ
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.CacheId
+import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils
final case class InvalidCommandInput(
@@ -3184,9 +3184,9 @@ class SparkConnectPlanner(
case StreamingQueryManagerCommand.CommandCase.ACTIVE =>
val active_queries = session.streams.active
respBuilder.getActiveBuilder.addAllActiveQueries(
- immutable.ArraySeq
- .unsafeWrapArray(active_queries
- .map(query => buildStreamingQueryInstance(query)))
+ active_queries
+ .map(query => buildStreamingQueryInstance(query))
+ .toImmutableArraySeq
.asJava)
case StreamingQueryManagerCommand.CommandCase.GET_QUERY =>
@@ -3265,15 +3265,16 @@ class SparkConnectPlanner(
.setGetResourcesCommandResult(
proto.GetResourcesCommandResult
.newBuilder()
- .putAllResources(session.sparkContext.resources.view
- .mapValues(resource =>
- proto.ResourceInformation
- .newBuilder()
- .setName(resource.name)
-
.addAllAddresses(immutable.ArraySeq.unsafeWrapArray(resource.addresses).asJava)
- .build())
- .toMap
- .asJava)
+ .putAllResources(
+ session.sparkContext.resources.view
+ .mapValues(resource =>
+ proto.ResourceInformation
+ .newBuilder()
+ .setName(resource.name)
+
.addAllAddresses(resource.addresses.toImmutableArraySeq.asJava)
+ .build())
+ .toMap
+ .asJava)
.build())
.build())
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 837ee5a00227..744fa3c8aa1a 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.connect.utils
import java.util.UUID
import scala.annotation.tailrec
-import scala.collection.immutable
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
@@ -43,6 +42,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.{ExecuteEventsManager,
SessionHolder, SessionKey, SparkConnectService}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.ArrayImplicits._
private[connect] object ErrorUtils extends Logging {
@@ -91,21 +91,21 @@ private[connect] object ErrorUtils extends Logging {
if (serverStackTraceEnabled) {
builder.addAllStackTrace(
- immutable.ArraySeq
- .unsafeWrapArray(currentError.getStackTrace
- .map { stackTraceElement =>
- val stackTraceBuilder =
FetchErrorDetailsResponse.StackTraceElement
- .newBuilder()
- .setDeclaringClass(stackTraceElement.getClassName)
- .setMethodName(stackTraceElement.getMethodName)
- .setLineNumber(stackTraceElement.getLineNumber)
-
- if (stackTraceElement.getFileName != null) {
- stackTraceBuilder.setFileName(stackTraceElement.getFileName)
- }
-
- stackTraceBuilder.build()
- })
+ currentError.getStackTrace
+ .map { stackTraceElement =>
+ val stackTraceBuilder =
FetchErrorDetailsResponse.StackTraceElement
+ .newBuilder()
+ .setDeclaringClass(stackTraceElement.getClassName)
+ .setMethodName(stackTraceElement.getMethodName)
+ .setLineNumber(stackTraceElement.getLineNumber)
+
+ if (stackTraceElement.getFileName != null) {
+ stackTraceBuilder.setFileName(stackTraceElement.getFileName)
+ }
+
+ stackTraceBuilder.build()
+ }
+ .toImmutableArraySeq
.asJava)
}
diff --git
a/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala
b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala
new file mode 100644
index 000000000000..135af550c4b3
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.collection.immutable
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.util.ArrayImplicits._
+
+class ArrayImplicitsSuite extends SparkFunSuite {
+
+ test("Int Array") {
+ val data = Array(1, 2, 3)
+ val arraySeq = data.toImmutableArraySeq
+ assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofInt])
+ assert(arraySeq.length === 3)
+ assert(arraySeq.unsafeArray.sameElements(data))
+ }
+
+ test("TestClass Array") {
+ val data = Array(TestClass(1), TestClass(2), TestClass(3))
+ val arraySeq = data.toImmutableArraySeq
+ assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofRef[TestClass]])
+ assert(arraySeq.length === 3)
+ assert(arraySeq.unsafeArray.sameElements(data))
+ }
+
+ test("Null Array") {
+ val data: Array[Int] = null
+ val arraySeq = data.toImmutableArraySeq
+ assert(arraySeq == null)
+ }
+
+ case class TestClass(i: Int)
+}
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
index 1eed97a8d4f6..2f3f396730be 100644
---
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
+++
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -17,12 +17,12 @@
package org.apache.spark.mllib.api.python
-import scala.collection.immutable
import scala.jdk.CollectionConverters._
import org.apache.spark.SparkContext
import org.apache.spark.mllib.clustering.GaussianMixtureModel
import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.util.ArrayImplicits._
/**
* Wrapper around GaussianMixtureModel to provide helper methods in Python
@@ -38,7 +38,7 @@ private[python] class GaussianMixtureModelWrapper(model:
GaussianMixtureModel) {
val modelGaussians = model.gaussians.map { gaussian =>
Array[Any](gaussian.mu, gaussian.sigma)
}
- SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(modelGaussians).asJava)
+ SerDe.dumps(modelGaussians.toImmutableArraySeq.asJava)
}
def predictSoft(point: Vector): Vector = {
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala
index b919b0a8c3f2..6a6c6cf6bcfb 100644
---
a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala
+++
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala
@@ -16,12 +16,12 @@
*/
package org.apache.spark.mllib.api.python
-import scala.collection.immutable
import scala.jdk.CollectionConverters._
import org.apache.spark.SparkContext
import org.apache.spark.mllib.clustering.LDAModel
import org.apache.spark.mllib.linalg.Matrix
+import org.apache.spark.util.ArrayImplicits._
/**
* Wrapper around LDAModel to provide helper methods in Python
@@ -36,11 +36,11 @@ private[python] class LDAModelWrapper(model: LDAModel) {
def describeTopics(maxTermsPerTopic: Int): Array[Byte] = {
val topics = model.describeTopics(maxTermsPerTopic).map { case (terms,
termWeights) =>
- val jTerms = immutable.ArraySeq.unsafeWrapArray(terms).asJava
- val jTermWeights = immutable.ArraySeq.unsafeWrapArray(termWeights).asJava
+ val jTerms = terms.toImmutableArraySeq.asJava
+ val jTermWeights = termWeights.toImmutableArraySeq.asJava
Array[Any](jTerms, jTermWeights)
}
- SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(topics).asJava)
+ SerDe.dumps(topics.toImmutableArraySeq.asJava)
}
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]