This is an automated email from the ASF dual-hosted git repository.

dongjoon-hyun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git


The following commit(s) were added to refs/heads/main by this push:
     new 10a3455  [SPARK-57309] Support `stat.sampleBy` for `DataFrame`
10a3455 is described below

commit 10a345563547936f2deb4692821c73a8fdbf0df2
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Sun Jun 7 18:47:58 2026 -0700

    [SPARK-57309] Support `stat.sampleBy` for `DataFrame`
    
    ### What changes were proposed in this pull request?
    
    This PR aims to support `sampleBy` for `DataFrame` by wiring the 
`StatSampleBy`
    Spark Connect relation through `DataFrameStatFunctions`, exposed via 
`DataFrame.stat`
    like PySpark/Scala.
    
    ```swift
    public func sampleBy<T: Sendable & Hashable>(_ col: String, _ fractions: 
[T: Double], _ seed: Int64) async -> DataFrame
    public func sampleBy<T: Sendable & Hashable>(_ col: String, _ fractions: 
[T: Double]) async -> DataFrame
    ```
    
    `sampleBy` returns a stratified sample without replacement based on the 
fraction
    given for each stratum. A stratum that is not specified is treated as 
having a
    fraction of zero. The seed is optional; a random seed is used when it is 
omitted.
    
    ### Why are the changes needed?
    
    To improve API coverage by mirroring PySpark/Scala `DataFrameStatFunctions`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this PR adds a new API, `DataFrame.stat.sampleBy`.
    
    ### How was this patch tested?
    
    Pass the CIs with a new test case, `sampleBy`, in 
`DataFrameStatFunctionsTests`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code (Claude Opus 4.8)
    
    This patch had conflicts when merged, resolved by
    Committer: Dongjoon Hyun <[email protected]>
    
    Closes #411 from dongjoon-hyun/SPARK-57309.
    
    Authored-by: Dongjoon Hyun <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 Sources/SparkConnect/DataFrameStatFunctions.swift  | 57 ++++++++++++++++++++++
 Sources/SparkConnect/SparkConnectClient.swift      | 18 +++++++
 .../DataFrameStatFunctionsTests.swift              | 15 ++++++
 3 files changed, 90 insertions(+)

diff --git a/Sources/SparkConnect/DataFrameStatFunctions.swift 
b/Sources/SparkConnect/DataFrameStatFunctions.swift
index 29a9ecc..cfdf356 100644
--- a/Sources/SparkConnect/DataFrameStatFunctions.swift
+++ b/Sources/SparkConnect/DataFrameStatFunctions.swift
@@ -100,6 +100,33 @@ public actor DataFrameStatFunctions: Sendable {
     return quantilesPerColumn.map { ($0 as! [any Sendable]).map { $0 as! 
Double } }
   }
 
+  /// Returns a stratified sample without replacement based on the fraction 
given on each stratum.
+  /// - Parameters:
+  ///   - col: The name of the column that defines the strata.
+  ///   - fractions: The sampling fraction for each stratum. If a stratum is 
not specified, its
+  ///   fraction is treated as zero. Each fraction must be in `[0, 1]`.
+  ///   - seed: The random seed.
+  /// - Returns: A ``DataFrame`` representing the stratified sample.
+  public func sampleBy<T: Sendable & Hashable>(
+    _ col: String, _ fractions: [T: Double], _ seed: Int64
+  ) async -> DataFrame {
+    let fractionLiterals = fractions.map { (stratumLiteral($0.key), $0.value) }
+    return await transform { SparkConnectClient.getStatSampleBy($0, col, 
fractionLiterals, seed) }
+  }
+
+  /// Returns a stratified sample without replacement based on the fraction 
given on each stratum,
+  /// using a random seed.
+  /// - Parameters:
+  ///   - col: The name of the column that defines the strata.
+  ///   - fractions: The sampling fraction for each stratum. If a stratum is 
not specified, its
+  ///   fraction is treated as zero. Each fraction must be in `[0, 1]`.
+  /// - Returns: A ``DataFrame`` representing the stratified sample.
+  public func sampleBy<T: Sendable & Hashable>(
+    _ col: String, _ fractions: [T: Double]
+  ) async -> DataFrame {
+    return await sampleBy(col, fractions, Int64.random(in: 
Int64.min...Int64.max))
+  }
+
   // MARK: - Helpers
 
   /// Builds a single-value ``DataFrame`` from this ``DataFrame``'s plan using 
