This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git


The following commit(s) were added to refs/heads/main by this push:
     new bb8b9fb  [SPARK-51986] Support `Parameterized SQL queries` in `sql` API
bb8b9fb is described below

commit bb8b9fb88c39a0bbf3ec6bc25388702743315788
Author: Dongjoon Hyun <dongj...@apache.org>
AuthorDate: Fri May 2 17:59:38 2025 -0700

    [SPARK-51986] Support `Parameterized SQL queries` in `sql` API
    
    ### What changes were proposed in this pull request?
    
    This PR aims to support `Parameterized SQL queries` in `sql` API.
    
    ### Why are the changes needed?
    
    For feature parity, we had better support this GA feature.
    
    - https://github.com/apache/spark/pull/38864 (Since Spark 3.4.0)
    - https://github.com/apache/spark/pull/40623 (Since Spark 3.4.0)
    - https://github.com/apache/spark/pull/41568 (Since Spark 3.5.0)
    - https://github.com/apache/spark/pull/48965 (GA Since Spark 4.0.0)
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Pass the CIs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #103 from dongjoon-hyun/SPARK-51986.
    
    Authored-by: Dongjoon Hyun <dongj...@apache.org>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 Sources/SparkConnect/DataFrame.swift            | 14 ++++-
 Sources/SparkConnect/Extension.swift            | 68 +++++++++++++++++++++++++
 Sources/SparkConnect/SparkSession.swift         | 20 ++++++++
 Sources/SparkConnect/TypeAliases.swift          |  1 +
 Tests/SparkConnectTests/SparkSessionTests.swift | 13 +++++
 5 files changed, 114 insertions(+), 2 deletions(-)

