This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git


The following commit(s) were added to refs/heads/main by this push:
     new 209e93e  [SPARK-51560] Support `cache/persist/unpersist` for 
`DataFrame`
209e93e is described below

commit 209e93e124e008109f5e8d1a6bdbac2556d8beeb
Author: Dongjoon Hyun <dongj...@apache.org>
AuthorDate: Wed Mar 19 11:49:32 2025 -0700

    [SPARK-51560] Support `cache/persist/unpersist` for `DataFrame`
    
    ### What changes were proposed in this pull request?
    
    This PR aims to support `cache`, `persist`, and `unpersist` for `DataFrame`.
    
    ### Why are the changes needed?
    
    For feature parity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. This is a new addition.
    
    ### How was this patch tested?
    
    Pass the CIs.
    
    ```
    $ swift test --filter DataFrameTests
    ...
    􀟈  Test run started.
    􀄵  Testing Library Version: 102 (arm64e-apple-macos13.0)
    􀟈  Suite DataFrameTests started.
    􀟈  Test orderBy() started.
    􀟈  Test isEmpty() started.
    􀟈  Test show() started.
    􀟈  Test persist() started.
    􀟈  Test showCommand() started.
    􀟈  Test table() started.
    􀟈  Test selectMultipleColumns() started.
    􀟈  Test showNull() started.
    􀟈  Test schema() started.
    􀟈  Test selectNone() started.
    􀟈  Test rdd() started.
    􀟈  Test sort() started.
    􀟈  Test unpersist() started.
    􀟈  Test limit() started.
    􀟈  Test count() started.
    􀟈  Test cache() started.
    􀟈  Test selectInvalidColumn() started.
    􀟈  Test collect() started.
    􀟈  Test countNull() started.
    􀟈  Test select() started.
    􀟈  Test persistInvalidStorageLevel() started.
    􁁛  Test rdd() passed after 0.571 seconds.
    􁁛  Test selectNone() passed after 1.347 seconds.
    􁁛  Test select() passed after 1.354 seconds.
    􁁛  Test selectMultipleColumns() passed after 1.354 seconds.
    􁁛  Test selectInvalidColumn() passed after 1.395 seconds.
    􁁛  Test schema() passed after 1.747 seconds.
    ++
    ||
    ++
    
    ++
    􁁛  Test showCommand() passed after 1.885 seconds.
    +-----------+-----------+-------------+
    | namespace | tableName | isTemporary |
    +-----------+-----------+-------------+
    
    +-----------+-----------+-------------+
    +------+-------+------+
    | col1 | col2  | col3 |
    +------+-------+------+
    | 1    | true  | abc  |
    | NULL | NULL  | NULL |
    | 3    | false | def  |
    +------+-------+------+
    􁁛  Test showNull() passed after 1.890 seconds.
    +------+-------+
    | col1 | col2  |
    +------+-------+
    | true | false |
    +------+-------+
    +------+------+
    | col1 | col2 |
    +------+------+
    | 1    | 2    |
    +------+------+
    +------+------+
    | col1 | col2 |
    +------+------+
    | abc  | def  |
    | ghi  | jkl  |
    +------+------+
    􁁛  Test show() passed after 1.975 seconds.
    􁁛  Test collect() passed after 2.045 seconds.
    􁁛  Test countNull() passed after 2.566 seconds.
    􁁛  Test persistInvalidStorageLevel() passed after 2.578 seconds.
    􁁛  Test cache() passed after 2.683 seconds.
    􁁛  Test isEmpty() passed after 2.778 seconds.
    􁁛  Test count() passed after 2.892 seconds.
    􁁛  Test persist() passed after 2.903 seconds.
    􁁛  Test unpersist() passed after 2.917 seconds.
    􁁛  Test limit() passed after 3.068 seconds.
    􁁛  Test orderBy() passed after 3.101 seconds.
    􁁛  Test sort() passed after 3.102 seconds.
    􁁛  Test table() passed after 3.720 seconds.
    􁁛  Suite DataFrameTests passed after 3.720 seconds.
    􁁛  Test run with 21 tests passed after 3.720 seconds.
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #22 from dongjoon-hyun/SPARK-51560.
    
    Authored-by: Dongjoon Hyun <dongj...@apache.org>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 Sources/SparkConnect/DataFrame.swift          | 39 +++++++++++++++++++++++++++
 Sources/SparkConnect/SparkConnectClient.swift | 35 ++++++++++++++++++++++++
 Sources/SparkConnect/TypeAliases.swift        |  1 +
 Tests/SparkConnectTests/DataFrameTests.swift  | 33 +++++++++++++++++++++++
 4 files changed, 108 insertions(+)

diff --git a/Sources/SparkConnect/DataFrame.swift 
b/Sources/SparkConnect/DataFrame.swift
index 81b74b1..81c92e4 100644
--- a/Sources/SparkConnect/DataFrame.swift
+++ b/Sources/SparkConnect/DataFrame.swift
@@ -245,4 +245,43 @@ public actor DataFrame: Sendable {
   public func isEmpty() async throws -> Bool {
     return try await select().limit(1).count() == 0
   }
+
+  public func cache() async throws -> DataFrame {
+    return try await persist()
+  }
+
+  public func persist(
+    useDisk: Bool = true, useMemory: Bool = true, useOffHeap: Bool = false,
+    deserialized: Bool = true, replication: Int32 = 1
+  )
+    async throws -> DataFrame
+  {
+    try await withGRPCClient(
+      transport: .http2NIOPosix(
+        target: .dns(host: spark.client.host, port: spark.client.port),
+        transportSecurity: .plaintext
+      )
+    ) { client in
+      let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
+      _ = try await service.analyzePlan(
+        spark.client.getPersist(
+          spark.sessionID, plan, useDisk, useMemory, useOffHeap, deserialized, 
replication))
+    }
+
+    return self
+  }
+
+  public func unpersist(blocking: Bool = false) async throws -> DataFrame {
+    try await withGRPCClient(
+      transport: .http2NIOPosix(
+        target: .dns(host: spark.client.host, port: spark.client.port),
+        transportSecurity: .plaintext
+      )
+    ) { client in
+      let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
+      _ = try await 
service.analyzePlan(spark.client.getUnpersist(spark.sessionID, plan, blocking))
+    }
+
+    return self
+  }
 }
diff --git a/Sources/SparkConnect/SparkConnectClient.swift 
b/Sources/SparkConnect/SparkConnectClient.swift
index c0c0828..aefd844 100644
--- a/Sources/SparkConnect/SparkConnectClient.swift
+++ b/Sources/SparkConnect/SparkConnectClient.swift
@@ -256,6 +256,41 @@ public actor SparkConnectClient {
     return request
   }
 
+  func getPersist(
+    _ sessionID: String, _ plan: Plan, _ useDisk: Bool = true, _ useMemory: 
Bool = true,
+    _ useOffHeap: Bool = false, _ deserialized: Bool = true, _ replication: 
Int32 = 1
+  ) async
+    -> AnalyzePlanRequest
+  {
+    return analyze(
+      sessionID,
+      {
+        var persist = AnalyzePlanRequest.Persist()
+        var level = StorageLevel()
+        level.useDisk = useDisk
+        level.useMemory = useMemory
+        level.useOffHeap = useOffHeap
+        level.deserialized = deserialized
+        level.replication = replication
+        persist.storageLevel = level
+        persist.relation = plan.root
+        return OneOf_Analyze.persist(persist)
+      })
+  }
+
+  func getUnpersist(_ sessionID: String, _ plan: Plan, _ blocking: Bool = 
false) async
+    -> AnalyzePlanRequest
+  {
+    return analyze(
+      sessionID,
+      {
+        var unpersist = AnalyzePlanRequest.Unpersist()
+        unpersist.relation = plan.root
+        unpersist.blocking = blocking
+        return OneOf_Analyze.unpersist(unpersist)
+      })
+  }
+
   static func getProject(_ child: Relation, _ cols: [String]) -> Plan {
     var project = Project()
     project.input = child
diff --git a/Sources/SparkConnect/TypeAliases.swift 
b/Sources/SparkConnect/TypeAliases.swift
index 8154662..92aa78e 100644
--- a/Sources/SparkConnect/TypeAliases.swift
+++ b/Sources/SparkConnect/TypeAliases.swift
@@ -30,5 +30,6 @@ typealias Range = Spark_Connect_Range
 typealias Relation = Spark_Connect_Relation
 typealias SparkConnectService = Spark_Connect_SparkConnectService
 typealias Sort = Spark_Connect_Sort
+typealias StorageLevel = Spark_Connect_StorageLevel
 typealias UserContext = Spark_Connect_UserContext
 typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
diff --git a/Tests/SparkConnectTests/DataFrameTests.swift 
b/Tests/SparkConnectTests/DataFrameTests.swift
index a1c7e7e..552374d 100644
--- a/Tests/SparkConnectTests/DataFrameTests.swift
+++ b/Tests/SparkConnectTests/DataFrameTests.swift
@@ -193,5 +193,38 @@ struct DataFrameTests {
     try await spark.sql("DROP TABLE IF EXISTS t").show()
     await spark.stop()
   }
+
+  @Test
+  func cache() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    #expect(try await spark.range(10).cache().count() == 10)
+    await spark.stop()
+  }
+
+  @Test
+  func persist() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    #expect(try await spark.range(20).persist().count() == 20)
+    #expect(try await spark.range(21).persist(useDisk: false).count() == 21)
+    await spark.stop()
+  }
+
+  @Test
+  func persistInvalidStorageLevel() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    try await #require(throws: Error.self) {
+      let _ = try await spark.range(9999).persist(replication: 0).count()
+    }
+    await spark.stop()
+  }
+
+  @Test
+  func unpersist() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let df = try await spark.range(30)
+    #expect(try await df.persist().count() == 30)
+    #expect(try await df.unpersist().count() == 30)
+    await spark.stop()
+  }
 #endif
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to