This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git
The following commit(s) were added to refs/heads/main by this push: new d5856c6 [SPARK-51785] Support `addTag/removeTag/getTags/clearTags` in `SparkSession` d5856c6 is described below commit d5856c69a781c49d3723e125620ea0b2baaef74a Author: Dongjoon Hyun <dongj...@apache.org> AuthorDate: Mon Apr 14 08:55:34 2025 +0900 [SPARK-51785] Support `addTag/removeTag/getTags/clearTags` in `SparkSession` ### What changes were proposed in this pull request? This PR aims to support the following `SparkSession` APIs. - `addTag` - `removeTag` - `getTags` - `clearTags` Note that `interrupt`-related operations will be supported later as an independent PR. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. This is a new addition. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #54 from dongjoon-hyun/SPARK-51785. Authored-by: Dongjoon Hyun <dongj...@apache.org> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- .../SparkConnect/ProtoUtils.swift | 38 ++++++++-------------- Sources/SparkConnect/SparkConnectClient.swift | 29 +++++++++++++++++ Sources/SparkConnect/SparkConnectError.swift | 1 + Sources/SparkConnect/SparkSession.swift | 25 ++++++++++++++ .../SparkConnectClientTests.swift | 17 ++++++++++ Tests/SparkConnectTests/SparkSessionTests.swift | 26 +++++++++++++++ 6 files changed, 112 insertions(+), 24 deletions(-) diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Sources/SparkConnect/ProtoUtils.swift similarity index 50% copy from Tests/SparkConnectTests/SparkConnectClientTests.swift copy to Sources/SparkConnect/ProtoUtils.swift index f50ae5d..738213f 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Sources/SparkConnect/ProtoUtils.swift @@ -16,34 +16,24 @@ // specific language governing permissions and limitations // under the License. // - import Foundation -import Testing -@testable import SparkConnect +/// Utility functions like `org.apache.spark.sql.connect.common.ProtoUtils`. +public enum ProtoUtils { -/// A test suite for `SparkConnectClient` -@Suite(.serialized) -struct SparkConnectClientTests { - @Test - func createAndStop() async throws { - let client = SparkConnectClient(remote: "sc://localhost", user: "test") - await client.stop() - } + private static let SPARK_JOB_TAGS_SEP = "," // SparkContext.SPARK_JOB_TAGS_SEP - @Test - func connectWithInvalidUUID() async throws { - let client = SparkConnectClient(remote: "sc://localhost", user: "test") - try await #require(throws: SparkConnectError.InvalidSessionIDException) { - let _ = try await client.connect("not-a-uuid-format") + /// Validate if a tag for ExecutePlanRequest.tags is valid. Throw IllegalArgumentException if not. + /// - Parameter tag: A tag string. + public static func throwIfInvalidTag(_ tag: String) throws { + // Same format rules apply to Spark Connect execution tags as to SparkContext job tags, + // because the Spark Connect job tag is also used as part of SparkContext job tag. + // See SparkContext.throwIfInvalidTag and ExecuteHolderSessionTag + if tag.isEmpty { + throw SparkConnectError.InvalidArgumentException + } + if tag.contains(SPARK_JOB_TAGS_SEP) { + throw SparkConnectError.InvalidArgumentException } - await client.stop() - } - - @Test - func connect() async throws { - let client = SparkConnectClient(remote: "sc://localhost", user: "test") - let _ = try await client.connect(UUID().uuidString) - await client.stop() } } diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 4e14077..6001ee8 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -30,6 +30,7 @@ public actor SparkConnectClient { let port: Int let userContext: UserContext var sessionID: String? = nil + var tags = Set<String>() /// Create a client to use GRPCClient. /// - Parameters: @@ -229,6 +230,7 @@ public actor SparkConnectClient { request.userContext = userContext request.sessionID = self.sessionID! request.operationID = UUID().uuidString + request.tags = Array(tags) request.plan = plan return request } @@ -409,4 +411,31 @@ public actor SparkConnectClient { } return result } + + /// Add a tag to be assigned to all the operations started by this thread in this session. + /// - Parameter tag: The tag to be added. Cannot contain ',' (comma) character or be an empty string. + public func addTag(tag: String) throws { + try ProtoUtils.throwIfInvalidTag(tag) + tags.insert(tag) + } + + /// Remove a tag previously added to be assigned to all the operations started by this thread in this session. + /// Noop if such a tag was not added earlier. + /// - Parameter tag: The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + public func removeTag(tag: String) throws { + try ProtoUtils.throwIfInvalidTag(tag) + tags.remove(tag) + } + + /// Get the operation tags that are currently set to be assigned to all the operations started by + /// this thread in this session. + /// - Returns: A set of string. + public func getTags() -> Set<String> { + return tags + } + + /// Clear the current thread's operation tags. + public func clearTags() { + tags.removeAll() + } } diff --git a/Sources/SparkConnect/SparkConnectError.swift b/Sources/SparkConnect/SparkConnectError.swift index df293e2..4434b6d 100644 --- a/Sources/SparkConnect/SparkConnectError.swift +++ b/Sources/SparkConnect/SparkConnectError.swift @@ -20,6 +20,7 @@ /// A enum for ``SparkConnect`` package errors public enum SparkConnectError: Error { case UnsupportedOperationException + case InvalidArgumentException case InvalidSessionIDException case InvalidTypeException } diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index 3b07c27..1b943a4 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -158,6 +158,31 @@ public actor SparkSession { return ret } + /// Add a tag to be assigned to all the operations started by this thread in this session. + /// - Parameter tag: The tag to be added. Cannot contain ',' (comma) character or be an empty string. + public func addTag(_ tag: String) async throws { + try await client.addTag(tag: tag) + } + + /// Remove a tag previously added to be assigned to all the operations started by this thread in this session. + /// Noop if such a tag was not added earlier. + /// - Parameter tag: The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + public func removeTag(_ tag: String) async throws { + try await client.removeTag(tag: tag) + } + + /// Get the operation tags that are currently set to be assigned to all the operations started by + /// this thread in this session. + /// - Returns: A set of string. + public func getTags() async -> Set<String> { + return await client.getTags() + } + + /// Clear the current thread's operation tags. + public func clearTags() async { + await client.clearTags() + } + /// This is defined as the return type of `SparkSession.sparkContext` method. /// This is an empty `Struct` type because `sparkContext` method is designed to throw /// `UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT`. diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift index f50ae5d..399e497 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -46,4 +46,21 @@ struct SparkConnectClientTests { let _ = try await client.connect(UUID().uuidString) await client.stop() } + + @Test + func tags() async throws { + let client = SparkConnectClient(remote: "sc://localhost", user: "test") + let sessionID = UUID().uuidString + let _ = try await client.connect(sessionID) + let plan = await client.getPlanRange(0, 1, 1) + + #expect(await client.getExecutePlanRequest(sessionID, plan).tags.isEmpty) + try await client.addTag(tag: "tag1") + + #expect(await client.getExecutePlanRequest(sessionID, plan).tags == ["tag1"]) + await client.clearTags() + + #expect(await client.getExecutePlanRequest(sessionID, plan).tags.isEmpty) + await client.stop() + } } diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index 24d0537..901a683 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -96,4 +96,30 @@ struct SparkSessionTests { #endif await spark.stop() } + + @Test + func tag() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.addTag("tag1") + #expect(await spark.getTags() == Set(["tag1"])) + try await spark.addTag("tag2") + #expect(await spark.getTags() == Set(["tag1", "tag2"])) + try await spark.removeTag("tag1") + #expect(await spark.getTags() == Set(["tag2"])) + await spark.clearTags() + #expect(await spark.getTags().isEmpty) + await spark.stop() + } + + @Test + func invalidTags() async throws { + let spark = try await SparkSession.builder.getOrCreate() + await #expect(throws: SparkConnectError.InvalidArgumentException) { + try await spark.addTag("") + } + await #expect(throws: SparkConnectError.InvalidArgumentException) { + try await spark.addTag(",") + } + await spark.stop() + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org