This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d67ca731f19c [SPARK-50132][SQL][PYTHON] Add DataFrame API for Lateral 
Joins
d67ca731f19c is described below

commit d67ca731f19cc571e8d69245f4837c0cf28b83ae
Author: Takuya Ueshin <[email protected]>
AuthorDate: Fri Dec 6 10:40:50 2024 +0900

    [SPARK-50132][SQL][PYTHON] Add DataFrame API for Lateral Joins
    
    ### What changes were proposed in this pull request?
    
    Adds DataFrame API for Lateral Joins.
    
    #### Examples:
    
    For the following DataFrames `customers` and `orders`:
    
    ```py
    >>> customers.printSchema()
    root
     |-- customer_id: long (nullable = true)
     |-- name: string (nullable = true)
    
    >>> orders.printSchema()
    root
     |-- order_id: long (nullable = true)
     |-- customer_id: long (nullable = true)
     |-- order_date: string (nullable = true)
     |-- items: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- product: string (nullable = true)
     |    |    |-- quantity: long (nullable = true)
    ```
    
    ##### Using TVF
    
    ```py
    # select customer_id, name, order_id, order_date, product, quantity
    # from customers join orders using (customer_id) join lateral (select col.* 
from explode(items))
    # order by customer_id, order_id, product
    customers.join(orders, "customer_id").lateralJoin(
        spark.tvf.explode(sf.col("items").outer()).select("col.*")
    ).select(
        "customer_id", "name", "order_id", "order_date", "product", "quantity"
    ).orderBy("customer_id", "order_id", "product").show()
    ```
    
    ##### Using Subquery
    
    ```py
    # select c.customer_id, name, order_id, order_date
    # from customers c left join lateral (
    #     select * from orders o where o.customer_id = c.customer_id order by 
order_date desc limit 2
    # )
    # order by customer_id, order_id
    customers.alias("c").lateralJoin(
        orders.alias("o")
        .where(sf.col("o.customer_id") == sf.col("c.customer_id").outer())
        .orderBy(sf.col("order_date").desc())
        .limit(2),
        how="left"
    ).select(
        "c.customer_id", "name", "order_id", "order_date"
    ).orderBy("customer_id", "order_id").show()
    ```
    
    ### Why are the changes needed?
    
    Lateral Join APIs are missing in DataFrame API.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, new DataFrame APIs for lateral join will be available.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49033 from ueshin/issues/SPARK-50132/lateral_join.
    
    Lead-authored-by: Takuya Ueshin <[email protected]>
    Co-authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  15 +
 python/pyspark/sql/classic/dataframe.py            |  16 +
 python/pyspark/sql/connect/dataframe.py            |  13 +
 python/pyspark/sql/dataframe.py                    | 103 ++++++
 .../pyspark/sql/tests/connect/test_parity_tvf.py   |  40 ++-
 .../pyspark/sql/tests/connect/test_parity_udtf.py  |   8 +
 python/pyspark/sql/tests/test_subquery.py          | 332 ++++++++++++++++++++
 python/pyspark/sql/tests/test_tvf.py               | 349 +++++++++++++++++++++
 python/pyspark/sql/tests/test_udtf.py              |  31 ++
 .../scala/org/apache/spark/sql/api/Dataset.scala   |  54 ++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  32 ++
 .../apache/spark/sql/DataFrameSubquerySuite.scala  | 287 +++++++++++++++++
 .../sql/DataFrameTableValuedFunctionsSuite.scala   | 260 +++++++++++++++
 13 files changed, 1539 insertions(+), 1 deletion(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 631e9057f8d1..eb166a1e8003 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -383,6 +383,21 @@ class Dataset[T] private[sql] (
     }
   }
 
+  // TODO(SPARK-50134): Support Lateral Join API in Spark Connect
+  // scalastyle:off not.implemented.error.usage
+  /** @inheritdoc */
+  def lateralJoin(right: DS[_]): DataFrame = ???
+
+  /** @inheritdoc */
+  def lateralJoin(right: DS[_], joinExprs: Column): DataFrame = ???
+
+  /** @inheritdoc */
+  def lateralJoin(right: DS[_], joinType: String): DataFrame = ???
+
+  /** @inheritdoc */
+  def lateralJoin(right: DS[_], joinExprs: Column, joinType: String): 
DataFrame = ???
+  // scalastyle:on not.implemented.error.usage
+
   override protected def sortInternal(global: Boolean, sortCols: Seq[Column]): 
