This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git


The following commit(s) were added to refs/heads/main by this push:
     new 2fb55a0  [SPARK-51879] Support `groupBy/rollup/cube` in `DataFrame`
2fb55a0 is described below

commit 2fb55a0468c9f9346560428fa0017885390cac62
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Thu Apr 24 01:49:50 2025 +0900

    [SPARK-51879] Support `groupBy/rollup/cube` in `DataFrame`
    
    ### What changes were proposed in this pull request?
    
    This PR aims to support `groupBy`, `rollup`, and `cube` API in `DataFrame`.
    
    ### Why are the changes needed?
    
    For feature parity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, these are additional APIs.
    
    ### How was this patch tested?
    
    Pass the CIs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #87 from dongjoon-hyun/SPARK-51879.
    
    Authored-by: Dongjoon Hyun <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 Sources/SparkConnect/DataFrame.swift         | 23 ++++++++++
 Sources/SparkConnect/Extension.swift         | 11 +++++
 Sources/SparkConnect/GroupedData.swift       | 51 ++++++++++++++++++++
 Sources/SparkConnect/TypeAliases.swift       |  2 +
 Tests/SparkConnectTests/DataFrameTests.swift | 69 ++++++++++++++++++++++++++++
 5 files changed, 156 insertions(+)

diff --git a/Sources/SparkConnect/DataFrame.swift 
b/Sources/SparkConnect/DataFrame.swift
index 83dbba1..8679362 100644
--- a/Sources/SparkConnect/DataFrame.swift
+++ b/Sources/SparkConnect/DataFrame.swift
@@ -733,6 +733,29 @@ public actor DataFrame: Sendable {
     return buildRepartition(numPartitions: numPartitions, shuffle: false)
   }
 
+  /// Groups the ``DataFrame`` using the specified columns, so we can run 
aggregation on them.
+  /// - Parameter cols: Grouping column names.
+  /// - Returns: A ``GroupedData``.
+  public func groupBy(_ cols: String...) -> GroupedData {
+    return GroupedData(self, GroupType.groupby, cols)
+  }
+
+  /// Create a multi-dimensional rollup for the current ``DataFrame`` using 
the specified columns, so we
+  /// can run aggregation on them.
+  /// - Parameter cols: Grouping column names.
+  /// - Returns: A ``GroupedData``.
+  public func rollup(_ cols: String...) -> GroupedData {
+    return GroupedData(self, GroupType.rollup, cols)
+  }
+
+  /// Create a multi-dimensional cube for the current ``DataFrame`` using the 
specified columns, so we
+  /// can run aggregation on them.
+  /// - Parameter cols: Grouping column names.
+  /// - Returns: A ``GroupedData``.
+  public func cube(_ cols: String...) -> GroupedData {
+    return GroupedData(self, GroupType.cube, cols)
+  }
+
   /// Returns a ``DataFrameWriter`` that can be used to write non-streaming 
