This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git
The following commit(s) were added to refs/heads/main by this push:
new 2fb55a0 [SPARK-51879] Support `groupBy/rollup/cube` in `DataFrame`
2fb55a0 is described below
commit 2fb55a0468c9f9346560428fa0017885390cac62
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Thu Apr 24 01:49:50 2025 +0900
[SPARK-51879] Support `groupBy/rollup/cube` in `DataFrame`
### What changes were proposed in this pull request?
This PR aims to support `groupBy`, `rollup`, and `cube` API in `DataFrame`.
### Why are the changes needed?
For feature parity.
### Does this PR introduce _any_ user-facing change?
No, these are additional APIs.
### How was this patch tested?
Pass the CIs.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #87 from dongjoon-hyun/SPARK-51879.
Authored-by: Dongjoon Hyun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
Sources/SparkConnect/DataFrame.swift | 23 ++++++++++
Sources/SparkConnect/Extension.swift | 11 +++++
Sources/SparkConnect/GroupedData.swift | 51 ++++++++++++++++++++
Sources/SparkConnect/TypeAliases.swift | 2 +
Tests/SparkConnectTests/DataFrameTests.swift | 69 ++++++++++++++++++++++++++++
5 files changed, 156 insertions(+)
diff --git a/Sources/SparkConnect/DataFrame.swift
b/Sources/SparkConnect/DataFrame.swift
index 83dbba1..8679362 100644
--- a/Sources/SparkConnect/DataFrame.swift
+++ b/Sources/SparkConnect/DataFrame.swift
@@ -733,6 +733,29 @@ public actor DataFrame: Sendable {
return buildRepartition(numPartitions: numPartitions, shuffle: false)
}
+ /// Groups the ``DataFrame`` using the specified columns, so we can run
aggregation on them.
+ /// - Parameter cols: Grouping column names.
+ /// - Returns: A ``GroupedData``.
+ public func groupBy(_ cols: String...) -> GroupedData {
+ return GroupedData(self, GroupType.groupby, cols)
+ }
+
+ /// Create a multi-dimensional rollup for the current ``DataFrame`` using
the specified columns, so we
+ /// can run aggregation on them.
+ /// - Parameter cols: Grouping column names.
+ /// - Returns: A ``GroupedData``.
+ public func rollup(_ cols: String...) -> GroupedData {
+ return GroupedData(self, GroupType.rollup, cols)
+ }
+
+ /// Create a multi-dimensional cube for the current ``DataFrame`` using the
specified columns, so we
+ /// can run aggregation on them.
+ /// - Parameter cols: Grouping column names.
+ /// - Returns: A ``GroupedData``.
+ public func cube(_ cols: String...) -> GroupedData {
+ return GroupedData(self, GroupType.cube, cols)
+ }
+
/// Returns a ``DataFrameWriter`` that can be used to write non-streaming
data.
public var write: DataFrameWriter {
get {
diff --git a/Sources/SparkConnect/Extension.swift
b/Sources/SparkConnect/Extension.swift
index d41b5b1..5d75b3d 100644
--- a/Sources/SparkConnect/Extension.swift
+++ b/Sources/SparkConnect/Extension.swift
@@ -93,6 +93,17 @@ extension String {
default: JoinType.inner
}
}
+
+ var toGroupType: GroupType {
+ return switch self.lowercased() {
+ case "groupby": .groupby
+ case "rollup": .rollup
+ case "cube": .cube
+ case "pivot": .pivot
+ case "groupingsets": .groupingSets
+ default: .UNRECOGNIZED(-1)
+ }
+ }
}
extension [String: String] {
diff --git a/Sources/SparkConnect/GroupedData.swift
b/Sources/SparkConnect/GroupedData.swift
new file mode 100644
index 0000000..a460832
--- /dev/null
+++ b/Sources/SparkConnect/GroupedData.swift
@@ -0,0 +1,51 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+
+public actor GroupedData {
+ let df: DataFrame
+ let groupType: GroupType
+ let groupingCols: [String]
+
+ init(_ df: DataFrame, _ groupType: GroupType, _ groupingCols: [String]) {
+ self.df = df
+ self.groupType = groupType
+ self.groupingCols = groupingCols
+ }
+
+ public func agg(_ exprs: String...) async -> DataFrame {
+ var aggregate = Aggregate()
+ aggregate.input = await (self.df.getPlan() as! Plan).root
+ aggregate.groupType = self.groupType
+ aggregate.groupingExpressions = self.groupingCols.map {
+ var expr = Spark_Connect_Expression()
+ expr.expressionString = $0.toExpressionString
+ return expr
+ }
+ aggregate.aggregateExpressions = exprs.map {
+ var expr = Spark_Connect_Expression()
+ expr.expressionString = $0.toExpressionString
+ return expr
+ }
+ var relation = Relation()
+ relation.aggregate = aggregate
+ var plan = Plan()
+ plan.opType = .root(relation)
+ return await DataFrame(spark: df.spark, plan: plan)
+ }
+}
diff --git a/Sources/SparkConnect/TypeAliases.swift
b/Sources/SparkConnect/TypeAliases.swift
index 2858de2..f0dcc04 100644
--- a/Sources/SparkConnect/TypeAliases.swift
+++ b/Sources/SparkConnect/TypeAliases.swift
@@ -16,6 +16,7 @@
// specific language governing permissions and limitations
// under the License.
+typealias Aggregate = Spark_Connect_Aggregate
typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest
typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
typealias Command = Spark_Connect_Command
@@ -29,6 +30,7 @@ typealias ExecutePlanResponse =
Spark_Connect_ExecutePlanResponse
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
typealias Filter = Spark_Connect_Filter
+typealias GroupType = Spark_Connect_Aggregate.GroupType
typealias Join = Spark_Connect_Join
typealias JoinType = Spark_Connect_Join.JoinType
typealias KeyValue = Spark_Connect_KeyValue
diff --git a/Tests/SparkConnectTests/DataFrameTests.swift
b/Tests/SparkConnectTests/DataFrameTests.swift
index 5772120..7fd1403 100644
--- a/Tests/SparkConnectTests/DataFrameTests.swift
+++ b/Tests/SparkConnectTests/DataFrameTests.swift
@@ -24,6 +24,20 @@ import SparkConnect
/// A test suite for `DataFrame`
struct DataFrameTests {
+ let DEALER_TABLE =
+ """
+ VALUES
+ (100, 'Fremont', 'Honda Civic', 10),
+ (100, 'Fremont', 'Honda Accord', 15),
+ (100, 'Fremont', 'Honda CRV', 7),
+ (200, 'Dublin', 'Honda Civic', 20),
+ (200, 'Dublin', 'Honda Accord', 10),
+ (200, 'Dublin', 'Honda CRV', 3),
+ (300, 'San Jose', 'Honda Civic', 5),
+ (300, 'San Jose', 'Honda Accord', 8)
+ dealer (id, city, car_model, quantity)
+ """
+
@Test
func sparkSession() async throws {
let spark = try await SparkSession.builder.getOrCreate()
@@ -577,6 +591,61 @@ struct DataFrameTests {
#expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
await spark.stop()
}
+
+ @Test
+ func groupBy() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let rows = try await spark.range(3).groupBy("id").agg("count(*)",
"sum(*)", "avg(*)").collect()
+ #expect(rows == [Row("0", "1", "0", "0.0"), Row("1", "1", "1", "1.0"),
Row("2", "1", "2", "2.0")])
+ await spark.stop()
+ }
+
+ @Test
+ func rollup() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model")
+ .agg("sum(quantity) sum").orderBy("city", "car_model").collect()
+ #expect(rows == [
+ Row("Dublin", "Honda Accord", "10"),
+ Row("Dublin", "Honda CRV", "3"),
+ Row("Dublin", "Honda Civic", "20"),
+ Row("Dublin", nil, "33"),
+ Row("Fremont", "Honda Accord", "15"),
+ Row("Fremont", "Honda CRV", "7"),
+ Row("Fremont", "Honda Civic", "10"),
+ Row("Fremont", nil, "32"),
+ Row("San Jose", "Honda Accord", "8"),
+ Row("San Jose", "Honda Civic", "5"),
+ Row("San Jose", nil, "13"),
+ Row(nil, nil, "78"),
+ ])
+ await spark.stop()
+ }
+
+ @Test
+ func cube() async throws {
+ let spark = try await SparkSession.builder.getOrCreate()
+ let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model")
+ .agg("sum(quantity) sum").orderBy("city", "car_model").collect()
+ #expect(rows == [
+ Row("Dublin", "Honda Accord", "10"),
+ Row("Dublin", "Honda CRV", "3"),
+ Row("Dublin", "Honda Civic", "20"),
+ Row("Dublin", nil, "33"),
+ Row("Fremont", "Honda Accord", "15"),
+ Row("Fremont", "Honda CRV", "7"),
+ Row("Fremont", "Honda Civic", "10"),
+ Row("Fremont", nil, "32"),
+ Row("San Jose", "Honda Accord", "8"),
+ Row("San Jose", "Honda Civic", "5"),
+ Row("San Jose", nil, "13"),
+ Row(nil, "Honda Accord", "33"),
+ Row(nil, "Honda CRV", "10"),
+ Row(nil, "Honda Civic", "35"),
+ Row(nil, nil, "78"),
+ ])
+ await spark.stop()
+ }
#endif
@Test
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]