Dataset[T] = {
     val sortExprs = sortCols.map { c =>
       ColumnNodeToProtoConverter(c.sortOrder).getSortOrder
diff --git a/python/pyspark/sql/classic/dataframe.py 
b/python/pyspark/sql/classic/dataframe.py
index 169755c75390..05c19913adf3 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -715,6 +715,22 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
             jdf = self._jdf.join(other._jdf, on, how)
         return DataFrame(jdf, self.sparkSession)
 
+    def lateralJoin(
+        self,
+        other: ParentDataFrame,
+        on: Optional[Column] = None,
+        how: Optional[str] = None,
+    ) -> ParentDataFrame:
+        if on is None and how is None:
+            jdf = self._jdf.lateralJoin(other._jdf)
+        elif on is None:
+            jdf = self._jdf.lateralJoin(other._jdf, how)
+        elif how is None:
+            jdf = self._jdf.lateralJoin(other._jdf, on._jc)
+        else:
+            jdf = self._jdf.lateralJoin(other._jdf, on._jc, how)
+        return DataFrame(jdf, self.sparkSession)
+
     # TODO(SPARK-22947): Fix the DataFrame API.
     def _joinAsOf(
         self,
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index e85efeb592df..124ce5e0d39a 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -686,6 +686,18 @@ class DataFrame(ParentDataFrame):
             session=self._session,
         )
 
+    def lateralJoin(
+        self,
+        other: ParentDataFrame,
+        on: Optional[Column] = None,
+        how: Optional[str] = None,
+    ) -> ParentDataFrame:
+        # TODO(SPARK-50134): Implement this method
+        raise PySparkNotImplementedError(
+            errorClass="NOT_IMPLEMENTED",
+            messageParameters={"feature": "lateralJoin()"},
+        )
+
     def _joinAsOf(
         self,
         other: ParentDataFrame,
@@ -2265,6 +2277,7 @@ def _test() -> None:
     # TODO(SPARK-50134): Support subquery in connect
     del pyspark.sql.dataframe.DataFrame.scalar.__doc__
     del pyspark.sql.dataframe.DataFrame.exists.__doc__
+    del pyspark.sql.dataframe.DataFrame.lateralJoin.__doc__
 
     globs["spark"] = (
         PySparkSession.builder.appName("sql.connect.dataframe tests")
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 0ea0eef50c0f..ccb9806cc76d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2629,6 +2629,109 @@ class DataFrame:
         """
         ...
 
+    def lateralJoin(
+        self,
+        other: "DataFrame",
+        on: Optional[Column] = None,
+        how: Optional[str] = None,
+    ) -> "DataFrame":
+        """
+        Lateral joins with another :class:`DataFrame`, using the given join 
expression.
+
+        A lateral join (also known as a correlated join) is a type of join 
where each row from
+        one DataFrame is used as input to a subquery or a derived table that 
computes a result
+        specific to that row. The right side `DataFrame` can reference columns 
from the current
+        row of the left side `DataFrame`, allowing for more complex and 
context-dependent results
+        than a standard join.
+
+        .. versionadded:: 4.0.0
+
+        Parameters
+        ----------
+        other : :class:`DataFrame`
+            Right side of the join
+        on : :class:`Column`, optional
+            a join expression (Column).
+        how : str, optional
+            default ``inner``. Must be one of: ``inner``, ``cross``, ``left``, 
``leftouter``,
+            and ``left_outer``.
+
+        Returns
+        -------
+        :class:`DataFrame`
+            Joined DataFrame.
+
+        Examples
+        --------
+        Setup a sample DataFrame.
+
+        >>> from pyspark.sql import functions as sf
+        >>> from pyspark.sql import Row
+        >>> customers_data = [
+        ...     Row(customer_id=1, name="Alice"), Row(customer_id=2, 
name="Bob"),
+        ...     Row(customer_id=3, name="Charlie"), Row(customer_id=4, 
name="Diana")
+        ... ]
+        >>> customers = spark.createDataFrame(customers_data)
+        >>> orders_data = [
+        ...     Row(order_id=101, customer_id=1, order_date="2024-01-10",
+        ...         items=[Row(product="laptop", quantity=5), 
Row(product="mouse", quantity=12)]),
+        ...     Row(order_id=102, customer_id=1, order_date="2024-02-15",
+        ...         items=[Row(product="phone", quantity=2), 
Row(product="charger", quantity=15)]),
+        ...     Row(order_id=105, customer_id=1, order_date="2024-03-20",
+        ...         items=[Row(product="tablet", quantity=4)]),
+        ...     Row(order_id=103, customer_id=2, order_date="2024-01-12",
+        ...         items=[Row(product="tablet", quantity=8)]),
+        ...     Row(order_id=104, customer_id=2, order_date="2024-03-05",
+        ...         items=[Row(product="laptop", quantity=7)]),
+        ...     Row(order_id=106, customer_id=3, order_date="2024-04-05",
+        ...         items=[Row(product="monitor", quantity=1)]),
+        ... ]
+        >>> orders = spark.createDataFrame(orders_data)
+
+        Example 1 (use TVF): Expanding Items in Each Order into Separate Rows
+
+        >>> customers.join(orders, "customer_id").lateralJoin(
+        ...     spark.tvf.explode(sf.col("items").outer()).select("col.*")
+        ... ).select(
+        ...     "customer_id", "name", "order_id", "order_date", "product", 
"quantity"
+        ... ).orderBy("customer_id", "order_id", "product").show()
+        +-----------+-------+--------+----------+-------+--------+
+        |customer_id|   name|order_id|order_date|product|quantity|
+        +-----------+-------+--------+----------+-------+--------+
+        |          1|  Alice|     101|2024-01-10| laptop|       5|
+        |          1|  Alice|     101|2024-01-10|  mouse|      12|
+        |          1|  Alice|     102|2024-02-15|charger|      15|
+        |          1|  Alice|     102|2024-02-15|  phone|       2|
+        |          1|  Alice|     105|2024-03-20| tablet|       4|
+        |          2|    Bob|     103|2024-01-12| tablet|       8|
+        |          2|    Bob|     104|2024-03-05| laptop|       7|
+        |          3|Charlie|     106|2024-04-05|monitor|       1|
+        +-----------+-------+--------+----------+-------+--------+
+
+        Example 2 (use subquery): Finding the Two Most Recent Orders for 
Customer
+
+        >>> customers.alias("c").lateralJoin(
+        ...     orders.alias("o")
+        ...     .where(sf.col("o.customer_id") == 
sf.col("c.customer_id").outer())
+        ...     .orderBy(sf.col("order_date").desc())
+        ...     .limit(2),
+        ...     how="left"
+        ... ).select(
+        ...     "c.customer_id", "name", "order_id", "order_date"
+        ... ).orderBy("customer_id", "order_id").show()
+        +-----------+-------+--------+----------+
+        |customer_id|   name|order_id|order_date|
+        +-----------+-------+--------+----------+
+        |          1|  Alice|     102|2024-02-15|
+        |          1|  Alice|     105|2024-03-20|
+        |          2|    Bob|     103|2024-01-12|
+        |          2|    Bob|     104|2024-03-05|
+        |          3|Charlie|     106|2024-04-05|
+        |          4|  Diana|    NULL|      NULL|
+        +-----------+-------+--------+----------+
+        """
+        ...
+
     # TODO(SPARK-22947): Fix the DataFrame API.
     @dispatch_df_method
     def _joinAsOf(
diff --git a/python/pyspark/sql/tests/connect/test_parity_tvf.py 
b/python/pyspark/sql/tests/connect/test_parity_tvf.py
index 61e3decf562c..c5edff02810f 100644
--- a/python/pyspark/sql/tests/connect/test_parity_tvf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_tvf.py
@@ -21,7 +21,45 @@ from pyspark.testing.connectutils import 
ReusedConnectTestCase
 
 
 class TVFParityTestsMixin(TVFTestsMixin, ReusedConnectTestCase):
-    pass
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_explode_with_lateral_join(self):
+        super().test_explode_with_lateral_join()
+
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_explode_outer_with_lateral_join(self):
+        super().test_explode_outer_with_lateral_join()
+
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_inline_with_lateral_join(self):
+        super().test_inline_with_lateral_join()
+
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_inline_outer_with_lateral_join(self):
+        super().test_inline_outer_with_lateral_join()
+
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_json_tuple_with_lateral_join(self):
+        super().test_json_tuple_with_lateral_join()
+
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_posexplode_with_lateral_join(self):
+        super().test_posexplode_with_lateral_join()
+
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_posexplode_outer_with_lateral_join(self):
+        super().test_posexplode_outer_with_lateral_join()
+
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_stack_with_lateral_join(self):
+        super().test_stack_with_lateral_join()
+
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_variant_explode_with_lateral_join(self):
+        super().test_variant_explode_with_lateral_join()
+
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_variant_explode_outer_with_lateral_join(self):
+        super().test_variant_explode_outer_with_lateral_join()
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py 
b/python/pyspark/sql/tests/connect/test_parity_udtf.py
index 6955e7377b4c..29d1718fe378 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udtf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py
@@ -85,6 +85,14 @@ class UDTFParityTests(BaseUDTFTestsMixin, 
ReusedConnectTestCase):
     def _add_file(self, path):
         self.spark.addArtifacts(path, file=True)
 
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_udtf_with_lateral_join_dataframe(self):
+        super().test_udtf_with_lateral_join_dataframe()
+
+    @unittest.skip("SPARK-50134: Support Spark Connect")
+    def test_udtf_with_conditional_return_dataframe(self):
+        super().test_udtf_with_conditional_return_dataframe()
+
 
 class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests):
     @classmethod
diff --git a/python/pyspark/sql/tests/test_subquery.py 
b/python/pyspark/sql/tests/test_subquery.py
index 7cc0360c3942..1b657e075c59 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -484,6 +484,338 @@ class SubqueryTestsMixin:
                     fragment="col",
                 )
 
+    def table1(self):
+        t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+        t1.createOrReplaceTempView("t1")
+        return self.spark.table("t1")
+
+    def table2(self):
+        t2 = self.spark.sql("VALUES (0, 2), (0, 3) AS t2(c1, c2)")
+        t2.createOrReplaceTempView("t2")
+        return self.spark.table("t2")
+
+    def table3(self):
+        t3 = self.spark.sql(
+            "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4)) AS t3(c1, c2)"
+        )
+        t3.createOrReplaceTempView("t3")
+        return self.spark.table("t3")
+
+    def test_lateral_join_with_single_column_select(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            assertDataFrameEqual(
+                
t1.lateralJoin(self.spark.range(1).select(sf.col("c1").outer())),
+                self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT c1)"""),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(t2.select(sf.col("t1.c1").outer())),
+                self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT t1.c1 FROM 
t2)"""),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(t2.select(sf.col("t1.c1").outer() + 
sf.col("t2.c1"))),
+                self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT t1.c1 + 
t2.c1 FROM t2)"""),
+            )
+
+    def test_lateral_join_with_different_join_types(self):
+        with self.tempView("t1"):
+            t1 = self.table1()
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.range(1).select(
+                        (sf.col("c1").outer() + 
sf.col("c2").outer()).alias("c3")
+                    ),
+                    sf.col("c2") == sf.col("c3"),
+                ),
+                self.spark.sql(
+                    """SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c3) ON 
c2 = c3"""
+                ),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.range(1).select(
+                        (sf.col("c1").outer() + 
sf.col("c2").outer()).alias("c3")
+                    ),
+                    sf.col("c2") == sf.col("c3"),
+                    "left",
+                ),
+                self.spark.sql(
+                    """SELECT * FROM t1 LEFT JOIN LATERAL (SELECT c1 + c2 AS 
c3) ON c2 = c3"""
+                ),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.range(1).select(
+                        (sf.col("c1").outer() + 
sf.col("c2").outer()).alias("c3")
+                    ),
+                    how="cross",
+                ),
+                self.spark.sql("""SELECT * FROM t1 CROSS JOIN LATERAL (SELECT 
c1 + c2 AS c3)"""),
+            )
+
+    def test_lateral_join_with_correlated_predicates(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    t2.where(sf.col("t1.c1").outer() == 
sf.col("t2.c1")).select(sf.col("c2"))
+                ),
+                self.spark.sql(
+                    """SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE 
t1.c1 = t2.c1)"""
+                ),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    t2.where(sf.col("t1.c1").outer() < 
sf.col("t2.c1")).select(sf.col("c2"))
+                ),
+                self.spark.sql(
+                    """SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE 
t1.c1 < t2.c1)"""
+                ),
+            )
+
+    def test_lateral_join_with_aggregation_and_correlated_predicates(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    t2.where(sf.col("t1.c2").outer() < sf.col("t2.c2")).select(
+                        sf.max(sf.col("c2")).alias("m")
+                    )
+                ),
+                self.spark.sql(
+                    """
+                    SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 
WHERE t1.c2 < t2.c2)
+                    """
+                ),
+            )
+
+    def test_lateral_join_reference_preceding_from_clause_items(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            assertDataFrameEqual(
+                t1.join(t2).lateralJoin(
+                    self.spark.range(1).select(sf.col("t1.c2").outer() + 
sf.col("t2.c2").outer())
+                ),
+                self.spark.sql("""SELECT * FROM t1 JOIN t2 JOIN LATERAL 
(SELECT t1.c2 + t2.c2)"""),
+            )
+
+    def test_multiple_lateral_joins(self):
+        with self.tempView("t1"):
+            t1 = self.table1()
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.range(1).select(
+                        (sf.col("c1").outer() + 
sf.col("c2").outer()).alias("a")
+                    )
+                )
+                .lateralJoin(
+                    self.spark.range(1).select(
+                        (sf.col("c1").outer() - 
sf.col("c2").outer()).alias("b")
+                    )
+                )
+                .lateralJoin(
+                    self.spark.range(1).select(
+                        (sf.col("a").outer() * sf.col("b").outer()).alias("c")
+                    )
+                ),
+                self.spark.sql(
+                    """
+                    SELECT * FROM t1,
+                    LATERAL (SELECT c1 + c2 AS a),
+                    LATERAL (SELECT c1 - c2 AS b),
+                    LATERAL (SELECT a * b AS c)
+                    """
+                ),
+            )
+
+    def test_lateral_join_in_between_regular_joins(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    t2.where(sf.col("t1.c1").outer() == 
sf.col("t2.c1")).select(sf.col("c2")),
+                    how="left",
+                ).join(t1.alias("t3"), sf.col("t2.c2") == sf.col("t3.c2"), 
how="left"),
+                self.spark.sql(
+                    """
+                    SELECT * FROM t1
+                    LEFT OUTER JOIN LATERAL (SELECT c2 FROM t2 WHERE t1.c1 = 
t2.c1) s
+                    LEFT OUTER JOIN t1 t3 ON s.c2 = t3.c2
+                    """
+                ),
+            )
+
+    def test_nested_lateral_joins(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            assertDataFrameEqual(
+                
t1.lateralJoin(t2.lateralJoin(self.spark.range(1).select(sf.col("c1").outer()))),
+                self.spark.sql(
+                    """SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL 
(SELECT c1))"""
+                ),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.range(1)
+                    .select((sf.col("c1").outer() + sf.lit(1)).alias("c1"))
+                    
.lateralJoin(self.spark.range(1).select(sf.col("c1").outer()))
+                ),
+                self.spark.sql(
+                    """
+                    SELECT * FROM t1,
+                    LATERAL (SELECT * FROM (SELECT c1 + 1 AS c1), LATERAL 
(SELECT c1))
+                    """
+                ),
+            )
+
+    def test_scalar_subquery_inside_lateral_join(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.range(1).select(
+                        sf.col("c2").outer(), 
t2.select(sf.min(sf.col("c2"))).scalar()
+                    )
+                ),
+                self.spark.sql(
+                    """SELECT * FROM t1, LATERAL (SELECT c2, (SELECT MIN(c2) 
FROM t2))"""
+                ),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.range(1)
+                    .select(sf.col("c1").outer().alias("a"))
+                    .select(
+                        t2.where(sf.col("c1") == sf.col("a").outer())
+                        .select(sf.sum(sf.col("c2")))
+                        .scalar()
+                    )
+                ),
+                self.spark.sql(
+                    """
+                    SELECT * FROM t1, LATERAL (
+                        SELECT (SELECT SUM(c2) FROM t2 WHERE c1 = a) FROM 
(SELECT c1 AS a)
+                    )
+                    """
+                ),
+            )
+
+    def test_lateral_join_inside_subquery(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            assertDataFrameEqual(
+                t1.where(
+                    sf.col("c1")
+                    == (
+                        
t2.lateralJoin(self.spark.range(1).select(sf.col("c1").outer().alias("a")))
+                        .select(sf.min(sf.col("a")))
+                        .scalar()
+                    )
+                ),
+                self.spark.sql(
+                    """
+                    SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, 
LATERAL (SELECT c1 AS a))
+                    """
+                ),
+            )
+            assertDataFrameEqual(
+                t1.where(
+                    sf.col("c1")
+                    == (
+                        
t2.lateralJoin(self.spark.range(1).select(sf.col("c1").outer().alias("a")))
+                        .where(sf.col("c1") == sf.col("t1.c1").outer())
+                        .select(sf.min(sf.col("a")))
+                        .scalar()
+                    )
+                ),
+                self.spark.sql(
+                    """
+                    SELECT * FROM t1
+                    WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS 
a) WHERE c1 = t1.c1)
+                    """
+                ),
+            )
+
+    def test_lateral_join_with_table_valued_functions(self):
+        with self.tempView("t1", "t3"):
+            t1 = self.table1()
+            t3 = self.table3()
+
+            assertDataFrameEqual(
+                t1.lateralJoin(self.spark.tvf.range(3)),
+                self.spark.sql("""SELECT * FROM t1, LATERAL RANGE(3)"""),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.tvf.explode(sf.array(sf.col("c1").outer(), 
sf.col("c2").outer()))
+                ).toDF("c1", "c2", "c3"),
+                self.spark.sql("""SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, 
c2)) t2(c3)"""),
+            )
+            assertDataFrameEqual(
+                
t3.lateralJoin(self.spark.tvf.explode_outer(sf.col("c2").outer())).toDF(
+                    "c1", "c2", "v"
+                ),
+                self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) 
t2(v)"""),
+            )
+            assertDataFrameEqual(
+                self.spark.tvf.explode(sf.array(sf.lit(1), sf.lit(2)))
+                .toDF("v")
+                .lateralJoin(self.spark.range(1).select((sf.col("v").outer() + 
1).alias("v"))),
+                self.spark.sql(
+                    """SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL 
(SELECT v + 1 AS v)"""
+                ),
+            )
+
+    def 
test_lateral_join_with_table_valued_functions_and_join_conditions(self):
+        with self.tempView("t1", "t3"):
+            t1 = self.table1()
+            t3 = self.table3()
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.tvf.explode(sf.array(sf.col("c1").outer(), 
sf.col("c2").outer())),
+                    sf.col("c1") == sf.col("col"),
+                ).toDF("c1", "c2", "c3"),
+                self.spark.sql(
+                    """SELECT * FROM t1 JOIN LATERAL EXPLODE(ARRAY(c1, c2)) 
t(c3) ON t1.c1 = c3"""
+                ),
+            )
+            assertDataFrameEqual(
+                t3.lateralJoin(
+                    self.spark.tvf.explode(sf.col("c2").outer()),
+                    sf.col("c1") == sf.col("col"),
+                ).toDF("c1", "c2", "c3"),
+                self.spark.sql("""SELECT * FROM t3 JOIN LATERAL EXPLODE(c2) 
t(c3) ON t3.c1 = c3"""),
+            )
+            assertDataFrameEqual(
+                t3.lateralJoin(
+                    self.spark.tvf.explode(sf.col("c2").outer()),
+                    sf.col("c1") == sf.col("col"),
+                    "left",
+                ).toDF("c1", "c2", "c3"),
+                self.spark.sql(
+                    """SELECT * FROM t3 LEFT JOIN LATERAL EXPLODE(c2) t(c3) ON 
t3.c1 = c3"""
+                ),
+            )
+
 
 class SubqueryTests(SubqueryTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/test_tvf.py 
b/python/pyspark/sql/tests/test_tvf.py
index 5c709437fc4d..ea20cbf9b8f3 100644
--- a/python/pyspark/sql/tests/test_tvf.py
+++ b/python/pyspark/sql/tests/test_tvf.py
@@ -52,6 +52,37 @@ class TVFTestsMixin:
         expected = self.spark.sql("""SELECT * FROM explode(null :: map<string, 
int>)""")
         assertDataFrameEqual(actual=actual, expected=expected)
 
+    def test_explode_with_lateral_join(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+            t1.createOrReplaceTempView("t1")
+            t3 = self.spark.sql(
+                "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4)) "
+                "AS t3(c1, c2)"
+            )
+            t3.createOrReplaceTempView("t3")
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.tvf.explode(sf.array(sf.col("c1").outer(), 
sf.col("c2").outer()))
+                ).toDF("c1", "c2", "c3"),
+                self.spark.sql("""SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, 
c2)) t2(c3)"""),
+            )
+            assertDataFrameEqual(
+                
t3.lateralJoin(self.spark.tvf.explode(sf.col("c2").outer())).toDF("c1", "c2", 
"v"),
+                self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE(c2) 
t2(v)"""),
+            )
+            assertDataFrameEqual(
+                self.spark.tvf.explode(sf.array(sf.lit(1), sf.lit(2)))
+                .toDF("v")
+                .lateralJoin(
+                    self.spark.range(1).select((sf.col("v").outer() + 
sf.lit(1)).alias("v2"))
+                ),
+                self.spark.sql(
+                    """SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL 
(SELECT v + 1 AS v2)"""
+                ),
+            )
+
     def test_explode_outer(self):
         actual = self.spark.tvf.explode_outer(sf.array(sf.lit(1), sf.lit(2)))
         expected = self.spark.sql("""SELECT * FROM explode_outer(array(1, 
2))""")
@@ -81,6 +112,43 @@ class TVFTestsMixin:
         expected = self.spark.sql("""SELECT * FROM explode_outer(null :: 
map<string, int>)""")
         assertDataFrameEqual(actual=actual, expected=expected)
 
