This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d67ca731f19c [SPARK-50132][SQL][PYTHON] Add DataFrame API for Lateral
Joins
d67ca731f19c is described below
commit d67ca731f19cc571e8d69245f4837c0cf28b83ae
Author: Takuya Ueshin <[email protected]>
AuthorDate: Fri Dec 6 10:40:50 2024 +0900
[SPARK-50132][SQL][PYTHON] Add DataFrame API for Lateral Joins
### What changes were proposed in this pull request?
Adds DataFrame API for Lateral Joins.
#### Examples:
For the following DataFrames `customers` and `orders`:
```py
>>> customers.printSchema()
root
|-- customer_id: long (nullable = true)
|-- name: string (nullable = true)
>>> orders.printSchema()
root
|-- order_id: long (nullable = true)
|-- customer_id: long (nullable = true)
|-- order_date: string (nullable = true)
|-- items: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- product: string (nullable = true)
| | |-- quantity: long (nullable = true)
```
##### Using TVF
```py
# select customer_id, name, order_id, order_date, product, quantity
# from customers join orders using (customer_id) join lateral (select col.*
from explode(items))
# order by customer_id, order_id, product
customers.join(orders, "customer_id").lateralJoin(
spark.tvf.explode(sf.col("items").outer()).select("col.*")
).select(
"customer_id", "name", "order_id", "order_date", "product", "quantity"
).orderBy("customer_id", "order_id", "product").show()
```
##### Using Subquery
```py
# select c.customer_id, name, order_id, order_date
# from customers c left join lateral (
# select * from orders o where o.customer_id = c.customer_id order by
order_date desc limit 2
# )
# order by customer_id, order_id
customers.alias("c").lateralJoin(
orders.alias("o")
.where(sf.col("o.customer_id") == sf.col("c.customer_id").outer())
.orderBy(sf.col("order_date").desc())
.limit(2),
how="left"
).select(
"c.customer_id", "name", "order_id", "order_date"
).orderBy("customer_id", "order_id").show()
```
### Why are the changes needed?
Lateral Join APIs are missing in DataFrame API.
### Does this PR introduce _any_ user-facing change?
Yes, new DataFrame APIs for lateral join will be available.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49033 from ueshin/issues/SPARK-50132/lateral_join.
Lead-authored-by: Takuya Ueshin <[email protected]>
Co-authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 15 +
python/pyspark/sql/classic/dataframe.py | 16 +
python/pyspark/sql/connect/dataframe.py | 13 +
python/pyspark/sql/dataframe.py | 103 ++++++
.../pyspark/sql/tests/connect/test_parity_tvf.py | 40 ++-
.../pyspark/sql/tests/connect/test_parity_udtf.py | 8 +
python/pyspark/sql/tests/test_subquery.py | 332 ++++++++++++++++++++
python/pyspark/sql/tests/test_tvf.py | 349 +++++++++++++++++++++
python/pyspark/sql/tests/test_udtf.py | 31 ++
.../scala/org/apache/spark/sql/api/Dataset.scala | 54 ++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 32 ++
.../apache/spark/sql/DataFrameSubquerySuite.scala | 287 +++++++++++++++++
.../sql/DataFrameTableValuedFunctionsSuite.scala | 260 +++++++++++++++
13 files changed, 1539 insertions(+), 1 deletion(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 631e9057f8d1..eb166a1e8003 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -383,6 +383,21 @@ class Dataset[T] private[sql] (
}
}
+ // TODO(SPARK-50134): Support Lateral Join API in Spark Connect
+ // scalastyle:off not.implemented.error.usage
+ /** @inheritdoc */
+ def lateralJoin(right: DS[_]): DataFrame = ???
+
+ /** @inheritdoc */
+ def lateralJoin(right: DS[_], joinExprs: Column): DataFrame = ???
+
+ /** @inheritdoc */
+ def lateralJoin(right: DS[_], joinType: String): DataFrame = ???
+
+ /** @inheritdoc */
+ def lateralJoin(right: DS[_], joinExprs: Column, joinType: String):
DataFrame = ???
+ // scalastyle:on not.implemented.error.usage
+
override protected def sortInternal(global: Boolean, sortCols: Seq[Column]):
Dataset[T] = {
val sortExprs = sortCols.map { c =>
ColumnNodeToProtoConverter(c.sortOrder).getSortOrder
diff --git a/python/pyspark/sql/classic/dataframe.py
b/python/pyspark/sql/classic/dataframe.py
index 169755c75390..05c19913adf3 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -715,6 +715,22 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
jdf = self._jdf.join(other._jdf, on, how)
return DataFrame(jdf, self.sparkSession)
+ def lateralJoin(
+ self,
+ other: ParentDataFrame,
+ on: Optional[Column] = None,
+ how: Optional[str] = None,
+ ) -> ParentDataFrame:
+ if on is None and how is None:
+ jdf = self._jdf.lateralJoin(other._jdf)
+ elif on is None:
+ jdf = self._jdf.lateralJoin(other._jdf, how)
+ elif how is None:
+ jdf = self._jdf.lateralJoin(other._jdf, on._jc)
+ else:
+ jdf = self._jdf.lateralJoin(other._jdf, on._jc, how)
+ return DataFrame(jdf, self.sparkSession)
+
# TODO(SPARK-22947): Fix the DataFrame API.
def _joinAsOf(
self,
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index e85efeb592df..124ce5e0d39a 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -686,6 +686,18 @@ class DataFrame(ParentDataFrame):
session=self._session,
)
+ def lateralJoin(
+ self,
+ other: ParentDataFrame,
+ on: Optional[Column] = None,
+ how: Optional[str] = None,
+ ) -> ParentDataFrame:
+ # TODO(SPARK-50134): Implement this method
+ raise PySparkNotImplementedError(
+ errorClass="NOT_IMPLEMENTED",
+ messageParameters={"feature": "lateralJoin()"},
+ )
+
def _joinAsOf(
self,
other: ParentDataFrame,
@@ -2265,6 +2277,7 @@ def _test() -> None:
# TODO(SPARK-50134): Support subquery in connect
del pyspark.sql.dataframe.DataFrame.scalar.__doc__
del pyspark.sql.dataframe.DataFrame.exists.__doc__
+ del pyspark.sql.dataframe.DataFrame.lateralJoin.__doc__
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.dataframe tests")
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 0ea0eef50c0f..ccb9806cc76d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2629,6 +2629,109 @@ class DataFrame:
"""
...
+ def lateralJoin(
+ self,
+ other: "DataFrame",
+ on: Optional[Column] = None,
+ how: Optional[str] = None,
+ ) -> "DataFrame":
+ """
+ Lateral joins with another :class:`DataFrame`, using the given join
expression.
+
+ A lateral join (also known as a correlated join) is a type of join
where each row from
+ one DataFrame is used as input to a subquery or a derived table that
computes a result
+ specific to that row. The right side `DataFrame` can reference columns
from the current
+ row of the left side `DataFrame`, allowing for more complex and
context-dependent results
+ than a standard join.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ other : :class:`DataFrame`
+ Right side of the join
+ on : :class:`Column`, optional
+ a join expression (Column).
+ how : str, optional
+ default ``inner``. Must be one of: ``inner``, ``cross``, ``left``,
``leftouter``,
+ and ``left_outer``.
+
+ Returns
+ -------
+ :class:`DataFrame`
+ Joined DataFrame.
+
+ Examples
+ --------
+ Setup a sample DataFrame.
+
+ >>> from pyspark.sql import functions as sf
+ >>> from pyspark.sql import Row
+ >>> customers_data = [
+ ... Row(customer_id=1, name="Alice"), Row(customer_id=2,
name="Bob"),
+ ... Row(customer_id=3, name="Charlie"), Row(customer_id=4,
name="Diana")
+ ... ]
+ >>> customers = spark.createDataFrame(customers_data)
+ >>> orders_data = [
+ ... Row(order_id=101, customer_id=1, order_date="2024-01-10",
+ ... items=[Row(product="laptop", quantity=5),
Row(product="mouse", quantity=12)]),
+ ... Row(order_id=102, customer_id=1, order_date="2024-02-15",
+ ... items=[Row(product="phone", quantity=2),
Row(product="charger", quantity=15)]),
+ ... Row(order_id=105, customer_id=1, order_date="2024-03-20",
+ ... items=[Row(product="tablet", quantity=4)]),
+ ... Row(order_id=103, customer_id=2, order_date="2024-01-12",
+ ... items=[Row(product="tablet", quantity=8)]),
+ ... Row(order_id=104, customer_id=2, order_date="2024-03-05",
+ ... items=[Row(product="laptop", quantity=7)]),
+ ... Row(order_id=106, customer_id=3, order_date="2024-04-05",
+ ... items=[Row(product="monitor", quantity=1)]),
+ ... ]
+ >>> orders = spark.createDataFrame(orders_data)
+
+ Example 1 (use TVF): Expanding Items in Each Order into Separate Rows
+
+ >>> customers.join(orders, "customer_id").lateralJoin(
+ ... spark.tvf.explode(sf.col("items").outer()).select("col.*")
+ ... ).select(
+ ... "customer_id", "name", "order_id", "order_date", "product",
"quantity"
+ ... ).orderBy("customer_id", "order_id", "product").show()
+ +-----------+-------+--------+----------+-------+--------+
+ |customer_id| name|order_id|order_date|product|quantity|
+ +-----------+-------+--------+----------+-------+--------+
+ | 1| Alice| 101|2024-01-10| laptop| 5|
+ | 1| Alice| 101|2024-01-10| mouse| 12|
+ | 1| Alice| 102|2024-02-15|charger| 15|
+ | 1| Alice| 102|2024-02-15| phone| 2|
+ | 1| Alice| 105|2024-03-20| tablet| 4|
+ | 2| Bob| 103|2024-01-12| tablet| 8|
+ | 2| Bob| 104|2024-03-05| laptop| 7|
+ | 3|Charlie| 106|2024-04-05|monitor| 1|
+ +-----------+-------+--------+----------+-------+--------+
+
+ Example 2 (use subquery): Finding the Two Most Recent Orders for
Customer
+
+ >>> customers.alias("c").lateralJoin(
+ ... orders.alias("o")
+ ... .where(sf.col("o.customer_id") ==
sf.col("c.customer_id").outer())
+ ... .orderBy(sf.col("order_date").desc())
+ ... .limit(2),
+ ... how="left"
+ ... ).select(
+ ... "c.customer_id", "name", "order_id", "order_date"
+ ... ).orderBy("customer_id", "order_id").show()
+ +-----------+-------+--------+----------+
+ |customer_id| name|order_id|order_date|
+ +-----------+-------+--------+----------+
+ | 1| Alice| 102|2024-02-15|
+ | 1| Alice| 105|2024-03-20|
+ | 2| Bob| 103|2024-01-12|
+ | 2| Bob| 104|2024-03-05|
+ | 3|Charlie| 106|2024-04-05|
+ | 4| Diana| NULL| NULL|
+ +-----------+-------+--------+----------+
+ """
+ ...
+
# TODO(SPARK-22947): Fix the DataFrame API.
@dispatch_df_method
def _joinAsOf(
diff --git a/python/pyspark/sql/tests/connect/test_parity_tvf.py
b/python/pyspark/sql/tests/connect/test_parity_tvf.py
index 61e3decf562c..c5edff02810f 100644
--- a/python/pyspark/sql/tests/connect/test_parity_tvf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_tvf.py
@@ -21,7 +21,45 @@ from pyspark.testing.connectutils import
ReusedConnectTestCase
class TVFParityTestsMixin(TVFTestsMixin, ReusedConnectTestCase):
- pass
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_explode_with_lateral_join(self):
+ super().test_explode_with_lateral_join()
+
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_explode_outer_with_lateral_join(self):
+ super().test_explode_outer_with_lateral_join()
+
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_inline_with_lateral_join(self):
+ super().test_inline_with_lateral_join()
+
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_inline_outer_with_lateral_join(self):
+ super().test_inline_outer_with_lateral_join()
+
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_json_tuple_with_lateral_join(self):
+ super().test_json_tuple_with_lateral_join()
+
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_posexplode_with_lateral_join(self):
+ super().test_posexplode_with_lateral_join()
+
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_posexplode_outer_with_lateral_join(self):
+ super().test_posexplode_outer_with_lateral_join()
+
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_stack_with_lateral_join(self):
+ super().test_stack_with_lateral_join()
+
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_variant_explode_with_lateral_join(self):
+ super().test_variant_explode_with_lateral_join()
+
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_variant_explode_outer_with_lateral_join(self):
+ super().test_variant_explode_outer_with_lateral_join()
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py
b/python/pyspark/sql/tests/connect/test_parity_udtf.py
index 6955e7377b4c..29d1718fe378 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udtf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py
@@ -85,6 +85,14 @@ class UDTFParityTests(BaseUDTFTestsMixin,
ReusedConnectTestCase):
def _add_file(self, path):
self.spark.addArtifacts(path, file=True)
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_udtf_with_lateral_join_dataframe(self):
+ super().test_udtf_with_lateral_join_dataframe()
+
+ @unittest.skip("SPARK-50134: Support Spark Connect")
+ def test_udtf_with_conditional_return_dataframe(self):
+ super().test_udtf_with_conditional_return_dataframe()
+
class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests):
@classmethod
diff --git a/python/pyspark/sql/tests/test_subquery.py
b/python/pyspark/sql/tests/test_subquery.py
index 7cc0360c3942..1b657e075c59 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -484,6 +484,338 @@ class SubqueryTestsMixin:
fragment="col",
)
+ def table1(self):
+ t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+ t1.createOrReplaceTempView("t1")
+ return self.spark.table("t1")
+
+ def table2(self):
+ t2 = self.spark.sql("VALUES (0, 2), (0, 3) AS t2(c1, c2)")
+ t2.createOrReplaceTempView("t2")
+ return self.spark.table("t2")
+
+ def table3(self):
+ t3 = self.spark.sql(
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4)) AS t3(c1, c2)"
+ )
+ t3.createOrReplaceTempView("t3")
+ return self.spark.table("t3")
+
+ def test_lateral_join_with_single_column_select(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ assertDataFrameEqual(
+
t1.lateralJoin(self.spark.range(1).select(sf.col("c1").outer())),
+ self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT c1)"""),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(t2.select(sf.col("t1.c1").outer())),
+ self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT t1.c1 FROM
t2)"""),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(t2.select(sf.col("t1.c1").outer() +
sf.col("t2.c1"))),
+ self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT t1.c1 +
t2.c1 FROM t2)"""),
+ )
+
+ def test_lateral_join_with_different_join_types(self):
+ with self.tempView("t1"):
+ t1 = self.table1()
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.range(1).select(
+ (sf.col("c1").outer() +
sf.col("c2").outer()).alias("c3")
+ ),
+ sf.col("c2") == sf.col("c3"),
+ ),
+ self.spark.sql(
+ """SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c3) ON
c2 = c3"""
+ ),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.range(1).select(
+ (sf.col("c1").outer() +
sf.col("c2").outer()).alias("c3")
+ ),
+ sf.col("c2") == sf.col("c3"),
+ "left",
+ ),
+ self.spark.sql(
+ """SELECT * FROM t1 LEFT JOIN LATERAL (SELECT c1 + c2 AS
c3) ON c2 = c3"""
+ ),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.range(1).select(
+ (sf.col("c1").outer() +
sf.col("c2").outer()).alias("c3")
+ ),
+ how="cross",
+ ),
+ self.spark.sql("""SELECT * FROM t1 CROSS JOIN LATERAL (SELECT
c1 + c2 AS c3)"""),
+ )
+
+ def test_lateral_join_with_correlated_predicates(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ t2.where(sf.col("t1.c1").outer() ==
sf.col("t2.c1")).select(sf.col("c2"))
+ ),
+ self.spark.sql(
+ """SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE
t1.c1 = t2.c1)"""
+ ),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ t2.where(sf.col("t1.c1").outer() <
sf.col("t2.c1")).select(sf.col("c2"))
+ ),
+ self.spark.sql(
+ """SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE
t1.c1 < t2.c1)"""
+ ),
+ )
+
+ def test_lateral_join_with_aggregation_and_correlated_predicates(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ t2.where(sf.col("t1.c2").outer() < sf.col("t2.c2")).select(
+ sf.max(sf.col("c2")).alias("m")
+ )
+ ),
+ self.spark.sql(
+ """
+ SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2
WHERE t1.c2 < t2.c2)
+ """
+ ),
+ )
+
+ def test_lateral_join_reference_preceding_from_clause_items(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ assertDataFrameEqual(
+ t1.join(t2).lateralJoin(
+ self.spark.range(1).select(sf.col("t1.c2").outer() +
sf.col("t2.c2").outer())
+ ),
+ self.spark.sql("""SELECT * FROM t1 JOIN t2 JOIN LATERAL
(SELECT t1.c2 + t2.c2)"""),
+ )
+
+ def test_multiple_lateral_joins(self):
+ with self.tempView("t1"):
+ t1 = self.table1()
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.range(1).select(
+ (sf.col("c1").outer() +
sf.col("c2").outer()).alias("a")
+ )
+ )
+ .lateralJoin(
+ self.spark.range(1).select(
+ (sf.col("c1").outer() -
sf.col("c2").outer()).alias("b")
+ )
+ )
+ .lateralJoin(
+ self.spark.range(1).select(
+ (sf.col("a").outer() * sf.col("b").outer()).alias("c")
+ )
+ ),
+ self.spark.sql(
+ """
+ SELECT * FROM t1,
+ LATERAL (SELECT c1 + c2 AS a),
+ LATERAL (SELECT c1 - c2 AS b),
+ LATERAL (SELECT a * b AS c)
+ """
+ ),
+ )
+
+ def test_lateral_join_in_between_regular_joins(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ t2.where(sf.col("t1.c1").outer() ==
sf.col("t2.c1")).select(sf.col("c2")),
+ how="left",
+ ).join(t1.alias("t3"), sf.col("t2.c2") == sf.col("t3.c2"),
how="left"),
+ self.spark.sql(
+ """
+ SELECT * FROM t1
+ LEFT OUTER JOIN LATERAL (SELECT c2 FROM t2 WHERE t1.c1 =
t2.c1) s
+ LEFT OUTER JOIN t1 t3 ON s.c2 = t3.c2
+ """
+ ),
+ )
+
+ def test_nested_lateral_joins(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ assertDataFrameEqual(
+
t1.lateralJoin(t2.lateralJoin(self.spark.range(1).select(sf.col("c1").outer()))),
+ self.spark.sql(
+ """SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL
(SELECT c1))"""
+ ),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.range(1)
+ .select((sf.col("c1").outer() + sf.lit(1)).alias("c1"))
+
.lateralJoin(self.spark.range(1).select(sf.col("c1").outer()))
+ ),
+ self.spark.sql(
+ """
+ SELECT * FROM t1,
+ LATERAL (SELECT * FROM (SELECT c1 + 1 AS c1), LATERAL
(SELECT c1))
+ """
+ ),
+ )
+
+ def test_scalar_subquery_inside_lateral_join(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.range(1).select(
+ sf.col("c2").outer(),
t2.select(sf.min(sf.col("c2"))).scalar()
+ )
+ ),
+ self.spark.sql(
+ """SELECT * FROM t1, LATERAL (SELECT c2, (SELECT MIN(c2)
FROM t2))"""
+ ),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.range(1)
+ .select(sf.col("c1").outer().alias("a"))
+ .select(
+ t2.where(sf.col("c1") == sf.col("a").outer())
+ .select(sf.sum(sf.col("c2")))
+ .scalar()
+ )
+ ),
+ self.spark.sql(
+ """
+ SELECT * FROM t1, LATERAL (
+ SELECT (SELECT SUM(c2) FROM t2 WHERE c1 = a) FROM
(SELECT c1 AS a)
+ )
+ """
+ ),
+ )
+
+ def test_lateral_join_inside_subquery(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ assertDataFrameEqual(
+ t1.where(
+ sf.col("c1")
+ == (
+
t2.lateralJoin(self.spark.range(1).select(sf.col("c1").outer().alias("a")))
+ .select(sf.min(sf.col("a")))
+ .scalar()
+ )
+ ),
+ self.spark.sql(
+ """
+ SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2,
LATERAL (SELECT c1 AS a))
+ """
+ ),
+ )
+ assertDataFrameEqual(
+ t1.where(
+ sf.col("c1")
+ == (
+
t2.lateralJoin(self.spark.range(1).select(sf.col("c1").outer().alias("a")))
+ .where(sf.col("c1") == sf.col("t1.c1").outer())
+ .select(sf.min(sf.col("a")))
+ .scalar()
+ )
+ ),
+ self.spark.sql(
+ """
+ SELECT * FROM t1
+ WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS
a) WHERE c1 = t1.c1)
+ """
+ ),
+ )
+
+ def test_lateral_join_with_table_valued_functions(self):
+ with self.tempView("t1", "t3"):
+ t1 = self.table1()
+ t3 = self.table3()
+
+ assertDataFrameEqual(
+ t1.lateralJoin(self.spark.tvf.range(3)),
+ self.spark.sql("""SELECT * FROM t1, LATERAL RANGE(3)"""),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.tvf.explode(sf.array(sf.col("c1").outer(),
sf.col("c2").outer()))
+ ).toDF("c1", "c2", "c3"),
+ self.spark.sql("""SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1,
c2)) t2(c3)"""),
+ )
+ assertDataFrameEqual(
+
t3.lateralJoin(self.spark.tvf.explode_outer(sf.col("c2").outer())).toDF(
+ "c1", "c2", "v"
+ ),
+ self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2)
t2(v)"""),
+ )
+ assertDataFrameEqual(
+ self.spark.tvf.explode(sf.array(sf.lit(1), sf.lit(2)))
+ .toDF("v")
+ .lateralJoin(self.spark.range(1).select((sf.col("v").outer() +
1).alias("v"))),
+ self.spark.sql(
+ """SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL
(SELECT v + 1 AS v)"""
+ ),
+ )
+
+ def
test_lateral_join_with_table_valued_functions_and_join_conditions(self):
+ with self.tempView("t1", "t3"):
+ t1 = self.table1()
+ t3 = self.table3()
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.tvf.explode(sf.array(sf.col("c1").outer(),
sf.col("c2").outer())),
+ sf.col("c1") == sf.col("col"),
+ ).toDF("c1", "c2", "c3"),
+ self.spark.sql(
+ """SELECT * FROM t1 JOIN LATERAL EXPLODE(ARRAY(c1, c2))
t(c3) ON t1.c1 = c3"""
+ ),
+ )
+ assertDataFrameEqual(
+ t3.lateralJoin(
+ self.spark.tvf.explode(sf.col("c2").outer()),
+ sf.col("c1") == sf.col("col"),
+ ).toDF("c1", "c2", "c3"),
+ self.spark.sql("""SELECT * FROM t3 JOIN LATERAL EXPLODE(c2)
t(c3) ON t3.c1 = c3"""),
+ )
+ assertDataFrameEqual(
+ t3.lateralJoin(
+ self.spark.tvf.explode(sf.col("c2").outer()),
+ sf.col("c1") == sf.col("col"),
+ "left",
+ ).toDF("c1", "c2", "c3"),
+ self.spark.sql(
+ """SELECT * FROM t3 LEFT JOIN LATERAL EXPLODE(c2) t(c3) ON
t3.c1 = c3"""
+ ),
+ )
+
class SubqueryTests(SubqueryTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/test_tvf.py
b/python/pyspark/sql/tests/test_tvf.py
index 5c709437fc4d..ea20cbf9b8f3 100644
--- a/python/pyspark/sql/tests/test_tvf.py
+++ b/python/pyspark/sql/tests/test_tvf.py
@@ -52,6 +52,37 @@ class TVFTestsMixin:
expected = self.spark.sql("""SELECT * FROM explode(null :: map<string,
int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
+ def test_explode_with_lateral_join(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+ t1.createOrReplaceTempView("t1")
+ t3 = self.spark.sql(
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4)) "
+ "AS t3(c1, c2)"
+ )
+ t3.createOrReplaceTempView("t3")
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.tvf.explode(sf.array(sf.col("c1").outer(),
sf.col("c2").outer()))
+ ).toDF("c1", "c2", "c3"),
+ self.spark.sql("""SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1,
c2)) t2(c3)"""),
+ )
+ assertDataFrameEqual(
+
t3.lateralJoin(self.spark.tvf.explode(sf.col("c2").outer())).toDF("c1", "c2",
"v"),
+ self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE(c2)
t2(v)"""),
+ )
+ assertDataFrameEqual(
+ self.spark.tvf.explode(sf.array(sf.lit(1), sf.lit(2)))
+ .toDF("v")
+ .lateralJoin(
+ self.spark.range(1).select((sf.col("v").outer() +
sf.lit(1)).alias("v2"))
+ ),
+ self.spark.sql(
+ """SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL
(SELECT v + 1 AS v2)"""
+ ),
+ )
+
def test_explode_outer(self):
actual = self.spark.tvf.explode_outer(sf.array(sf.lit(1), sf.lit(2)))
expected = self.spark.sql("""SELECT * FROM explode_outer(array(1,
2))""")
@@ -81,6 +112,43 @@ class TVFTestsMixin:
expected = self.spark.sql("""SELECT * FROM explode_outer(null ::
map<string, int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
+ def test_explode_outer_with_lateral_join(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+ t1.createOrReplaceTempView("t1")
+ t3 = self.spark.sql(
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4)) "
+ "AS t3(c1, c2)"
+ )
+ t3.createOrReplaceTempView("t3")
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.tvf.explode_outer(
+ sf.array(sf.col("c1").outer(), sf.col("c2").outer())
+ )
+ ).toDF("c1", "c2", "c3"),
+ self.spark.sql("""SELECT * FROM t1, LATERAL
EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)"""),
+ )
+ assertDataFrameEqual(
+
t3.lateralJoin(self.spark.tvf.explode_outer(sf.col("c2").outer())).toDF(
+ "c1", "c2", "v"
+ ),
+ self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2)
t2(v)"""),
+ )
+ assertDataFrameEqual(
+ self.spark.tvf.explode_outer(sf.array(sf.lit(1), sf.lit(2)))
+ .toDF("v")
+ .lateralJoin(
+ self.spark.range(1).select((sf.col("v").outer() +
sf.lit(1)).alias("v2"))
+ ),
+ self.spark.sql(
+ """
+ SELECT * FROM EXPLODE_OUTER(ARRAY(1, 2)) t(v), LATERAL
(SELECT v + 1 AS v2)
+ """
+ ),
+ )
+
def test_inline(self):
actual = self.spark.tvf.inline(
sf.array(sf.struct(sf.lit(1), sf.lit("a")), sf.struct(sf.lit(2),
sf.lit("b")))
@@ -107,6 +175,35 @@ class TVFTestsMixin:
)
assertDataFrameEqual(actual=actual, expected=expected)
+ def test_inline_with_lateral_join(self):
+ with self.tempView("array_struct"):
+ array_struct = self.spark.sql(
+ """
+ VALUES
+ (1, ARRAY(STRUCT(1, 'a'), STRUCT(2, 'b'))),
+ (2, ARRAY()),
+ (3, ARRAY(STRUCT(3, 'c'))) AS array_struct(id, arr)
+ """
+ )
+ array_struct.createOrReplaceTempView("array_struct")
+
+ assertDataFrameEqual(
+
array_struct.lateralJoin(self.spark.tvf.inline(sf.col("arr").outer())),
+ self.spark.sql("""SELECT * FROM array_struct JOIN LATERAL
INLINE(arr)"""),
+ )
+ assertDataFrameEqual(
+ array_struct.lateralJoin(
+ self.spark.tvf.inline(sf.col("arr").outer()),
+ sf.col("id") == sf.col("col1"),
+ "left",
+ ).toDF("id", "arr", "k", "v"),
+ self.spark.sql(
+ """
+ SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr)
t(k, v) ON id = k
+ """
+ ),
+ )
+
def test_inline_outer(self):
actual = self.spark.tvf.inline_outer(
sf.array(sf.struct(sf.lit(1), sf.lit("a")), sf.struct(sf.lit(2),
sf.lit("b")))
@@ -137,6 +234,35 @@ class TVFTestsMixin:
)
assertDataFrameEqual(actual=actual, expected=expected)
+ def test_inline_outer_with_lateral_join(self):
+ with self.tempView("array_struct"):
+ array_struct = self.spark.sql(
+ """
+ VALUES
+ (1, ARRAY(STRUCT(1, 'a'), STRUCT(2, 'b'))),
+ (2, ARRAY()),
+ (3, ARRAY(STRUCT(3, 'c'))) AS array_struct(id, arr)
+ """
+ )
+ array_struct.createOrReplaceTempView("array_struct")
+
+ assertDataFrameEqual(
+
array_struct.lateralJoin(self.spark.tvf.inline_outer(sf.col("arr").outer())),
+ self.spark.sql("""SELECT * FROM array_struct JOIN LATERAL
INLINE_OUTER(arr)"""),
+ )
+ assertDataFrameEqual(
+ array_struct.lateralJoin(
+ self.spark.tvf.inline_outer(sf.col("arr").outer()),
+ sf.col("id") == sf.col("col1"),
+ "left",
+ ).toDF("id", "arr", "k", "v"),
+ self.spark.sql(
+ """
+ SELECT * FROM array_struct LEFT JOIN LATERAL
INLINE_OUTER(arr) t(k, v) ON id = k
+ """
+ ),
+ )
+
def test_json_tuple(self):
actual = self.spark.tvf.json_tuple(sf.lit('{"a":1, "b":2}'),
sf.lit("a"), sf.lit("b"))
expected = self.spark.sql("""SELECT json_tuple('{"a":1, "b":2}', 'a',
'b')""")
@@ -151,6 +277,64 @@ class TVFTestsMixin:
messageParameters={"item": "field"},
)
+ def test_json_tuple_with_lateral_join(self):
+ with self.tempView("json_table"):
+ json_table = self.spark.sql(
+ """
+ VALUES
+ ('1', '{"f1": "1", "f2": "2", "f3": 3, "f5": 5.23}'),
+ ('2', '{"f1": "1", "f3": "3", "f2": 2, "f4": 4.01}'),
+ ('3', '{"f1": 3, "f4": "4", "f3": "3", "f2": 2, "f5": 5.01}'),
+ ('4', cast(null as string)),
+ ('5', '{"f1": null, "f5": ""}'),
+ ('6', '[invalid JSON string]') AS json_table(key, jstring)
+ """
+ )
+ json_table.createOrReplaceTempView("json_table")
+
+ assertDataFrameEqual(
+ json_table.alias("t1")
+ .lateralJoin(
+ self.spark.tvf.json_tuple(
+ sf.col("jstring").outer(),
+ sf.lit("f1"),
+ sf.lit("f2"),
+ sf.lit("f3"),
+ sf.lit("f4"),
+ sf.lit("f5"),
+ )
+ )
+ .select("key", "c0", "c1", "c2", "c3", "c4"),
+ self.spark.sql(
+ """
+ SELECT t1.key, t2.* FROM json_table t1,
+ LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4',
'f5') t2
+ """
+ ),
+ )
+ assertDataFrameEqual(
+ json_table.alias("t1")
+ .lateralJoin(
+ self.spark.tvf.json_tuple(
+ sf.col("jstring").outer(),
+ sf.lit("f1"),
+ sf.lit("f2"),
+ sf.lit("f3"),
+ sf.lit("f4"),
+ sf.lit("f5"),
+ )
+ )
+ .where(sf.col("c0").isNotNull())
+ .select("key", "c0", "c1", "c2", "c3", "c4"),
+ self.spark.sql(
+ """
+ SELECT t1.key, t2.* FROM json_table t1,
+ LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4',
'f5') t2
+ WHERE t2.c0 IS NOT NULL
+ """
+ ),
+ )
+
def test_posexplode(self):
actual = self.spark.tvf.posexplode(sf.array(sf.lit(1), sf.lit(2)))
expected = self.spark.sql("""SELECT * FROM posexplode(array(1, 2))""")
@@ -180,6 +364,39 @@ class TVFTestsMixin:
expected = self.spark.sql("""SELECT * FROM posexplode(null ::
map<string, int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
+ def test_posexplode_with_lateral_join(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+ t1.createOrReplaceTempView("t1")
+ t3 = self.spark.sql(
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4)) "
+ "AS t3(c1, c2)"
+ )
+ t3.createOrReplaceTempView("t3")
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.tvf.posexplode(sf.array(sf.col("c1").outer(),
sf.col("c2").outer()))
+ ),
+ self.spark.sql("""SELECT * FROM t1, LATERAL
POSEXPLODE(ARRAY(c1, c2))"""),
+ )
+ assertDataFrameEqual(
+
t3.lateralJoin(self.spark.tvf.posexplode(sf.col("c2").outer())),
+ self.spark.sql("""SELECT * FROM t3, LATERAL POSEXPLODE(c2)"""),
+ )
+ assertDataFrameEqual(
+ self.spark.tvf.posexplode(sf.array(sf.lit(1), sf.lit(2)))
+ .toDF("p", "v")
+ .lateralJoin(
+ self.spark.range(1).select((sf.col("v").outer() +
sf.lit(1)).alias("v2"))
+ ),
+ self.spark.sql(
+ """
+ SELECT * FROM POSEXPLODE(ARRAY(1, 2)) t(p, v), LATERAL
(SELECT v + 1 AS v2)
+ """
+ ),
+ )
+
def test_posexplode_outer(self):
actual = self.spark.tvf.posexplode_outer(sf.array(sf.lit(1),
sf.lit(2)))
expected = self.spark.sql("""SELECT * FROM posexplode_outer(array(1,
2))""")
@@ -209,11 +426,93 @@ class TVFTestsMixin:
expected = self.spark.sql("""SELECT * FROM posexplode_outer(null ::
map<string, int>)""")
assertDataFrameEqual(actual=actual, expected=expected)
+ def test_posexplode_outer_with_lateral_join(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+ t1.createOrReplaceTempView("t1")
+ t3 = self.spark.sql(
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4)) "
+ "AS t3(c1, c2)"
+ )
+ t3.createOrReplaceTempView("t3")
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.tvf.posexplode_outer(
+ sf.array(sf.col("c1").outer(), sf.col("c2").outer())
+ )
+ ),
+ self.spark.sql("""SELECT * FROM t1, LATERAL
POSEXPLODE_OUTER(ARRAY(c1, c2))"""),
+ )
+ assertDataFrameEqual(
+
t3.lateralJoin(self.spark.tvf.posexplode_outer(sf.col("c2").outer())),
+ self.spark.sql("""SELECT * FROM t3, LATERAL
POSEXPLODE_OUTER(c2)"""),
+ )
+ assertDataFrameEqual(
+ self.spark.tvf.posexplode_outer(sf.array(sf.lit(1), sf.lit(2)))
+ .toDF("p", "v")
+ .lateralJoin(
+ self.spark.range(1).select((sf.col("v").outer() +
sf.lit(1)).alias("v2"))
+ ),
+ self.spark.sql(
+ """
+ SELECT * FROM POSEXPLODE_OUTER(ARRAY(1, 2)) t(p, v),
+ LATERAL (SELECT v + 1 AS v2)
+ """
+ ),
+ )
+
def test_stack(self):
actual = self.spark.tvf.stack(sf.lit(2), sf.lit(1), sf.lit(2),
sf.lit(3))
expected = self.spark.sql("""SELECT * FROM stack(2, 1, 2, 3)""")
assertDataFrameEqual(actual=actual, expected=expected)
+ def test_stack_with_lateral_join(self):
+ with self.tempView("t1", "t3"):
+ t1 = self.spark.sql("VALUES (0, 1), (1, 2) AS t1(c1, c2)")
+ t1.createOrReplaceTempView("t1")
+ t3 = self.spark.sql(
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4)) "
+ "AS t3(c1, c2)"
+ )
+ t3.createOrReplaceTempView("t3")
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.tvf.stack(
+ sf.lit(2),
+ sf.lit("Key"),
+ sf.col("c1").outer(),
+ sf.lit("Value"),
+ sf.col("c2").outer(),
+ )
+ ).select("col0", "col1"),
+ self.spark.sql(
+ """SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1,
'Value', c2) t"""
+ ),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.tvf.stack(sf.lit(1), sf.col("c1").outer(),
sf.col("c2").outer())
+ ).select("col0", "col1"),
+ self.spark.sql("""SELECT t.* FROM t1 JOIN LATERAL stack(1, c1,
c2) t"""),
+ )
+ assertDataFrameEqual(
+ t1.join(t3, sf.col("t1.c1") == sf.col("t3.c1"))
+ .lateralJoin(
+ self.spark.tvf.stack(
+ sf.lit(1), sf.col("t1.c2").outer(),
sf.col("t3.c2").outer()
+ )
+ )
+ .select("col0", "col1"),
+ self.spark.sql(
+ """
+ SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1
+ JOIN LATERAL stack(1, t1.c2, t3.c2) t
+ """
+ ),
+ )
+
def test_collations(self):
actual = self.spark.tvf.collations()
expected = self.spark.sql("""SELECT * FROM collations()""")
@@ -256,6 +555,31 @@ class TVFTestsMixin:
expected = self.spark.sql("""SELECT * FROM
variant_explode(parse_json('1'))""")
assertDataFrameEqual(actual=actual, expected=expected)
+ def test_variant_explode_with_lateral_join(self):
+ with self.tempView("variant_table"):
+ variant_table = self.spark.sql(
+ """
+ SELECT id, parse_json(v) AS v FROM VALUES
+ (0, '["hello", "world"]'), (1, '{"a": true, "b": 3.14}'),
+ (2, '[]'), (3, '{}'),
+ (4, NULL), (5, '1')
+ AS t(id, v)
+ """
+ )
+ variant_table.createOrReplaceTempView("variant_table")
+
+ assertDataFrameEqual(
+ variant_table.alias("t1")
+
.lateralJoin(self.spark.tvf.variant_explode(sf.col("v").outer()))
+ .select("id", "pos", "key", "value"),
+ self.spark.sql(
+ """
+ SELECT t1.id, t.* FROM variant_table AS t1,
+ LATERAL variant_explode(v) AS t
+ """
+ ),
+ )
+
def test_variant_explode_outer(self):
actual =
self.spark.tvf.variant_explode_outer(sf.parse_json(sf.lit('["hello",
"world"]')))
expected = self.spark.sql(
@@ -290,6 +614,31 @@ class TVFTestsMixin:
expected = self.spark.sql("""SELECT * FROM
variant_explode_outer(parse_json('1'))""")
assertDataFrameEqual(actual=actual, expected=expected)
+ def test_variant_explode_outer_with_lateral_join(self):
+ with self.tempView("variant_table"):
+ variant_table = self.spark.sql(
+ """
+ SELECT id, parse_json(v) AS v FROM VALUES
+ (0, '["hello", "world"]'), (1, '{"a": true, "b": 3.14}'),
+ (2, '[]'), (3, '{}'),
+ (4, NULL), (5, '1')
+ AS t(id, v)
+ """
+ )
+ variant_table.createOrReplaceTempView("variant_table")
+
+ assertDataFrameEqual(
+ variant_table.alias("t1")
+
.lateralJoin(self.spark.tvf.variant_explode_outer(sf.col("v").outer()))
+ .select("id", "pos", "key", "value"),
+ self.spark.sql(
+ """
+ SELECT t1.id, t.* FROM variant_table AS t1,
+ LATERAL variant_explode_outer(v) AS t
+ """
+ ),
+ )
+
class TVFTests(TVFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index 8447edfbbb15..31cd4c80370e 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -31,6 +31,7 @@ from pyspark.errors import (
from pyspark.util import PythonEvalType
from pyspark.sql.functions import (
array,
+ col,
create_map,
array,
lit,
@@ -155,6 +156,22 @@ class BaseUDTFTestsMixin:
)
assertDataFrameEqual(df, expected)
+ def test_udtf_with_lateral_join_dataframe(self):
+ @udtf(returnType="a: int, b: int, c: int")
+ class TestUDTF:
+ def eval(self, a: int, b: int) -> Iterator:
+ yield a, b, a + b
+ yield a, b, a - b
+
+ self.spark.udtf.register("testUDTF", TestUDTF)
+
+ assertDataFrameEqual(
+ self.spark.sql("values (0, 1), (1, 2) t(a, b)").lateralJoin(
+ TestUDTF(col("a").outer(), col("b").outer())
+ ),
+ self.spark.sql("SELECT * FROM values (0, 1), (1, 2) t(a, b),
LATERAL testUDTF(a, b)"),
+ )
+
def test_udtf_eval_with_return_stmt(self):
class TestUDTF:
def eval(self, a: int, b: int):
@@ -239,6 +256,20 @@ class BaseUDTFTestsMixin:
[Row(id=6, a=6), Row(id=7, a=7)],
)
+ def test_udtf_with_conditional_return_dataframe(self):
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a: int):
+ if a > 5:
+ yield a,
+
+ self.spark.udtf.register("test_udtf", TestUDTF)
+
+ assertDataFrameEqual(
+ self.spark.range(8).lateralJoin(TestUDTF(col("id").outer())),
+ self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL
test_udtf(id)"),
+ )
+
def test_udtf_with_empty_yield(self):
@udtf(returnType="a: int")
class TestUDTF:
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
index 9d41998f11dc..20c181e7b9cf 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
@@ -859,6 +859,60 @@ abstract class Dataset[T] extends Serializable {
joinWith(other, condition, "inner")
}
+ /**
+ * Lateral join with another `DataFrame`.
+ *
+ * Behaves as an JOIN LATERAL.
+ *
+ * @param right
+ * Right side of the join operation.
+ * @group untypedrel
+ * @since 4.0.0
+ */
+ def lateralJoin(right: DS[_]): Dataset[Row]
+
+ /**
+ * Lateral join with another `DataFrame`.
+ *
+ * Behaves as an JOIN LATERAL.
+ *
+ * @param right
+ * Right side of the join operation.
+ * @param joinExprs
+ * Join expression.
+ * @group untypedrel
+ * @since 4.0.0
+ */
+ def lateralJoin(right: DS[_], joinExprs: Column): Dataset[Row]
+
+ /**
+ * Lateral join with another `DataFrame`.
+ *
+ * @param right
+ * Right side of the join operation.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`,
`cross`, `left`,
+ * `leftouter`, `left_outer`.
+ * @group untypedrel
+ * @since 4.0.0
+ */
+ def lateralJoin(right: DS[_], joinType: String): Dataset[Row]
+
+ /**
+ * Lateral join with another `DataFrame`.
+ *
+ * @param right
+ * Right side of the join operation.
+ * @param joinExprs
+ * Join expression.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`,
`cross`, `left`,
+ * `leftouter`, `left_outer`.
+ * @group untypedrel
+ * @since 4.0.0
+ */
+ def lateralJoin(right: DS[_], joinExprs: Column, joinType: String):
Dataset[Row]
+
protected def sortInternal(global: Boolean, sortExprs: Seq[Column]):
Dataset[T]
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 846d97b25786..8726ee268a47 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -709,6 +709,38 @@ class Dataset[T] private[sql](
new Dataset(sparkSession, joinWith, joinEncoder)
}
+ private[sql] def lateralJoin(
+ right: DS[_], joinExprs: Option[Column], joinType: JoinType): DataFrame
= {
+ withPlan {
+ LateralJoin(
+ logicalPlan,
+ LateralSubquery(right.logicalPlan),
+ joinType,
+ joinExprs.map(_.expr)
+ )
+ }
+ }
+
+ /** @inheritdoc */
+ def lateralJoin(right: DS[_]): DataFrame = {
+ lateralJoin(right, None, Inner)
+ }
+
+ /** @inheritdoc */
+ def lateralJoin(right: DS[_], joinExprs: Column): DataFrame = {
+ lateralJoin(right, Some(joinExprs), Inner)
+ }
+
+ /** @inheritdoc */
+ def lateralJoin(right: DS[_], joinType: String): DataFrame = {
+ lateralJoin(right, None, JoinType(joinType))
+ }
+
+ /** @inheritdoc */
+ def lateralJoin(right: DS[_], joinExprs: Column, joinType: String):
DataFrame = {
+ lateralJoin(right, Some(joinExprs), JoinType(joinType))
+ }
+
// TODO(SPARK-22947): Fix the DataFrame API.
private[sql] def joinAsOf(
other: Dataset[_],
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index 2420ad34d9ba..cd425162fb01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -377,4 +377,291 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
Array(ExpectedContext(fragment = "$", callSitePattern =
getCurrentClassCallSitePattern))
)
}
+
+ private def table1() = {
+ sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+ spark.table("t1")
+ }
+
+ private def table2() = {
+ sql("CREATE VIEW t2(c1, c2) AS VALUES (0, 2), (0, 3)")
+ spark.table("t2")
+ }
+
+ private def table3() = {
+ sql("CREATE VIEW t3(c1, c2) AS " +
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null, ARRAY(4))")
+ spark.table("t3")
+ }
+
+ test("lateral join with single column select") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkAnswer(
+ t1.lateralJoin(spark.range(1).select($"c1".outer())),
+ sql("SELECT * FROM t1, LATERAL (SELECT c1)")
+ )
+ checkAnswer(
+ t1.lateralJoin(t2.select($"c1")),
+ sql("SELECT * FROM t1, LATERAL (SELECT c1 FROM t2)")
+ )
+ checkAnswer(
+ t1.lateralJoin(t2.select($"t1.c1".outer())),
+ sql("SELECT * FROM t1, LATERAL (SELECT t1.c1 FROM t2)")
+ )
+ checkAnswer(
+ t1.lateralJoin(t2.select($"t1.c1".outer() + $"t2.c1")),
+ sql("SELECT * FROM t1, LATERAL (SELECT t1.c1 + t2.c1 FROM t2)")
+ )
+ }
+ }
+
+ test("lateral join with different join types") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(
+ t1.lateralJoin(
+ spark.range(1).select(($"c1".outer() + $"c2".outer()).as("c3")),
+ $"c2" === $"c3"),
+ sql("SELECT * FROM t1 JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 = c3")
+ )
+ checkAnswer(
+ t1.lateralJoin(
+ spark.range(1).select(($"c1".outer() + $"c2".outer()).as("c3")),
+ $"c2" === $"c3",
+ "left"),
+ sql("SELECT * FROM t1 LEFT JOIN LATERAL (SELECT c1 + c2 AS c3) ON c2 =
c3")
+ )
+ checkAnswer(
+ t1.lateralJoin(
+ spark.range(1).select(($"c1".outer() + $"c2".outer()).as("c3")),
+ "cross"),
+ sql("SELECT * FROM t1 CROSS JOIN LATERAL (SELECT c1 + c2 AS c3)")
+ )
+ }
+ }
+
+ test("lateral join with correlated equality / non-equality predicates") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkAnswer(
+ t1.lateralJoin(t2.where($"t1.c1".outer() === $"t2.c1").select($"c2")),
+ sql("SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c1 =
t2.c1)")
+ )
+ checkAnswer(
+ t1.lateralJoin(t2.where($"t1.c1".outer() < $"t2.c1").select($"c2")),
+ sql("SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c1 <
t2.c1)")
+ )
+ }
+ }
+
+ test("lateral join with aggregation and correlated non-equality predicates")
{
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkAnswer(
+ t1.lateralJoin(t2.where($"t1.c2".outer() <
$"t2.c2").select(max($"c2").as("m"))),
+ sql("SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE
t1.c2 < t2.c2)")
+ )
+ }
+ }
+
+ test("lateral join can reference preceding FROM clause items") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkAnswer(
+ t1.join(t2).lateralJoin(
+ spark.range(1).select($"t1.c2".outer() + $"t2.c2".outer())
+ ),
+ sql("SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2)")
+ )
+ }
+ }
+
+ test("multiple lateral joins") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(
+ t1.lateralJoin(
+ spark.range(1).select(($"c1".outer() + $"c2".outer()).as("a"))
+ ).lateralJoin(
+ spark.range(1).select(($"c1".outer() - $"c2".outer()).as("b"))
+ ).lateralJoin(
+ spark.range(1).select(($"a".outer() * $"b".outer()).as("c"))
+ ),
+ sql(
+ """
+ |SELECT * FROM t1,
+ |LATERAL (SELECT c1 + c2 AS a),
+ |LATERAL (SELECT c1 - c2 AS b),
+ |LATERAL (SELECT a * b AS c)
+ |""".stripMargin)
+ )
+ }
+ }
+
+ test("lateral join in between regular joins") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkAnswer(
+ t1.lateralJoin(
+ t2.where($"t1.c1".outer() === $"t2.c1").select($"c2"), "left"
+ ).join(t1.as("t3"), $"t2.c2" === $"t3.c2", "left"),
+ sql(
+ """
+ |SELECT * FROM t1
+ |LEFT OUTER JOIN LATERAL (SELECT c2 FROM t2 WHERE t1.c1 = t2.c1) s
+ |LEFT OUTER JOIN t1 t3 ON s.c2 = t3.c2
+ |""".stripMargin)
+ )
+ }
+ }
+
+ test("nested lateral joins") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkAnswer(
+ t1.lateralJoin(
+ t2.lateralJoin(spark.range(1).select($"c1".outer()))
+ ),
+ sql("SELECT * FROM t1, LATERAL (SELECT * FROM t2, LATERAL (SELECT
c1))")
+ )
+ checkAnswer(
+ t1.lateralJoin(
+ spark.range(1).select(($"c1".outer() + lit(1)).as("c1"))
+ .lateralJoin(spark.range(1).select($"c1".outer()))
+ ),
+ sql("SELECT * FROM t1, LATERAL (SELECT * FROM (SELECT c1 + 1 AS c1),
LATERAL (SELECT c1))")
+ )
+ }
+ }
+
+ test("scalar subquery inside lateral join") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ // uncorrelated
+ checkAnswer(
+ t1.lateralJoin(
+ spark.range(1).select(
+ $"c2".outer(),
+ t2.select(min($"c2")).scalar()
+ )
+ ),
+ sql("SELECT * FROM t1, LATERAL (SELECT c2, (SELECT MIN(c2) FROM t2))")
+ )
+
+ // correlated
+ checkAnswer(
+ t1.lateralJoin(
+ spark.range(1).select($"c1".outer().as("a"))
+ .select(t2.where($"c1" ===
$"a".outer()).select(sum($"c2")).scalar())
+ ),
+ sql(
+ """
+ |SELECT * FROM t1, LATERAL (
+ | SELECT (SELECT SUM(c2) FROM t2 WHERE c1 = a) FROM (SELECT c1
AS a)
+ |)
+ |""".stripMargin)
+ )
+ }
+ }
+
+ test("lateral join inside subquery") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ // uncorrelated
+ checkAnswer(
+ t1.where(
+ $"c1" === t2.lateralJoin(
+ spark.range(1).select($"c1".outer().as("a"))).select(min($"a")
+ ).scalar()
+ ),
+ sql("SELECT * FROM t1 WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL
(SELECT c1 AS a))")
+ )
+ // correlated
+ checkAnswer(
+ t1.where(
+ $"c1" === t2.lateralJoin(
+ spark.range(1).select($"c1".outer().as("a")))
+ .where($"c1" === $"t1.c1".outer())
+ .select(min($"a"))
+ .scalar()
+ ),
+ sql("SELECT * FROM t1 " +
+ "WHERE c1 = (SELECT MIN(a) FROM t2, LATERAL (SELECT c1 AS a) WHERE
c1 = t1.c1)")
+ )
+ }
+ }
+
+ test("lateral join with table-valued functions") {
+ withView("t1", "t3") {
+ val t1 = table1()
+ val t3 = table3()
+
+ checkAnswer(
+ t1.lateralJoin(spark.tvf.range(3)),
+ sql("SELECT * FROM t1, LATERAL RANGE(3)")
+ )
+ checkAnswer(
+ t1.lateralJoin(spark.tvf.explode(array($"c1".outer(), $"c2".outer()))),
+ sql("SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)")
+ )
+ checkAnswer(
+ t3.lateralJoin(spark.tvf.explode_outer($"c2".outer())),
+ sql("SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)")
+ )
+ checkAnswer(
+ spark.tvf.explode(array(lit(1), lit(2))).toDF("v")
+ .lateralJoin(spark.range(1).select($"v".outer() + 1)),
+ sql("SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL (SELECT v + 1)")
+ )
+ }
+ }
+
+ test("lateral join with table-valued functions and join conditions") {
+ withView("t1", "t3") {
+ val t1 = table1()
+ val t3 = table3()
+
+ checkAnswer(
+ t1.lateralJoin(
+ spark.tvf.explode(array($"c1".outer(), $"c2".outer())),
+ $"c1" === $"col"
+ ),
+ sql("SELECT * FROM t1 JOIN LATERAL EXPLODE(ARRAY(c1, c2)) t(c3) ON
t1.c1 = c3")
+ )
+ checkAnswer(
+ t3.lateralJoin(
+ spark.tvf.explode($"c2".outer()),
+ $"c1" === $"col"
+ ),
+ sql("SELECT * FROM t3 JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = c3")
+ )
+ checkAnswer(
+ t3.lateralJoin(
+ spark.tvf.explode($"c2".outer()),
+ $"c1" === $"col",
+ "left"
+ ),
+ sql("SELECT * FROM t3 LEFT JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 =
c3")
+ )
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
index c2f53ff56d1a..4f2cd275ffdf 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
class DataFrameTableValuedFunctionsSuite extends QueryTest with
SharedSparkSession {
+ import testImplicits._
test("explode") {
val actual1 = spark.tvf.explode(array(lit(1), lit(2)))
@@ -50,6 +51,30 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
checkAnswer(actual6, expected6)
}
+ test("explode - lateral join") {
+ withView("t1", "t3") {
+ sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+ sql("CREATE VIEW t3(c1, c2) AS " +
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4))")
+ val t1 = spark.table("t1")
+ val t3 = spark.table("t3")
+
+ checkAnswer(
+ t1.lateralJoin(spark.tvf.explode(array($"c1".outer(), $"c2".outer()))),
+ sql("SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)")
+ )
+ checkAnswer(
+ t3.lateralJoin(spark.tvf.explode($"c2".outer())),
+ sql("SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v)")
+ )
+ checkAnswer(
+ spark.tvf.explode(array(lit(1), lit(2))).toDF("v")
+ .lateralJoin(spark.range(1).select($"v".outer() + lit(1))),
+ sql("SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL (SELECT v + 1)")
+ )
+ }
+ }
+
test("explode_outer") {
val actual1 = spark.tvf.explode_outer(array(lit(1), lit(2)))
val expected1 = spark.sql("SELECT * FROM explode_outer(array(1, 2))")
@@ -78,6 +103,30 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
checkAnswer(actual6, expected6)
}
+ test("explode_outer - lateral join") {
+ withView("t1", "t3") {
+ sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+ sql("CREATE VIEW t3(c1, c2) AS " +
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4))")
+ val t1 = spark.table("t1")
+ val t3 = spark.table("t3")
+
+ checkAnswer(
+ t1.lateralJoin(spark.tvf.explode_outer(array($"c1".outer(),
$"c2".outer()))),
+ sql("SELECT * FROM t1, LATERAL EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)")
+ )
+ checkAnswer(
+ t3.lateralJoin(spark.tvf.explode_outer($"c2".outer())),
+ sql("SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)")
+ )
+ checkAnswer(
+ spark.tvf.explode_outer(array(lit(1), lit(2))).toDF("v")
+ .lateralJoin(spark.range(1).select($"v".outer() + lit(1))),
+ sql("SELECT * FROM EXPLODE_OUTER(ARRAY(1, 2)) t(v), LATERAL (SELECT v
+ 1)")
+ )
+ }
+ }
+
test("inline") {
val actual1 = spark.tvf.inline(array(struct(lit(1), lit("a")),
struct(lit(2), lit("b"))))
val expected1 = spark.sql("SELECT * FROM inline(array(struct(1, 'a'),
struct(2, 'b')))")
@@ -98,6 +147,32 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
checkAnswer(actual3, expected3)
}
+ test("inline - lateral join") {
+ withView("array_struct") {
+ sql(
+ """
+ |CREATE VIEW array_struct(id, arr) AS VALUES
+ | (1, ARRAY(STRUCT(1, 'a'), STRUCT(2, 'b'))),
+ | (2, ARRAY()),
+ | (3, ARRAY(STRUCT(3, 'c')))
+ |""".stripMargin)
+ val arrayStruct = spark.table("array_struct")
+
+ checkAnswer(
+ arrayStruct.lateralJoin(spark.tvf.inline($"arr".outer())),
+ sql("SELECT * FROM array_struct JOIN LATERAL INLINE(arr)")
+ )
+ checkAnswer(
+ arrayStruct.lateralJoin(
+ spark.tvf.inline($"arr".outer()),
+ $"id" === $"col1",
+ "left"
+ ),
+ sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr) t(k, v)
ON id = k")
+ )
+ }
+ }
+
test("inline_outer") {
val actual1 = spark.tvf.inline_outer(array(struct(lit(1), lit("a")),
struct(lit(2), lit("b"))))
val expected1 = spark.sql("SELECT * FROM inline_outer(array(struct(1,
'a'), struct(2, 'b')))")
@@ -118,6 +193,32 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
checkAnswer(actual3, expected3)
}
+ test("inline_outer - lateral join") {
+ withView("array_struct") {
+ sql(
+ """
+ |CREATE VIEW array_struct(id, arr) AS VALUES
+ | (1, ARRAY(STRUCT(1, 'a'), STRUCT(2, 'b'))),
+ | (2, ARRAY()),
+ | (3, ARRAY(STRUCT(3, 'c')))
+ |""".stripMargin)
+ val arrayStruct = spark.table("array_struct")
+
+ checkAnswer(
+ arrayStruct.lateralJoin(spark.tvf.inline_outer($"arr".outer())),
+ sql("SELECT * FROM array_struct JOIN LATERAL INLINE_OUTER(arr)")
+ )
+ checkAnswer(
+ arrayStruct.lateralJoin(
+ spark.tvf.inline_outer($"arr".outer()),
+ $"id" === $"col1",
+ "left"
+ ),
+ sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE_OUTER(arr)
t(k, v) ON id = k")
+ )
+ }
+ }
+
test("json_tuple") {
val actual = spark.tvf.json_tuple(lit("""{"a":1,"b":2}"""), lit("a"),
lit("b"))
val expected = spark.sql("""SELECT * FROM json_tuple('{"a":1,"b":2}', 'a',
'b')""")
@@ -130,6 +231,43 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
assert(ex.messageParameters("functionName") == "`json_tuple`")
}
+ test("json_tuple - lateral join") {
+ withView("json_table") {
+ sql(
+ """
+ |CREATE OR REPLACE TEMP VIEW json_table(key, jstring) AS VALUES
+ | ('1', '{"f1": "1", "f2": "2", "f3": 3, "f5": 5.23}'),
+ | ('2', '{"f1": "1", "f3": "3", "f2": 2, "f4": 4.01}'),
+ | ('3', '{"f1": 3, "f4": "4", "f3": "3", "f2": 2, "f5": 5.01}'),
+ | ('4', cast(null as string)),
+ | ('5', '{"f1": null, "f5": ""}'),
+ | ('6', '[invalid JSON string]')
+ |""".stripMargin)
+ val jsonTable = spark.table("json_table")
+
+ checkAnswer(
+ jsonTable.as("t1").lateralJoin(
+ spark.tvf.json_tuple(
+ $"t1.jstring".outer(),
+ lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5"))
+ ).select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+ sql("SELECT t1.key, t2.* FROM json_table t1, " +
+ "LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2")
+ )
+ checkAnswer(
+ jsonTable.as("t1").lateralJoin(
+ spark.tvf.json_tuple(
+ $"jstring".outer(),
+ lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5"))
+ ).where($"c0".isNotNull)
+ .select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+ sql("SELECT t1.key, t2.* FROM json_table t1, " +
+ "LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2 " +
+ "WHERE t2.c0 IS NOT NULL")
+ )
+ }
+ }
+
test("posexplode") {
val actual1 = spark.tvf.posexplode(array(lit(1), lit(2)))
val expected1 = spark.sql("SELECT * FROM posexplode(array(1, 2))")
@@ -158,6 +296,30 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
checkAnswer(actual6, expected6)
}
+ test("posexplode - lateral join") {
+ withView("t1", "t3") {
+ sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+ sql("CREATE VIEW t3(c1, c2) AS " +
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4))")
+ val t1 = spark.table("t1")
+ val t3 = spark.table("t3")
+
+ checkAnswer(
+ t1.lateralJoin(spark.tvf.posexplode(array($"c1".outer(),
$"c2".outer()))),
+ sql("SELECT * FROM t1, LATERAL POSEXPLODE(ARRAY(c1, c2))")
+ )
+ checkAnswer(
+ t3.lateralJoin(spark.tvf.posexplode($"c2".outer())),
+ sql("SELECT * FROM t3, LATERAL POSEXPLODE(c2)")
+ )
+ checkAnswer(
+ spark.tvf.posexplode(array(lit(1), lit(2))).toDF("p", "v")
+ .lateralJoin(spark.range(1).select($"v".outer() + lit(1))),
+ sql("SELECT * FROM POSEXPLODE(ARRAY(1, 2)) t(p, v), LATERAL (SELECT v
+ 1)")
+ )
+ }
+ }
+
test("posexplode_outer") {
val actual1 = spark.tvf.posexplode_outer(array(lit(1), lit(2)))
val expected1 = spark.sql("SELECT * FROM posexplode_outer(array(1, 2))")
@@ -186,12 +348,66 @@ class DataFrameTableValuedFunctionsSuite extends
QueryTest with SharedSparkSessi
checkAnswer(actual6, expected6)
}
+ test("posexplode_outer - lateral join") {
+ withView("t1", "t3") {
+ sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+ sql("CREATE VIEW t3(c1, c2) AS " +
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4))")
+ val t1 = spark.table("t1")
+ val t3 = spark.table("t3")
+
+ checkAnswer(
+ t1.lateralJoin(spark.tvf.posexplode_outer(array($"c1".outer(),
$"c2".outer()))),
+ sql("SELECT * FROM t1, LATERAL POSEXPLODE_OUTER(ARRAY(c1, c2))")
+ )
+ checkAnswer(
+ t3.lateralJoin(spark.tvf.posexplode_outer($"c2".outer())),
+ sql("SELECT * FROM t3, LATERAL POSEXPLODE_OUTER(c2)")
+ )
+ checkAnswer(
+ spark.tvf.posexplode_outer(array(lit(1), lit(2))).toDF("p", "v")
+ .lateralJoin(spark.range(1).select($"v".outer() + lit(1))),
+ sql("SELECT * FROM POSEXPLODE_OUTER(ARRAY(1, 2)) t(p, v), LATERAL
(SELECT v + 1)")
+ )
+ }
+ }
+
test("stack") {
val actual = spark.tvf.stack(lit(2), lit(1), lit(2), lit(3))
val expected = spark.sql("SELECT * FROM stack(2, 1, 2, 3)")
checkAnswer(actual, expected)
}
+ test("stack - lateral join") {
+ withView("t1", "t3") {
+ sql("CREATE VIEW t1(c1, c2) AS VALUES (0, 1), (1, 2)")
+ sql("CREATE VIEW t3(c1, c2) AS " +
+ "VALUES (0, ARRAY(0, 1)), (1, ARRAY(2)), (2, ARRAY()), (null,
ARRAY(4))")
+ val t1 = spark.table("t1")
+ val t3 = spark.table("t3")
+
+ checkAnswer(
+ t1.lateralJoin(
+ spark.tvf.stack(lit(2), lit("Key"), $"c1".outer(), lit("Value"),
$"c2".outer())
+ ).select($"col0", $"col1"),
+ sql("SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1, 'Value', c2) t")
+ )
+ checkAnswer(
+ t1.lateralJoin(
+ spark.tvf.stack(lit(1), $"c1".outer(), $"c2".outer())
+ ).select($"col0".as("x"), $"col1".as("y")),
+ sql("SELECT t.* FROM t1 JOIN LATERAL stack(1, c1, c2) t(x, y)")
+ )
+ checkAnswer(
+ t1.join(t3, $"t1.c1" === $"t3.c1")
+ .lateralJoin(
+ spark.tvf.stack(lit(1), $"t1.c2".outer(), $"t3.c2".outer())
+ ).select($"col0", $"col1"),
+ sql("SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1 JOIN LATERAL stack(1,
t1.c2, t3.c2) t")
+ )
+ }
+ }
+
test("collations") {
val actual = spark.tvf.collations()
val expected = spark.sql("SELECT * FROM collations()")
@@ -235,6 +451,28 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
checkAnswer(actual6, expected6)
}
+ test("variant_explode - lateral join") {
+ withView("variant_table") {
+ sql(
+ """
+ |CREATE VIEW variant_table(id, v) AS
+ |SELECT id, parse_json(v) AS v FROM VALUES
+ |(0, '["hello", "world"]'), (1, '{"a": true, "b": 3.14}'),
+ |(2, '[]'), (3, '{}'),
+ |(4, NULL), (5, '1')
+ |AS t(id, v)
+ |""".stripMargin)
+ val variantTable = spark.table("variant_table")
+
+ checkAnswer(
+ variantTable.as("t1").lateralJoin(
+ spark.tvf.variant_explode($"v".outer())
+ ).select($"id", $"pos", $"key", $"value"),
+ sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL
variant_explode(v) AS t")
+ )
+ }
+ }
+
test("variant_explode_outer") {
val actual1 = spark.tvf.variant_explode_outer(parse_json(lit("""["hello",
"world"]""")))
val expected1 = spark.sql(
@@ -265,4 +503,26 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
val expected6 = spark.sql("SELECT * FROM
variant_explode_outer(parse_json('1'))")
checkAnswer(actual6, expected6)
}
+
+ test("variant_explode_outer - lateral join") {
+ withView("variant_table") {
+ sql(
+ """
+ |CREATE VIEW variant_table(id, v) AS
+ |SELECT id, parse_json(v) AS v FROM VALUES
+ |(0, '["hello", "world"]'), (1, '{"a": true, "b": 3.14}'),
+ |(2, '[]'), (3, '{}'),
+ |(4, NULL), (5, '1')
+ |AS t(id, v)
+ |""".stripMargin)
+ val variantTable = spark.table("variant_table")
+
+ checkAnswer(
+ variantTable.as("t1").lateralJoin(
+ spark.tvf.variant_explode_outer($"v".outer())
+ ).select($"id", $"pos", $"key", $"value"),
+ sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL
variant_explode_outer(v) AS t")
+ )
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]