This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/spark-connect-swift.git


The following commit(s) were added to refs/heads/main by this push:
     new 9c8b0eb  [SPARK-51911] Support `lateralJoin` in `DataFrame`
9c8b0eb is described below

commit 9c8b0eb4001a88099a332848079c531523acc445
Author: Dongjoon Hyun <dongj...@apache.org>
AuthorDate: Fri Apr 25 16:29:10 2025 +0900

    [SPARK-51911] Support `lateralJoin` in `DataFrame`
    
    ### What changes were proposed in this pull request?
    
    This PR aims to support `lateralJoin` API in `DataFrame`.
    
    ### Why are the changes needed?
    
    To provide a foundation of `lateralJoin` API although `column` is not 
supported yet.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Pass the CIs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #88 from dongjoon-hyun/SPARK-51911.
    
    Authored-by: Dongjoon Hyun <dongj...@apache.org>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 Sources/SparkConnect/DataFrame.swift          | 76 +++++++++++++++++++++++++++
 Sources/SparkConnect/SparkConnectClient.swift | 18 +++++++
 Sources/SparkConnect/TypeAliases.swift        |  1 +
 Tests/SparkConnectTests/DataFrameTests.swift  | 20 +++++++
 4 files changed, 115 insertions(+)

diff --git a/Sources/SparkConnect/DataFrame.swift 
b/Sources/SparkConnect/DataFrame.swift
index 3e3b484..7c81105 100644
--- a/Sources/SparkConnect/DataFrame.swift
+++ b/Sources/SparkConnect/DataFrame.swift
@@ -609,6 +609,82 @@ public actor DataFrame: Sendable {
     return DataFrame(spark: self.spark, plan: plan)
   }
 
+  /// Lateral join with another ``DataFrame``.
+  ///
+  /// Behaves as an JOIN LATERAL.
+  ///
+  /// - Parameters:
+  ///   - right: Right side of the join operation.
+  /// - Returns: A ``DataFrame``.
+  public func lateralJoin(_ right: DataFrame) async -> DataFrame {
+    let rightPlan = await (right.getPlan() as! Plan).root
+    let plan = SparkConnectClient.getLateralJoin(
+      self.plan.root,
+      rightPlan,
+      JoinType.inner
+    )
+    return DataFrame(spark: self.spark, plan: plan)
+  }
+
+  /// Lateral join with another ``DataFrame``.
+  ///
+  /// Behaves as an JOIN LATERAL.
+  ///
+  /// - Parameters:
+  ///   - right: Right side of the join operation.
+  ///   - joinType: One of `inner` (default), `cross`, `left`, `leftouter`, 
`left_outer`.
+  /// - Returns: A ``DataFrame``.
+  public func lateralJoin(_ right: DataFrame, joinType: String) async -> 
DataFrame {
+    let rightPlan = await (right.getPlan() as! Plan).root
+    let plan = SparkConnectClient.getLateralJoin(
+      self.plan.root,
+      rightPlan,
+      joinType.toJoinType
+    )
+    return DataFrame(spark: self.spark, plan: plan)
+  }
+
+  /// Lateral join with another ``DataFrame``.
+  ///
+  /// Behaves as an JOIN LATERAL.
+  ///
+  /// - Parameters:
+  ///   - right: Right side of the join operation.
+  ///   - joinExprs: A join expression string.
+  /// - Returns: A ``DataFrame``.
+  public func lateralJoin(_ right: DataFrame, joinExprs: String) async -> 
DataFrame {
+    let rightPlan = await (right.getPlan() as! Plan).root
+    let plan = SparkConnectClient.getLateralJoin(
+      self.plan.root,
+      rightPlan,
+      JoinType.inner,
+      joinCondition: joinExprs
+    )
+    return DataFrame(spark: self.spark, plan: plan)
+  }
+
+  /// Lateral join with another ``DataFrame``.
+  ///
+  /// Behaves as an JOIN LATERAL.
+  ///
+  /// - Parameters:
+  ///   - right: Right side of the join operation.
+  ///   - joinType: One of `inner` (default), `cross`, `left`, `leftouter`, 
`left_outer`.
+  ///   - joinExprs: A join expression string.
+  /// - Returns: A ``DataFrame``.
+  public func lateralJoin(
+    _ right: DataFrame, joinExprs: String, joinType: String = "inner"
+  ) async -> DataFrame {
+    let rightPlan = await (right.getPlan() as! Plan).root
+    let plan = SparkConnectClient.getLateralJoin(
+      self.plan.root,
+      rightPlan,
+      joinType.toJoinType,
+      joinCondition: joinExprs
+    )
+    return DataFrame(spark: self.spark, plan: plan)
+  }
+
   /// Returns a new `DataFrame` containing rows in this `DataFrame` but not in 
