This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4cb364e6f615 [SPARK-47680][SQL] Add variant_explode expression
4cb364e6f615 is described below
commit 4cb364e6f615512811b3001597d0cf98a7a30b00
Author: Chenhao Li <[email protected]>
AuthorDate: Wed Apr 10 22:47:43 2024 +0800
[SPARK-47680][SQL] Add variant_explode expression
### What changes were proposed in this pull request?
This PR adds a new `VariantExplode` expression. It separates a variant
object/array into multiple rows containing its fields/elements. Its result
schema is `struct<pos int, key string, value variant>`. `pos` is the position
of the field/element in its parent object/array, and `value` is the
field/element value. `key` is the field name when exploding a variant object,
or is NULL when exploding a variant array. It ignores any input that is not a
variant array/object, including SQL NULL, [...]
It is exposed as two SQL expressions, `variant_explode` and
`variant_explode_outer`. The only difference is that whenever `variant_explode`
produces zero output row for an input row, `variant_explode_outer` will produce
one output row containing `{NULL, NULL, NULL}`.
Usage examples:
```
> SELECT variant_explode(parse_json('["hello", "world"]'));
0 NULL "hello"
1 NULL "world"
> SELECT variant_explode(parse_json('{"a": true, "b": 3.14}'));
0 a true
1 b 3.14
```
### Why are the changes needed?
This expression allows the user to process variant array and object more
conveniently.
### Does this PR introduce _any_ user-facing change?
Yes. A new SQL expression is added.
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45805 from chenhao-db/variant_explode.
Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/analysis/FunctionRegistry.scala | 4 +-
.../expressions/variant/variantExpressions.scala | 83 ++++++++++++++++++++++
.../scala/org/apache/spark/sql/VariantSuite.scala | 26 +++++++
3 files changed, 112 insertions(+), 1 deletion(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 99ae3adde44f..9447ea63b51f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -1096,7 +1096,9 @@ object TableFunctionRegistry {
generator[PosExplode]("posexplode"),
generator[PosExplode]("posexplode_outer", outer = true),
generator[Stack]("stack"),
- generator[SQLKeywords]("sql_keywords")
+ generator[SQLKeywords]("sql_keywords"),
+ generator[VariantExplode]("variant_explode"),
+ generator[VariantExplode]("variant_explode_outer", outer = true)
)
val builtin: SimpleTableFunctionRegistry = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index 7d1a3cf00d2b..c5e316dc6c8c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.variant
import scala.util.parsing.combinator.RegexParsers
import org.apache.spark.SparkRuntimeException
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -419,6 +420,88 @@ object VariantGetExpressionBuilder extends
VariantGetExpressionBuilderBase(true)
// scalastyle:on line.size.limit
object TryVariantGetExpressionBuilder extends
VariantGetExpressionBuilderBase(false)
+// scalastyle:off line.size.limit line.contains.tab
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - It separates a variant object/array into multiple
rows containing its fields/elements. Its result schema is `struct<pos int, key
string, value variant>`. `pos` is the position of the field/element in its
parent object/array, and `value` is the field/element value. `key` is the field
name when exploding a variant object, or is NULL when exploding a variant
array. It ignores any input that is not a variant array/object, including SQL
NULL, variant null, and any ot [...]
+ examples = """
+ Examples:
+ > SELECT * from _FUNC_(parse_json('["hello", "world"]'));
+ 0 NULL "hello"
+ 1 NULL "world"
+ > SELECT * from _FUNC_(parse_json('{"a": true, "b": 3.14}'));
+ 0 a true
+ 1 b 3.14
+ """,
+ since = "4.0.0",
+ group = "variant_funcs")
+// scalastyle:on line.size.limit line.contains.tab
+case class VariantExplode(child: Expression) extends UnaryExpression with
Generator
+ with ExpectsInputTypes {
+ override def inputTypes: Seq[AbstractDataType] = Seq(VariantType)
+
+ override def prettyName: String = "variant_explode"
+
+ override protected def withNewChildInternal(newChild: Expression):
VariantExplode =
+ copy(child = newChild)
+
+ override def eval(input: InternalRow): IterableOnce[InternalRow] = {
+ val inputVariant = child.eval(input).asInstanceOf[VariantVal]
+ VariantExplode.variantExplode(inputVariant, inputVariant == null)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
+ val childCode = child.genCode(ctx)
+ val cls = classOf[VariantExplode].getName
+ val code = code"""
+ ${childCode.code}
+ scala.collection.Seq<InternalRow> ${ev.value} = $cls.variantExplode(
+ ${childCode.value}, ${childCode.isNull});
+ """
+ ev.copy(code = code, isNull = FalseLiteral)
+ }
+
+ override def elementSchema: StructType = {
+ new StructType()
+ .add("pos", IntegerType, nullable = false)
+ .add("key", StringType, nullable = true)
+ .add("value", VariantType, nullable = false)
+ }
+}
+
+object VariantExplode {
+ /**
+ * The actual implementation of the `VariantExplode` expression. We check
`isNull` separately
+ * rather than `input == null` because the documentation of `ExprCode` says
that the value is not
+ * valid if `isNull` is set to `true`.
+ */
+ def variantExplode(input: VariantVal, isNull: Boolean):
scala.collection.Seq[InternalRow] = {
+ if (isNull) {
+ return Nil
+ }
+ val v = new Variant(input.getValue, input.getMetadata)
+ v.getType match {
+ case Type.OBJECT =>
+ val size = v.objectSize()
+ val result = new Array[InternalRow](size)
+ for (i <- 0 until size) {
+ val field = v.getFieldAtIndex(i)
+ result(i) = InternalRow(i, UTF8String.fromString(field.key),
+ new VariantVal(field.value.getValue, field.value.getMetadata))
+ }
+ result
+ case Type.ARRAY =>
+ val size = v.arraySize()
+ val result = new Array[InternalRow](size)
+ for (i <- 0 until size) {
+ val elem = v.getElementAtIndex(i)
+ result(i) = InternalRow(i, null, new VariantVal(elem.getValue,
elem.getMetadata))
+ }
+ result
+ case _ => Nil
+ }
+ }
+}
+
@ExpressionDescription(
usage = "_FUNC_(v) - Returns schema in the SQL format of a variant.",
examples = """
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index 4f82dbc90dc5..d276ec4428b9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -32,6 +32,8 @@ import org.apache.spark.unsafe.types.VariantVal
import org.apache.spark.util.ArrayImplicits._
class VariantSuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
test("basic tests") {
def verifyResult(df: DataFrame): Unit = {
val result = df.collect()
@@ -298,4 +300,28 @@ class VariantSuite extends QueryTest with
SharedSparkSession {
}
assert(ex.getErrorClass == "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE")
}
+
+ test("variant_explode") {
+ def check(input: String, expected: Seq[Row]): Unit = {
+ withView("v") {
+ Seq(input).toDF("json").createOrReplaceTempView("v")
+ checkAnswer(sql("select pos, key, to_json(value) from v, " +
+ "lateral variant_explode(parse_json(json))"), expected)
+ val expectedOuter = if (expected.isEmpty) Seq(Row(null, null, null))
else expected
+ checkAnswer(sql("select pos, key, to_json(value) from v, " +
+ "lateral variant_explode_outer(parse_json(json))"), expectedOuter)
+ }
+ }
+
+ Seq("true", "false").foreach { codegenEnabled =>
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled) {
+ check(null, Nil)
+ check("1", Nil)
+ check("null", Nil)
+ check("""{"a": [1, 2, 3], "b": true}""", Seq(Row(0, "a", "[1,2,3]"),
Row(1, "b", "true")))
+ check("""[null, "hello", {}]""",
+ Seq(Row(0, null, "null"), Row(1, null, "\"hello\""), Row(2, null,
"{}")))
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]