+    def test_explode_outer_with_lateral_join(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+            t1.createOrReplaceTempView("t1")
+            t3 = self.spark.sql(
+                "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4)) "
+                "AS t3(c1, c2)"
+            )
+            t3.createOrReplaceTempView("t3")
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.tvf.explode_outer(
+                        sf.array(sf.col("c1").outer(), sf.col("c2").outer())
+                    )
+                ).toDF("c1", "c2", "c3"),
+                self.spark.sql("""SELECT * FROM t1, LATERAL 
EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)"""),
+            )
+            assertDataFrameEqual(
+                
t3.lateralJoin(self.spark.tvf.explode_outer(sf.col("c2").outer())).toDF(
+                    "c1", "c2", "v"
+                ),
+                self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) 
t2(v)"""),
+            )
+            assertDataFrameEqual(
+                self.spark.tvf.explode_outer(sf.array(sf.lit(1), sf.lit(2)))
+                .toDF("v")
+                .lateralJoin(
+                    self.spark.range(1).select((sf.col("v").outer() + 
sf.lit(1)).alias("v2"))
+                ),
+                self.spark.sql(
+                    """
+                    SELECT * FROM EXPLODE_OUTER(ARRAY(1, 2)) t(v), LATERAL 
(SELECT v + 1 AS v2)
+                    """
+                ),
+            )
+
     def test_inline(self):
         actual = self.spark.tvf.inline(
             sf.array(sf.struct(sf.lit(1), sf.lit("a")), sf.struct(sf.lit(2), 
sf.lit("b")))
@@ -107,6 +175,35 @@ class TVFTestsMixin:
         )
         assertDataFrameEqual(actual=actual, expected=expected)
 
+    def test_inline_with_lateral_join(self):
+        with self.tempView("array_struct"):
+            array_struct = self.spark.sql(
+                """
+                VALUES
+                (1, ARRAY(STRUCT(1, 'a'), STRUCT(2, 'b'))),
+                (2, ARRAY()),
+                (3, ARRAY(STRUCT(3, 'c'))) AS array_struct(id, arr)
+                """
+            )
+            array_struct.createOrReplaceTempView("array_struct")
+
+            assertDataFrameEqual(
+                
array_struct.lateralJoin(self.spark.tvf.inline(sf.col("arr").outer())),
+                self.spark.sql("""SELECT * FROM array_struct JOIN LATERAL 
INLINE(arr)"""),
+            )
+            assertDataFrameEqual(
+                array_struct.lateralJoin(
+                    self.spark.tvf.inline(sf.col("arr").outer()),
+                    sf.col("id") == sf.col("col1"),
+                    "left",
+                ).toDF("id", "arr", "k", "v"),
+                self.spark.sql(
+                    """
+                    SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr) 
t(k, v) ON id = k
+                    """
+                ),
+            )
+
     def test_inline_outer(self):
         actual = self.spark.tvf.inline_outer(
             sf.array(sf.struct(sf.lit(1), sf.lit("a")), sf.struct(sf.lit(2), 
sf.lit("b")))
@@ -137,6 +234,35 @@ class TVFTestsMixin:
         )
         assertDataFrameEqual(actual=actual, expected=expected)
 
+    def test_inline_outer_with_lateral_join(self):
+        with self.tempView("array_struct"):
+            array_struct = self.spark.sql(
+                """
+                VALUES
+                (1, ARRAY(STRUCT(1, 'a'), STRUCT(2, 'b'))),
+                (2, ARRAY()),
+                (3, ARRAY(STRUCT(3, 'c'))) AS array_struct(id, arr)
+                """
+            )
+            array_struct.createOrReplaceTempView("array_struct")
+
+            assertDataFrameEqual(
+                
array_struct.lateralJoin(self.spark.tvf.inline_outer(sf.col("arr").outer())),
+                self.spark.sql("""SELECT * FROM array_struct JOIN LATERAL 
INLINE_OUTER(arr)"""),
+            )
+            assertDataFrameEqual(
+                array_struct.lateralJoin(
+                    self.spark.tvf.inline_outer(sf.col("arr").outer()),
+                    sf.col("id") == sf.col("col1"),
+                    "left",
+                ).toDF("id", "arr", "k", "v"),
+                self.spark.sql(
+                    """
+                    SELECT * FROM array_struct LEFT JOIN LATERAL 
INLINE_OUTER(arr) t(k, v) ON id = k
+                    """
+                ),
+            )
+
     def test_json_tuple(self):
         actual = self.spark.tvf.json_tuple(sf.lit('{"a":1, "b":2}'), 
sf.lit("a"), sf.lit("b"))
         expected = self.spark.sql("""SELECT json_tuple('{"a":1, "b":2}', 'a', 
'b')""")
@@ -151,6 +277,64 @@ class TVFTestsMixin:
             messageParameters={"item": "field"},
         )
 
+    def test_json_tuple_with_lateral_join(self):
+        with self.tempView("json_table"):
+            json_table = self.spark.sql(
+                """
+                VALUES
+                ('1', '{"f1": "1", "f2": "2", "f3": 3, "f5": 5.23}'),
+                ('2', '{"f1": "1", "f3": "3", "f2": 2, "f4": 4.01}'),
+                ('3', '{"f1": 3, "f4": "4", "f3": "3", "f2": 2, "f5": 5.01}'),
+                ('4', cast(null as string)),
+                ('5', '{"f1": null, "f5": ""}'),
+                ('6', '[invalid JSON string]') AS json_table(key, jstring)
+                """
+            )
+            json_table.createOrReplaceTempView("json_table")
+
+            assertDataFrameEqual(
+                json_table.alias("t1")
+                .lateralJoin(
+                    self.spark.tvf.json_tuple(
+                        sf.col("jstring").outer(),
+                        sf.lit("f1"),
+                        sf.lit("f2"),
+                        sf.lit("f3"),
+                        sf.lit("f4"),
+                        sf.lit("f5"),
+                    )
+                )
+                .select("key", "c0", "c1", "c2", "c3", "c4"),
+                self.spark.sql(
+                    """
+                    SELECT t1.key, t2.* FROM json_table t1,
+                    LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 
'f5') t2
+                    """
+                ),
+            )
+            assertDataFrameEqual(
+                json_table.alias("t1")
+                .lateralJoin(
+                    self.spark.tvf.json_tuple(
+                        sf.col("jstring").outer(),
+                        sf.lit("f1"),
+                        sf.lit("f2"),
+                        sf.lit("f3"),
+                        sf.lit("f4"),
+                        sf.lit("f5"),
+                    )
+                )
+                .where(sf.col("c0").isNotNull())
+                .select("key", "c0", "c1", "c2", "c3", "c4"),
+                self.spark.sql(
+                    """
+                    SELECT t1.key, t2.* FROM json_table t1,
+                    LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 
'f5') t2
+                    WHERE t2.c0 IS NOT NULL
+                    """
+                ),
+            )
+
     def test_posexplode(self):
         actual = self.spark.tvf.posexplode(sf.array(sf.lit(1), sf.lit(2)))
         expected = self.spark.sql("""SELECT * FROM posexplode(array(1, 2))""")
@@ -180,6 +364,39 @@ class TVFTestsMixin:
         expected = self.spark.sql("""SELECT * FROM posexplode(null :: 
map<string, int>)""")
         assertDataFrameEqual(actual=actual, expected=expected)
 
+    def test_posexplode_with_lateral_join(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+            t1.createOrReplaceTempView("t1")
+            t3 = self.spark.sql(
+                "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4)) "
+                "AS t3(c1, c2)"
+            )
+            t3.createOrReplaceTempView("t3")
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.tvf.posexplode(sf.array(sf.col("c1").outer(), 
sf.col("c2").outer()))
+                ),
+                self.spark.sql("""SELECT * FROM t1, LATERAL 
POSEXPLODE(ARRAY(c1, c2))"""),
+            )
+            assertDataFrameEqual(
+                
t3.lateralJoin(self.spark.tvf.posexplode(sf.col("c2").outer())),
+                self.spark.sql("""SELECT * FROM t3, LATERAL POSEXPLODE(c2)"""),
+            )
+            assertDataFrameEqual(
+                self.spark.tvf.posexplode(sf.array(sf.lit(1), sf.lit(2)))
+                .toDF("p", "v")
+                .lateralJoin(
+                    self.spark.range(1).select((sf.col("v").outer() + 
sf.lit(1)).alias("v2"))
+                ),
+                self.spark.sql(
+                    """
+                    SELECT * FROM POSEXPLODE(ARRAY(1, 2)) t(p, v), LATERAL 
(SELECT v + 1 AS v2)
+                    """
+                ),
+            )
+
     def test_posexplode_outer(self):
         actual = self.spark.tvf.posexplode_outer(sf.array(sf.lit(1), 
sf.lit(2)))
         expected = self.spark.sql("""SELECT * FROM posexplode_outer(array(1, 
2))""")
@@ -209,11 +426,93 @@ class TVFTestsMixin:
         expected = self.spark.sql("""SELECT * FROM posexplode_outer(null :: 
map<string, int>)""")
         assertDataFrameEqual(actual=actual, expected=expected)
 
+    def test_posexplode_outer_with_lateral_join(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+            t1.createOrReplaceTempView("t1")
+            t3 = self.spark.sql(
+                "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4)) "
+                "AS t3(c1, c2)"
+            )
+            t3.createOrReplaceTempView("t3")
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.tvf.posexplode_outer(
+                        sf.array(sf.col("c1").outer(), sf.col("c2").outer())
+                    )
+                ),
+                self.spark.sql("""SELECT * FROM t1, LATERAL 
POSEXPLODE_OUTER(ARRAY(c1, c2))"""),
+            )
+            assertDataFrameEqual(
+                
t3.lateralJoin(self.spark.tvf.posexplode_outer(sf.col("c2").outer())),
+                self.spark.sql("""SELECT * FROM t3, LATERAL 
POSEXPLODE_OUTER(c2)"""),
+            )
+            assertDataFrameEqual(
+                self.spark.tvf.posexplode_outer(sf.array(sf.lit(1), sf.lit(2)))
+                .toDF("p", "v")
+                .lateralJoin(
+                    self.spark.range(1).select((sf.col("v").outer() + 
sf.lit(1)).alias("v2"))
+                ),
+                self.spark.sql(
+                    """
+                    SELECT * FROM POSEXPLODE_OUTER(ARRAY(1, 2)) t(p, v),
+                        LATERAL (SELECT v + 1 AS v2)
+                    """
+                ),
+            )
+
     def test_stack(self):
         actual = self.spark.tvf.stack(sf.lit(2), sf.lit(1), sf.lit(2), 
sf.lit(3))
         expected = self.spark.sql("""SELECT * FROM stack(2, 1, 2, 3)""")
         assertDataFrameEqual(actual=actual, expected=expected)
 
+    def test_stack_with_lateral_join(self):
+        with self.tempView("t1", "t3"):
+            t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+            t1.createOrReplaceTempView("t1")
+            t3 = self.spark.sql(
+                "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4)) "
+                "AS t3(c1, c2)"
+            )
+            t3.createOrReplaceTempView("t3")
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.tvf.stack(
+                        sf.lit(2),
+                        sf.lit("Key"),
+                        sf.col("c1").outer(),
+                        sf.lit("Value"),
+                        sf.col("c2").outer(),
+                    )
+                ).select("col0", "col1"),
+                self.spark.sql(
+                    """SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1, 
'Value', c2) t"""
+                ),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.tvf.stack(sf.lit(1), sf.col("c1").outer(), 
sf.col("c2").outer())
+                ).select("col0", "col1"),
+                self.spark.sql("""SELECT t.* FROM t1 JOIN LATERAL stack(1, c1, 
c2) t"""),
+            )
+            assertDataFrameEqual(
+                t1.join(t3, sf.col("t1.c1") == sf.col("t3.c1"))
+                .lateralJoin(
+                    self.spark.tvf.stack(
+                        sf.lit(1), sf.col("t1.c2").outer(), 
sf.col("t3.c2").outer()
+                    )
+                )
+                .select("col0", "col1"),
+                self.spark.sql(
+                    """
+                    SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1
+                        JOIN LATERAL stack(1, t1.c2, t3.c2) t
+                    """
+                ),
+            )
+
     def test_collations(self):
         actual = self.spark.tvf.collations()
         expected = self.spark.sql("""SELECT * FROM collations()""")
@@ -256,6 +555,31 @@ class TVFTestsMixin:
         expected = self.spark.sql("""SELECT * FROM 
