This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git
The following commit(s) were added to refs/heads/main by this push: new ebffbbe [SPARK-51759] Add `ErrorUtils` and `SQLHelper` ebffbbe is described below commit ebffbbedc93e881bed1d2b9a4e50ccac6a1af6aa Author: Dongjoon Hyun <dongj...@apache.org> AuthorDate: Thu Apr 10 18:48:54 2025 +0900 [SPARK-51759] Add `ErrorUtils` and `SQLHelper` ### What changes were proposed in this pull request? This PR aims to add `ErrorUtils` and `SQLHelper` and use it. ### Why are the changes needed? To support a similar developer experience with `tryWithSafeFinally`, `withDatabase`, and `withTable`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49 from dongjoon-hyun/SPARK-51759. Authored-by: Dongjoon Hyun <dongj...@apache.org> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- Sources/SparkConnect/ErrorUtils.swift | 36 ++++++++++++++ Tests/SparkConnectTests/CatalogTests.swift | 9 +++- Tests/SparkConnectTests/DataFrameReaderTests.swift | 9 ++-- Tests/SparkConnectTests/SQLHelper.swift | 56 ++++++++++++++++++++++ Tests/SparkConnectTests/SparkSessionTests.swift | 9 ++-- 5 files changed, 110 insertions(+), 9 deletions(-) diff --git a/Sources/SparkConnect/ErrorUtils.swift b/Sources/SparkConnect/ErrorUtils.swift new file mode 100644 index 0000000..483e781 --- /dev/null +++ b/Sources/SparkConnect/ErrorUtils.swift @@ -0,0 +1,36 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +import Foundation + +/// Utility functions like `org.apache.spark.util.SparkErrorUtils`. +public enum ErrorUtils { + public static func tryWithSafeFinally<T>( + _ block: () async throws -> T, _ finallyBlock: () async throws -> Void + ) async rethrows -> T { + let result: T + do { + result = try await block() + try await finallyBlock() + } catch { + try? await finallyBlock() + throw error + } + return result + } +} diff --git a/Tests/SparkConnectTests/CatalogTests.swift b/Tests/SparkConnectTests/CatalogTests.swift index f49f2db..3b8feca 100644 --- a/Tests/SparkConnectTests/CatalogTests.swift +++ b/Tests/SparkConnectTests/CatalogTests.swift @@ -100,7 +100,14 @@ struct CatalogTests { func databaseExists() async throws { let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.catalog.databaseExists("default")) - #expect(try await spark.catalog.databaseExists("not_exist_database") == false) + + let dbName = "DB_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + #expect(try await spark.catalog.databaseExists(dbName) == false) + try await SQLHelper.withDatabase(spark, dbName) ({ + _ = try await spark.sql("CREATE DATABASE \(dbName)").count() + #expect(try await spark.catalog.databaseExists(dbName)) + }) + #expect(try await spark.catalog.databaseExists(dbName) == false) await spark.stop() } #endif diff --git a/Tests/SparkConnectTests/DataFrameReaderTests.swift b/Tests/SparkConnectTests/DataFrameReaderTests.swift index 21d61e8..101d842 100644 --- a/Tests/SparkConnectTests/DataFrameReaderTests.swift +++ b/Tests/SparkConnectTests/DataFrameReaderTests.swift @@ -67,11 +67,12 @@ struct DataFrameReaderTests { @Test func table() async throws { - let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "") + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() == 0) - #expect(try await spark.read.table(tableName).count() == 2) - #expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0) + try await SQLHelper.withTable(spark, tableName)({ + _ = try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() + #expect(try await spark.read.table(tableName).count() == 2) + }) await spark.stop() } } diff --git a/Tests/SparkConnectTests/SQLHelper.swift b/Tests/SparkConnectTests/SQLHelper.swift new file mode 100644 index 0000000..c552119 --- /dev/null +++ b/Tests/SparkConnectTests/SQLHelper.swift @@ -0,0 +1,56 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +import Foundation +import Testing + +@testable import SparkConnect + +/// A test utility +struct SQLHelper { + public static func withDatabase(_ spark: SparkSession, _ dbNames: String...) -> ( + () async throws -> Void + ) async throws -> Void { + func body(_ f: () async throws -> Void) async throws { + try await ErrorUtils.tryWithSafeFinally( + f, + { + for name in dbNames { + _ = try await spark.sql("DROP DATABASE IF EXISTS \(name) CASCADE").count() + } + }) + } + return body + } + + public static func withTable(_ spark: SparkSession, _ tableNames: String...) -> ( + () async throws -> Void + ) async throws -> Void { + func body(_ f: () async throws -> Void) async throws { + try await ErrorUtils.tryWithSafeFinally( + f, + { + for name in tableNames { + _ = try await spark.sql("DROP TABLE IF EXISTS \(name)").count() + } + }) + } + return body + } +} diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index cba57e4..24d0537 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -77,11 +77,12 @@ struct SparkSessionTests { @Test func table() async throws { - let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "") + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() == 0) - #expect(try await spark.table(tableName).count() == 2) - #expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0) + try await SQLHelper.withTable(spark, tableName)({ + _ = try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() + #expect(try await spark.table(tableName).count() == 2) + }) await spark.stop() } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org