the given plan
@@ -109,6 +136,36 @@ public actor DataFrameStatFunctions: Sendable {
     let result = DataFrame(spark: await df.spark, plan: f(plan.root))
     return try await result.collect()[0].get(0) as! Double
   }
+
+  /// Builds a new ``DataFrame`` from this ``DataFrame``'s plan using the 
given plan builder.
+  private func transform(_ f: (Relation) -> Plan) async -> DataFrame {
+    let plan = await df.getPlan() as! Plan
+    return DataFrame(spark: await df.spark, plan: f(plan.root))
+  }
+
+  /// Converts a `sampleBy` stratum value to an ``ExpressionLiteral``.
+  private func stratumLiteral(_ value: Sendable) -> ExpressionLiteral {
+    var literal = ExpressionLiteral()
+    switch value {
+    case let value as Bool:
+      literal.boolean = value
+    case let value as Int:
+      literal.long = Int64(value)
+    case let value as Int32:
+      literal.integer = value
+    case let value as Int64:
+      literal.long = value
+    case let value as Float:
+      literal.float = value
+    case let value as Double:
+      literal.double = value
+    case let value as String:
+      literal.string = value
+    default:
+      literal.string = value as! String
+    }
+    return literal
+  }
 }
 
 extension DataFrame {
diff --git a/Sources/SparkConnect/SparkConnectClient.swift 
b/Sources/SparkConnect/SparkConnectClient.swift
index a491c42..89f32eb 100644
--- a/Sources/SparkConnect/SparkConnectClient.swift
+++ b/Sources/SparkConnect/SparkConnectClient.swift
@@ -650,6 +650,24 @@ public actor SparkConnectClient {
     return createPlan { $0.approxQuantile = approxQuantile }
   }
 
+  static func getStatSampleBy(
+    _ child: Relation, _ col: String, _ fractions: [(ExpressionLiteral, 
Double)], _ seed: Int64
+  ) -> Plan {
+    var sampleBy = Spark_Connect_StatSampleBy()
+    sampleBy.input = child
+    var colExpr = Spark_Connect_Expression()
+    colExpr.exprType = .unresolvedAttribute(col.toUnresolvedAttribute)
+    sampleBy.col = colExpr
+    sampleBy.fractions = fractions.map {
+      var fraction = Spark_Connect_StatSampleBy.Fraction()
+      fraction.stratum = $0.0
+      fraction.fraction = $0.1
+      return fraction
+    }
+    sampleBy.seed = seed
+    return createPlan { $0.sampleBy = sampleBy }
+  }
+
   static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
     var sort = Sort()
     sort.input = child
diff --git a/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift 
b/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift
index bda3b55..578a354 100644
--- a/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift
+++ b/Tests/SparkConnectTests/DataFrameStatFunctionsTests.swift
@@ -70,4 +70,19 @@ struct DataFrameStatFunctionsTests {
     #expect(quantiles == [[1.0, 3.0, 5.0], [10.0, 30.0, 50.0]])
     await spark.stop()
   }
+
+  @Test
+  func sampleBy() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    // Strata 0, 1, 2 each have 33 rows.
+    let df = try await spark.sql("SELECT id % 3 AS key FROM range(0, 99)")
+    // A fraction of 1.0 keeps every row of a stratum; an unspecified stratum 
(or 0.0) keeps none,
+    // so the result count is deterministic regardless of the seed.
+    #expect(try await df.stat.sampleBy("key", [0: 1.0, 1: 0.0], 0).count() == 
33)
+    // `Int64` strata are also supported.
+    #expect(try await df.stat.sampleBy("key", [Int64(0): 1.0, Int64(2): 1.0], 
0).count() == 66)
+    // The seed is optional.
+    #expect(try await df.stat.sampleBy("key", [0: 1.0, 1: 1.0, 2: 
1.0]).count() == 99)
+    await spark.stop()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to