variant_explode(parse_json('1'))""")
         assertDataFrameEqual(actual=actual, expected=expected)
 
+    def test_variant_explode_with_lateral_join(self):
+        with self.tempView("variant_table"):
+            variant_table = self.spark.sql(
+                """
+                SELECT id, parse_json(v) AS v FROM VALUES
+                    (0, '["hello", "world"]'), (1, '{"a": true, "b": 3.14}'),
+                    (2, '[]'), (3, '{}'),
+                    (4, NULL), (5, '1')
+                    AS t(id, v)
+                """
+            )
+            variant_table.createOrReplaceTempView("variant_table")
+
+            assertDataFrameEqual(
+                variant_table.alias("t1")
+                
.lateralJoin(self.spark.tvf.variant_explode(sf.col("v").outer()))
+                .select("id", "pos", "key", "value"),
+                self.spark.sql(
+                    """
+                    SELECT t1.id, t.* FROM variant_table AS t1,
+                        LATERAL variant_explode(v) AS t
+                    """
+                ),
+            )
+
     def test_variant_explode_outer(self):
         actual = 
self.spark.tvf.variant_explode_outer(sf.parse_json(sf.lit('["hello", 
"world"]')))
         expected = self.spark.sql(
@@ -290,6 +614,31 @@ class TVFTestsMixin:
         expected = self.spark.sql("""SELECT * FROM 