another `DataFrame`.
   /// This is equivalent to `EXCEPT DISTINCT` in SQL.
   /// - Parameter other: A `DataFrame` to exclude.
diff --git a/Sources/SparkConnect/SparkConnectClient.swift 
b/Sources/SparkConnect/SparkConnectClient.swift
index 2d4e2a9..fa7c392 100644
--- a/Sources/SparkConnect/SparkConnectClient.swift
+++ b/Sources/SparkConnect/SparkConnectClient.swift
@@ -617,6 +617,24 @@ public actor SparkConnectClient {
     return plan
   }
 
+  static func getLateralJoin(
+    _ left: Relation, _ right: Relation, _ joinType: JoinType,
+    joinCondition: String? = nil
+  ) -> Plan {
+    var lateralJoin = LateralJoin()
+    lateralJoin.left = left
+    lateralJoin.right = right
+    lateralJoin.joinType = joinType
+    if let joinCondition {
+      lateralJoin.joinCondition.expressionString = 
joinCondition.toExpressionString
+    }
+    var relation = Relation()
+    relation.lateralJoin = lateralJoin
+    var plan = Plan()
+    plan.opType = .root(relation)
+    return plan
+  }
+
   static func getSetOperation(
     _ left: Relation, _ right: Relation, _ opType: SetOpType, isAll: Bool = 
false,
     byName: Bool = false, allowMissingColumns: Bool = false
diff --git a/Sources/SparkConnect/TypeAliases.swift 
b/Sources/SparkConnect/TypeAliases.swift
index f0dcc04..ba35fb6 100644
--- a/Sources/SparkConnect/TypeAliases.swift
+++ b/Sources/SparkConnect/TypeAliases.swift
@@ -34,6 +34,7 @@ typealias GroupType = Spark_Connect_Aggregate.GroupType
 typealias Join = Spark_Connect_Join
 typealias JoinType = Spark_Connect_Join.JoinType
 typealias KeyValue = Spark_Connect_KeyValue
+typealias LateralJoin = Spark_Connect_LateralJoin
 typealias Limit = Spark_Connect_Limit
 typealias MapType = Spark_Connect_DataType.Map
 typealias NamedTable = Spark_Connect_Read.NamedTable
diff --git a/Tests/SparkConnectTests/DataFrameTests.swift 
b/Tests/SparkConnectTests/DataFrameTests.swift
index 7f06925..f8673e5 100644
--- a/Tests/SparkConnectTests/DataFrameTests.swift
+++ b/Tests/SparkConnectTests/DataFrameTests.swift
@@ -479,6 +479,26 @@ struct DataFrameTests {
     await spark.stop()
   }
 
+  @Test
+  func lateralJoin() async throws {
+    let spark = try await SparkSession.builder.getOrCreate()
+    let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b', '2') 
AS T(a, b)")
+    let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d', '3') 
AS S(c, b)")
+    let expectedCross = [
+      Row("a", "1", "c", "2"),
+      Row("a", "1", "d", "3"),
+      Row("b", "2", "c", "2"),
+      Row("b", "2", "d", "3"),
+    ]
+    #expect(try await df1.lateralJoin(df2).collect() == expectedCross)
+    #expect(try await df1.lateralJoin(df2, joinType: "inner").collect() == 
expectedCross)
+
+    let expected = [Row("b", "2", "c", "2")]
+    #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect() 
== expected)
+    #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType: 
"inner").collect() == expected)
+    await spark.stop()
+  }
+
   @Test
   func except() async throws {
     let spark = try await SparkSession.builder.getOrCreate()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to