This is an automated email from the ASF dual-hosted git repository.
allisonwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cb5938363ff5 [SPARK-50075][SQL][PYTHON][CONNECT] Add DataFrame APIs
for table-valued functions
cb5938363ff5 is described below
commit cb5938363ff582b5c32d81f1ec972fdbc6eb98e9
Author: Takuya Ueshin <[email protected]>
AuthorDate: Wed Oct 30 16:43:20 2024 -0700
[SPARK-50075][SQL][PYTHON][CONNECT] Add DataFrame APIs for table-valued
functions
### What changes were proposed in this pull request?
Adds DataFrame APIs for table-valued functions.
For example:
```py
spark.tvf.range(10)
spark.tvf.explode(array(lit(1), lit(2)))
spark.tvf.explode_outer(array(lit(1), lit(2)))
spark.tvf.inline(array(struct(lit(1), lit("a")), struct(lit(2), lit("b"))))
spark.tvf.inline_outer(array(struct(lit(1), lit("a")), struct(lit(2),
lit("b"))))
spark.tvf.json_tuple(lit("""{"a":1,"b":2}"""), lit("a"), lit("b"))
spark.tvf.posexplode(array(lit(1), lit(2)))
spark.tvf.posexplode_outer(array(lit(1), lit(2)))
spark.tvf.stack(lit(2), lit(1), lit(2), lit(3))
spark.tvf.collations()
spark.tvf.sql_keywords()
spark.tvf.variant_explode(parse_json(lit("""["hello", "world"]""")))
spark.tvf.variant_explode_outer(parse_json(lit("""["hello", "world"]""")))
```
### Why are the changes needed?
DataFrame APIs for table-valued functions are missing.
### Does this PR introduce _any_ user-facing change?
Yes, this provides DataFrame APIs for table-valued functions.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48603 from ueshin/issues/SPARK-50075/tvf.
Lead-authored-by: Takuya Ueshin <[email protected]>
Co-authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Allison Wang <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 3 +
.../org/apache/spark/sql/TableValuedFunction.scala | 104 +++
.../sql/DataFrameTableValuedFunctionsSuite.scala | 271 ++++++++
dev/sparktestsupport/modules.py | 4 +
.../source/reference/pyspark.sql/spark_session.rst | 1 +
python/pyspark/sql/connect/plan.py | 14 +
python/pyspark/sql/connect/proto/relations_pb2.py | 306 ++++-----
python/pyspark/sql/connect/proto/relations_pb2.pyi | 39 ++
python/pyspark/sql/connect/session.py | 9 +
python/pyspark/sql/connect/tvf.py | 148 +++++
python/pyspark/sql/session.py | 36 ++
.../pyspark/sql/tests/connect/test_parity_tvf.py | 36 ++
python/pyspark/sql/tests/test_tvf.py | 307 +++++++++
python/pyspark/sql/tvf.py | 712 +++++++++++++++++++++
python/pyspark/testing/utils.py | 4 +-
.../org/apache/spark/sql/api/SparkSession.scala | 7 +
.../apache/spark/sql/api/TableValuedFunction.scala | 177 +++++
.../main/protobuf/spark/connect/relations.proto | 9 +
.../sql/connect/planner/SparkConnectPlanner.scala | 11 +-
.../scala/org/apache/spark/sql/SparkSession.scala | 3 +
.../org/apache/spark/sql/TableValuedFunction.scala | 98 +++
.../sql/DataFrameTableValuedFunctionsSuite.scala | 268 ++++++++
22 files changed, 2413 insertions(+), 154 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index c0590fbd1728..366a9bc3b559 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -243,6 +243,9 @@ class SparkSession private[sql] (
/** @inheritdoc */
def readStream: DataStreamReader = new DataStreamReader(this)
+ /** @inheritdoc */
+ def tvf: TableValuedFunction = new TableValuedFunction(this)
+
/** @inheritdoc */
lazy val streams: StreamingQueryManager = new StreamingQueryManager(this)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
new file mode 100644
index 000000000000..4f2687b53786
--- /dev/null
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr
+
+class TableValuedFunction(sparkSession: SparkSession) extends
api.TableValuedFunction {
+
+ /** @inheritdoc */
+ override def range(end: Long): Dataset[java.lang.Long] = {
+ sparkSession.range(end)
+ }
+
+ /** @inheritdoc */
+ override def range(start: Long, end: Long): Dataset[java.lang.Long] = {
+ sparkSession.range(start, end)
+ }
+
+ /** @inheritdoc */
+ override def range(start: Long, end: Long, step: Long):
Dataset[java.lang.Long] = {
+ sparkSession.range(start, end, step)
+ }
+
+ /** @inheritdoc */
+ override def range(
+ start: Long,
+ end: Long,
+ step: Long,
+ numPartitions: Int): Dataset[java.lang.Long] = {
+ sparkSession.range(start, end, step, numPartitions)
+ }
+
+ private def fn(name: String, args: Seq[Column]): Dataset[Row] = {
+ sparkSession.newDataFrame { builder =>
+ builder.getUnresolvedTableValuedFunctionBuilder
+ .setFunctionName(name)
+ .addAllArguments(args.map(toExpr).asJava)
+ }
+ }
+
+ /** @inheritdoc */
+ override def explode(collection: Column): Dataset[Row] =
+ fn("explode", Seq(collection))
+
+ /** @inheritdoc */
+ override def explode_outer(collection: Column): Dataset[Row] =
+ fn("explode_outer", Seq(collection))
+
+ /** @inheritdoc */
+ override def inline(input: Column): Dataset[Row] =
+ fn("inline", Seq(input))
+
+ /** @inheritdoc */
+ override def inline_outer(input: Column): Dataset[Row] =
+ fn("inline_outer", Seq(input))
+
+ /** @inheritdoc */
+ override def json_tuple(input: Column, fields: Column*): Dataset[Row] =
+ fn("json_tuple", input +: fields)
+
+ /** @inheritdoc */
+ override def posexplode(collection: Column): Dataset[Row] =
+ fn("posexplode", Seq(collection))
+
+ /** @inheritdoc */
+ override def posexplode_outer(collection: Column): Dataset[Row] =
+ fn("posexplode_outer", Seq(collection))
+
+ /** @inheritdoc */
+ override def stack(n: Column, fields: Column*): Dataset[Row] =
+ fn("stack", n +: fields)
+
+ /** @inheritdoc */
+ override def collations(): Dataset[Row] =
+ fn("collations", Seq.empty)
+
+ /** @inheritdoc */
+ override def sql_keywords(): Dataset[Row] =
+ fn("sql_keywords", Seq.empty)
+
+ /** @inheritdoc */
+ override def variant_explode(input: Column): Dataset[Row] =
+ fn("variant_explode", Seq(input))
+
+ /** @inheritdoc */
+ override def variant_explode_outer(input: Column): Dataset[Row] =
+ fn("variant_explode_outer", Seq(input))
+}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
new file mode 100644
index 000000000000..4c0357a3ed98
--- /dev/null
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession}
+
+class DataFrameTableValuedFunctionsSuite extends QueryTest with
RemoteSparkSession {
+
+ test("explode") {
+ val actual1 = spark.tvf.explode(array(lit(1), lit(2)))
+ val expected1 = spark.sql("SELECT * FROM explode(array(1, 2))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.explode(map(lit("a"), lit(1), lit("b"), lit(2)))
+ val expected2 = spark.sql("SELECT * FROM explode(map('a', 1, 'b', 2))")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.explode(array())
+ val expected3 = spark.sql("SELECT * FROM explode(array())")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.explode(map())
+ val expected4 = spark.sql("SELECT * FROM explode(map())")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.explode(lit(null).cast("array<int>"))
+ val expected5 = spark.sql("SELECT * FROM explode(null :: array<int>)")
+ checkAnswer(actual5, expected5)
+
+ val actual6 = spark.tvf.explode(lit(null).cast("map<string, int>"))
+ val expected6 = spark.sql("SELECT * FROM explode(null :: map<string,
int>)")
+ checkAnswer(actual6, expected6)
+ }
+
+ test("explode_outer") {
+ val actual1 = spark.tvf.explode_outer(array(lit(1), lit(2)))
+ val expected1 = spark.sql("SELECT * FROM explode_outer(array(1, 2))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.explode_outer(map(lit("a"), lit(1), lit("b"),
lit(2)))
+ val expected2 = spark.sql("SELECT * FROM explode_outer(map('a', 1, 'b',
2))")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.explode_outer(array())
+ val expected3 = spark.sql("SELECT * FROM explode_outer(array())")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.explode_outer(map())
+ val expected4 = spark.sql("SELECT * FROM explode_outer(map())")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.explode_outer(lit(null).cast("array<int>"))
+ val expected5 = spark.sql("SELECT * FROM explode_outer(null ::
array<int>)")
+ checkAnswer(actual5, expected5)
+
+ val actual6 = spark.tvf.explode_outer(lit(null).cast("map<string, int>"))
+ val expected6 = spark.sql("SELECT * FROM explode_outer(null :: map<string,
int>)")
+ checkAnswer(actual6, expected6)
+ }
+
+ test("inline") {
+ val actual1 = spark.tvf.inline(array(struct(lit(1), lit("a")),
struct(lit(2), lit("b"))))
+ val expected1 = spark.sql("SELECT * FROM inline(array(struct(1, 'a'),
struct(2, 'b')))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.inline(array().cast("array<struct<a:int,b:int>>"))
+ val expected2 = spark.sql("SELECT * FROM inline(array() ::
array<struct<a:int,b:int>>)")
+ checkAnswer(actual2, expected2)
+
+ val actual3 = spark.tvf.inline(
+ array(
+ named_struct(lit("a"), lit(1), lit("b"), lit(2)),
+ lit(null),
+ named_struct(lit("a"), lit(3), lit("b"), lit(4))))
+ val expected3 = spark.sql(
+ "SELECT * FROM " +
+ "inline(array(named_struct('a', 1, 'b', 2), null, named_struct('a', 3,
'b', 4)))")
+ checkAnswer(actual3, expected3)
+ }
+
+ test("inline_outer") {
+ val actual1 =
+ spark.tvf.inline_outer(array(struct(lit(1), lit("a")), struct(lit(2),
lit("b"))))
+ val expected1 = spark.sql("SELECT * FROM inline_outer(array(struct(1,
'a'), struct(2, 'b')))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 =
spark.tvf.inline_outer(array().cast("array<struct<a:int,b:int>>"))
+ val expected2 = spark.sql("SELECT * FROM inline_outer(array() ::
array<struct<a:int,b:int>>)")
+ checkAnswer(actual2, expected2)
+
+ val actual3 = spark.tvf.inline_outer(
+ array(
+ named_struct(lit("a"), lit(1), lit("b"), lit(2)),
+ lit(null),
+ named_struct(lit("a"), lit(3), lit("b"), lit(4))))
+ val expected3 = spark.sql(
+ "SELECT * FROM " +
+ "inline_outer(array(named_struct('a', 1, 'b', 2), null,
named_struct('a', 3, 'b', 4)))")
+ checkAnswer(actual3, expected3)
+ }
+
+ test("json_tuple") {
+ val actual = spark.tvf.json_tuple(lit("""{"a":1,"b":2}"""), lit("a"),
lit("b"))
+ val expected = spark.sql("""SELECT * FROM json_tuple('{"a":1,"b":2}', 'a',
'b')""")
+ checkAnswer(actual, expected)
+
+ val ex = intercept[AnalysisException] {
+ spark.tvf.json_tuple(lit("""{"a":1,"b":2}""")).collect()
+ }
+ assert(ex.errorClass.get == "WRONG_NUM_ARGS.WITHOUT_SUGGESTION")
+ assert(ex.messageParameters("functionName") == "`json_tuple`")
+ }
+
+ test("posexplode") {
+ val actual1 = spark.tvf.posexplode(array(lit(1), lit(2)))
+ val expected1 = spark.sql("SELECT * FROM posexplode(array(1, 2))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.posexplode(map(lit("a"), lit(1), lit("b"), lit(2)))
+ val expected2 = spark.sql("SELECT * FROM posexplode(map('a', 1, 'b', 2))")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.posexplode(array())
+ val expected3 = spark.sql("SELECT * FROM posexplode(array())")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.posexplode(map())
+ val expected4 = spark.sql("SELECT * FROM posexplode(map())")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.posexplode(lit(null).cast("array<int>"))
+ val expected5 = spark.sql("SELECT * FROM posexplode(null :: array<int>)")
+ checkAnswer(actual5, expected5)
+
+ val actual6 = spark.tvf.posexplode(lit(null).cast("map<string, int>"))
+ val expected6 = spark.sql("SELECT * FROM posexplode(null :: map<string,
int>)")
+ checkAnswer(actual6, expected6)
+ }
+
+ test("posexplode_outer") {
+ val actual1 = spark.tvf.posexplode_outer(array(lit(1), lit(2)))
+ val expected1 = spark.sql("SELECT * FROM posexplode_outer(array(1, 2))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.posexplode_outer(map(lit("a"), lit(1), lit("b"),
lit(2)))
+ val expected2 = spark.sql("SELECT * FROM posexplode_outer(map('a', 1, 'b',
2))")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.posexplode_outer(array())
+ val expected3 = spark.sql("SELECT * FROM posexplode_outer(array())")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.posexplode_outer(map())
+ val expected4 = spark.sql("SELECT * FROM posexplode_outer(map())")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.posexplode_outer(lit(null).cast("array<int>"))
+ val expected5 = spark.sql("SELECT * FROM posexplode_outer(null ::
array<int>)")
+ checkAnswer(actual5, expected5)
+
+ val actual6 = spark.tvf.posexplode_outer(lit(null).cast("map<string,
int>"))
+ val expected6 = spark.sql("SELECT * FROM posexplode_outer(null ::
map<string, int>)")
+ checkAnswer(actual6, expected6)
+ }
+
+ test("stack") {
+ val actual = spark.tvf.stack(lit(2), lit(1), lit(2), lit(3))
+ val expected = spark.sql("SELECT * FROM stack(2, 1, 2, 3)")
+ checkAnswer(actual, expected)
+ }
+
+ test("collations") {
+ val actual = spark.tvf.collations()
+ val expected = spark.sql("SELECT * FROM collations()")
+ checkAnswer(actual, expected)
+ }
+
+ test("sql_keywords") {
+ val actual = spark.tvf.sql_keywords()
+ val expected = spark.sql("SELECT * FROM sql_keywords()")
+ checkAnswer(actual, expected)
+ }
+
+ // TODO(SPARK-50063): Support VARIANT in Spark Connect Scala client
+ ignore("variant_explode") {
+ val actual1 = spark.tvf.variant_explode(parse_json(lit("""["hello",
"world"]""")))
+ val expected1 =
+ spark.sql("""SELECT * FROM variant_explode(parse_json('["hello",
"world"]'))""")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.variant_explode(parse_json(lit("""{"a": true, "b":
3.14}""")))
+ val expected2 =
+ spark.sql("""SELECT * FROM variant_explode(parse_json('{"a": true, "b":
3.14}'))""")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.variant_explode(parse_json(lit("[]")))
+ val expected3 = spark.sql("SELECT * FROM
variant_explode(parse_json('[]'))")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.variant_explode(parse_json(lit("{}")))
+ val expected4 = spark.sql("SELECT * FROM
variant_explode(parse_json('{}'))")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.variant_explode(lit(null).cast("variant"))
+ val expected5 = spark.sql("SELECT * FROM variant_explode(null :: variant)")
+ checkAnswer(actual5, expected5)
+
+ // not a variant object/array
+ val actual6 = spark.tvf.variant_explode(parse_json(lit("1")))
+ val expected6 = spark.sql("SELECT * FROM variant_explode(parse_json('1'))")
+ checkAnswer(actual6, expected6)
+ }
+
+ // TODO(SPARK-50063): Support VARIANT in Spark Connect Scala client
+ ignore("variant_explode_outer") {
+ val actual1 = spark.tvf.variant_explode_outer(parse_json(lit("""["hello",
"world"]""")))
+ val expected1 =
+ spark.sql("""SELECT * FROM variant_explode_outer(parse_json('["hello",
"world"]'))""")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.variant_explode_outer(parse_json(lit("""{"a":
true, "b": 3.14}""")))
+ val expected2 =
+ spark.sql("""SELECT * FROM variant_explode_outer(parse_json('{"a": true,
"b": 3.14}'))""")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.variant_explode_outer(parse_json(lit("[]")))
+ val expected3 = spark.sql("SELECT * FROM
variant_explode_outer(parse_json('[]'))")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.variant_explode_outer(parse_json(lit("{}")))
+ val expected4 = spark.sql("SELECT * FROM
variant_explode_outer(parse_json('{}'))")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.variant_explode_outer(lit(null).cast("variant"))
+ val expected5 = spark.sql("SELECT * FROM variant_explode_outer(null ::
variant)")
+ checkAnswer(actual5, expected5)
+
+ // not a variant object/array
+ val actual6 = spark.tvf.variant_explode_outer(parse_json(lit("1")))
+ val expected6 = spark.sql("SELECT * FROM
variant_explode_outer(parse_json('1'))")
+ checkAnswer(actual6, expected6)
+ }
+}
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 92b7d9aa25c0..6849ce1f3590 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -500,6 +500,7 @@ pyspark_sql = Module(
"pyspark.sql.pandas.typehints",
"pyspark.sql.pandas.utils",
"pyspark.sql.observation",
+ "pyspark.sql.tvf",
# unittests
"pyspark.sql.tests.test_arrow",
"pyspark.sql.tests.test_arrow_cogrouped_map",
@@ -547,6 +548,7 @@ pyspark_sql = Module(
"pyspark.sql.tests.test_udf",
"pyspark.sql.tests.test_udf_profiler",
"pyspark.sql.tests.test_udtf",
+ "pyspark.sql.tests.test_tvf",
"pyspark.sql.tests.test_utils",
"pyspark.sql.tests.test_resources",
"pyspark.sql.tests.plot.test_frame_plot",
@@ -1012,6 +1014,7 @@ pyspark_connect = Module(
"pyspark.sql.connect.protobuf.functions",
"pyspark.sql.connect.streaming.readwriter",
"pyspark.sql.connect.streaming.query",
+ "pyspark.sql.connect.tvf",
# sql unittests
"pyspark.sql.tests.connect.test_connect_plan",
"pyspark.sql.tests.connect.test_connect_basic",
@@ -1047,6 +1050,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_parity_udf_profiler",
"pyspark.sql.tests.connect.test_parity_memory_profiler",
"pyspark.sql.tests.connect.test_parity_udtf",
+ "pyspark.sql.tests.connect.test_parity_tvf",
"pyspark.sql.tests.connect.test_parity_pandas_udf",
"pyspark.sql.tests.connect.test_parity_pandas_map",
"pyspark.sql.tests.connect.test_parity_arrow_map",
diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst
b/python/docs/source/reference/pyspark.sql/spark_session.rst
index 4e679da59c16..859332fa5e42 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -59,6 +59,7 @@ See also :class:`SparkSession`.
SparkSession.stop
SparkSession.streams
SparkSession.table
+ SparkSession.tvf
SparkSession.udf
SparkSession.udtf
SparkSession.version
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index b8268d46b332..b387ca1d4e50 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1350,6 +1350,20 @@ class Transpose(LogicalPlan):
return plan
+class UnresolvedTableValuedFunction(LogicalPlan):
+ def __init__(self, name: str, args: Sequence[Column]):
+ super().__init__(None)
+ self._name = name
+ self._args = args
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ plan = self._create_proto_relation()
+ plan.unresolved_table_valued_function.function_name = self._name
+ for arg in self._args:
+
plan.unresolved_table_valued_function.arguments.append(arg.to_plan(session))
+ return plan
+
+
class CollectMetrics(LogicalPlan):
"""Logical plan object for a CollectMetrics operation."""
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index ee625241600f..9c3766a3552d 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import common_pb2 as
spark_dot_connect_dot_common
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto"\xa3\x1b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.Project [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto"\x9c\x1c\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.Project [...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -69,155 +69,157 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_PARSE_OPTIONSENTRY._options = None
_PARSE_OPTIONSENTRY._serialized_options = b"8\001"
_RELATION._serialized_start = 193
- _RELATION._serialized_end = 3684
- _UNKNOWN._serialized_start = 3686
- _UNKNOWN._serialized_end = 3695
- _RELATIONCOMMON._serialized_start = 3698
- _RELATIONCOMMON._serialized_end = 3840
- _SQL._serialized_start = 3843
- _SQL._serialized_end = 4321
- _SQL_ARGSENTRY._serialized_start = 4137
- _SQL_ARGSENTRY._serialized_end = 4227
- _SQL_NAMEDARGUMENTSENTRY._serialized_start = 4229
- _SQL_NAMEDARGUMENTSENTRY._serialized_end = 4321
- _WITHRELATIONS._serialized_start = 4323
- _WITHRELATIONS._serialized_end = 4440
- _READ._serialized_start = 4443
- _READ._serialized_end = 5106
- _READ_NAMEDTABLE._serialized_start = 4621
- _READ_NAMEDTABLE._serialized_end = 4813
- _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 4755
- _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 4813
- _READ_DATASOURCE._serialized_start = 4816
- _READ_DATASOURCE._serialized_end = 5093
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 4755
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4813
- _PROJECT._serialized_start = 5108
- _PROJECT._serialized_end = 5225
- _FILTER._serialized_start = 5227
- _FILTER._serialized_end = 5339
- _JOIN._serialized_start = 5342
- _JOIN._serialized_end = 6003
- _JOIN_JOINDATATYPE._serialized_start = 5681
- _JOIN_JOINDATATYPE._serialized_end = 5773
- _JOIN_JOINTYPE._serialized_start = 5776
- _JOIN_JOINTYPE._serialized_end = 5984
- _SETOPERATION._serialized_start = 6006
- _SETOPERATION._serialized_end = 6485
- _SETOPERATION_SETOPTYPE._serialized_start = 6322
- _SETOPERATION_SETOPTYPE._serialized_end = 6436
- _LIMIT._serialized_start = 6487
- _LIMIT._serialized_end = 6563
- _OFFSET._serialized_start = 6565
- _OFFSET._serialized_end = 6644
- _TAIL._serialized_start = 6646
- _TAIL._serialized_end = 6721
- _AGGREGATE._serialized_start = 6724
- _AGGREGATE._serialized_end = 7490
- _AGGREGATE_PIVOT._serialized_start = 7139
- _AGGREGATE_PIVOT._serialized_end = 7250
- _AGGREGATE_GROUPINGSETS._serialized_start = 7252
- _AGGREGATE_GROUPINGSETS._serialized_end = 7328
- _AGGREGATE_GROUPTYPE._serialized_start = 7331
- _AGGREGATE_GROUPTYPE._serialized_end = 7490
- _SORT._serialized_start = 7493
- _SORT._serialized_end = 7653
- _DROP._serialized_start = 7656
- _DROP._serialized_end = 7797
- _DEDUPLICATE._serialized_start = 7800
- _DEDUPLICATE._serialized_end = 8040
- _LOCALRELATION._serialized_start = 8042
- _LOCALRELATION._serialized_end = 8131
- _CACHEDLOCALRELATION._serialized_start = 8133
- _CACHEDLOCALRELATION._serialized_end = 8205
- _CACHEDREMOTERELATION._serialized_start = 8207
- _CACHEDREMOTERELATION._serialized_end = 8262
- _SAMPLE._serialized_start = 8265
- _SAMPLE._serialized_end = 8538
- _RANGE._serialized_start = 8541
- _RANGE._serialized_end = 8686
- _SUBQUERYALIAS._serialized_start = 8688
- _SUBQUERYALIAS._serialized_end = 8802
- _REPARTITION._serialized_start = 8805
- _REPARTITION._serialized_end = 8947
- _SHOWSTRING._serialized_start = 8950
- _SHOWSTRING._serialized_end = 9092
- _HTMLSTRING._serialized_start = 9094
- _HTMLSTRING._serialized_end = 9208
- _STATSUMMARY._serialized_start = 9210
- _STATSUMMARY._serialized_end = 9302
- _STATDESCRIBE._serialized_start = 9304
- _STATDESCRIBE._serialized_end = 9385
- _STATCROSSTAB._serialized_start = 9387
- _STATCROSSTAB._serialized_end = 9488
- _STATCOV._serialized_start = 9490
- _STATCOV._serialized_end = 9586
- _STATCORR._serialized_start = 9589
- _STATCORR._serialized_end = 9726
- _STATAPPROXQUANTILE._serialized_start = 9729
- _STATAPPROXQUANTILE._serialized_end = 9893
- _STATFREQITEMS._serialized_start = 9895
- _STATFREQITEMS._serialized_end = 10020
- _STATSAMPLEBY._serialized_start = 10023
- _STATSAMPLEBY._serialized_end = 10332
- _STATSAMPLEBY_FRACTION._serialized_start = 10224
- _STATSAMPLEBY_FRACTION._serialized_end = 10323
- _NAFILL._serialized_start = 10335
- _NAFILL._serialized_end = 10469
- _NADROP._serialized_start = 10472
- _NADROP._serialized_end = 10606
- _NAREPLACE._serialized_start = 10609
- _NAREPLACE._serialized_end = 10905
- _NAREPLACE_REPLACEMENT._serialized_start = 10764
- _NAREPLACE_REPLACEMENT._serialized_end = 10905
- _TODF._serialized_start = 10907
- _TODF._serialized_end = 10995
- _WITHCOLUMNSRENAMED._serialized_start = 10998
- _WITHCOLUMNSRENAMED._serialized_end = 11380
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 11242
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 11309
- _WITHCOLUMNSRENAMED_RENAME._serialized_start = 11311
- _WITHCOLUMNSRENAMED_RENAME._serialized_end = 11380
- _WITHCOLUMNS._serialized_start = 11382
- _WITHCOLUMNS._serialized_end = 11501
- _WITHWATERMARK._serialized_start = 11504
- _WITHWATERMARK._serialized_end = 11638
- _HINT._serialized_start = 11641
- _HINT._serialized_end = 11773
- _UNPIVOT._serialized_start = 11776
- _UNPIVOT._serialized_end = 12103
- _UNPIVOT_VALUES._serialized_start = 12033
- _UNPIVOT_VALUES._serialized_end = 12092
- _TRANSPOSE._serialized_start = 12105
- _TRANSPOSE._serialized_end = 12227
- _TOSCHEMA._serialized_start = 12229
- _TOSCHEMA._serialized_end = 12335
- _REPARTITIONBYEXPRESSION._serialized_start = 12338
- _REPARTITIONBYEXPRESSION._serialized_end = 12541
- _MAPPARTITIONS._serialized_start = 12544
- _MAPPARTITIONS._serialized_end = 12776
- _GROUPMAP._serialized_start = 12779
- _GROUPMAP._serialized_end = 13414
- _COGROUPMAP._serialized_start = 13417
- _COGROUPMAP._serialized_end = 13943
- _APPLYINPANDASWITHSTATE._serialized_start = 13946
- _APPLYINPANDASWITHSTATE._serialized_end = 14303
- _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 14306
- _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 14550
- _PYTHONUDTF._serialized_start = 14553
- _PYTHONUDTF._serialized_end = 14730
- _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_start = 14733
- _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_end = 14884
- _PYTHONDATASOURCE._serialized_start = 14886
- _PYTHONDATASOURCE._serialized_end = 14961
- _COLLECTMETRICS._serialized_start = 14964
- _COLLECTMETRICS._serialized_end = 15100
- _PARSE._serialized_start = 15103
- _PARSE._serialized_end = 15491
- _PARSE_OPTIONSENTRY._serialized_start = 4755
- _PARSE_OPTIONSENTRY._serialized_end = 4813
- _PARSE_PARSEFORMAT._serialized_start = 15392
- _PARSE_PARSEFORMAT._serialized_end = 15480
- _ASOFJOIN._serialized_start = 15494
- _ASOFJOIN._serialized_end = 15969
+ _RELATION._serialized_end = 3805
+ _UNKNOWN._serialized_start = 3807
+ _UNKNOWN._serialized_end = 3816
+ _RELATIONCOMMON._serialized_start = 3819
+ _RELATIONCOMMON._serialized_end = 3961
+ _SQL._serialized_start = 3964
+ _SQL._serialized_end = 4442
+ _SQL_ARGSENTRY._serialized_start = 4258
+ _SQL_ARGSENTRY._serialized_end = 4348
+ _SQL_NAMEDARGUMENTSENTRY._serialized_start = 4350
+ _SQL_NAMEDARGUMENTSENTRY._serialized_end = 4442
+ _WITHRELATIONS._serialized_start = 4444
+ _WITHRELATIONS._serialized_end = 4561
+ _READ._serialized_start = 4564
+ _READ._serialized_end = 5227
+ _READ_NAMEDTABLE._serialized_start = 4742
+ _READ_NAMEDTABLE._serialized_end = 4934
+ _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 4876
+ _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 4934
+ _READ_DATASOURCE._serialized_start = 4937
+ _READ_DATASOURCE._serialized_end = 5214
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 4876
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4934
+ _PROJECT._serialized_start = 5229
+ _PROJECT._serialized_end = 5346
+ _FILTER._serialized_start = 5348
+ _FILTER._serialized_end = 5460
+ _JOIN._serialized_start = 5463
+ _JOIN._serialized_end = 6124
+ _JOIN_JOINDATATYPE._serialized_start = 5802
+ _JOIN_JOINDATATYPE._serialized_end = 5894
+ _JOIN_JOINTYPE._serialized_start = 5897
+ _JOIN_JOINTYPE._serialized_end = 6105
+ _SETOPERATION._serialized_start = 6127
+ _SETOPERATION._serialized_end = 6606
+ _SETOPERATION_SETOPTYPE._serialized_start = 6443
+ _SETOPERATION_SETOPTYPE._serialized_end = 6557
+ _LIMIT._serialized_start = 6608
+ _LIMIT._serialized_end = 6684
+ _OFFSET._serialized_start = 6686
+ _OFFSET._serialized_end = 6765
+ _TAIL._serialized_start = 6767
+ _TAIL._serialized_end = 6842
+ _AGGREGATE._serialized_start = 6845
+ _AGGREGATE._serialized_end = 7611
+ _AGGREGATE_PIVOT._serialized_start = 7260
+ _AGGREGATE_PIVOT._serialized_end = 7371
+ _AGGREGATE_GROUPINGSETS._serialized_start = 7373
+ _AGGREGATE_GROUPINGSETS._serialized_end = 7449
+ _AGGREGATE_GROUPTYPE._serialized_start = 7452
+ _AGGREGATE_GROUPTYPE._serialized_end = 7611
+ _SORT._serialized_start = 7614
+ _SORT._serialized_end = 7774
+ _DROP._serialized_start = 7777
+ _DROP._serialized_end = 7918
+ _DEDUPLICATE._serialized_start = 7921
+ _DEDUPLICATE._serialized_end = 8161
+ _LOCALRELATION._serialized_start = 8163
+ _LOCALRELATION._serialized_end = 8252
+ _CACHEDLOCALRELATION._serialized_start = 8254
+ _CACHEDLOCALRELATION._serialized_end = 8326
+ _CACHEDREMOTERELATION._serialized_start = 8328
+ _CACHEDREMOTERELATION._serialized_end = 8383
+ _SAMPLE._serialized_start = 8386
+ _SAMPLE._serialized_end = 8659
+ _RANGE._serialized_start = 8662
+ _RANGE._serialized_end = 8807
+ _SUBQUERYALIAS._serialized_start = 8809
+ _SUBQUERYALIAS._serialized_end = 8923
+ _REPARTITION._serialized_start = 8926
+ _REPARTITION._serialized_end = 9068
+ _SHOWSTRING._serialized_start = 9071
+ _SHOWSTRING._serialized_end = 9213
+ _HTMLSTRING._serialized_start = 9215
+ _HTMLSTRING._serialized_end = 9329
+ _STATSUMMARY._serialized_start = 9331
+ _STATSUMMARY._serialized_end = 9423
+ _STATDESCRIBE._serialized_start = 9425
+ _STATDESCRIBE._serialized_end = 9506
+ _STATCROSSTAB._serialized_start = 9508
+ _STATCROSSTAB._serialized_end = 9609
+ _STATCOV._serialized_start = 9611
+ _STATCOV._serialized_end = 9707
+ _STATCORR._serialized_start = 9710
+ _STATCORR._serialized_end = 9847
+ _STATAPPROXQUANTILE._serialized_start = 9850
+ _STATAPPROXQUANTILE._serialized_end = 10014
+ _STATFREQITEMS._serialized_start = 10016
+ _STATFREQITEMS._serialized_end = 10141
+ _STATSAMPLEBY._serialized_start = 10144
+ _STATSAMPLEBY._serialized_end = 10453
+ _STATSAMPLEBY_FRACTION._serialized_start = 10345
+ _STATSAMPLEBY_FRACTION._serialized_end = 10444
+ _NAFILL._serialized_start = 10456
+ _NAFILL._serialized_end = 10590
+ _NADROP._serialized_start = 10593
+ _NADROP._serialized_end = 10727
+ _NAREPLACE._serialized_start = 10730
+ _NAREPLACE._serialized_end = 11026
+ _NAREPLACE_REPLACEMENT._serialized_start = 10885
+ _NAREPLACE_REPLACEMENT._serialized_end = 11026
+ _TODF._serialized_start = 11028
+ _TODF._serialized_end = 11116
+ _WITHCOLUMNSRENAMED._serialized_start = 11119
+ _WITHCOLUMNSRENAMED._serialized_end = 11501
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 11363
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 11430
+ _WITHCOLUMNSRENAMED_RENAME._serialized_start = 11432
+ _WITHCOLUMNSRENAMED_RENAME._serialized_end = 11501
+ _WITHCOLUMNS._serialized_start = 11503
+ _WITHCOLUMNS._serialized_end = 11622
+ _WITHWATERMARK._serialized_start = 11625
+ _WITHWATERMARK._serialized_end = 11759
+ _HINT._serialized_start = 11762
+ _HINT._serialized_end = 11894
+ _UNPIVOT._serialized_start = 11897
+ _UNPIVOT._serialized_end = 12224
+ _UNPIVOT_VALUES._serialized_start = 12154
+ _UNPIVOT_VALUES._serialized_end = 12213
+ _TRANSPOSE._serialized_start = 12226
+ _TRANSPOSE._serialized_end = 12348
+ _UNRESOLVEDTABLEVALUEDFUNCTION._serialized_start = 12350
+ _UNRESOLVEDTABLEVALUEDFUNCTION._serialized_end = 12475
+ _TOSCHEMA._serialized_start = 12477
+ _TOSCHEMA._serialized_end = 12583
+ _REPARTITIONBYEXPRESSION._serialized_start = 12586
+ _REPARTITIONBYEXPRESSION._serialized_end = 12789
+ _MAPPARTITIONS._serialized_start = 12792
+ _MAPPARTITIONS._serialized_end = 13024
+ _GROUPMAP._serialized_start = 13027
+ _GROUPMAP._serialized_end = 13662
+ _COGROUPMAP._serialized_start = 13665
+ _COGROUPMAP._serialized_end = 14191
+ _APPLYINPANDASWITHSTATE._serialized_start = 14194
+ _APPLYINPANDASWITHSTATE._serialized_end = 14551
+ _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 14554
+ _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 14798
+ _PYTHONUDTF._serialized_start = 14801
+ _PYTHONUDTF._serialized_end = 14978
+ _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_start = 14981
+ _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_end = 15132
+ _PYTHONDATASOURCE._serialized_start = 15134
+ _PYTHONDATASOURCE._serialized_end = 15209
+ _COLLECTMETRICS._serialized_start = 15212
+ _COLLECTMETRICS._serialized_end = 15348
+ _PARSE._serialized_start = 15351
+ _PARSE._serialized_end = 15739
+ _PARSE_OPTIONSENTRY._serialized_start = 4876
+ _PARSE_OPTIONSENTRY._serialized_end = 4934
+ _PARSE_PARSEFORMAT._serialized_start = 15640
+ _PARSE_PARSEFORMAT._serialized_end = 15728
+ _ASOFJOIN._serialized_start = 15742
+ _ASOFJOIN._serialized_end = 16217
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index b1cd2e184d08..03753056c6bf 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -105,6 +105,7 @@ class Relation(google.protobuf.message.Message):
COMMON_INLINE_USER_DEFINED_DATA_SOURCE_FIELD_NUMBER: builtins.int
WITH_RELATIONS_FIELD_NUMBER: builtins.int
TRANSPOSE_FIELD_NUMBER: builtins.int
+ UNRESOLVED_TABLE_VALUED_FUNCTION_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
@@ -208,6 +209,8 @@ class Relation(google.protobuf.message.Message):
@property
def transpose(self) -> global___Transpose: ...
@property
+ def unresolved_table_valued_function(self) ->
global___UnresolvedTableValuedFunction: ...
+ @property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
@@ -288,6 +291,7 @@ class Relation(google.protobuf.message.Message):
| None = ...,
with_relations: global___WithRelations | None = ...,
transpose: global___Transpose | None = ...,
+ unresolved_table_valued_function:
global___UnresolvedTableValuedFunction | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
@@ -412,6 +416,8 @@ class Relation(google.protobuf.message.Message):
b"unknown",
"unpivot",
b"unpivot",
+ "unresolved_table_valued_function",
+ b"unresolved_table_valued_function",
"with_columns",
b"with_columns",
"with_columns_renamed",
@@ -531,6 +537,8 @@ class Relation(google.protobuf.message.Message):
b"unknown",
"unpivot",
b"unpivot",
+ "unresolved_table_valued_function",
+ b"unresolved_table_valued_function",
"with_columns",
b"with_columns",
"with_columns_renamed",
@@ -586,6 +594,7 @@ class Relation(google.protobuf.message.Message):
"common_inline_user_defined_data_source",
"with_relations",
"transpose",
+ "unresolved_table_valued_function",
"fill_na",
"drop_na",
"replace",
@@ -3191,6 +3200,36 @@ class Transpose(google.protobuf.message.Message):
global___Transpose = Transpose
+class UnresolvedTableValuedFunction(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ FUNCTION_NAME_FIELD_NUMBER: builtins.int
+ ARGUMENTS_FIELD_NUMBER: builtins.int
+ function_name: builtins.str
+ """(Required) name (or unparsed name for user defined function) for the
unresolved function."""
+ @property
+ def arguments(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """(Optional) Function arguments. Empty arguments are allowed."""
+ def __init__(
+ self,
+ *,
+ function_name: builtins.str = ...,
+ arguments:
collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
+ | None = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "arguments", b"arguments", "function_name", b"function_name"
+ ],
+ ) -> None: ...
+
+global___UnresolvedTableValuedFunction = UnresolvedTableValuedFunction
+
class ToSchema(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index a4047f09401e..e9984fae9ddb 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -109,6 +109,7 @@ if TYPE_CHECKING:
from pyspark.sql.connect.catalog import Catalog
from pyspark.sql.connect.udf import UDFRegistration
from pyspark.sql.connect.udtf import UDTFRegistration
+ from pyspark.sql.connect.tvf import TableValuedFunction
from pyspark.sql.connect.shell.progress import ProgressHandler
from pyspark.sql.connect.datasource import DataSourceRegistration
@@ -382,6 +383,14 @@ class SparkSession:
readStream.__doc__ = PySparkSession.readStream.__doc__
+ @property
+ def tvf(self) -> "TableValuedFunction":
+ from pyspark.sql.connect.tvf import TableValuedFunction
+
+ return TableValuedFunction(self)
+
+ tvf.__doc__ = PySparkSession.tvf.__doc__
+
def registerProgressHandler(self, handler: "ProgressHandler") -> None:
self._client.register_progress_handler(handler)
diff --git a/python/pyspark/sql/connect/tvf.py
b/python/pyspark/sql/connect/tvf.py
new file mode 100644
index 000000000000..2fca610a5fe3
--- /dev/null
+++ b/python/pyspark/sql/connect/tvf.py
@@ -0,0 +1,148 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Optional
+
+from pyspark.errors import PySparkValueError
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.dataframe import DataFrame
+from pyspark.sql.connect.functions.builtin import _to_col
+from pyspark.sql.connect.plan import UnresolvedTableValuedFunction
+from pyspark.sql.connect.session import SparkSession
+from pyspark.sql.tvf import TableValuedFunction as PySparkTableValuedFunction
+
+
+class TableValuedFunction:
+ __doc__ = PySparkTableValuedFunction.__doc__
+
+ def __init__(self, sparkSession: SparkSession):
+ self._sparkSession = sparkSession
+
+ def range(
+ self,
+ start: int,
+ end: Optional[int] = None,
+ step: int = 1,
+ numPartitions: Optional[int] = None,
+ ) -> DataFrame:
+ return self._sparkSession.range( # type: ignore[return-value]
+ start, end, step, numPartitions
+ )
+
+ range.__doc__ = PySparkTableValuedFunction.range.__doc__
+
+ def explode(self, collection: Column) -> DataFrame:
+ return self._fn("explode", collection)
+
+ explode.__doc__ = PySparkTableValuedFunction.explode.__doc__
+
+ def explode_outer(self, collection: Column) -> DataFrame:
+ return self._fn("explode_outer", collection)
+
+ explode_outer.__doc__ = PySparkTableValuedFunction.explode_outer.__doc__
+
+ def inline(self, input: Column) -> DataFrame:
+ return self._fn("inline", input)
+
+ inline.__doc__ = PySparkTableValuedFunction.inline.__doc__
+
+ def inline_outer(self, input: Column) -> DataFrame:
+ return self._fn("inline_outer", input)
+
+ inline_outer.__doc__ = PySparkTableValuedFunction.inline_outer.__doc__
+
+ def json_tuple(self, input: Column, *fields: Column) -> DataFrame:
+ if len(fields) == 0:
+ raise PySparkValueError(
+ errorClass="CANNOT_BE_EMPTY",
+ messageParameters={"item": "field"},
+ )
+ return self._fn("json_tuple", input, *fields)
+
+ json_tuple.__doc__ = PySparkTableValuedFunction.json_tuple.__doc__
+
+ def posexplode(self, collection: Column) -> DataFrame:
+ return self._fn("posexplode", collection)
+
+ posexplode.__doc__ = PySparkTableValuedFunction.posexplode.__doc__
+
+ def posexplode_outer(self, collection: Column) -> DataFrame:
+ return self._fn("posexplode_outer", collection)
+
+ posexplode_outer.__doc__ =
PySparkTableValuedFunction.posexplode_outer.__doc__
+
+ def stack(self, n: Column, *fields: Column) -> DataFrame:
+ return self._fn("stack", n, *fields)
+
+ stack.__doc__ = PySparkTableValuedFunction.stack.__doc__
+
+ def collations(self) -> DataFrame:
+ return self._fn("collations")
+
+ collations.__doc__ = PySparkTableValuedFunction.collations.__doc__
+
+ def sql_keywords(self) -> DataFrame:
+ return self._fn("sql_keywords")
+
+ sql_keywords.__doc__ = PySparkTableValuedFunction.sql_keywords.__doc__
+
+ def variant_explode(self, input: Column) -> DataFrame:
+ return self._fn("variant_explode", input)
+
+ variant_explode.__doc__ =
PySparkTableValuedFunction.variant_explode.__doc__
+
+ def variant_explode_outer(self, input: Column) -> DataFrame:
+ return self._fn("variant_explode_outer", input)
+
+ variant_explode_outer.__doc__ =
PySparkTableValuedFunction.variant_explode_outer.__doc__
+
+ def _fn(self, name: str, *args: Column) -> DataFrame:
+ return DataFrame(
+ UnresolvedTableValuedFunction(name, [_to_col(arg) for arg in
args]), self._sparkSession
+ )
+
+
+def _test() -> None:
+ import os
+ import doctest
+ import sys
+ from pyspark.sql import SparkSession as PySparkSession
+ import pyspark.sql.connect.tvf
+
+ globs = pyspark.sql.connect.tvf.__dict__.copy()
+
+ globs["spark"] = (
+ PySparkSession.builder.appName("sql.connect.tvf tests")
+ .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
+ .getOrCreate()
+ )
+
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.connect.tvf,
+ globs=globs,
+ optionflags=doctest.ELLIPSIS
+ | doctest.NORMALIZE_WHITESPACE
+ | doctest.IGNORE_EXCEPTION_DETAIL,
+ )
+
+ globs["spark"].stop()
+
+ if failure_count:
+ sys.exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 96344efba2d2..748dd2cafa7c 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -74,6 +74,7 @@ if TYPE_CHECKING:
from pyspark.sql.catalog import Catalog
from pyspark.sql.pandas._typing import ArrayLike, DataFrameLike as
PandasDataFrameLike
from pyspark.sql.streaming import StreamingQueryManager
+ from pyspark.sql.tvf import TableValuedFunction
from pyspark.sql.udf import UDFRegistration
from pyspark.sql.udtf import UDTFRegistration
from pyspark.sql.datasource import DataSourceRegistration
@@ -1963,6 +1964,41 @@ class SparkSession(SparkConversionMixin):
self._sqm: StreamingQueryManager =
StreamingQueryManager(self._jsparkSession.streams())
return self._sqm
+ @property
+ def tvf(self) -> "TableValuedFunction":
+ """
+ Returns a :class:`TableValuedFunction` that can be used to call a
table-valued function
+ (TVF).
+
+ .. versionadded:: 4.0.0
+
+ Notes
+ -----
+ This API is evolving.
+
+ Returns
+ -------
+ :class:`TableValuedFunction`
+
+ Examples
+ --------
+ >>> spark.tvf
+ <pyspark...TableValuedFunction object ...>
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.explode(sf.array(sf.lit(1), sf.lit(2), sf.lit(3))).show()
+ +---+
+ |col|
+ +---+
+ | 1|
+ | 2|
+ | 3|
+ +---+
+ """
+ from pyspark.sql.tvf import TableValuedFunction
+
+ return TableValuedFunction(self)
+
def stop(self) -> None:
"""
Stop the underlying :class:`SparkContext`.
diff --git a/python/pyspark/sql/tests/connect/test_parity_tvf.py
b/python/pyspark/sql/tests/connect/test_parity_tvf.py
new file mode 100644
index 000000000000..61e3decf562c
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_tvf.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.sql.tests.test_tvf import TVFTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class TVFParityTestsMixin(TVFTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_parity_tvf import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_tvf.py
b/python/pyspark/sql/tests/test_tvf.py
new file mode 100644
index 000000000000..5c709437fc4d
--- /dev/null
+++ b/python/pyspark/sql/tests/test_tvf.py
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.errors import PySparkValueError
+from pyspark.sql import functions as sf
+from pyspark.testing import assertDataFrameEqual
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class TVFTestsMixin:
+ def test_explode(self):
+ actual = self.spark.tvf.explode(sf.array(sf.lit(1), sf.lit(2)))
+ expected = self.spark.sql("""SELECT * FROM explode(array(1, 2))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.explode(
+ sf.create_map(sf.lit("a"), sf.lit(1), sf.lit("b"), sf.lit(2))
+ )
+ expected = self.spark.sql("""SELECT * FROM explode(map('a', 1, 'b',
2))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # empty
+ actual = self.spark.tvf.explode(sf.array())
+ expected = self.spark.sql("""SELECT * FROM explode(array())""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.explode(sf.create_map())
+ expected = self.spark.sql("""SELECT * FROM explode(map())""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # null
+ actual = self.spark.tvf.explode(sf.lit(None).astype("array<int>"))
+ expected = self.spark.sql("""SELECT * FROM explode(null ::
array<int>)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.explode(sf.lit(None).astype("map<string,
int>"))
+ expected = self.spark.sql("""SELECT * FROM explode(null :: map<string,
int>)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ def test_explode_outer(self):
+ actual = self.spark.tvf.explode_outer(sf.array(sf.lit(1), sf.lit(2)))
+ expected = self.spark.sql("""SELECT * FROM explode_outer(array(1,
2))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.explode_outer(
+ sf.create_map(sf.lit("a"), sf.lit(1), sf.lit("b"), sf.lit(2))
+ )
+ expected = self.spark.sql("""SELECT * FROM explode_outer(map('a', 1,
'b', 2))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # empty
+ actual = self.spark.tvf.explode_outer(sf.array())
+ expected = self.spark.sql("""SELECT * FROM explode_outer(array())""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.explode_outer(sf.create_map())
+ expected = self.spark.sql("""SELECT * FROM explode_outer(map())""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # null
+ actual =
self.spark.tvf.explode_outer(sf.lit(None).astype("array<int>"))
+ expected = self.spark.sql("""SELECT * FROM explode_outer(null ::
array<int>)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.explode_outer(sf.lit(None).astype("map<string,
int>"))
+ expected = self.spark.sql("""SELECT * FROM explode_outer(null ::
map<string, int>)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ def test_inline(self):
+ actual = self.spark.tvf.inline(
+ sf.array(sf.struct(sf.lit(1), sf.lit("a")), sf.struct(sf.lit(2),
sf.lit("b")))
+ )
+ expected = self.spark.sql("""SELECT * FROM inline(array(struct(1,
'a'), struct(2, 'b')))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual =
self.spark.tvf.inline(sf.array().astype("array<struct<a:int,b:int>>"))
+ expected = self.spark.sql("""SELECT * FROM inline(array() ::
array<struct<a:int,b:int>>)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.inline(
+ sf.array(
+ sf.named_struct(sf.lit("a"), sf.lit(1), sf.lit("b"),
sf.lit(2)),
+ sf.lit(None),
+ sf.named_struct(sf.lit("a"), sf.lit(3), sf.lit("b"),
sf.lit(4)),
+ )
+ )
+ expected = self.spark.sql(
+ """
+ SELECT * FROM
+ inline(array(named_struct('a', 1, 'b', 2), null,
named_struct('a', 3, 'b', 4)))
+ """
+ )
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ def test_inline_outer(self):
+ actual = self.spark.tvf.inline_outer(
+ sf.array(sf.struct(sf.lit(1), sf.lit("a")), sf.struct(sf.lit(2),
sf.lit("b")))
+ )
+ expected = self.spark.sql(
+ """SELECT * FROM inline_outer(array(struct(1, 'a'), struct(2,
'b')))"""
+ )
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual =
self.spark.tvf.inline_outer(sf.array().astype("array<struct<a:int,b:int>>"))
+ expected = self.spark.sql(
+ """SELECT * FROM inline_outer(array() ::
array<struct<a:int,b:int>>)"""
+ )
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.inline_outer(
+ sf.array(
+ sf.named_struct(sf.lit("a"), sf.lit(1), sf.lit("b"),
sf.lit(2)),
+ sf.lit(None),
+ sf.named_struct(sf.lit("a"), sf.lit(3), sf.lit("b"),
sf.lit(4)),
+ )
+ )
+ expected = self.spark.sql(
+ """
+ SELECT * FROM
+ inline_outer(array(named_struct('a', 1, 'b', 2), null,
named_struct('a', 3, 'b', 4)))
+ """
+ )
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ def test_json_tuple(self):
+ actual = self.spark.tvf.json_tuple(sf.lit('{"a":1, "b":2}'),
sf.lit("a"), sf.lit("b"))
+ expected = self.spark.sql("""SELECT json_tuple('{"a":1, "b":2}', 'a',
'b')""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ with self.assertRaises(PySparkValueError) as pe:
+ self.spark.tvf.json_tuple(sf.lit('{"a":1, "b":2}'))
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass="CANNOT_BE_EMPTY",
+ messageParameters={"item": "field"},
+ )
+
+ def test_posexplode(self):
+ actual = self.spark.tvf.posexplode(sf.array(sf.lit(1), sf.lit(2)))
+ expected = self.spark.sql("""SELECT * FROM posexplode(array(1, 2))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.posexplode(
+ sf.create_map(sf.lit("a"), sf.lit(1), sf.lit("b"), sf.lit(2))
+ )
+ expected = self.spark.sql("""SELECT * FROM posexplode(map('a', 1, 'b',
2))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # empty
+ actual = self.spark.tvf.posexplode(sf.array())
+ expected = self.spark.sql("""SELECT * FROM posexplode(array())""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.posexplode(sf.create_map())
+ expected = self.spark.sql("""SELECT * FROM posexplode(map())""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # null
+ actual = self.spark.tvf.posexplode(sf.lit(None).astype("array<int>"))
+ expected = self.spark.sql("""SELECT * FROM posexplode(null ::
array<int>)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.posexplode(sf.lit(None).astype("map<string,
int>"))
+ expected = self.spark.sql("""SELECT * FROM posexplode(null ::
map<string, int>)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ def test_posexplode_outer(self):
+ actual = self.spark.tvf.posexplode_outer(sf.array(sf.lit(1),
sf.lit(2)))
+ expected = self.spark.sql("""SELECT * FROM posexplode_outer(array(1,
2))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.posexplode_outer(
+ sf.create_map(sf.lit("a"), sf.lit(1), sf.lit("b"), sf.lit(2))
+ )
+ expected = self.spark.sql("""SELECT * FROM posexplode_outer(map('a',
1, 'b', 2))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # empty
+ actual = self.spark.tvf.posexplode_outer(sf.array())
+ expected = self.spark.sql("""SELECT * FROM
posexplode_outer(array())""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.posexplode_outer(sf.create_map())
+ expected = self.spark.sql("""SELECT * FROM posexplode_outer(map())""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # null
+ actual =
self.spark.tvf.posexplode_outer(sf.lit(None).astype("array<int>"))
+ expected = self.spark.sql("""SELECT * FROM posexplode_outer(null ::
array<int>)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual =
self.spark.tvf.posexplode_outer(sf.lit(None).astype("map<string, int>"))
+ expected = self.spark.sql("""SELECT * FROM posexplode_outer(null ::
map<string, int>)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ def test_stack(self):
+ actual = self.spark.tvf.stack(sf.lit(2), sf.lit(1), sf.lit(2),
sf.lit(3))
+ expected = self.spark.sql("""SELECT * FROM stack(2, 1, 2, 3)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ def test_collations(self):
+ actual = self.spark.tvf.collations()
+ expected = self.spark.sql("""SELECT * FROM collations()""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ def test_sql_keywords(self):
+ actual = self.spark.tvf.sql_keywords()
+ expected = self.spark.sql("""SELECT * FROM sql_keywords()""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ def test_variant_explode(self):
+ actual =
self.spark.tvf.variant_explode(sf.parse_json(sf.lit('["hello", "world"]')))
+ expected = self.spark.sql(
+ """SELECT * FROM variant_explode(parse_json('["hello",
"world"]'))"""
+ )
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.variant_explode(sf.parse_json(sf.lit('{"a":
true, "b": 3.14}')))
+ expected = self.spark.sql(
+ """SELECT * FROM variant_explode(parse_json('{"a": true, "b":
3.14}'))"""
+ )
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # empty
+ actual = self.spark.tvf.variant_explode(sf.parse_json(sf.lit("[]")))
+ expected = self.spark.sql("""SELECT * FROM
variant_explode(parse_json('[]'))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.variant_explode(sf.parse_json(sf.lit("{}")))
+ expected = self.spark.sql("""SELECT * FROM
variant_explode(parse_json('{}'))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # null
+ actual = self.spark.tvf.variant_explode(sf.lit(None).astype("variant"))
+ expected = self.spark.sql("""SELECT * FROM variant_explode(null ::
variant)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # not a variant object/array
+ actual = self.spark.tvf.variant_explode(sf.parse_json(sf.lit("1")))
+ expected = self.spark.sql("""SELECT * FROM
variant_explode(parse_json('1'))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ def test_variant_explode_outer(self):
+ actual =
self.spark.tvf.variant_explode_outer(sf.parse_json(sf.lit('["hello",
"world"]')))
+ expected = self.spark.sql(
+ """SELECT * FROM variant_explode_outer(parse_json('["hello",
"world"]'))"""
+ )
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual = self.spark.tvf.variant_explode_outer(
+ sf.parse_json(sf.lit('{"a": true, "b": 3.14}'))
+ )
+ expected = self.spark.sql(
+ """SELECT * FROM variant_explode_outer(parse_json('{"a": true,
"b": 3.14}'))"""
+ )
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # empty
+ actual =
self.spark.tvf.variant_explode_outer(sf.parse_json(sf.lit("[]")))
+ expected = self.spark.sql("""SELECT * FROM
variant_explode_outer(parse_json('[]'))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ actual =
self.spark.tvf.variant_explode_outer(sf.parse_json(sf.lit("{}")))
+ expected = self.spark.sql("""SELECT * FROM
variant_explode_outer(parse_json('{}'))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # null
+ actual =
self.spark.tvf.variant_explode_outer(sf.lit(None).astype("variant"))
+ expected = self.spark.sql("""SELECT * FROM variant_explode_outer(null
:: variant)""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+ # not a variant object/array
+ actual =
self.spark.tvf.variant_explode_outer(sf.parse_json(sf.lit("1")))
+ expected = self.spark.sql("""SELECT * FROM
variant_explode_outer(parse_json('1'))""")
+ assertDataFrameEqual(actual=actual, expected=expected)
+
+
+class TVFTests(TVFTestsMixin, ReusedSQLTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_tvf import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tvf.py b/python/pyspark/sql/tvf.py
new file mode 100644
index 000000000000..1d0febf9ba3a
--- /dev/null
+++ b/python/pyspark/sql/tvf.py
@@ -0,0 +1,712 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Optional
+
+from pyspark.errors import PySparkValueError
+from pyspark.sql.column import Column
+from pyspark.sql.dataframe import DataFrame
+from pyspark.sql.session import SparkSession
+
+__all__ = ["TableValuedFunction"]
+
+
+class TableValuedFunction:
+ """
+ Interface for invoking table-valued functions in Spark SQL.
+ """
+
+ def __init__(self, sparkSession: SparkSession):
+ self._sparkSession = sparkSession
+
+ def range(
+ self,
+ start: int,
+ end: Optional[int] = None,
+ step: int = 1,
+ numPartitions: Optional[int] = None,
+ ) -> DataFrame:
+ """
+ Create a :class:`DataFrame` with single
:class:`pyspark.sql.types.LongType` column named
+ ``id``, containing elements in a range from ``start`` to ``end``
(exclusive) with
+ step value ``step``.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ start : int
+ the start value
+ end : int, optional
+ the end value (exclusive)
+ step : int, optional
+ the incremental step (default: 1)
+ numPartitions : int, optional
+ the number of partitions of the DataFrame
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ Examples
+ --------
+ >>> spark.tvf.range(1, 7, 2).show()
+ +---+
+ | id|
+ +---+
+ | 1|
+ | 3|
+ | 5|
+ +---+
+
+ If only one argument is specified, it will be used as the end value.
+
+ >>> spark.tvf.range(3).show()
+ +---+
+ | id|
+ +---+
+ | 0|
+ | 1|
+ | 2|
+ +---+
+ """
+ return self._sparkSession.range(start, end, step, numPartitions)
+
+ def explode(self, collection: Column) -> DataFrame:
+ """
+ Returns a :class:`DataFrame` containing a new row for each element
+ in the given array or map.
+ Uses the default column name `col` for elements in the array and
+ `key` and `value` for elements in the map unless specified otherwise.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ collection : :class:`~pyspark.sql.Column`
+ Target column to work on.
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ See Also
+ --------
+ :meth:`pyspark.sql.functions.explode`
+
+ Examples
+ --------
+ Example 1: Exploding an array column
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.explode(sf.array(sf.lit(1), sf.lit(2), sf.lit(3))).show()
+ +---+
+ |col|
+ +---+
+ | 1|
+ | 2|
+ | 3|
+ +---+
+
+ Example 2: Exploding a map column
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.explode(
+ ... sf.create_map(sf.lit("a"), sf.lit("b"), sf.lit("c"),
sf.lit("d"))
+ ... ).show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | a| b|
+ | c| d|
+ +---+-----+
+
+ Example 3: Exploding an array of struct column
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.explode(sf.array(
+ ... sf.named_struct(sf.lit("a"), sf.lit(1), sf.lit("b"),
sf.lit(2)),
+ ... sf.named_struct(sf.lit("a"), sf.lit(3), sf.lit("b"), sf.lit(4))
+ ... )).select("col.*").show()
+ +---+---+
+ | a| b|
+ +---+---+
+ | 1| 2|
+ | 3| 4|
+ +---+---+
+
+ Example 4: Exploding an empty array column
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.explode(sf.array()).show()
+ +---+
+ |col|
+ +---+
+ +---+
+
+ Example 5: Exploding an empty map column
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.explode(sf.create_map()).show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ +---+-----+
+ """
+ return self._fn("explode", collection)
+
+ def explode_outer(self, collection: Column) -> DataFrame:
+ """
+ Returns a :class:`DataFrame` containing a new row for each element
with position
+ in the given array or map.
+ Unlike explode, if the array/map is null or empty then null is
produced.
+ Uses the default column name `col` for elements in the array and
+ `key` and `value` for elements in the map unless specified otherwise.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ collection : :class:`~pyspark.sql.Column`
+ target column to work on.
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ See Also
+ --------
+ :meth:`pyspark.sql.functions.explode_outer`
+
+ Examples
+ --------
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.explode_outer(sf.array(sf.lit("foo"),
sf.lit("bar"))).show()
+ +---+
+ |col|
+ +---+
+ |foo|
+ |bar|
+ +---+
+ >>> spark.tvf.explode_outer(sf.array()).show()
+ +----+
+ | col|
+ +----+
+ |NULL|
+ +----+
+ >>> spark.tvf.explode_outer(sf.create_map(sf.lit("x"),
sf.lit(1.0))).show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | x| 1.0|
+ +---+-----+
+ >>> spark.tvf.explode_outer(sf.create_map()).show()
+ +----+-----+
+ | key|value|
+ +----+-----+
+ |NULL| NULL|
+ +----+-----+
+ """
+ return self._fn("explode_outer", collection)
+
+ def inline(self, input: Column) -> DataFrame:
+ """
+ Explodes an array of structs into a table.
+
+ This function takes an input column containing an array of structs and
returns a
+ new column where each struct in the array is exploded into a separate
row.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ input : :class:`~pyspark.sql.Column`
+ Input column of values to explode.
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ See Also
+ --------
+ :meth:`pyspark.sql.functions.inline`
+
+ Examples
+ --------
+ Example 1: Using inline with a single struct array
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.inline(sf.array(
+ ... sf.named_struct(sf.lit("a"), sf.lit(1), sf.lit("b"),
sf.lit(2)),
+ ... sf.named_struct(sf.lit("a"), sf.lit(3), sf.lit("b"), sf.lit(4))
+ ... )).show()
+ +---+---+
+ | a| b|
+ +---+---+
+ | 1| 2|
+ | 3| 4|
+ +---+---+
+
+ Example 2: Using inline with an empty struct array column
+
+ >>> import pyspark.sql.functions as sf
+ >>>
spark.tvf.inline(sf.array().astype("array<struct<a:int,b:int>>")).show()
+ +---+---+
+ | a| b|
+ +---+---+
+ +---+---+
+
+ Example 3: Using inline with a struct array column containing null
values
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.inline(sf.array(
+ ... sf.named_struct(sf.lit("a"), sf.lit(1), sf.lit("b"),
sf.lit(2)),
+ ... sf.lit(None),
+ ... sf.named_struct(sf.lit("a"), sf.lit(3), sf.lit("b"), sf.lit(4))
+ ... )).show()
+ +----+----+
+ | a| b|
+ +----+----+
+ | 1| 2|
+ |NULL|NULL|
+ | 3| 4|
+ +----+----+
+ """
+ return self._fn("inline", input)
+
+ def inline_outer(self, input: Column) -> DataFrame:
+ """
+ Explodes an array of structs into a table.
+ Unlike inline, if the array is null or empty then null is produced for
each nested column.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ input : :class:`~pyspark.sql.Column`
+ input column of values to explode.
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ See Also
+ --------
+ :meth:`pyspark.sql.functions.inline_outer`
+
+ Examples
+ --------
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.inline_outer(sf.array(
+ ... sf.named_struct(sf.lit("a"), sf.lit(1), sf.lit("b"),
sf.lit(2)),
+ ... sf.named_struct(sf.lit("a"), sf.lit(3), sf.lit("b"), sf.lit(4))
+ ... )).show()
+ +---+---+
+ | a| b|
+ +---+---+
+ | 1| 2|
+ | 3| 4|
+ +---+---+
+ >>>
spark.tvf.inline_outer(sf.array().astype("array<struct<a:int,b:int>>")).show()
+ +----+----+
+ | a| b|
+ +----+----+
+ |NULL|NULL|
+ +----+----+
+ """
+ return self._fn("inline_outer", input)
+
+ def json_tuple(self, input: Column, *fields: Column) -> DataFrame:
+ """
+ Creates a new row for a json column according to the given field names.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ input : :class:`~pyspark.sql.Column`
+ string column in json format
+ fields : :class:`~pyspark.sql.Column`
+ a field or fields to extract
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ See Also
+ --------
+ :meth:`pyspark.sql.functions.json_tuple`
+
+ Examples
+ --------
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.json_tuple(
+ ... sf.lit('{"f1": "value1", "f2": "value2"}'), sf.lit("f1"),
sf.lit("f2")
+ ... ).show()
+ +------+------+
+ | c0| c1|
+ +------+------+
+ |value1|value2|
+ +------+------+
+ """
+ from pyspark.sql.classic.column import _to_seq, _to_java_column
+
+ if len(fields) == 0:
+ raise PySparkValueError(
+ errorClass="CANNOT_BE_EMPTY",
+ messageParameters={"item": "field"},
+ )
+
+ sc = self._sparkSession.sparkContext
+ return DataFrame(
+ self._sparkSession._jsparkSession.tvf().json_tuple(
+ _to_java_column(input), _to_seq(sc, fields, _to_java_column)
+ ),
+ self._sparkSession,
+ )
+
+ def posexplode(self, collection: Column) -> DataFrame:
+ """
+ Returns a :class:`DataFrame` containing a new row for each element
with position
+ in the given array or map.
+ Uses the default column name `pos` for position, and `col` for
elements in the
+ array and `key` and `value` for elements in the map unless specified
otherwise.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ collection : :class:`~pyspark.sql.Column`
+ target column to work on.
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ See Also
+ --------
+ :meth:`pyspark.sql.functions.posexplode`
+
+ Examples
+ --------
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.posexplode(sf.array(sf.lit(1), sf.lit(2),
sf.lit(3))).show()
+ +---+---+
+ |pos|col|
+ +---+---+
+ | 0| 1|
+ | 1| 2|
+ | 2| 3|
+ +---+---+
+ >>> spark.tvf.posexplode(sf.create_map(sf.lit("a"),
sf.lit("b"))).show()
+ +---+---+-----+
+ |pos|key|value|
+ +---+---+-----+
+ | 0| a| b|
+ +---+---+-----+
+ """
+ return self._fn("posexplode", collection)
+
+ def posexplode_outer(self, collection: Column) -> DataFrame:
+ """
+ Returns a :class:`DataFrame` containing a new row for each element
with position
+ in the given array or map.
+ Unlike posexplode, if the array/map is null or empty then the row
(null, null) is produced.
+ Uses the default column name `pos` for position, and `col` for
elements in the
+ array and `key` and `value` for elements in the map unless specified
otherwise.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ collection : :class:`~pyspark.sql.Column`
+ target column to work on.
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ See Also
+ --------
+ :meth:`pyspark.sql.functions.posexplode_outer`
+
+ Examples
+ --------
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.posexplode_outer(sf.array(sf.lit("foo"),
sf.lit("bar"))).show()
+ +---+---+
+ |pos|col|
+ +---+---+
+ | 0|foo|
+ | 1|bar|
+ +---+---+
+ >>> spark.tvf.posexplode_outer(sf.array()).show()
+ +----+----+
+ | pos| col|
+ +----+----+
+ |NULL|NULL|
+ +----+----+
+ >>> spark.tvf.posexplode_outer(sf.create_map(sf.lit("x"),
sf.lit(1.0))).show()
+ +---+---+-----+
+ |pos|key|value|
+ +---+---+-----+
+ | 0| x| 1.0|
+ +---+---+-----+
+ >>> spark.tvf.posexplode_outer(sf.create_map()).show()
+ +----+----+-----+
+ | pos| key|value|
+ +----+----+-----+
+ |NULL|NULL| NULL|
+ +----+----+-----+
+ """
+ return self._fn("posexplode_outer", collection)
+
+ def stack(self, n: Column, *fields: Column) -> DataFrame:
+ """
+ Separates `col1`, ..., `colk` into `n` rows. Uses column names col0,
col1, etc. by default
+ unless specified otherwise.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ n : :class:`~pyspark.sql.Column`
+ the number of rows to separate
+ fields : :class:`~pyspark.sql.Column`
+ input elements to be separated
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ See Also
+ --------
+ :meth:`pyspark.sql.functions.stack`
+
+ Examples
+ --------
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.stack(sf.lit(2), sf.lit(1), sf.lit(2), sf.lit(3)).show()
+ +----+----+
+ |col0|col1|
+ +----+----+
+ | 1| 2|
+ | 3|NULL|
+ +----+----+
+ """
+ from pyspark.sql.classic.column import _to_seq, _to_java_column
+
+ sc = self._sparkSession.sparkContext
+ return DataFrame(
+ self._sparkSession._jsparkSession.tvf().stack(
+ _to_java_column(n), _to_seq(sc, fields, _to_java_column)
+ ),
+ self._sparkSession,
+ )
+
+ def collations(self) -> DataFrame:
+ """
+ Get all of the Spark SQL string collations.
+
+ .. versionadded:: 4.0.0
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ Examples
+ --------
+ >>> spark.tvf.collations().show()
+ +-------+-------+-------------+...
+ |CATALOG| SCHEMA| NAME|...
+ +-------+-------+-------------+...
+ ...
+ +-------+-------+-------------+...
+ """
+ return self._fn("collations")
+
+ def sql_keywords(self) -> DataFrame:
+ """
+ Get Spark SQL keywords.
+
+ .. versionadded:: 4.0.0
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ Examples
+ --------
+ >>> spark.tvf.sql_keywords().show()
+ +-------------+--------+
+ | keyword|reserved|
+ +-------------+--------+
+ ...
+ +-------------+--------+...
+ """
+ return self._fn("sql_keywords")
+
+ def variant_explode(self, input: Column) -> DataFrame:
+ """
+ Separates a variant object/array into multiple rows containing its
fields/elements.
+
+ Its result schema is `struct<pos int, key string, value variant>`.
`pos` is the position of
+ the field/element in its parent object/array, and `value` is the
field/element value.
+ `key` is the field name when exploding a variant object, or is NULL
when exploding a variant
+ array. It ignores any input that is not a variant array/object,
including SQL NULL, variant
+ null, and any other variant values.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ input : :class:`~pyspark.sql.Column`
+ input column of values to explode.
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ Examples
+ --------
+ Example 1: Using variant_explode with a variant array
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.variant_explode(sf.parse_json(sf.lit('["hello",
"world"]'))).show()
+ +---+----+-------+
+ |pos| key| value|
+ +---+----+-------+
+ | 0|NULL|"hello"|
+ | 1|NULL|"world"|
+ +---+----+-------+
+
+ Example 2: Using variant_explode with a variant object
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.variant_explode(sf.parse_json(sf.lit('{"a": true, "b":
3.14}'))).show()
+ +---+---+-----+
+ |pos|key|value|
+ +---+---+-----+
+ | 0| a| true|
+ | 1| b| 3.14|
+ +---+---+-----+
+
+ Example 3: Using variant_explode with an empty variant array
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.variant_explode(sf.parse_json(sf.lit('[]'))).show()
+ +---+---+-----+
+ |pos|key|value|
+ +---+---+-----+
+ +---+---+-----+
+
+ Example 4: Using variant_explode with an empty variant object
+
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.variant_explode(sf.parse_json(sf.lit('{}'))).show()
+ +---+---+-----+
+ |pos|key|value|
+ +---+---+-----+
+ +---+---+-----+
+ """
+ return self._fn("variant_explode", input)
+
+ def variant_explode_outer(self, input: Column) -> DataFrame:
+ """
+ Separates a variant object/array into multiple rows containing its
fields/elements.
+
+ Its result schema is `struct<pos int, key string, value variant>`.
`pos` is the position of
+ the field/element in its parent object/array, and `value` is the
field/element value.
+ `key` is the field name when exploding a variant object, or is NULL
when exploding a variant
+ array. Unlike variant_explode, if the given variant is not a variant
array/object, including
+ SQL NULL, variant null, and any other variant values, then NULL is
produced.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ input : :class:`~pyspark.sql.Column`
+ input column of values to explode.
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ Examples
+ --------
+ >>> import pyspark.sql.functions as sf
+ >>> spark.tvf.variant_explode_outer(sf.parse_json(sf.lit('["hello",
"world"]'))).show()
+ +---+----+-------+
+ |pos| key| value|
+ +---+----+-------+
+ | 0|NULL|"hello"|
+ | 1|NULL|"world"|
+ +---+----+-------+
+ >>> spark.tvf.variant_explode_outer(sf.parse_json(sf.lit('[]'))).show()
+ +----+----+-----+
+ | pos| key|value|
+ +----+----+-----+
+ |NULL|NULL| NULL|
+ +----+----+-----+
+ >>> spark.tvf.variant_explode_outer(sf.parse_json(sf.lit('{"a": true,
"b": 3.14}'))).show()
+ +---+---+-----+
+ |pos|key|value|
+ +---+---+-----+
+ | 0| a| true|
+ | 1| b| 3.14|
+ +---+---+-----+
+ >>> spark.tvf.variant_explode_outer(sf.parse_json(sf.lit('{}'))).show()
+ +----+----+-----+
+ | pos| key|value|
+ +----+----+-----+
+ |NULL|NULL| NULL|
+ +----+----+-----+
+ """
+ return self._fn("variant_explode_outer", input)
+
+ def _fn(self, functionName: str, *args: Column) -> DataFrame:
+ from pyspark.sql.classic.column import _to_java_column
+
+ return DataFrame(
+ getattr(self._sparkSession._jsparkSession.tvf(), functionName)(
+ *(_to_java_column(arg) for arg in args)
+ ),
+ self._sparkSession,
+ )
+
+
+def _test() -> None:
+ import os
+ import doctest
+ import sys
+ import pyspark.sql.tvf
+
+ os.chdir(os.environ["SPARK_HOME"])
+
+ globs = pyspark.sql.tvf.__dict__.copy()
+ globs["spark"] = SparkSession.builder.master("local[4]").appName("sql.tvf
tests").getOrCreate()
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.tvf,
+ globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
+ )
+ globs["spark"].stop()
+ if failure_count:
+ sys.exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 5488d11d868f..1773cdcf0a0a 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -59,7 +59,7 @@ from pyspark.errors.exceptions.base import QueryContextType
from pyspark.find_spark_home import _find_spark_home
from pyspark.sql.dataframe import DataFrame
from pyspark.sql import Row
-from pyspark.sql.types import StructType, StructField
+from pyspark.sql.types import StructType, StructField, VariantVal
from pyspark.sql.functions import col, when
@@ -899,6 +899,8 @@ def assertDataFrameEqual(
elif isinstance(val1, Decimal) and isinstance(val2, Decimal):
if abs(val1 - val2) > (Decimal(atol) + Decimal(rtol) *
abs(val2)):
return False
+ elif isinstance(val1, VariantVal) and isinstance(val2, VariantVal):
+ return compare_vals(val1.toPython(), val2.toPython())
else:
if val1 != val2:
return False
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
index 31ceecb9e4ca..cb8c2a66ad28 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
@@ -629,6 +629,13 @@ abstract class SparkSession extends Serializable with
Closeable {
*/
def readStream: DataStreamReader
+ /**
+ * Returns a [[TableValuedFunction]] that can be used to call a table-valued
function (TVF).
+ *
+ * @since 4.0.0
+ */
+ def tvf: TableValuedFunction
+
/**
* (Scala-specific) Implicit methods available in Scala for converting
common Scala objects into
* `DataFrame`s.
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/api/TableValuedFunction.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/TableValuedFunction.scala
new file mode 100644
index 000000000000..c03abe0e3d97
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/TableValuedFunction.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.api
+
+import _root_.java.lang
+
+import org.apache.spark.sql.{Column, Row}
+
+abstract class TableValuedFunction {
+
+ /**
+ * Creates a `Dataset` with a single `LongType` column named `id`,
containing elements in a
+ * range from 0 to `end` (exclusive) with step value 1.
+ *
+ * @since 4.0.0
+ */
+ def range(end: Long): Dataset[lang.Long]
+
+ /**
+ * Creates a `Dataset` with a single `LongType` column named `id`,
containing elements in a
+ * range from `start` to `end` (exclusive) with step value 1.
+ *
+ * @since 4.0.0
+ */
+ def range(start: Long, end: Long): Dataset[lang.Long]
+
+ /**
+ * Creates a `Dataset` with a single `LongType` column named `id`,
containing elements in a
+ * range from `start` to `end` (exclusive) with a step value.
+ *
+ * @since 4.0.0
+ */
+ def range(start: Long, end: Long, step: Long): Dataset[lang.Long]
+
+ /**
+ * Creates a `Dataset` with a single `LongType` column named `id`,
containing elements in a
+ * range from `start` to `end` (exclusive) with a step value, with partition
number specified.
+ *
+ * @since 4.0.0
+ */
+ def range(start: Long, end: Long, step: Long, numPartitions: Int):
Dataset[lang.Long]
+
+ /**
+ * Creates a `DataFrame` containing a new row for each element in the given
array or map column.
+ * Uses the default column name `col` for elements in the array and `key`
and `value` for
+ * elements in the map unless specified otherwise.
+ *
+ * @group generator_funcs
+ * @since 4.0.0
+ */
+ def explode(collection: Column): Dataset[Row]
+
+ /**
+ * Creates a `DataFrame` containing a new row for each element in the given
array or map column.
+ * Uses the default column name `col` for elements in the array and `key`
and `value` for
+ * elements in the map unless specified otherwise. Unlike explode, if the
array/map is null or
+ * empty then null is produced.
+ *
+ * @group generator_funcs
+ * @since 4.0.0
+ */
+ def explode_outer(collection: Column): Dataset[Row]
+
+ /**
+ * Creates a `DataFrame` containing a new row for each element in the given
array of structs.
+ *
+ * @group generator_funcs
+ * @since 4.0.0
+ */
+ def inline(input: Column): Dataset[Row]
+
+ /**
+ * Creates a `DataFrame` containing a new row for each element in the given
array of structs.
+ * Unlike inline, if the array is null or empty then null is produced for
each nested column.
+ *
+ * @group generator_funcs
+ * @since 4.0.0
+ */
+ def inline_outer(input: Column): Dataset[Row]
+
+ /**
+ * Creates a `DataFrame` containing a new row for a json column according to
the given field
+ * names.
+ *
+ * @group json_funcs
+ * @since 4.0.0
+ */
+ @scala.annotation.varargs
+ def json_tuple(input: Column, fields: Column*): Dataset[Row]
+
+ /**
+ * Creates a `DataFrame` containing a new row for each element with position
in the given array
+ * or map column. Uses the default column name `pos` for position, and `col`
for elements in the
+ * array and `key` and `value` for elements in the map unless specified
otherwise.
+ *
+ * @group generator_funcs
+ * @since 4.0.0
+ */
+ def posexplode(collection: Column): Dataset[Row]
+
+ /**
+ * Creates a `DataFrame` containing a new row for each element with position
in the given array
+ * or map column. Uses the default column name `pos` for position, and `col`
for elements in the
+ * array and `key` and `value` for elements in the map unless specified
otherwise. Unlike
+ * posexplode, if the array/map is null or empty then the row (null, null)
is produced.
+ *
+ * @group generator_funcs
+ * @since 4.0.0
+ */
+ def posexplode_outer(collection: Column): Dataset[Row]
+
+ /**
+ * Separates `col1`, ..., `colk` into `n` rows. Uses column names col0,
col1, etc. by default
+ * unless specified otherwise.
+ *
+ * @group generator_funcs
+ * @since 4.0.0
+ */
+ @scala.annotation.varargs
+ def stack(n: Column, fields: Column*): Dataset[Row]
+
+ /**
+ * Gets all of the Spark SQL string collations.
+ *
+ * @group generator_funcs
+ * @since 4.0.0
+ */
+ def collations(): Dataset[Row]
+
+ /**
+ * Gets Spark SQL keywords.
+ *
+ * @group generator_funcs
+ * @since 4.0.0
+ */
+ def sql_keywords(): Dataset[Row]
+
+ /**
+ * Separates a variant object/array into multiple rows containing its
fields/elements. Its
+ * result schema is `struct<pos int, key string, value variant>`.
`pos` is the position of
+ * the field/element in its parent object/array, and `value` is the
field/element value. `key`
+ * is the field name when exploding a variant object, or is NULL when
exploding a variant array.
+ * It ignores any input that is not a variant array/object, including SQL
NULL, variant null,
+ * and any other variant values.
+ *
+ * @group variant_funcs
+ * @since 4.0.0
+ */
+ def variant_explode(input: Column): Dataset[Row]
+
+ /**
+ * Separates a variant object/array into multiple rows containing its
fields/elements. Its
+ * result schema is `struct<pos int, key string, value variant>`.
`pos` is the position of
+ * the field/element in its parent object/array, and `value` is the
field/element value. `key`
+ * is the field name when exploding a variant object, or is NULL when
exploding a variant array.
+ * Unlike variant_explode, if the given variant is not a variant
array/object, including SQL
+ * NULL, variant null, and any other variant values, then NULL is produced.
+ *
+ * @group variant_funcs
+ * @since 4.0.0
+ */
+ def variant_explode_outer(input: Column): Dataset[Row]
+}
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
index 1003e5c21d63..a7b9137c3400 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -77,6 +77,7 @@ message Relation {
CommonInlineUserDefinedDataSource common_inline_user_defined_data_source =
40;
WithRelations with_relations = 41;
Transpose transpose = 42;
+ UnresolvedTableValuedFunction unresolved_table_valued_function = 43;
// NA functions
NAFill fill_na = 90;
@@ -902,6 +903,14 @@ message Transpose {
repeated Expression index_columns = 2;
}
+message UnresolvedTableValuedFunction {
+ // (Required) name (or unparsed name for user defined function) for the
unresolved function.
+ string function_name = 1;
+
+ // (Optional) Function arguments. Empty arguments are allowed.
+ repeated Expression arguments = 2;
+}
+
message ToSchema {
// (Required) The input relation.
Relation input = 1;
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4e6994f9c2f8..a9d2bd482150 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,7 +45,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID,
SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile,
TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation,
RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery,
PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute,
UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue,
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar,
UnresolvedTranspose}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery,
PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute,
UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue,
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar,
UnresolvedTableValuedFunction, UnresolvedTranspose}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder,
ExpressionEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
@@ -203,6 +203,8 @@ class SparkConnectPlanner(
case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
case proto.Relation.RelTypeCase.UNPIVOT =>
transformUnpivot(rel.getUnpivot)
case proto.Relation.RelTypeCase.TRANSPOSE =>
transformTranspose(rel.getTranspose)
+ case proto.Relation.RelTypeCase.UNRESOLVED_TABLE_VALUED_FUNCTION =>
+
transformUnresolvedTableValuedFunction(rel.getUnresolvedTableValuedFunction)
case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
transformRepartitionByExpression(rel.getRepartitionByExpression)
case proto.Relation.RelTypeCase.MAP_PARTITIONS =>
@@ -1132,6 +1134,13 @@ class SparkConnectPlanner(
UnresolvedTranspose(indices = indices, child = child)
}
+ private def transformUnresolvedTableValuedFunction(
+ rel: proto.UnresolvedTableValuedFunction): LogicalPlan = {
+ UnresolvedTableValuedFunction(
+ rel.getFunctionName,
+ rel.getArgumentsList.asScala.map(transformExpression).toSeq)
+ }
+
private def transformUnpivot(rel: proto.Unpivot): LogicalPlan = {
val ids = rel.getIdsList.asScala.toArray.map { expr =>
column(transformExpression(expr))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 8d54c1862c82..823356af3195 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -676,6 +676,9 @@ class SparkSession private(
/** @inheritdoc */
def readStream: DataStreamReader = new DataStreamReader(self)
+ /** @inheritdoc */
+ def tvf: TableValuedFunction = new TableValuedFunction(self)
+
// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
object implicits extends SQLImplicits {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
new file mode 100644
index 000000000000..406b67e6f3b8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedTableValuedFunction
+
+class TableValuedFunction(sparkSession: SparkSession)
+ extends api.TableValuedFunction {
+
+ /** @inheritdoc */
+ override def range(end: Long): Dataset[java.lang.Long] = {
+ sparkSession.range(end)
+ }
+
+ /** @inheritdoc */
+ override def range(start: Long, end: Long): Dataset[java.lang.Long] = {
+ sparkSession.range(start, end)
+ }
+
+ /** @inheritdoc */
+ override def range(start: Long, end: Long, step: Long):
Dataset[java.lang.Long] = {
+ sparkSession.range(start, end, step)
+ }
+
+ /** @inheritdoc */
+ override def range(
+ start: Long, end: Long, step: Long, numPartitions: Int):
Dataset[java.lang.Long] = {
+ sparkSession.range(start, end, step, numPartitions)
+ }
+
+ private def fn(name: String, args: Seq[Column]): Dataset[Row] = {
+ Dataset.ofRows(
+ sparkSession,
+ UnresolvedTableValuedFunction(name, args.map(sparkSession.expression)))
+ }
+
+ /** @inheritdoc */
+ override def explode(collection: Column): Dataset[Row] =
+ fn("explode", Seq(collection))
+
+ /** @inheritdoc */
+ override def explode_outer(collection: Column): Dataset[Row] =
+ fn("explode_outer", Seq(collection))
+
+ /** @inheritdoc */
+ override def inline(input: Column): Dataset[Row] =
+ fn("inline", Seq(input))
+
+ /** @inheritdoc */
+ override def inline_outer(input: Column): Dataset[Row] =
+ fn("inline_outer", Seq(input))
+
+ /** @inheritdoc */
+ override def json_tuple(input: Column, fields: Column*): Dataset[Row] =
+ fn("json_tuple", input +: fields)
+
+ /** @inheritdoc */
+ override def posexplode(collection: Column): Dataset[Row] =
+ fn("posexplode", Seq(collection))
+
+ /** @inheritdoc */
+ override def posexplode_outer(collection: Column): Dataset[Row] =
+ fn("posexplode_outer", Seq(collection))
+
+ /** @inheritdoc */
+ override def stack(n: Column, fields: Column*): Dataset[Row] =
+ fn("stack", n +: fields)
+
+ /** @inheritdoc */
+ override def collations(): Dataset[Row] =
+ fn("collations", Seq.empty)
+
+ /** @inheritdoc */
+ override def sql_keywords(): Dataset[Row] =
+ fn("sql_keywords", Seq.empty)
+
+ /** @inheritdoc */
+ override def variant_explode(input: Column): Dataset[Row] =
+ fn("variant_explode", Seq(input))
+
+ /** @inheritdoc */
+ override def variant_explode_outer(input: Column): Dataset[Row] =
+ fn("variant_explode_outer", Seq(input))
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
new file mode 100644
index 000000000000..c2f53ff56d1a
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSparkSession
+
+class DataFrameTableValuedFunctionsSuite extends QueryTest with
SharedSparkSession {
+
+ test("explode") {
+ val actual1 = spark.tvf.explode(array(lit(1), lit(2)))
+ val expected1 = spark.sql("SELECT * FROM explode(array(1, 2))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.explode(map(lit("a"), lit(1), lit("b"), lit(2)))
+ val expected2 = spark.sql("SELECT * FROM explode(map('a', 1, 'b', 2))")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.explode(array())
+ val expected3 = spark.sql("SELECT * FROM explode(array())")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.explode(map())
+ val expected4 = spark.sql("SELECT * FROM explode(map())")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.explode(lit(null).cast("array<int>"))
+ val expected5 = spark.sql("SELECT * FROM explode(null :: array<int>)")
+ checkAnswer(actual5, expected5)
+
+ val actual6 = spark.tvf.explode(lit(null).cast("map<string, int>"))
+ val expected6 = spark.sql("SELECT * FROM explode(null :: map<string,
int>)")
+ checkAnswer(actual6, expected6)
+ }
+
+ test("explode_outer") {
+ val actual1 = spark.tvf.explode_outer(array(lit(1), lit(2)))
+ val expected1 = spark.sql("SELECT * FROM explode_outer(array(1, 2))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.explode_outer(map(lit("a"), lit(1), lit("b"),
lit(2)))
+ val expected2 = spark.sql("SELECT * FROM explode_outer(map('a', 1, 'b',
2))")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.explode_outer(array())
+ val expected3 = spark.sql("SELECT * FROM explode_outer(array())")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.explode_outer(map())
+ val expected4 = spark.sql("SELECT * FROM explode_outer(map())")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.explode_outer(lit(null).cast("array<int>"))
+ val expected5 = spark.sql("SELECT * FROM explode_outer(null ::
array<int>)")
+ checkAnswer(actual5, expected5)
+
+ val actual6 = spark.tvf.explode_outer(lit(null).cast("map<string, int>"))
+ val expected6 = spark.sql("SELECT * FROM explode_outer(null :: map<string,
int>)")
+ checkAnswer(actual6, expected6)
+ }
+
+ test("inline") {
+ val actual1 = spark.tvf.inline(array(struct(lit(1), lit("a")),
struct(lit(2), lit("b"))))
+ val expected1 = spark.sql("SELECT * FROM inline(array(struct(1, 'a'),
struct(2, 'b')))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.inline(array().cast("array<struct<a:int,b:int>>"))
+ val expected2 = spark.sql("SELECT * FROM inline(array() ::
array<struct<a:int,b:int>>)")
+ checkAnswer(actual2, expected2)
+
+ val actual3 = spark.tvf.inline(array(
+ named_struct(lit("a"), lit(1), lit("b"), lit(2)),
+ lit(null),
+ named_struct(lit("a"), lit(3), lit("b"), lit(4))
+ ))
+ val expected3 = spark.sql(
+ "SELECT * FROM " +
+ "inline(array(named_struct('a', 1, 'b', 2), null, named_struct('a', 3,
'b', 4)))")
+ checkAnswer(actual3, expected3)
+ }
+
+ test("inline_outer") {
+ val actual1 = spark.tvf.inline_outer(array(struct(lit(1), lit("a")),
struct(lit(2), lit("b"))))
+ val expected1 = spark.sql("SELECT * FROM inline_outer(array(struct(1,
'a'), struct(2, 'b')))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 =
spark.tvf.inline_outer(array().cast("array<struct<a:int,b:int>>"))
+ val expected2 = spark.sql("SELECT * FROM inline_outer(array() ::
array<struct<a:int,b:int>>)")
+ checkAnswer(actual2, expected2)
+
+ val actual3 = spark.tvf.inline_outer(array(
+ named_struct(lit("a"), lit(1), lit("b"), lit(2)),
+ lit(null),
+ named_struct(lit("a"), lit(3), lit("b"), lit(4))
+ ))
+ val expected3 = spark.sql(
+ "SELECT * FROM " +
+ "inline_outer(array(named_struct('a', 1, 'b', 2), null,
named_struct('a', 3, 'b', 4)))")
+ checkAnswer(actual3, expected3)
+ }
+
+ test("json_tuple") {
+ val actual = spark.tvf.json_tuple(lit("""{"a":1,"b":2}"""), lit("a"),
lit("b"))
+ val expected = spark.sql("""SELECT * FROM json_tuple('{"a":1,"b":2}', 'a',
'b')""")
+ checkAnswer(actual, expected)
+
+ val ex = intercept[AnalysisException] {
+ spark.tvf.json_tuple(lit("""{"a":1,"b":2}""")).collect()
+ }
+ assert(ex.errorClass.get == "WRONG_NUM_ARGS.WITHOUT_SUGGESTION")
+ assert(ex.messageParameters("functionName") == "`json_tuple`")
+ }
+
+ test("posexplode") {
+ val actual1 = spark.tvf.posexplode(array(lit(1), lit(2)))
+ val expected1 = spark.sql("SELECT * FROM posexplode(array(1, 2))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.posexplode(map(lit("a"), lit(1), lit("b"), lit(2)))
+ val expected2 = spark.sql("SELECT * FROM posexplode(map('a', 1, 'b', 2))")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.posexplode(array())
+ val expected3 = spark.sql("SELECT * FROM posexplode(array())")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.posexplode(map())
+ val expected4 = spark.sql("SELECT * FROM posexplode(map())")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.posexplode(lit(null).cast("array<int>"))
+ val expected5 = spark.sql("SELECT * FROM posexplode(null :: array<int>)")
+ checkAnswer(actual5, expected5)
+
+ val actual6 = spark.tvf.posexplode(lit(null).cast("map<string, int>"))
+ val expected6 = spark.sql("SELECT * FROM posexplode(null :: map<string,
int>)")
+ checkAnswer(actual6, expected6)
+ }
+
+ test("posexplode_outer") {
+ val actual1 = spark.tvf.posexplode_outer(array(lit(1), lit(2)))
+ val expected1 = spark.sql("SELECT * FROM posexplode_outer(array(1, 2))")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.posexplode_outer(map(lit("a"), lit(1), lit("b"),
lit(2)))
+ val expected2 = spark.sql("SELECT * FROM posexplode_outer(map('a', 1, 'b',
2))")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.posexplode_outer(array())
+ val expected3 = spark.sql("SELECT * FROM posexplode_outer(array())")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.posexplode_outer(map())
+ val expected4 = spark.sql("SELECT * FROM posexplode_outer(map())")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.posexplode_outer(lit(null).cast("array<int>"))
+ val expected5 = spark.sql("SELECT * FROM posexplode_outer(null ::
array<int>)")
+ checkAnswer(actual5, expected5)
+
+ val actual6 = spark.tvf.posexplode_outer(lit(null).cast("map<string,
int>"))
+ val expected6 = spark.sql("SELECT * FROM posexplode_outer(null ::
map<string, int>)")
+ checkAnswer(actual6, expected6)
+ }
+
+ test("stack") {
+ val actual = spark.tvf.stack(lit(2), lit(1), lit(2), lit(3))
+ val expected = spark.sql("SELECT * FROM stack(2, 1, 2, 3)")
+ checkAnswer(actual, expected)
+ }
+
+ test("collations") {
+ val actual = spark.tvf.collations()
+ val expected = spark.sql("SELECT * FROM collations()")
+ checkAnswer(actual, expected)
+ }
+
+ test("sql_keywords") {
+ val actual = spark.tvf.sql_keywords()
+ val expected = spark.sql("SELECT * FROM sql_keywords()")
+ checkAnswer(actual, expected)
+ }
+
+ test("variant_explode") {
+ val actual1 = spark.tvf.variant_explode(parse_json(lit("""["hello",
"world"]""")))
+ val expected1 = spark.sql(
+ """SELECT * FROM variant_explode(parse_json('["hello", "world"]'))""")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.variant_explode(parse_json(lit("""{"a": true, "b":
3.14}""")))
+ val expected2 = spark.sql(
+ """SELECT * FROM variant_explode(parse_json('{"a": true, "b":
3.14}'))""")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.variant_explode(parse_json(lit("[]")))
+ val expected3 = spark.sql("SELECT * FROM
variant_explode(parse_json('[]'))")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.variant_explode(parse_json(lit("{}")))
+ val expected4 = spark.sql("SELECT * FROM
variant_explode(parse_json('{}'))")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.variant_explode(lit(null).cast("variant"))
+ val expected5 = spark.sql("SELECT * FROM variant_explode(null :: variant)")
+ checkAnswer(actual5, expected5)
+
+ // not a variant object/array
+ val actual6 = spark.tvf.variant_explode(parse_json(lit("1")))
+ val expected6 = spark.sql("SELECT * FROM variant_explode(parse_json('1'))")
+ checkAnswer(actual6, expected6)
+ }
+
+ test("variant_explode_outer") {
+ val actual1 = spark.tvf.variant_explode_outer(parse_json(lit("""["hello",
"world"]""")))
+ val expected1 = spark.sql(
+ """SELECT * FROM variant_explode_outer(parse_json('["hello",
"world"]'))""")
+ checkAnswer(actual1, expected1)
+
+ val actual2 = spark.tvf.variant_explode_outer(parse_json(lit("""{"a":
true, "b": 3.14}""")))
+ val expected2 = spark.sql(
+ """SELECT * FROM variant_explode_outer(parse_json('{"a": true, "b":
3.14}'))""")
+ checkAnswer(actual2, expected2)
+
+ // empty
+ val actual3 = spark.tvf.variant_explode_outer(parse_json(lit("[]")))
+ val expected3 = spark.sql("SELECT * FROM
variant_explode_outer(parse_json('[]'))")
+ checkAnswer(actual3, expected3)
+
+ val actual4 = spark.tvf.variant_explode_outer(parse_json(lit("{}")))
+ val expected4 = spark.sql("SELECT * FROM
variant_explode_outer(parse_json('{}'))")
+ checkAnswer(actual4, expected4)
+
+ // null
+ val actual5 = spark.tvf.variant_explode_outer(lit(null).cast("variant"))
+ val expected5 = spark.sql("SELECT * FROM variant_explode_outer(null ::
variant)")
+ checkAnswer(actual5, expected5)
+
+ // not a variant object/array
+ val actual6 = spark.tvf.variant_explode_outer(parse_json(lit("1")))
+ val expected6 = spark.sql("SELECT * FROM
variant_explode_outer(parse_json('1'))")
+ checkAnswer(actual6, expected6)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]