variant_explode_outer(parse_json('1'))""")
         assertDataFrameEqual(actual=actual, expected=expected)
 
+    def test_variant_explode_outer_with_lateral_join(self):
+        with self.tempView("variant_table"):
+            variant_table = self.spark.sql(
+                """
+                SELECT id, parse_json(v) AS v FROM VALUES
+                    (0, '["hello", "world"]'), (1, '{"a": true, "b": 3.14}'),
+                    (2, '[]'), (3, '{}'),
+                    (4, NULL), (5, '1')
+                    AS t(id, v)
+                """
+            )
+            variant_table.createOrReplaceTempView("variant_table")
+
+            assertDataFrameEqual(
+                variant_table.alias("t1")
+                
.lateralJoin(self.spark.tvf.variant_explode_outer(sf.col("v").outer()))
+                .select("id", "pos", "key", "value"),
+                self.spark.sql(
+                    """
+                    SELECT t1.id, t.* FROM variant_table AS t1,
+                        LATERAL variant_explode_outer(v) AS t
+                    """
+                ),
+            )
+
 
 class TVFTests(TVFTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 8447edfbbb15..31cd4c80370e 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -31,6 +31,7 @@ from pyspark.errors import (
 from pyspark.util import PythonEvalType
 from pyspark.sql.functions import (
     array,
+    col,
     create_map,
     array,
     lit,
@@ -155,6 +156,22 @@ class BaseUDTFTestsMixin:
         )
         assertDataFrameEqual(df, expected)
 
+    def test_udtf_with_lateral_join_dataframe(self):
+        @udtf(returnType="a: int, b: int, c: int")
+        class TestUDTF:
+            def eval(self, a: int, b: int) -> Iterator:
+                yield a, b, a + b
+                yield a, b, a - b
+
+        self.spark.udtf.register("testUDTF", TestUDTF)
+
+        assertDataFrameEqual(
+            self.spark.sql("values (0, 1), (1, 2) t(a, b)").lateralJoin(
+                TestUDTF(col("a").outer(), col("b").outer())
+            ),
+            self.spark.sql("SELECT * FROM values (0, 1), (1, 2) t(a, b), 
LATERAL testUDTF(a, b)"),
+        )
+
     def test_udtf_eval_with_return_stmt(self):
         class TestUDTF:
             def eval(self, a: int, b: int):
@@ -239,6 +256,20 @@ class BaseUDTFTestsMixin:
             [Row(id=6, a=6), Row(id=7, a=7)],
         )
 
+    def test_udtf_with_conditional_return_dataframe(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                if a > 5:
+                    yield a,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        assertDataFrameEqual(
+            self.spark.range(8).lateralJoin(TestUDTF(col("id").outer())),
+            self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL 
test_udtf(id)"),
+        )
+
     def test_udtf_with_empty_yield(self):
         @udtf(returnType="a: int")
         class TestUDTF:
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
index 9d41998f11dc..20c181e7b9cf 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
@@ -859,6 +859,60 @@ abstract class Dataset[T] extends Serializable {
     joinWith(other, condition, "inner")
   }
 
+  /**
+   * Lateral join with another `DataFrame`.
+   *
+   * Behaves as an JOIN LATERAL.
+   *
+   * @param right
+   *   Right side of the join operation.
+   * @group untypedrel
+   * @since 4.0.0
+   */
+  def lateralJoin(right: DS[_]): Dataset[Row]
+
+  /**
+   * Lateral join with another `DataFrame`.
+   *
+   * Behaves as an JOIN LATERAL.
+   *
+   * @param right
+   *   Right side of the join operation.
+   * @param joinExprs
+   *   Join expression.
+   * @group untypedrel
+   * @since 4.0.0
+   */
+  def lateralJoin(right: DS[_], joinExprs: Column): Dataset[Row]
+
+  /**
+   * Lateral join with another `DataFrame`.
+   *
+   * @param right
+   *   Right side of the join operation.
+   * @param joinType
+   *   Type of join to perform. Default `inner`. Must be one of: `inner`, 
`cross`, `left`,
+   *   `leftouter`, `left_outer`.
+   * @group untypedrel
+   * @since 4.0.0
+   */
+  def lateralJoin(right: DS[_], joinType: String): Dataset[Row]
+
+  /**
+   * Lateral join with another `DataFrame`.
+   *
+   * @param right
+   *   Right side of the join operation.
+   * @param joinExprs
+   *   Join expression.
+   * @param joinType
+   *   Type of join to perform. Default `inner`. Must be one of: `inner`, 
`cross`, `left`,
+   *   `leftouter`, `left_outer`.
+   * @group untypedrel
+   * @since 4.0.0
+   */
+  def lateralJoin(right: DS[_], joinExprs: Column, joinType: String): 
Dataset[Row]
+
   protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): 
Dataset[T]
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 846d97b25786..8726ee268a47 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -709,6 +709,38 @@ class Dataset[T] private[sql](
     new Dataset(sparkSession, joinWith, joinEncoder)
   }
 
+  private[sql] def lateralJoin(
+      right: DS[_], joinExprs: Option[Column], joinType: JoinType): DataFrame 
= {
+    withPlan {
+      LateralJoin(
+        logicalPlan,
+        LateralSubquery(right.logicalPlan),
+        joinType,
+        joinExprs.map(_.expr)
+      )
+    }
+  }
+
+  /** @inheritdoc */
+  def lateralJoin(right: DS[_]): DataFrame = {
+    lateralJoin(right, None, Inner)
+  }
+
+  /** @inheritdoc */
+  def lateralJoin(right: DS[_], joinExprs: Column): DataFrame = {
+    lateralJoin(right, Some(joinExprs), Inner)
+  }
+
+  /** @inheritdoc */
+  def lateralJoin(right: DS[_], joinType: String): DataFrame = {
+    lateralJoin(right, None, JoinType(joinType))
+  }
+
+  /** @inheritdoc */
+  def lateralJoin(right: DS[_], joinExprs: Column, joinType: String): 
DataFrame = {
+    lateralJoin(right, Some(joinExprs), JoinType(joinType))
+  }
+
   // TODO(SPARK-22947): Fix the DataFrame API.
   private[sql] def joinAsOf(
       other: Dataset[_],
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index 2420ad34d9ba..cd425162fb01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -377,4 +377,291 @@ class DataFrameSubquerySuite extends QueryTest with 
SharedSparkSession {
         Array(ExpectedContext(fragment = "$", callSitePattern = 
getCurrentClassCallSitePattern))
     )
   }
+
+  private def table1() = {
+    sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+    spark.table("t1")
+  }
+
+  private def table2() = {
+    sql("CREATE VIEW t2(c1, c2) AS VALUES (0, 2), (0, 3)")
+    spark.table("t2")
+  }
+
+  private def table3() = {
+    sql("CREATE VIEW t3(c1, c2) AS " +
+      "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, ARRAY(4))")
+    spark.table("t3")
+  }
+
+  test("lateral join with single column select") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkAnswer(
+        t1.lateralJoin(spark.range(1).select($"c1".outer())),
+        sql("SELECT * FROM t1, LATERAL (SELECT c1)")
+      )
+      checkAnswer(
+        t1.lateralJoin(t2.select($"c1")),
+        sql("SELECT * FROM t1, LATERAL (SELECT c1 FROM t2)")
+      )
+      checkAnswer(
+        t1.lateralJoin(t2.select($"t1.c1".outer())),
+        sql("SELECT * FROM t1, LATERAL (SELECT t1.c1 FROM t2)")
+      )
+      checkAnswer(
+        t1.lateralJoin(t2.select($"t1.c1".outer() + $"t2.c1")),
+        sql("SELECT * FROM t1, LATERAL (SELECT t1.c1 + t2.c1 FROM t2)")
+      )
+    }
+  }
+
+  test("lateral join with different join types") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(
+        t1.lateralJoin(
+          spark.range(1).select(($"c1".outer() + $"c2".outer()).as("c3")),
+          $"c2" === $"c3"),
+        sql("SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3")
+      )
+      checkAnswer(
+        t1.lateralJoin(
+          spark.range(1).select(($"c1".outer() + $"c2".outer()).as("c3")),
+          $"c2" === $"c3",
+          "left"),
+        sql("SELECT * FROM t1 LEFT JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = 
c3")
+      )
+      checkAnswer(
+        t1.lateralJoin(
+          spark.range(1).select(($"c1".outer() + $"c2".outer()).as("c3")),
+          "cross"),
+        sql("SELECT * FROM t1 CROSS JOIN LATERAL (SELECT c1 + c2 AS c3)")
+      )
+    }
+  }
+
+  test("lateral join with correlated equality / non-equality predicates") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkAnswer(
+        t1.lateralJoin(t2.where($"t1.c1".outer() === $"t2.c1").select($"c2")),
+        sql("SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c1 = 
t2.c1)")
+      )
+      checkAnswer(
+        t1.lateralJoin(t2.where($"t1.c1".outer() < $"t2.c1").select($"c2")),
+        sql("SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c1 < 
t2.c1)")
+      )
+    }
+  }
+
+  test("lateral join with aggregation and correlated non-equality predicates") 
{
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkAnswer(
+        t1.lateralJoin(t2.where($"t1.c2".outer() < 
$"t2.c2").select(max($"c2").as("m"))),
+        sql("SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE 
t1.c2 < t2.c2)")
+      )
+    }
+  }
+
+  test("lateral join can reference preceding FROM clause items") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkAnswer(
+        t1.join(t2).lateralJoin(
+          spark.range(1).select($"t1.c2".outer() + $"t2.c2".outer())
+        ),
+        sql("SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2)")
+      )
+    }
+  }
+
+  test("multiple lateral joins") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(
+        t1.lateralJoin(
+          spark.range(1).select(($"c1".outer() + $"c2".outer()).as("a"))
+        ).lateralJoin(
+          spark.range(1).select(($"c1".outer() - $"c2".outer()).as("b"))
+        ).lateralJoin(
+          spark.range(1).select(($"a".outer() * $"b".outer()).as("c"))
+        ),
+        sql(
+          """
+            |SELECT * FROM t1,
+            |LATERAL (SELECT c1 + c2 AS a),
+            |LATERAL (SELECT c1 - c2 AS b),
+            |LATERAL (SELECT a * b AS c)
+            |""".stripMargin)
+      )
+    }
+  }
+
+  test("lateral join in between regular joins") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkAnswer(
+        t1.lateralJoin(
+          t2.where($"t1.c1".outer() === $"t2.c1").select($"c2"), "left"
+        ).join(t1.as("t3"), $"t2.c2" === $"t3.c2", "left"),
+        sql(
+          """
+            |SELECT * FROM t1
+            |LEFT OUTER JOIN LATERAL (SELECT c2 FROM t2 WHERE t1.c1 = t2.c1) s
+            |LEFT OUTER JOIN t1 t3 ON s.c2 = t3.c2
+            |""".stripMargin)
+      )
+    }
+  }
+
+  test("nested lateral joins") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkAnswer(
+        t1.lateralJoin(
+          t2.lateralJoin(spark.range(1).select($"c1".outer()))
+        ),
+        sql("SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL (SELECT 
c1))")
+      )
+      checkAnswer(
+        t1.lateralJoin(
+          spark.range(1).select(($"c1".outer() + lit(1)).as("c1"))
+            .lateralJoin(spark.range(1).select($"c1".outer()))
+        ),
+        sql("SELECT * FROM t1, LATERAL (SELECT * FROM (SELECT c1 + 1 AS c1), 
LATERAL (SELECT c1))")
+      )
+    }
+  }
+
+  test("scalar subquery inside lateral join") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      // uncorrelated
+      checkAnswer(
+        t1.lateralJoin(
+          spark.range(1).select(
+            $"c2".outer(),
+            t2.select(min($"c2")).scalar()
+          )
+        ),
+        sql("SELECT * FROM t1, LATERAL (SELECT c2, (SELECT MIN(c2) FROM t2))")
+      )
+
+      // correlated
+      checkAnswer(
+        t1.lateralJoin(
+          spark.range(1).select($"c1".outer().as("a"))
+            .select(t2.where($"c1" === 
$"a".outer()).select(sum($"c2")).scalar())
+        ),
+        sql(
+          """
+            |SELECT * FROM t1, LATERAL (
+            |    SELECT (SELECT SUM(c2) FROM t2 WHERE c1 = a) FROM (SELECT c1 
AS a)
+            |)
+            |""".stripMargin)
+      )
+    }
+  }
+
+  test("lateral join inside subquery") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      // uncorrelated
+      checkAnswer(
+        t1.where(
+          $"c1" === t2.lateralJoin(
+            spark.range(1).select($"c1".outer().as("a"))).select(min($"a")
+          ).scalar()
+        ),
+        sql("SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL 
(SELECT c1 AS a))")
+      )
+      // correlated
+      checkAnswer(
+        t1.where(
+          $"c1" === t2.lateralJoin(
+              spark.range(1).select($"c1".outer().as("a")))
+            .where($"c1" === $"t1.c1".outer())
+            .select(min($"a"))
+            .scalar()
+        ),
+        sql("SELECT * FROM t1 " +
+          "WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a) WHERE 
c1 = t1.c1)")
+      )
+    }
+  }
+
+  test("lateral join with table-valued functions") {
+    withView("t1", "t3") {
+      val t1 = table1()
+      val t3 = table3()
+
+      checkAnswer(
+        t1.lateralJoin(spark.tvf.range(3)),
+        sql("SELECT * FROM t1, LATERAL RANGE(3)")
+      )
+      checkAnswer(
+        t1.lateralJoin(spark.tvf.explode(array($"c1".outer(), $"c2".outer()))),
+        sql("SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)")
+      )
+      checkAnswer(
+        t3.lateralJoin(spark.tvf.explode_outer($"c2".outer())),
+        sql("SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)")
+      )
+      checkAnswer(
+        spark.tvf.explode(array(lit(1), lit(2))).toDF("v")
+          .lateralJoin(spark.range(1).select($"v".outer() + 1)),
+        sql("SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL (SELECT v + 1)")
+      )
+    }
+  }
+
+  test("lateral join with table-valued functions and join conditions") {
+    withView("t1", "t3") {
+      val t1 = table1()
+      val t3 = table3()
+
+      checkAnswer(
+        t1.lateralJoin(
+          spark.tvf.explode(array($"c1".outer(), $"c2".outer())),
+          $"c1" === $"col"
+        ),
+        sql("SELECT * FROM t1 JOIN LATERAL EXPLODE(ARRAY(c1, c2)) t(c3) ON 
t1.c1 = c3")
+      )
+      checkAnswer(
+        t3.lateralJoin(
+          spark.tvf.explode($"c2".outer()),
+          $"c1" === $"col"
+        ),
+        sql("SELECT * FROM t3 JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = c3")
+      )
+      checkAnswer(
+        t3.lateralJoin(
+          spark.tvf.explode($"c2".outer()),
+          $"c1" === $"col",
+          "left"
+        ),
+        sql("SELECT * FROM t3 LEFT JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = 
c3")
+      )
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
index c2f53ff56d1a..4f2cd275ffdf 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSparkSession
 
 class DataFrameTableValuedFunctionsSuite extends QueryTest with 
SharedSparkSession {
+  import testImplicits._
 
   test("explode") {
     val actual1 = spark.tvf.explode(array(lit(1), lit(2)))
@@ -50,6 +51,30 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
     checkAnswer(actual6, expected6)
   }
 
+  test("explode - lateral join") {
+    withView("t1", "t3") {
+      sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+      sql("CREATE VIEW t3(c1, c2) AS " +
+        "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4))")
+      val t1 = spark.table("t1")
+      val t3 = spark.table("t3")
+
+      checkAnswer(
+        t1.lateralJoin(spark.tvf.explode(array($"c1".outer(), $"c2".outer()))),
+        sql("SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)")
+      )
+      checkAnswer(
+        t3.lateralJoin(spark.tvf.explode($"c2".outer())),
+        sql("SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v)")
+      )
+      checkAnswer(
+        spark.tvf.explode(array(lit(1), lit(2))).toDF("v")
+          .lateralJoin(spark.range(1).select($"v".outer() + lit(1))),
+        sql("SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL (SELECT v + 1)")
+      )
+    }
+  }
+
   test("explode_outer") {
     val actual1 = spark.tvf.explode_outer(array(lit(1), lit(2)))
     val expected1 = spark.sql("SELECT * FROM explode_outer(array(1, 2))")
@@ -78,6 +103,30 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
     checkAnswer(actual6, expected6)
   }
 
+  test("explode_outer - lateral join") {
+    withView("t1", "t3") {
+      sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+      sql("CREATE VIEW t3(c1, c2) AS " +
+        "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4))")
+      val t1 = spark.table("t1")
+      val t3 = spark.table("t3")
+
+      checkAnswer(
+        t1.lateralJoin(spark.tvf.explode_outer(array($"c1".outer(), 
$"c2".outer()))),
+        sql("SELECT * FROM t1, LATERAL EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)")
+      )
+      checkAnswer(
+        t3.lateralJoin(spark.tvf.explode_outer($"c2".outer())),
+        sql("SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)")
+      )
+      checkAnswer(
+        spark.tvf.explode_outer(array(lit(1), lit(2))).toDF("v")
+          .lateralJoin(spark.range(1).select($"v".outer() + lit(1))),
+        sql("SELECT * FROM EXPLODE_OUTER(ARRAY(1, 2)) t(v), LATERAL (SELECT v 
+ 1)")
+      )
+    }
+  }
+
   test("inline") {
     val actual1 = spark.tvf.inline(array(struct(lit(1), lit("a")), 
struct(lit(2), lit("b"))))
     val expected1 = spark.sql("SELECT * FROM inline(array(struct(1, 'a'), 
struct(2, 'b')))")
@@ -98,6 +147,32 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
     checkAnswer(actual3, expected3)
   }
 
+  test("inline - lateral join") {
+    withView("array_struct") {
+      sql(
+        """
+          |CREATE VIEW array_struct(id, arr) AS VALUES
+          |    (1, ARRAY(STRUCT(1, 'a'), STRUCT(2, 'b'))),
+          |    (2, ARRAY()),
+          |    (3, ARRAY(STRUCT(3, 'c')))
+          |""".stripMargin)
+      val arrayStruct = spark.table("array_struct")
+
+      checkAnswer(
+        arrayStruct.lateralJoin(spark.tvf.inline($"arr".outer())),
+        sql("SELECT * FROM array_struct JOIN LATERAL INLINE(arr)")
+      )
+      checkAnswer(
+        arrayStruct.lateralJoin(
+          spark.tvf.inline($"arr".outer()),
+          $"id" === $"col1",
+          "left"
+        ),
+        sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr) t(k, v) 
ON id = k")
+      )
+    }
+  }
+
   test("inline_outer") {
     val actual1 = spark.tvf.inline_outer(array(struct(lit(1), lit("a")), 
struct(lit(2), lit("b"))))
     val expected1 = spark.sql("SELECT * FROM inline_outer(array(struct(1, 
'a'), struct(2, 'b')))")
@@ -118,6 +193,32 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
     checkAnswer(actual3, expected3)
   }
 
+  test("inline_outer - lateral join") {
+    withView("array_struct") {
+      sql(
+        """
+          |CREATE VIEW array_struct(id, arr) AS VALUES
+          |    (1, ARRAY(STRUCT(1, 'a'), STRUCT(2, 'b'))),
+          |    (2, ARRAY()),
+          |    (3, ARRAY(STRUCT(3, 'c')))
+          |""".stripMargin)
+      val arrayStruct = spark.table("array_struct")
+
+      checkAnswer(
+        arrayStruct.lateralJoin(spark.tvf.inline_outer($"arr".outer())),
+        sql("SELECT * FROM array_struct JOIN LATERAL INLINE_OUTER(arr)")
+      )
+      checkAnswer(
+        arrayStruct.lateralJoin(
+          spark.tvf.inline_outer($"arr".outer()),
+          $"id" === $"col1",
+          "left"
+        ),
+        sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE_OUTER(arr) 
t(k, v) ON id = k")
+      )
+    }
+  }
+
   test("json_tuple") {
     val actual = spark.tvf.json_tuple(lit("""{"a":1,"b":2}"""), lit("a"), 
lit("b"))
     val expected = spark.sql("""SELECT * FROM json_tuple('{"a":1,"b":2}', 'a', 
'b')""")
@@ -130,6 +231,43 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
     assert(ex.messageParameters("functionName") == "`json_tuple`")
   }
 
+  test("json_tuple - lateral join") {
+    withView("json_table") {
+      sql(
+        """
+          |CREATE OR REPLACE TEMP VIEW json_table(key, jstring) AS VALUES
+          |    ('1', '{"f1": "1", "f2": "2", "f3": 3, "f5": 5.23}'),
+          |    ('2', '{"f1": "1", "f3": "3", "f2": 2, "f4": 4.01}'),
+          |    ('3', '{"f1": 3, "f4": "4", "f3": "3", "f2": 2, "f5": 5.01}'),
+          |    ('4', cast(null as string)),
+          |    ('5', '{"f1": null, "f5": ""}'),
+          |    ('6', '[invalid JSON string]')
+          |""".stripMargin)
+      val jsonTable = spark.table("json_table")
+
+      checkAnswer(
+        jsonTable.as("t1").lateralJoin(
+          spark.tvf.json_tuple(
+            $"t1.jstring".outer(),
+            lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5"))
+        ).select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+        sql("SELECT t1.key, t2.* FROM json_table t1, " +
+          "LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2")
+      )
+      checkAnswer(
+        jsonTable.as("t1").lateralJoin(
+          spark.tvf.json_tuple(
+            $"jstring".outer(),
+            lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5"))
+        ).where($"c0".isNotNull)
+          .select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+        sql("SELECT t1.key, t2.* FROM json_table t1, " +
+          "LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2 " +
+          "WHERE t2.c0 IS NOT NULL")
+      )
+    }
+  }
+
   test("posexplode") {
     val actual1 = spark.tvf.posexplode(array(lit(1), lit(2)))
     val expected1 = spark.sql("SELECT * FROM posexplode(array(1, 2))")
@@ -158,6 +296,30 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
     checkAnswer(actual6, expected6)
   }
 
+  test("posexplode - lateral join") {
+    withView("t1", "t3") {
+      sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+      sql("CREATE VIEW t3(c1, c2) AS " +
+        "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4))")
+      val t1 = spark.table("t1")
+      val t3 = spark.table("t3")
+
+      checkAnswer(
+        t1.lateralJoin(spark.tvf.posexplode(array($"c1".outer(), 
$"c2".outer()))),
+        sql("SELECT * FROM t1, LATERAL POSEXPLODE(ARRAY(c1, c2))")
+      )
+      checkAnswer(
+        t3.lateralJoin(spark.tvf.posexplode($"c2".outer())),
+        sql("SELECT * FROM t3, LATERAL POSEXPLODE(c2)")
+      )
+      checkAnswer(
+        spark.tvf.posexplode(array(lit(1), lit(2))).toDF("p", "v")
+          .lateralJoin(spark.range(1).select($"v".outer() + lit(1))),
+        sql("SELECT * FROM POSEXPLODE(ARRAY(1, 2)) t(p, v), LATERAL (SELECT v 
+ 1)")
+      )
+    }
+  }
+
   test("posexplode_outer") {
     val actual1 = spark.tvf.posexplode_outer(array(lit(1), lit(2)))
     val expected1 = spark.sql("SELECT * FROM posexplode_outer(array(1, 2))")
@@ -186,12 +348,66 @@ class DataFrameTableValuedFunctionsSuite extends 
QueryTest with SharedSparkSessi
     checkAnswer(actual6, expected6)
   }
 
+  test("posexplode_outer - lateral join") {
+    withView("t1", "t3") {
+      sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+      sql("CREATE VIEW t3(c1, c2) AS " +
+        "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4))")
+      val t1 = spark.table("t1")
+      val t3 = spark.table("t3")
+
+      checkAnswer(
+        t1.lateralJoin(spark.tvf.posexplode_outer(array($"c1".outer(), 
$"c2".outer()))),
+        sql("SELECT * FROM t1, LATERAL POSEXPLODE_OUTER(ARRAY(c1, c2))")
+      )
+      checkAnswer(
+        t3.lateralJoin(spark.tvf.posexplode_outer($"c2".outer())),
+        sql("SELECT * FROM t3, LATERAL POSEXPLODE_OUTER(c2)")
+      )
+      checkAnswer(
+        spark.tvf.posexplode_outer(array(lit(1), lit(2))).toDF("p", "v")
+          .lateralJoin(spark.range(1).select($"v".outer() + lit(1))),
+        sql("SELECT * FROM POSEXPLODE_OUTER(ARRAY(1, 2)) t(p, v), LATERAL 
(SELECT v + 1)")
+      )
+    }
+  }
+
   test("stack") {
     val actual = spark.tvf.stack(lit(2), lit(1), lit(2), lit(3))
     val expected = spark.sql("SELECT * FROM stack(2, 1, 2, 3)")
     checkAnswer(actual, expected)
   }
 
+  test("stack - lateral join") {
+    withView("t1", "t3") {
+      sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+      sql("CREATE VIEW t3(c1, c2) AS " +
+        "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, 
ARRAY(4))")
+      val t1 = spark.table("t1")
+      val t3 = spark.table("t3")
+
+      checkAnswer(
+        t1.lateralJoin(
+          spark.tvf.stack(lit(2), lit("Key"), $"c1".outer(), lit("Value"), 
$"c2".outer())
+        ).select($"col0", $"col1"),
+        sql("SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1, 'Value', c2) t")
+      )
+      checkAnswer(
+        t1.lateralJoin(
+          spark.tvf.stack(lit(1), $"c1".outer(), $"c2".outer())
+        ).select($"col0".as("x"), $"col1".as("y")),
+        sql("SELECT t.* FROM t1 JOIN LATERAL stack(1, c1, c2) t(x, y)")
+      )
+      checkAnswer(
+        t1.join(t3, $"t1.c1" === $"t3.c1")
+          .lateralJoin(
+            spark.tvf.stack(lit(1), $"t1.c2".outer(), $"t3.c2".outer())
+          ).select($"col0", $"col1"),
+        sql("SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1 JOIN LATERAL stack(1, 
t1.c2, t3.c2) t")
+      )
+    }
+  }
+
   test("collations") {
     val actual = spark.tvf.collations()
     val expected = spark.sql("SELECT * FROM collations()")
@@ -235,6 +451,28 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
     checkAnswer(actual6, expected6)
   }
 
+  test("variant_explode - lateral join") {
+    withView("variant_table") {
+      sql(
+        """
+          |CREATE VIEW variant_table(id, v) AS
+          |SELECT id, parse_json(v) AS v FROM VALUES
+          |(0, '["hello", "world"]'), (1, '{"a": true, "b": 3.14}'),
+          |(2, '[]'), (3, '{}'),
+          |(4, NULL), (5, '1')
+          |AS t(id, v)
+          |""".stripMargin)
+      val variantTable = spark.table("variant_table")
+
+      checkAnswer(
+        variantTable.as("t1").lateralJoin(
+          spark.tvf.variant_explode($"v".outer())
+        ).select($"id", $"pos", $"key", $"value"),
+        sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL 
variant_explode(v) AS t")
+      )
+    }
+  }
+
   test("variant_explode_outer") {
     val actual1 = spark.tvf.variant_explode_outer(parse_json(lit("""["hello", 
"world"]""")))
     val expected1 = spark.sql(
@@ -265,4 +503,26 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
     val expected6 = spark.sql("SELECT * FROM 
variant_explode_outer(parse_json('1'))")
     checkAnswer(actual6, expected6)
   }
+
+  test("variant_explode_outer - lateral join") {
+    withView("variant_table") {
+      sql(
+        """
+          |CREATE VIEW variant_table(id, v) AS
+          |SELECT id, parse_json(v) AS v FROM VALUES
+          |(0, '["hello", "world"]'), (1, '{"a": true, "b": 3.14}'),
+          |(2, '[]'), (3, '{}'),
+          |(4, NULL), (5, '1')
+          |AS t(id, v)
+          |""".stripMargin)
+      val variantTable = spark.table("variant_table")
+
+      checkAnswer(
+        variantTable.as("t1").lateralJoin(
+          spark.tvf.variant_explode_outer($"v".outer())
+        ).select($"id", $"pos", $"key", $"value"),
+        sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL 
variant_explode_outer(v) AS t")
+      )
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to