This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a8507ceba96 [SPARK-45461][CORE][SQL][MLLIB] Introduce a mapper for
StorageLevel
a8507ceba96 is described below
commit a8507ceba9673f2926b944cd4d9916b0f5927248
Author: Jiaan Geng <[email protected]>
AuthorDate: Sun Oct 8 13:21:59 2023 -0700
[SPARK-45461][CORE][SQL][MLLIB] Introduce a mapper for StorageLevel
### What changes were proposed in this pull request?
Currently, `StorageLevel` provides `fromString` to get the `StorageLevel`'s
instance with its name. So developers or users have to copy the string literal
of `StorageLevel`'s name to set or get instance of `StorageLevel`. This issue
lead to developers need to manually maintain its consistency. It is easy to
make mistakes and reduce development efficiency.
This PR could also fix the issue, refer:
https://github.com/apache/spark/pull/43259/files#r1349488662
### Why are the changes needed?
Let developers easy to use `StorageLevel`.
### Does this PR introduce _any_ user-facing change?
'No'.
Introduce a new class.
### How was this patch tested?
Exists test cases.
### Was this patch authored or co-authored using generative AI tooling?
'No'.
Closes #43278 from beliefer/SPARK-45461.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../apache/spark/storage/StorageLevelMapper.java | 47 ++++++++++++++++++++++
.../org/apache/spark/storage/StorageLevel.scala | 22 ++++------
.../org/apache/spark/ml/recommendation/ALS.scala | 6 +--
.../apache/spark/ml/recommendation/ALSSuite.scala | 10 ++---
.../spark/sql/catalyst/parser/DDLParserSuite.scala | 3 +-
.../python/AttachDistributedSequenceExec.scala | 4 +-
6 files changed, 66 insertions(+), 26 deletions(-)
diff --git
a/common/utils/src/main/java/org/apache/spark/storage/StorageLevelMapper.java
b/common/utils/src/main/java/org/apache/spark/storage/StorageLevelMapper.java
new file mode 100644
index 00000000000..18fa354a6e0
--- /dev/null
+++
b/common/utils/src/main/java/org/apache/spark/storage/StorageLevelMapper.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage;
+
+/**
+ * A mapper class easy to obtain storage levels based on their names.
+ */
+public enum StorageLevelMapper {
+ NONE(StorageLevel.NONE()),
+ DISK_ONLY(StorageLevel.DISK_ONLY()),
+ DISK_ONLY_2(StorageLevel.DISK_ONLY_2()),
+ DISK_ONLY_3(StorageLevel.DISK_ONLY_3()),
+ MEMORY_ONLY(StorageLevel.MEMORY_ONLY()),
+ MEMORY_ONLY_2(StorageLevel.MEMORY_ONLY_2()),
+ MEMORY_ONLY_SER(StorageLevel.MEMORY_ONLY_SER()),
+ MEMORY_ONLY_SER_2(StorageLevel.MEMORY_ONLY_SER_2()),
+ MEMORY_AND_DISK(StorageLevel.MEMORY_AND_DISK()),
+ MEMORY_AND_DISK_2(StorageLevel.MEMORY_AND_DISK_2()),
+ MEMORY_AND_DISK_SER(StorageLevel.MEMORY_AND_DISK_SER()),
+ MEMORY_AND_DISK_SER_2(StorageLevel.MEMORY_AND_DISK_SER_2()),
+ OFF_HEAP(StorageLevel.OFF_HEAP());
+
+ private final StorageLevel storageLevel;
+
+ StorageLevelMapper(StorageLevel storageLevel) {
+ this.storageLevel = storageLevel;
+ }
+
+ public static StorageLevel fromString(String s) throws
IllegalArgumentException {
+ return StorageLevelMapper.valueOf(s).storageLevel;
+ }
+}
diff --git
a/common/utils/src/main/scala/org/apache/spark/storage/StorageLevel.scala
b/common/utils/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 73bc53dab89..4280c78dc67 100644
--- a/common/utils/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/common/utils/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -165,21 +165,13 @@ object StorageLevel {
* Return the StorageLevel object with the specified name.
*/
@DeveloperApi
- def fromString(s: String): StorageLevel = s match {
- case "NONE" => NONE
- case "DISK_ONLY" => DISK_ONLY
- case "DISK_ONLY_2" => DISK_ONLY_2
- case "DISK_ONLY_3" => DISK_ONLY_3
- case "MEMORY_ONLY" => MEMORY_ONLY
- case "MEMORY_ONLY_2" => MEMORY_ONLY_2
- case "MEMORY_ONLY_SER" => MEMORY_ONLY_SER
- case "MEMORY_ONLY_SER_2" => MEMORY_ONLY_SER_2
- case "MEMORY_AND_DISK" => MEMORY_AND_DISK
- case "MEMORY_AND_DISK_2" => MEMORY_AND_DISK_2
- case "MEMORY_AND_DISK_SER" => MEMORY_AND_DISK_SER
- case "MEMORY_AND_DISK_SER_2" => MEMORY_AND_DISK_SER_2
- case "OFF_HEAP" => OFF_HEAP
- case _ => throw new IllegalArgumentException(s"Invalid StorageLevel: $s")
+ def fromString(s: String): StorageLevel = {
+ try {
+ StorageLevelMapper.fromString(s)
+ } catch {
+ case _: IllegalArgumentException =>
+ throw new IllegalArgumentException(s"Invalid StorageLevel: $s")
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 9e562f26abf..65c7d399a88 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -47,7 +47,7 @@ import org.apache.spark.rdd.{DeterministicLevel, RDD}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, StorageLevelMapper}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet,
SortDataFormat, Sorter}
import org.apache.spark.util.random.XORShiftRandom
@@ -245,8 +245,8 @@ private[recommendation] trait ALSParams extends
ALSModelParams with HasMaxIter w
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10,
numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
- intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel ->
"MEMORY_AND_DISK",
- coldStartStrategy -> "nan")
+ intermediateStorageLevel -> StorageLevelMapper.MEMORY_AND_DISK.name(),
+ finalStorageLevel -> StorageLevelMapper.MEMORY_AND_DISK.name(),
coldStartStrategy -> "nan")
/**
* Validates and transforms the input schema.
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 3f1e4d3887c..7ad26c02c89 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamingQueryException
import org.apache.spark.sql.types._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, StorageLevelMapper}
import org.apache.spark.util.Utils
class ALSSuite extends MLTest with DefaultReadWriteTest with Logging {
@@ -1114,8 +1114,8 @@ class ALSStorageSuite extends SparkFunSuite with
MLlibTestSparkContext with Defa
val nonDefaultListener = new IntermediateRDDStorageListener
sc.addSparkListener(nonDefaultListener)
val nonDefaultModel = als
- .setFinalStorageLevel("MEMORY_ONLY")
- .setIntermediateStorageLevel("DISK_ONLY")
+ .setFinalStorageLevel(StorageLevelMapper.MEMORY_ONLY.name())
+ .setIntermediateStorageLevel(StorageLevelMapper.DISK_ONLY.name())
.fit(data)
// check final factor RDD non-default storage levels
val levels = sc.getPersistentRDDs.collect {
@@ -1168,8 +1168,8 @@ object ALSSuite extends Logging {
"alpha" -> 0.9,
"nonnegative" -> true,
"checkpointInterval" -> 20,
- "intermediateStorageLevel" -> "MEMORY_ONLY",
- "finalStorageLevel" -> "MEMORY_AND_DISK_SER"
+ "intermediateStorageLevel" -> StorageLevelMapper.MEMORY_ONLY.name(),
+ "finalStorageLevel" -> StorageLevelMapper.MEMORY_AND_DISK_SER.name()
)
// Helper functions to generate test data we share between ALS test suites
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index 9644f6ea038..b1f73405044 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -29,6 +29,7 @@ import
org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransfo
import org.apache.spark.sql.connector.expressions.LogicalExpressions.bucket
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{Decimal, IntegerType, LongType,
MetadataBuilder, StringType, StructType, TimestampType}
+import org.apache.spark.storage.StorageLevelMapper
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class DDLParserSuite extends AnalysisTest {
@@ -2341,7 +2342,7 @@ class DDLParserSuite extends AnalysisTest {
UnresolvedRelation(Seq("a", "b", "c")),
Seq("a", "b", "c"),
true,
- Map("storageLevel" -> "DISK_ONLY")))
+ Map("storageLevel" -> StorageLevelMapper.DISK_ONLY.name())))
val sql = "CACHE TABLE a.b.c AS SELECT * FROM testData"
checkError(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
index a1df89a20cb..e353bf5a51e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, StorageLevelMapper}
/**
* A physical plan that adds a new long column with `sequenceAttr` that
@@ -70,7 +70,7 @@ case class AttachDistributedSequenceExec(
// The string is double quoted because of JSON ser/deser for pandas API on
Spark
val storageLevel = SQLConf.get.getConfString(
"pandas_on_Spark.compute.default_index_cache",
- "MEMORY_AND_DISK_SER"
+ StorageLevelMapper.MEMORY_AND_DISK_SER.name()
).stripPrefix("\"").stripSuffix("\"")
val cachedRDD = storageLevel match {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]