This is an automated email from the ASF dual-hosted git repository.
dongjoon-hyun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git
The following commit(s) were added to refs/heads/main by this push:
new e035f3d [SPARK-57305] Support `stat.(cov|corr)` for `DataFrame`
e035f3d is described below
commit e035f3d9cb564d9a7b09b8f70ec08babcb83cb59
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Sun Jun 7 14:10:10 2026 -0700
[SPARK-57305] Support `stat.(cov|corr)` for `DataFrame`
### What changes were proposed in this pull request?
This PR aims to support `cov` and `corr` for `DataFrame` by wiring the
`StatCov`/`StatCorr` Spark Connect relations through a new
`DataFrameStatFunctions`, exposed via `DataFrame.stat` like PySpark/Scala.
```swift
public func cov(_ col1: String, _ col2: String) async throws -> Double
public func corr(_ col1: String, _ col2: String, method: String =
"pearson") async throws -> Double
```
### Why are the changes needed?
To improve API coverage by mirroring PySpark/Scala `DataFrameStatFunctions`.
### Does this PR introduce _any_ user-facing change?
Yes, this PR adds new APIs, `DataFrame.stat.cov` and `DataFrame.stat.corr`.
### How was this patch tested?
Pass the CIs with a newly added test suite, `DataFrameStatFunctionsTests`.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (Claude Opus 4.8)
Closes #407 from dongjoon-hyun/SPARK-57305.
Authored-by: Dongjoon Hyun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
Sources/SparkConnect/DataFrameStatFunctions.swift | 69 ++++++++++++++++++++++
Sources/SparkConnect/SparkConnectClient.swift | 19 ++++++
.../DataFrameStatFunctionsTests.swift | 45 ++++++++++++++
3 files changed, 133 insertions(+)
diff --git a/Sources/SparkConnect/DataFrameStatFunctions.swift
b/Sources/SparkConnect/DataFrameStatFunctions.swift
new file mode 100644
index 0000000..8c9bf72
--- /dev/null
+++ b/Sources/SparkConnect/DataFrameStatFunctions.swift
@@ -0,0 +1,69 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+
+/// Statistic functions for ``DataFrame``s.
+///
+/// Use ``DataFrame/stat`` to access this. It mirrors PySpark's
`DataFrameStatFunctions`
+/// (`df.stat.cov`, `df.stat.corr`).
+public actor DataFrameStatFunctions: Sendable {
+ let df: DataFrame
+
+ init(df: DataFrame) {
+ self.df = df
+ }
+
+ /// Calculates the sample covariance of two numerical columns of a
``DataFrame``.
+ /// - Parameters:
+ /// - col1: The name of the first column.
+ /// - col2: The name of the second column.
+ /// - Returns: The sample covariance of the two columns.
+ public func cov(_ col1: String, _ col2: String) async throws -> Double {
+ return try await collectDouble { SparkConnectClient.getStatCov($0, col1,
col2) }
+ }
+
+ /// Calculates the correlation of two columns of a ``DataFrame``. Currently
only supports the
+ /// Pearson Correlation Coefficient.
+ /// - Parameters:
+ /// - col1: The name of the first column.
+ /// - col2: The name of the second column.
+ /// - method: The correlation method. Currently only `pearson` is
supported.
+ /// - Returns: The Pearson Correlation Coefficient of the two columns.
+ public func corr(
+ _ col1: String, _ col2: String, method: String = "pearson"
+ ) async throws -> Double {
+ return try await collectDouble { SparkConnectClient.getStatCorr($0, col1,
col2, method) }
+ }
+
+ // MARK: - Helpers
+
+ /// Builds a single-value ``DataFrame`` from this ``DataFrame``'s plan using
the given plan
+ /// builder, executes it, and returns the resulting `Double`.
+ private func collectDouble(_ f: (Relation) -> Plan) async throws -> Double {
+ let plan = await df.getPlan() as! Plan
+ let result = DataFrame(spark: await df.spark, plan: f(plan.root))
+ return try await result.collect()[0].get(0) as! Double
+ }
+}
+
+extension DataFrame {
+ /// Returns a ``DataFrameStatFunctions`` for working with statistic
functions.
+ public var stat: DataFrameStatFunctions {
+ DataFrameStatFunctions(df: self)
+ }
+}
diff --git a/Sources/SparkConnect/SparkConnectClient.swift
b/Sources/SparkConnect/SparkConnectClient.swift
index 6793302..e43fe34 100644
--- a/Sources/SparkConnect/SparkConnectClient.swift
+++ b/Sources/SparkConnect/SparkConnectClient.swift
@@ -612,6 +612,25 @@ public actor SparkConnectClient {
return createPlan { $0.summary = summary }
}
+ static func getStatCov(_ child: Relation, _ col1: String, _ col2: String) ->
Plan {
+ var cov = Spark_Connect_StatCov()
+ cov.input = child
+ cov.col1 = col1
+ cov.col2 = col2
+ return createPlan { $0.cov = cov }
+ }
+
+ static func getStatCorr(
+ _ child: Relation, _ col1: String, _ col2: String, _ method: String
+ ) -> Plan {
+ var corr = Spark_Connect_StatCorr()
+ corr.input = child
+ corr.col1 = col1
+ corr.col2 = col2
+ corr.method = method
+ return createPlan { $0.corr = corr }
+ }
+
static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
var sort = Sort()
sort.input = child
diff --git a/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift
b/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift
new file mode 100644
index 0000000..54a0315
--- /dev/null
+++ b/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift
@@ -0,0 +1,45 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+
+import SparkConnect
+import Testing
+
+/// A test suite for `DataFrameStatFunctions`
+@Suite(.serialized)
+struct DataFrameStatFunctionsTests {
+ @Test
+ func cov() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.sql("SELECT * FROM VALUES (1, 2), (2, 4), (3, 6)
AS T(c1, c2)")
+ #expect(try await df.stat.cov("c1", "c2") == 2.0)
+ #expect(try await df.stat.cov("c1", "c1") == 1.0)
+ await spark.stop()
+ }
+
+ @Test
+ func corr() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let df = try await spark.sql("SELECT * FROM VALUES (1, 2), (2, 4), (3, 6)
AS T(c1, c2)")
+ // Perfectly positively correlated columns.
+ #expect(try await df.stat.corr("c1", "c2") == 1.0)
+ // `method` defaults to `pearson`.
+ #expect(try await df.stat.corr("c1", "c2", method: "pearson") == 1.0)
+ await spark.stop()
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]