This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git
The following commit(s) were added to refs/heads/main by this push: new bb8b9fb [SPARK-51986] Support `Parameterized SQL queries` in `sql` API bb8b9fb is described below commit bb8b9fb88c39a0bbf3ec6bc25388702743315788 Author: Dongjoon Hyun <dongj...@apache.org> AuthorDate: Fri May 2 17:59:38 2025 -0700 [SPARK-51986] Support `Parameterized SQL queries` in `sql` API ### What changes were proposed in this pull request? This PR aims to support `Parameterized SQL queries` in `sql` API. ### Why are the changes needed? For feature parity, we had better support this GA feature. - https://github.com/apache/spark/pull/38864 (Since Spark 3.4.0) - https://github.com/apache/spark/pull/40623 (Since Spark 3.4.0) - https://github.com/apache/spark/pull/41568 (Since Spark 3.5.0) - https://github.com/apache/spark/pull/48965 (GA Since Spark 4.0.0) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #103 from dongjoon-hyun/SPARK-51986. Authored-by: Dongjoon Hyun <dongj...@apache.org> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- Sources/SparkConnect/DataFrame.swift | 14 ++++- Sources/SparkConnect/Extension.swift | 68 +++++++++++++++++++++++++ Sources/SparkConnect/SparkSession.swift | 20 ++++++++ Sources/SparkConnect/TypeAliases.swift | 1 + Tests/SparkConnectTests/SparkSessionTests.swift | 13 +++++ 5 files changed, 114 insertions(+), 2 deletions(-) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 5531917..cbe4793 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -43,9 +43,19 @@ public actor DataFrame: Sendable { /// - Parameters: /// - spark: A `SparkSession` instance to use. /// - sqlText: A SQL statement. - init(spark: SparkSession, sqlText: String) async throws { + /// - posArgs: An array of strings. + init(spark: SparkSession, sqlText: String, _ posArgs: [Sendable]? = nil) async throws { self.spark = spark - self.plan = sqlText.toSparkConnectPlan + if let posArgs { + self.plan = sqlText.toSparkConnectPlan(posArgs) + } else { + self.plan = sqlText.toSparkConnectPlan + } + } + + init(spark: SparkSession, sqlText: String, _ args: [String: Sendable]) async throws { + self.spark = spark + self.plan = sqlText.toSparkConnectPlan(args) } public func getPlan() -> Sendable { diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index 5d75b3d..e841fa4 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -31,6 +31,74 @@ extension String { return plan } + func toSparkConnectPlan(_ posArguments: [Sendable]) -> Plan { + var sql = Spark_Connect_SQL() + sql.query = self + sql.posArguments = posArguments.map { + var literal = ExpressionLiteral() + switch $0 { + case let value as Bool: + literal.boolean = value + case let value as Int8: + literal.byte = Int32(value) + case let value as Int16: + literal.short = Int32(value) + case let value as Int32: + literal.integer = value + case let value as Int64: + literal.long = value + case let value as Int: + literal.long = Int64(value) + case let value as String: + literal.string = value + default: + literal.string = $0 as! String + } + var expr = Spark_Connect_Expression() + expr.literal = literal + return expr + } + var relation = Relation() + relation.sql = sql + var plan = Plan() + plan.opType = Plan.OneOf_OpType.root(relation) + return plan + } + + func toSparkConnectPlan(_ namedArguments: [String: Sendable]) -> Plan { + var sql = Spark_Connect_SQL() + sql.query = self + sql.namedArguments = namedArguments.mapValues { value in + var literal = ExpressionLiteral() + switch value { + case let value as Bool: + literal.boolean = value + case let value as Int8: + literal.byte = Int32(value) + case let value as Int16: + literal.short = Int32(value) + case let value as Int32: + literal.integer = value + case let value as Int64: + literal.long = value + case let value as Int: + literal.long = Int64(value) + case let value as String: + literal.string = value + default: + literal.string = value as! String + } + var expr = Spark_Connect_Expression() + expr.literal = literal + return expr + } + var relation = Relation() + relation.sql = sql + var plan = Plan() + plan.opType = Plan.OneOf_OpType.root(relation) + return plan + } + /// Get a `UserContext` instance from a string. var toUserContext: UserContext { var context = UserContext() diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index b06370e..ebaf190 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -112,6 +112,26 @@ public actor SparkSession { return try await DataFrame(spark: self, sqlText: sqlText) } + /// Executes a SQL query substituting positional parameters by the given arguments, returning the + /// result as a `DataFrame`. + /// - Parameters: + /// - sqlText: A SQL statement with positional parameters to execute. + /// - args: ``Sendable`` values that can be converted to SQL literal expressions. + /// - Returns: A ``DataFrame``. + public func sql(_ sqlText: String, _ args: Sendable...) async throws -> DataFrame { + return try await DataFrame(spark: self, sqlText: sqlText, args) + } + + /// Executes a SQL query substituting named parameters by the given arguments, returning the + /// result as a `DataFrame`. + /// - Parameters: + /// - sqlText: A SQL statement with named parameters to execute. + /// - args: A dictionary with key string and ``Sendable`` value. + /// - Returns: A ``DataFrame``. + public func sql(_ sqlText: String, args: [String: Sendable]) async throws -> DataFrame { + return try await DataFrame(spark: self, sqlText: sqlText, args) + } + /// Returns a ``DataFrameReader`` that can be used to read non-streaming data in as a /// `DataFrame` public var read: DataFrameReader { diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 60f0fb8..41547f8 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -28,6 +28,7 @@ typealias Drop = Spark_Connect_Drop typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode +typealias ExpressionLiteral = Spark_Connect_Expression.Literal typealias ExpressionString = Spark_Connect_Expression.ExpressionString typealias Filter = Spark_Connect_Filter typealias GroupType = Spark_Connect_Aggregate.GroupType diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index 2bc887e..69f0aee 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -76,6 +76,19 @@ struct SparkSessionTests { await spark.stop() } +#if !os(Linux) + @Test + func sql() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let expected = [Row(true, 1, "a")] + if await spark.version.starts(with: "4.") { + #expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect() == expected) + #expect(try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y": 1, "z": "a"]).collect() == expected) + } + await spark.stop() + } +#endif + @Test func table() async throws { let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org