data.
   public var write: DataFrameWriter {
     get {
diff --git a/Sources/SparkConnect/Extension.swift 
b/Sources/SparkConnect/Extension.swift
index d41b5b1..5d75b3d 100644
--- a/Sources/SparkConnect/Extension.swift
+++ b/Sources/SparkConnect/Extension.swift
@@ -93,6 +93,17 @@ extension String {
     default: JoinType.inner
     }
   }
+
+  var toGroupType: GroupType {
+    return switch self.lowercased() {
+    case "groupby": .groupby
+    case "rollup": .rollup
+    case "cube": .cube
+    case "pivot": .pivot
+    case "groupingsets": .groupingSets
+    default: .UNRECOGNIZED(-1)
+    }
+  }
 }
 
 extension [String: String] {
diff --git a/Sources/SparkConnect/GroupedData.swift 
b/Sources/SparkConnect/GroupedData.swift
new file mode 100644
index 0000000..a460832
--- /dev/null
+++ b/Sources/SparkConnect/GroupedData.swift
@@ -0,0 +1,51 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//  http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+
+public actor GroupedData {
+  let df: DataFrame
+  let groupType: GroupType
+  let groupingCols: [String]
+
+  init(_ df: DataFrame, _ groupType: GroupType, _ groupingCols: [String]) {
+    self.df = df
+    self.groupType = groupType
+    self.groupingCols = groupingCols
+  }
+
+  public func agg(_ exprs: String...) async -> DataFrame {
+    var aggregate = Aggregate()
+    aggregate.input = await (self.df.getPlan() as! Plan).root
+    aggregate.groupType = self.groupType
+    aggregate.groupingExpressions = self.groupingCols.map {
+      var expr = Spark_Connect_Expression()
+      expr.expressionString = $0.toExpressionString
+      return expr
+    }
+    aggregate.aggregateExpressions = exprs.map {
+      var expr = Spark_Connect_Expression()
+      expr.expressionString = $0.toExpressionString
+      return expr
+    }
+    var relation = Relation()
+    relation.aggregate = aggregate
+    var plan = Plan()
+    plan.opType = .root(relation)
+    return await DataFrame(spark: df.spark, plan: plan)
+  }
+}
diff --git a/Sources/SparkConnect/TypeAliases.swift 
b/Sources/SparkConnect/TypeAliases.swift
index 2858de2..f0dcc04 100644
--- a/Sources/SparkConnect/TypeAliases.swift
+++ b/Sources/SparkConnect/TypeAliases.swift
@@ -16,6 +16,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+typealias Aggregate = Spark_Connect_Aggregate
 typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest
 typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
 typealias Command = Spark_Connect_Command
@@ -29,6 +30,7 @@ typealias ExecutePlanResponse = 
Spark_Connect_ExecutePlanResponse
 typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
 typealias ExpressionString = Spark_Connect_Expression.ExpressionString
 typealias Filter = Spark_Connect_Filter
+typealias GroupType = Spark_Connect_Aggregate.GroupType
 typealias Join = Spark_Connect_Join
 typealias JoinType = Spark_Connect_Join.JoinType
 typealias KeyValue = Spark_Connect_KeyValue
diff --git a/Tests/SparkConnectTests/DataFrameTests.swift 
b/Tests/SparkConnectTests/DataFrameTests.swift
index 5772120..7fd1403 100644
--- a/Tests/SparkConnectTests/DataFrameTests.swift
+++ b/Tests/SparkConnectTests/DataFrameTests.swift
@@ -24,6 +24,20 @@ import SparkConnect
 
 /// A test suite for `DataFrame`
 struct DataFrameTests {
+  let DEALER_TABLE =
+    """
+    VALUES
+      (100, 'Fremont', 'Honda Civic', 10),
+      (100, 'Fremont', 'Honda Accord', 15),
+      (100, 'Fremont', 'Honda CRV', 7),
+      (200, 'Dublin', 'Honda Civic', 20),
+      (200, 'Dublin', 'Honda Accord', 10),
+      (200, 'Dublin', 'Honda CRV', 3),
+      (300, 'San Jose', 'Honda Civic', 5),
+      (300, 'San Jose', 'Honda Accord', 8)
+    dealer (id, city, car_model, quantity)
+    """
+
   @Test
   func sparkSession() async throws {
     let spark = try await SparkSession.builder.getOrCreate()
@@ -577,6 +591,61 @@ struct DataFrameTests {
     #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
     await spark.stop()
   }
+
+  @Test
+  func groupBy() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let rows = try await spark.range(3).groupBy("id").agg("count(*)", 
"sum(*)", "avg(*)").collect()
+    #expect(rows == [Row("0", "1", "0", "0.0"), Row("1", "1", "1", "1.0"), 
Row("2", "1", "2", "2.0")])
+    await spark.stop()
+  }
+
+  @Test
+  func rollup() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model")
+      .agg("sum(quantity) sum").orderBy("city", "car_model").collect()
+    #expect(rows == [
+      Row("Dublin", "Honda Accord", "10"),
+      Row("Dublin", "Honda CRV", "3"),
+      Row("Dublin", "Honda Civic", "20"),
+      Row("Dublin", nil, "33"),
+      Row("Fremont", "Honda Accord", "15"),
+      Row("Fremont", "Honda CRV", "7"),
+      Row("Fremont", "Honda Civic", "10"),
+      Row("Fremont", nil, "32"),
+      Row("San Jose", "Honda Accord", "8"),
+      Row("San Jose", "Honda Civic", "5"),
+      Row("San Jose", nil, "13"),
+      Row(nil, nil, "78"),
+    ])
+    await spark.stop()
+  }
+
+  @Test
+  func cube() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model")
+      .agg("sum(quantity) sum").orderBy("city", "car_model").collect()
+    #expect(rows == [
+      Row("Dublin", "Honda Accord", "10"),
+      Row("Dublin", "Honda CRV", "3"),
+      Row("Dublin", "Honda Civic", "20"),
+      Row("Dublin", nil, "33"),
+      Row("Fremont", "Honda Accord", "15"),
+      Row("Fremont", "Honda CRV", "7"),
+      Row("Fremont", "Honda Civic", "10"),
+      Row("Fremont", nil, "32"),
+      Row("San Jose", "Honda Accord", "8"),
+      Row("San Jose", "Honda Civic", "5"),
+      Row("San Jose", nil, "13"),
+      Row(nil, "Honda Accord", "33"),
+      Row(nil, "Honda CRV", "10"),
+      Row(nil, "Honda Civic", "35"),
+      Row(nil, nil, "78"),
+    ])
+    await spark.stop()
+  }
 #endif
 
   @Test


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to