This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new f73b5c5b55a [SPARK-42894][CONNECT] Support 
`cache`/`persist`/`unpersist`/`storageLevel` for Spark connect jvm client
f73b5c5b55a is described below

commit f73b5c5b55a6dc64300233f8fbcd7c60e16eeb99
Author: yangjie01 <[email protected]>
AuthorDate: Wed Mar 22 09:52:58 2023 -0700

    [SPARK-42894][CONNECT] Support `cache`/`persist`/`unpersist`/`storageLevel` 
for Spark connect jvm client
    
    ### What changes were proposed in this pull request?
    This pr follow SPARK-42889 to support 
`cache`/`persist`/`unpersist`/`storageLevel` for Spark connect jvm client
    
    ### Why are the changes needed?
    Add Spark connect jvm client api coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    
    - Add new test
    
    Closes #40516 from LuciferYang/SPARK-42894.
    
    Authored-by: yangjie01 <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
    (cherry picked from commit b58df45d97fce6997d42063f9536a4fceb58125b)
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 76 ++++++++++++++++++++--
 .../scala/org/apache/spark/sql/SparkSession.scala  |  7 ++
 .../sql/connect/client/SparkConnectClient.scala    |  3 +-
 .../CheckConnectJvmClientCompatibility.scala       |  1 -
 4 files changed, 79 insertions(+), 8 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index fdc994b2d90..ca90afa14cf 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveLongEncoder, 
StringEncoder, UnboundRowEncoder}
 import org.apache.spark.sql.catalyst.expressions.RowOrdering
 import org.apache.spark.sql.connect.client.SparkResult
-import org.apache.spark.sql.connect.common.DataTypeProtoConverter
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
StorageLevelProtoConverter}
 import org.apache.spark.sql.functions.{struct, to_json}
 import org.apache.spark.sql.types.{Metadata, StructType}
 import org.apache.spark.storage.StorageLevel
@@ -2704,22 +2704,86 @@ class Dataset[T] private[sql] (
     new DataFrameWriterV2[T](table, this)
   }
 
+  /**
+   * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`).
+   *
+   * @group basic
+   * @since 3.4.0
+   */
   def persist(): this.type = {
-    throw new UnsupportedOperationException("persist is not implemented.")
+    sparkSession.analyze { builder =>
+      builder.getPersistBuilder.setRelation(plan.getRoot)
+    }
+    this
   }
 
+  /**
+   * Persist this Dataset with the given storage level.
+   *
+   * @param newLevel
+   *   One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, 
`MEMORY_AND_DISK_SER`,
+   *   `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc.
+   * @group basic
+   * @since 3.4.0
+   */
   def persist(newLevel: StorageLevel): this.type = {
-    throw new UnsupportedOperationException("persist is not implemented.")
+    sparkSession.analyze { builder =>
+      builder.getPersistBuilder
+        .setRelation(plan.getRoot)
+        
.setStorageLevel(StorageLevelProtoConverter.toConnectProtoType(newLevel))
+    }
+    this
   }
 
+  /**
+   * Mark the Dataset as non-persistent, and remove all blocks for it from 
memory and disk. This
+   * will not un-persist any cached data that is built upon this Dataset.
+   *
+   * @param blocking
+   *   Whether to block until all blocks are deleted.
+   * @group basic
+   * @since 3.4.0
+   */
   def unpersist(blocking: Boolean): this.type = {
-    throw new UnsupportedOperationException("unpersist() is not implemented.")
+    sparkSession.analyze { builder =>
+      builder.getUnpersistBuilder
+        .setRelation(plan.getRoot)
+        .setBlocking(blocking)
+    }
+    this
   }
 
+  /**
+   * Mark the Dataset as non-persistent, and remove all blocks for it from 
memory and disk. This
+   * will not un-persist any cached data that is built upon this Dataset.
+   *
+   * @group basic
+   * @since 3.4.0
+   */
   def unpersist(): this.type = unpersist(blocking = false)
 
-  def cache(): this.type = {
-    throw new UnsupportedOperationException("cache() is not implemented.")
+  /**
+   * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`).
+   *
+   * @group basic
+   * @since 3.4.0
+   */
+  def cache(): this.type = persist()
+
+  /**
+   * Get the Dataset's current storage level, or StorageLevel.NONE if not 
persisted.
+   *
+   * @group basic
+   * @since 3.4.0
+   */
+  def storageLevel: StorageLevel = {
+    StorageLevelProtoConverter.toStorageLevel(
+      sparkSession
+        .analyze { builder =>
+          builder.getGetStorageLevelBuilder.setRelation(plan.getRoot)
+        }
+        .getGetStorageLevel
+        .getStorageLevel)
   }
 
   def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = {
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 141bb637e15..f1e82507393 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -400,6 +400,13 @@ class SparkSession private[sql] (
     client.analyze(method, Some(plan), explainMode)
   }
 
+  private[sql] def analyze(
+      f: proto.AnalyzePlanRequest.Builder => Unit): proto.AnalyzePlanResponse 
= {
+    val builder = proto.AnalyzePlanRequest.newBuilder()
+    f(builder)
+    client.analyze(builder)
+  }
+
   private[sql] def sameSemantics(plan: proto.Plan, otherPlan: proto.Plan): 
Boolean = {
     client.sameSemantics(plan, otherPlan).getSameSemantics.getResult
   }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index a298c526883..a508f3c067f 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -166,7 +166,8 @@ private[sql] class SparkConnectClient(
     analyze(builder)
   }
 
-  private def analyze(builder: proto.AnalyzePlanRequest.Builder): 
proto.AnalyzePlanResponse = {
+  private[sql] def analyze(
+      builder: proto.AnalyzePlanRequest.Builder): proto.AnalyzePlanResponse = {
     val request = builder
       .setUserContext(userContext)
       .setSessionId(sessionId)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index a2b4762f0a9..68369512fb7 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -161,7 +161,6 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.flatMap"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.foreach"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.foreachPartition"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.storageLevel"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to