This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git
The following commit(s) were added to refs/heads/main by this push: new 209e93e [SPARK-51560] Support `cache/persist/unpersist` for `DataFrame` 209e93e is described below commit 209e93e124e008109f5e8d1a6bdbac2556d8beeb Author: Dongjoon Hyun <dongj...@apache.org> AuthorDate: Wed Mar 19 11:49:32 2025 -0700 [SPARK-51560] Support `cache/persist/unpersist` for `DataFrame` ### What changes were proposed in this pull request? This PR aims to support `cache`, `persist`, and `unpersist` for `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. This is a new addition. ### How was this patch tested? Pass the CIs. ``` $ swift test --filter DataFrameTests ... Test run started. Testing Library Version: 102 (arm64e-apple-macos13.0) Suite DataFrameTests started. Test orderBy() started. Test isEmpty() started. Test show() started. Test persist() started. Test showCommand() started. Test table() started. Test selectMultipleColumns() started. Test showNull() started. Test schema() started. Test selectNone() started. Test rdd() started. Test sort() started. Test unpersist() started. Test limit() started. Test count() started. Test cache() started. Test selectInvalidColumn() started. Test collect() started. Test countNull() started. Test select() started. Test persistInvalidStorageLevel() started. Test rdd() passed after 0.571 seconds. Test selectNone() passed after 1.347 seconds. Test select() passed after 1.354 seconds. Test selectMultipleColumns() passed after 1.354 seconds. Test selectInvalidColumn() passed after 1.395 seconds. Test schema() passed after 1.747 seconds. ++ || ++ ++ Test showCommand() passed after 1.885 seconds. +-----------+-----------+-------------+ | namespace | tableName | isTemporary | +-----------+-----------+-------------+ +-----------+-----------+-------------+ +------+-------+------+ | col1 | col2 | col3 | +------+-------+------+ | 1 | true | abc | | NULL | NULL | NULL | | 3 | false | def | +------+-------+------+ Test showNull() passed after 1.890 seconds. +------+-------+ | col1 | col2 | +------+-------+ | true | false | +------+-------+ +------+------+ | col1 | col2 | +------+------+ | 1 | 2 | +------+------+ +------+------+ | col1 | col2 | +------+------+ | abc | def | | ghi | jkl | +------+------+ Test show() passed after 1.975 seconds. Test collect() passed after 2.045 seconds. Test countNull() passed after 2.566 seconds. Test persistInvalidStorageLevel() passed after 2.578 seconds. Test cache() passed after 2.683 seconds. Test isEmpty() passed after 2.778 seconds. Test count() passed after 2.892 seconds. Test persist() passed after 2.903 seconds. Test unpersist() passed after 2.917 seconds. Test limit() passed after 3.068 seconds. Test orderBy() passed after 3.101 seconds. Test sort() passed after 3.102 seconds. Test table() passed after 3.720 seconds. Suite DataFrameTests passed after 3.720 seconds. Test run with 21 tests passed after 3.720 seconds. ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #22 from dongjoon-hyun/SPARK-51560. Authored-by: Dongjoon Hyun <dongj...@apache.org> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- Sources/SparkConnect/DataFrame.swift | 39 +++++++++++++++++++++++++++ Sources/SparkConnect/SparkConnectClient.swift | 35 ++++++++++++++++++++++++ Sources/SparkConnect/TypeAliases.swift | 1 + Tests/SparkConnectTests/DataFrameTests.swift | 33 +++++++++++++++++++++++ 4 files changed, 108 insertions(+) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 81b74b1..81c92e4 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -245,4 +245,43 @@ public actor DataFrame: Sendable { public func isEmpty() async throws -> Bool { return try await select().limit(1).count() == 0 } + + public func cache() async throws -> DataFrame { + return try await persist() + } + + public func persist( + useDisk: Bool = true, useMemory: Bool = true, useOffHeap: Bool = false, + deserialized: Bool = true, replication: Int32 = 1 + ) + async throws -> DataFrame + { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: spark.client.host, port: spark.client.port), + transportSecurity: .plaintext + ) + ) { client in + let service = Spark_Connect_SparkConnectService.Client(wrapping: client) + _ = try await service.analyzePlan( + spark.client.getPersist( + spark.sessionID, plan, useDisk, useMemory, useOffHeap, deserialized, replication)) + } + + return self + } + + public func unpersist(blocking: Bool = false) async throws -> DataFrame { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: spark.client.host, port: spark.client.port), + transportSecurity: .plaintext + ) + ) { client in + let service = Spark_Connect_SparkConnectService.Client(wrapping: client) + _ = try await service.analyzePlan(spark.client.getUnpersist(spark.sessionID, plan, blocking)) + } + + return self + } } diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index c0c0828..aefd844 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -256,6 +256,41 @@ public actor SparkConnectClient { return request } + func getPersist( + _ sessionID: String, _ plan: Plan, _ useDisk: Bool = true, _ useMemory: Bool = true, + _ useOffHeap: Bool = false, _ deserialized: Bool = true, _ replication: Int32 = 1 + ) async + -> AnalyzePlanRequest + { + return analyze( + sessionID, + { + var persist = AnalyzePlanRequest.Persist() + var level = StorageLevel() + level.useDisk = useDisk + level.useMemory = useMemory + level.useOffHeap = useOffHeap + level.deserialized = deserialized + level.replication = replication + persist.storageLevel = level + persist.relation = plan.root + return OneOf_Analyze.persist(persist) + }) + } + + func getUnpersist(_ sessionID: String, _ plan: Plan, _ blocking: Bool = false) async + -> AnalyzePlanRequest + { + return analyze( + sessionID, + { + var unpersist = AnalyzePlanRequest.Unpersist() + unpersist.relation = plan.root + unpersist.blocking = blocking + return OneOf_Analyze.unpersist(unpersist) + }) + } + static func getProject(_ child: Relation, _ cols: [String]) -> Plan { var project = Project() project.input = child diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 8154662..92aa78e 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -30,5 +30,6 @@ typealias Range = Spark_Connect_Range typealias Relation = Spark_Connect_Relation typealias SparkConnectService = Spark_Connect_SparkConnectService typealias Sort = Spark_Connect_Sort +typealias StorageLevel = Spark_Connect_StorageLevel typealias UserContext = Spark_Connect_UserContext typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index a1c7e7e..552374d 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -193,5 +193,38 @@ struct DataFrameTests { try await spark.sql("DROP TABLE IF EXISTS t").show() await spark.stop() } + + @Test + func cache() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).cache().count() == 10) + await spark.stop() + } + + @Test + func persist() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(20).persist().count() == 20) + #expect(try await spark.range(21).persist(useDisk: false).count() == 21) + await spark.stop() + } + + @Test + func persistInvalidStorageLevel() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await #require(throws: Error.self) { + let _ = try await spark.range(9999).persist(replication: 0).count() + } + await spark.stop() + } + + @Test + func unpersist() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(30) + #expect(try await df.persist().count() == 30) + #expect(try await df.unpersist().count() == 30) + await spark.stop() + } #endif } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org