diff --git a/Sources/SparkConnect/DataFrame.swift 
b/Sources/SparkConnect/DataFrame.swift
index 5531917..cbe4793 100644
--- a/Sources/SparkConnect/DataFrame.swift
+++ b/Sources/SparkConnect/DataFrame.swift
@@ -43,9 +43,19 @@ public actor DataFrame: Sendable {
   /// - Parameters:
   ///   - spark: A `SparkSession` instance to use.
   ///   - sqlText: A SQL statement.
-  init(spark: SparkSession, sqlText: String) async throws {
+  ///   - posArgs: An array of strings.
+  init(spark: SparkSession, sqlText: String, _ posArgs: [Sendable]? = nil) 
async throws {
     self.spark = spark
-    self.plan = sqlText.toSparkConnectPlan
+    if let posArgs {
+      self.plan = sqlText.toSparkConnectPlan(posArgs)
+    } else {
+      self.plan = sqlText.toSparkConnectPlan
+    }
+  }
+
+  init(spark: SparkSession, sqlText: String, _ args: [String: Sendable]) async 
throws {
+    self.spark = spark
+    self.plan = sqlText.toSparkConnectPlan(args)
   }
 
   public func getPlan() -> Sendable {
diff --git a/Sources/SparkConnect/Extension.swift 
b/Sources/SparkConnect/Extension.swift
index 5d75b3d..e841fa4 100644
--- a/Sources/SparkConnect/Extension.swift
+++ b/Sources/SparkConnect/Extension.swift
@@ -31,6 +31,74 @@ extension String {
     return plan
   }
 
+  func toSparkConnectPlan(_ posArguments: [Sendable]) -> Plan {
+    var sql = Spark_Connect_SQL()
+    sql.query = self
+    sql.posArguments = posArguments.map {
+      var literal = ExpressionLiteral()
+      switch $0 {
+      case let value as Bool:
+        literal.boolean = value
+      case let value as Int8:
+        literal.byte = Int32(value)
+      case let value as Int16:
+        literal.short = Int32(value)
+      case let value as Int32:
+        literal.integer = value
+      case let value as Int64:
+        literal.long = value
+      case let value as Int:
+        literal.long = Int64(value)
+      case let value as String:
+        literal.string = value
+      default:
+        literal.string = $0 as! String
+      }
+      var expr = Spark_Connect_Expression()
+      expr.literal = literal
+      return expr
+    }
+    var relation = Relation()
+    relation.sql = sql
+    var plan = Plan()
+    plan.opType = Plan.OneOf_OpType.root(relation)
+    return plan
+  }
+
+  func toSparkConnectPlan(_ namedArguments: [String: Sendable]) -> Plan {
+    var sql = Spark_Connect_SQL()
+    sql.query = self
+    sql.namedArguments = namedArguments.mapValues { value in
+      var literal = ExpressionLiteral()
+      switch value {
+      case let value as Bool:
+        literal.boolean = value
+      case let value as Int8:
+        literal.byte = Int32(value)
+      case let value as Int16:
+        literal.short = Int32(value)
+      case let value as Int32:
+        literal.integer = value
+      case let value as Int64:
+        literal.long = value
+      case let value as Int:
+        literal.long = Int64(value)
+      case let value as String:
+        literal.string = value
+      default:
+        literal.string = value as! String
+      }
+      var expr = Spark_Connect_Expression()
+      expr.literal = literal
+      return expr
+    }
+    var relation = Relation()
+    relation.sql = sql
+    var plan = Plan()
+    plan.opType = Plan.OneOf_OpType.root(relation)
+    return plan
+  }
+
   /// Get a `UserContext` instance from a string.
   var toUserContext: UserContext {
     var context = UserContext()
diff --git a/Sources/SparkConnect/SparkSession.swift 
b/Sources/SparkConnect/SparkSession.swift
index b06370e..ebaf190 100644
--- a/Sources/SparkConnect/SparkSession.swift
+++ b/Sources/SparkConnect/SparkSession.swift
@@ -112,6 +112,26 @@ public actor SparkSession {
     return try await DataFrame(spark: self, sqlText: sqlText)
   }
 
+  /// Executes a SQL query substituting positional parameters by the given 
arguments, returning the
+  /// result as a `DataFrame`.
+  /// - Parameters:
+  ///   - sqlText: A SQL statement with positional parameters to execute.
+  ///   - args: ``Sendable`` values that can be converted to SQL literal 
expressions.
+  /// - Returns: A ``DataFrame``.
+  public func sql(_ sqlText: String, _ args: Sendable...) async throws -> 
DataFrame {
+    return try await DataFrame(spark: self, sqlText: sqlText, args)
+  }
+
+  /// Executes a SQL query substituting named parameters by the given 
arguments, returning the
+  /// result as a `DataFrame`.
+  /// - Parameters:
+  ///   - sqlText: A SQL statement with named parameters to execute.
+  ///   - args: A dictionary with key string and ``Sendable`` value.
+  /// - Returns: A ``DataFrame``.
+  public func sql(_ sqlText: String, args: [String: Sendable]) async throws -> 
DataFrame {
+    return try await DataFrame(spark: self, sqlText: sqlText, args)
+  }
+
   /// Returns a ``DataFrameReader`` that can be used to read non-streaming 
data in as a
   /// `DataFrame`
   public var read: DataFrameReader {
diff --git a/Sources/SparkConnect/TypeAliases.swift 
b/Sources/SparkConnect/TypeAliases.swift
index 60f0fb8..41547f8 100644
--- a/Sources/SparkConnect/TypeAliases.swift
+++ b/Sources/SparkConnect/TypeAliases.swift
@@ -28,6 +28,7 @@ typealias Drop = Spark_Connect_Drop
 typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
 typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse
 typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
+typealias ExpressionLiteral = Spark_Connect_Expression.Literal
 typealias ExpressionString = Spark_Connect_Expression.ExpressionString
 typealias Filter = Spark_Connect_Filter
 typealias GroupType = Spark_Connect_Aggregate.GroupType
diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift 
b/Tests/SparkConnectTests/SparkSessionTests.swift
index 2bc887e..69f0aee 100644
--- a/Tests/SparkConnectTests/SparkSessionTests.swift
+++ b/Tests/SparkConnectTests/SparkSessionTests.swift
@@ -76,6 +76,19 @@ struct SparkSessionTests {
     await spark.stop()
   }
 
+#if !os(Linux)
+  @Test
+  func sql() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let expected = [Row(true, 1, "a")]
+    if await spark.version.starts(with: "4.") {
+      #expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect() == 
expected)
+      #expect(try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y": 
1, "z": "a"]).collect() == expected)
+    }
+    await spark.stop()
+  }
+#endif
+
   @Test
   func table() async throws {
     let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", 
with: "")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to