This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git


The following commit(s) were added to refs/heads/main by this push:
     new ebffbbe  [SPARK-51759] Add `ErrorUtils` and `SQLHelper`
ebffbbe is described below

commit ebffbbedc93e881bed1d2b9a4e50ccac6a1af6aa
Author: Dongjoon Hyun <dongj...@apache.org>
AuthorDate: Thu Apr 10 18:48:54 2025 +0900

    [SPARK-51759] Add `ErrorUtils` and `SQLHelper`
    
    ### What changes were proposed in this pull request?
    
    This PR aims to add `ErrorUtils` and `SQLHelper` and use it.
    
    ### Why are the changes needed?
    
    To support a similar developer experience with `tryWithSafeFinally`, 
`withDatabase`, and `withTable`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Pass the CIs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49 from dongjoon-hyun/SPARK-51759.
    
    Authored-by: Dongjoon Hyun <dongj...@apache.org>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 Sources/SparkConnect/ErrorUtils.swift              | 36 ++++++++++++++
 Tests/SparkConnectTests/CatalogTests.swift         |  9 +++-
 Tests/SparkConnectTests/DataFrameReaderTests.swift |  9 ++--
 Tests/SparkConnectTests/SQLHelper.swift            | 56 ++++++++++++++++++++++
 Tests/SparkConnectTests/SparkSessionTests.swift    |  9 ++--
 5 files changed, 110 insertions(+), 9 deletions(-)

diff --git a/Sources/SparkConnect/ErrorUtils.swift 
b/Sources/SparkConnect/ErrorUtils.swift
new file mode 100644
index 0000000..483e781
--- /dev/null
+++ b/Sources/SparkConnect/ErrorUtils.swift
@@ -0,0 +1,36 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//  http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+import Foundation
+
+/// Utility functions like `org.apache.spark.util.SparkErrorUtils`.
+public enum ErrorUtils {
+  public static func tryWithSafeFinally<T>(
+    _ block: () async throws -> T, _ finallyBlock: () async throws -> Void
+  ) async rethrows -> T {
+    let result: T
+    do {
+      result = try await block()
+      try await finallyBlock()
+    } catch {
+      try? await finallyBlock()
+      throw error
+    }
+    return result
+  }
+}
diff --git a/Tests/SparkConnectTests/CatalogTests.swift 
b/Tests/SparkConnectTests/CatalogTests.swift
index f49f2db..3b8feca 100644
--- a/Tests/SparkConnectTests/CatalogTests.swift
+++ b/Tests/SparkConnectTests/CatalogTests.swift
@@ -100,7 +100,14 @@ struct CatalogTests {
   func databaseExists() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
     #expect(try await spark.catalog.databaseExists("default"))
-    #expect(try await spark.catalog.databaseExists("not_exist_database") == 
false)
+
+    let dbName = "DB_" + UUID().uuidString.replacingOccurrences(of: "-", with: 
"")
+    #expect(try await spark.catalog.databaseExists(dbName) == false)
+    try await SQLHelper.withDatabase(spark, dbName) ({
+      _ = try await spark.sql("CREATE DATABASE \(dbName)").count()
+      #expect(try await spark.catalog.databaseExists(dbName))
+    })
+    #expect(try await spark.catalog.databaseExists(dbName) == false)
     await spark.stop()
   }
 #endif
diff --git a/Tests/SparkConnectTests/DataFrameReaderTests.swift 
b/Tests/SparkConnectTests/DataFrameReaderTests.swift
index 21d61e8..101d842 100644
--- a/Tests/SparkConnectTests/DataFrameReaderTests.swift
+++ b/Tests/SparkConnectTests/DataFrameReaderTests.swift
@@ -67,11 +67,12 @@ struct DataFrameReaderTests {
 
   @Test
   func table() async throws {
-    let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "")
+    let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", 
with: "")
     let spark = try await SparkSession.builder.getOrCreate()
-    #expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), 
(2)").count() == 0)
-    #expect(try await spark.read.table(tableName).count() == 2)
-    #expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0)
+    try await SQLHelper.withTable(spark, tableName)({
+      _ = try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), 
(2)").count()
+      #expect(try await spark.read.table(tableName).count() == 2)
+    })
     await spark.stop()
   }
 }
diff --git a/Tests/SparkConnectTests/SQLHelper.swift 
b/Tests/SparkConnectTests/SQLHelper.swift
new file mode 100644
index 0000000..c552119
--- /dev/null
+++ b/Tests/SparkConnectTests/SQLHelper.swift
@@ -0,0 +1,56 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//  http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+
+import Foundation
+import Testing
+
+@testable import SparkConnect
+
+/// A test utility
+struct SQLHelper {
+  public static func withDatabase(_ spark: SparkSession, _ dbNames: String...) 
-> (
+    () async throws -> Void
+  ) async throws -> Void {
+    func body(_ f: () async throws -> Void) async throws {
+      try await ErrorUtils.tryWithSafeFinally(
+        f,
+        {
+          for name in dbNames {
+            _ = try await spark.sql("DROP DATABASE IF EXISTS \(name) 
CASCADE").count()
+          }
+        })
+    }
+    return body
+  }
+
+  public static func withTable(_ spark: SparkSession, _ tableNames: String...) 
-> (
+    () async throws -> Void
+  ) async throws -> Void {
+    func body(_ f: () async throws -> Void) async throws {
+      try await ErrorUtils.tryWithSafeFinally(
+        f,
+        {
+          for name in tableNames {
+            _ = try await spark.sql("DROP TABLE IF EXISTS \(name)").count()
+          }
+        })
+    }
+    return body
+  }
+}
diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift 
b/Tests/SparkConnectTests/SparkSessionTests.swift
index cba57e4..24d0537 100644
--- a/Tests/SparkConnectTests/SparkSessionTests.swift
+++ b/Tests/SparkConnectTests/SparkSessionTests.swift
@@ -77,11 +77,12 @@ struct SparkSessionTests {
 
   @Test
   func table() async throws {
-    let tableName = UUID().uuidString.replacingOccurrences(of: "-", with: "")
+    let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", 
with: "")
     let spark = try await SparkSession.builder.getOrCreate()
-    #expect(try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), 
(2)").count() == 0)
-    #expect(try await spark.table(tableName).count() == 2)
-    #expect(try await spark.sql("DROP TABLE \(tableName)").count() == 0)
+    try await SQLHelper.withTable(spark, tableName)({
+      _ = try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), 
(2)").count()
+      #expect(try await spark.table(tableName).count() == 2)
+    })
     await spark.stop()
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to