This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8c4d6764674f [SPARK-47559][SQL] Codegen Support for variant
`parse_json`
8c4d6764674f is described below
commit 8c4d6764674fcddf30245dfc25ef825eabba0ace
Author: panbingkun <[email protected]>
AuthorDate: Thu Mar 28 20:34:42 2024 +0800
[SPARK-47559][SQL] Codegen Support for variant `parse_json`
### What changes were proposed in this pull request?
The PR adds Codegen Support for `parse_json`.
### Why are the changes needed?
Improve codegen coverage.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- Add new UT.
- Pass GA.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45714 from panbingkun/ParseJson_CodeGenerator.
Authored-by: panbingkun <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
...ions.scala => VariantExpressionEvalUtils.scala} | 35 ++----
.../expressions/variant/variantExpressions.scala | 34 ++---
.../variant/VariantExpressionEvalUtilsSuite.scala | 125 +++++++++++++++++++
.../variant/VariantExpressionSuite.scala | 137 +--------------------
.../apache/spark/sql/VariantEndToEndSuite.scala | 84 +++++++++++++
5 files changed, 229 insertions(+), 186 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
similarity index 56%
copy from
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
copy to
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
index cab61d2b12c2..74fae91f98a6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
@@ -19,35 +19,17 @@ package org.apache.spark.sql.catalyst.expressions.variant
import scala.util.control.NonFatal
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.BadRecordException
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.types._
import org.apache.spark.types.variant.{VariantBuilder,
VariantSizeLimitException, VariantUtil}
-import org.apache.spark.unsafe.types._
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(jsonStr) - Parse a JSON string as an Variant value. Throw an
exception when the string is not valid JSON value.",
- examples = """
- Examples:
- > SELECT _FUNC_('{"a":1,"b":0.8}');
- {"a":1,"b":0.8}
- """,
- since = "4.0.0",
- group = "variant_funcs"
-)
-// scalastyle:on line.size.limit
-case class ParseJson(child: Expression) extends UnaryExpression
- with NullIntolerant with ExpectsInputTypes with CodegenFallback {
- override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
-
- override def dataType: DataType = VariantType
-
- override def prettyName: String = "parse_json"
+/**
+ * A utility class for constructing variant expressions.
+ */
+object VariantExpressionEvalUtils {
- protected override def nullSafeEval(input: Any): Any = {
+ def parseJson(input: UTF8String): VariantVal = {
try {
val v = VariantBuilder.parseJson(input.toString)
new VariantVal(v.getValue, v.getMetadata)
@@ -56,10 +38,7 @@ case class ParseJson(child: Expression) extends
UnaryExpression
throw
QueryExecutionErrors.variantSizeLimitError(VariantUtil.SIZE_LIMIT, "parse_json")
case NonFatal(e) =>
throw
QueryExecutionErrors.malformedRecordsDetectedInRecordParsingError(
- input.toString, BadRecordException(() =>
input.asInstanceOf[UTF8String], cause = e))
+ input.toString, BadRecordException(() => input, cause = e))
}
}
-
- override protected def withNewChildInternal(newChild: Expression): ParseJson
=
- copy(child = newChild)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index cab61d2b12c2..00708f863e81 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -17,15 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.variant
-import scala.util.control.NonFatal
-
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.BadRecordException
-import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.types._
-import org.apache.spark.types.variant.{VariantBuilder,
VariantSizeLimitException, VariantUtil}
-import org.apache.spark.unsafe.types._
// scalastyle:off line.size.limit
@ExpressionDescription(
@@ -39,27 +33,23 @@ import org.apache.spark.unsafe.types._
group = "variant_funcs"
)
// scalastyle:on line.size.limit
-case class ParseJson(child: Expression) extends UnaryExpression
- with NullIntolerant with ExpectsInputTypes with CodegenFallback {
+case class ParseJson(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with RuntimeReplaceable {
+
+ override lazy val replacement: Expression = StaticInvoke(
+ VariantExpressionEvalUtils.getClass,
+ VariantType,
+ "parseJson",
+ Seq(child),
+ inputTypes,
+ returnNullable = false)
+
override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
override def dataType: DataType = VariantType
override def prettyName: String = "parse_json"
- protected override def nullSafeEval(input: Any): Any = {
- try {
- val v = VariantBuilder.parseJson(input.toString)
- new VariantVal(v.getValue, v.getMetadata)
- } catch {
- case _: VariantSizeLimitException =>
- throw
QueryExecutionErrors.variantSizeLimitError(VariantUtil.SIZE_LIMIT, "parse_json")
- case NonFatal(e) =>
- throw
QueryExecutionErrors.malformedRecordsDetectedInRecordParsingError(
- input.toString, BadRecordException(() =>
input.asInstanceOf[UTF8String], cause = e))
- }
- }
-
override protected def withNewChildInternal(newChild: Expression): ParseJson
=
copy(child = newChild)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala
new file mode 100644
index 000000000000..574d5daa361e
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.variant
+
+import org.apache.spark.{SparkFunSuite, SparkThrowable}
+import org.apache.spark.types.variant.VariantUtil._
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
+
+class VariantExpressionEvalUtilsSuite extends SparkFunSuite {
+
+ test("parseJson type coercion") {
+ def check(json: String, expectedValue: Array[Byte], expectedMetadata:
Array[Byte]): Unit = {
+ val actual =
VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json))
+ val expected = new VariantVal(expectedValue, expectedMetadata)
+ assert(actual === expected)
+ }
+
+ // Dictionary size is `0` for value 0. An empty dictionary contains one
offset `0` for the
+ // one-past-the-end position (i.e. the sum of all string lengths).
+ val emptyMetadata = Array[Byte](VERSION, 0, 0)
+ check("null", Array(primitiveHeader(NULL)), emptyMetadata)
+ check("true", Array(primitiveHeader(TRUE)), emptyMetadata)
+ check("false", Array(primitiveHeader(FALSE)), emptyMetadata)
+ check("1", Array(primitiveHeader(INT1), 1), emptyMetadata)
+ check("-1", Array(primitiveHeader(INT1), -1), emptyMetadata)
+ check("127", Array(primitiveHeader(INT1), 127), emptyMetadata)
+ check("128", Array(primitiveHeader(INT2), -128, 0), emptyMetadata)
+ check("-32768", Array(primitiveHeader(INT2), 0, -128), emptyMetadata)
+ check("-32769", Array(primitiveHeader(INT4), -1, 127, -1, -1),
emptyMetadata)
+ check("2147483647", Array(primitiveHeader(INT4), -1, -1, -1, 127),
emptyMetadata)
+ check("2147483648", Array(primitiveHeader(INT8), 0, 0, 0, -128, 0, 0, 0,
0), emptyMetadata)
+ check("9223372036854775807",
+ Array(primitiveHeader(INT8), -1, -1, -1, -1, -1, -1, -1, 127),
emptyMetadata)
+ check("-9223372036854775808",
+ Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0, -128), emptyMetadata)
+ check("9223372036854775808",
+ Array(primitiveHeader(DECIMAL16), 0, 0, 0, 0, 0, 0, 0, 0, -128, 0, 0, 0,
0, 0, 0, 0, 0),
+ emptyMetadata)
+ check("1.0", Array(primitiveHeader(DECIMAL4), 1, 10, 0, 0, 0),
emptyMetadata)
+ check("1.01", Array(primitiveHeader(DECIMAL4), 2, 101, 0, 0, 0),
emptyMetadata)
+ check("99999.9999", Array(primitiveHeader(DECIMAL4), 4, -1, -55, -102,
59), emptyMetadata)
+ check("99999.99999",
+ Array(primitiveHeader(DECIMAL8), 5, -1, -29, 11, 84, 2, 0, 0, 0),
emptyMetadata)
+ check("0.000000001", Array(primitiveHeader(DECIMAL4), 9, 1, 0, 0, 0),
emptyMetadata)
+ check("0.0000000001",
+ Array(primitiveHeader(DECIMAL8), 10, 1, 0, 0, 0, 0, 0, 0, 0),
emptyMetadata)
+ check("9" * 38,
+ Array[Byte](primitiveHeader(DECIMAL16), 0) ++ BigInt("9" *
38).toByteArray.reverse,
+ emptyMetadata)
+ check("1" + "0" * 38,
+ Array(primitiveHeader(DOUBLE)) ++
+ BigInt(java.lang.Double.doubleToLongBits(1E38)).toByteArray.reverse,
+ emptyMetadata)
+ check("\"\"", Array(shortStrHeader(0)), emptyMetadata)
+ check("\"abcd\"", Array(shortStrHeader(4), 'a', 'b', 'c', 'd'),
emptyMetadata)
+ check("\"" + ("x" * 63) + "\"",
+ Array(shortStrHeader(63)) ++ Array.fill(63)('x'.toByte), emptyMetadata)
+ check("\"" + ("y" * 64) + "\"",
+ Array[Byte](primitiveHeader(LONG_STR), 64, 0, 0, 0) ++
Array.fill(64)('y'.toByte),
+ emptyMetadata)
+ check("{}", Array(objectHeader(false, 1, 1),
+ /* size */ 0,
+ /* offset list */ 0), emptyMetadata)
+ check("[]", Array(arrayHeader(false, 1),
+ /* size */ 0,
+ /* offset list */ 0), emptyMetadata)
+ check("""{"a": 1, "b": 2, "c": "3"}""", Array(objectHeader(false, 1, 1),
+ /* size */ 3,
+ /* id list */ 0, 1, 2,
+ /* offset list */ 0, 2, 4, 6,
+ /* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2,
shortStrHeader(1), '3'),
+ Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c'))
+ check("""{"z": 1, "y": 2, "x": "3"}""", Array(objectHeader(false, 1, 1),
+ /* size */ 3,
+ /* id list */ 2, 1, 0,
+ /* offset list */ 4, 2, 0, 6,
+ /* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2,
shortStrHeader(1), '3'),
+ Array(VERSION, 3, 0, 1, 2, 3, 'z', 'y', 'x'))
+ check("""[null, true, {"false" : 0}]""", Array(arrayHeader(false, 1),
+ /* size */ 3,
+ /* offset list */ 0, 1, 2, 9,
+ /* element data */ primitiveHeader(NULL), primitiveHeader(TRUE),
objectHeader(false, 1, 1),
+ /* size */ 1,
+ /* id list */ 0,
+ /* offset list */ 0, 2,
+ /* field data */ primitiveHeader(INT1), 0),
+ Array(VERSION, 1, 0, 5, 'f', 'a', 'l', 's', 'e'))
+ }
+
+ test("parseJson negative") {
+ def checkException(json: String, errorClass: String, parameters:
Map[String, String]): Unit = {
+ checkError(
+ exception = intercept[SparkThrowable] {
+ VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json))
+ },
+ errorClass = errorClass,
+ parameters = parameters
+ )
+ }
+ for (json <- Seq("", "[", "+1", "1a", """{"a": 1, "b": 2, "a": "3"}""")) {
+ checkException(json, "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION",
+ Map("badRecord" -> json, "failFastMode" -> "FAILFAST"))
+ }
+ for (json <- Seq("\"" + "a" * (16 * 1024 * 1024) + "\"",
+ (0 to 4 * 1024 * 1024).mkString("[", ",", "]"))) {
+ checkException(json, "VARIANT_SIZE_LIMIT",
+ Map("sizeLimit" -> "16.0 MiB", "functionName" -> "`parse_json`"))
+ }
+ }
+}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index 2793b1c8c1fb..51c610bb4609 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.variant
-import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException}
+import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.types.variant.VariantUtil._
@@ -34,141 +34,6 @@ class VariantExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
}
}
- test("parse_json") {
- def check(json: String, expectedValue: Array[Byte], expectedMetadata:
Array[Byte]): Unit = {
- checkEvaluation(ParseJson(Literal(json)), new VariantVal(expectedValue,
expectedMetadata))
- }
-
- // Dictionary size is `0` for value 0. An empty dictionary contains one
offset `0` for the
- // one-past-the-end position (i.e. the sum of all string lengths).
- val emptyMetadata = Array[Byte](VERSION, 0, 0)
- check("null", Array(primitiveHeader(NULL)), emptyMetadata)
- check("true", Array(primitiveHeader(TRUE)), emptyMetadata)
- check("false", Array(primitiveHeader(FALSE)), emptyMetadata)
- check("1", Array(primitiveHeader(INT1), 1), emptyMetadata)
- check("-1", Array(primitiveHeader(INT1), -1), emptyMetadata)
- check("127", Array(primitiveHeader(INT1), 127), emptyMetadata)
- check("128", Array(primitiveHeader(INT2), -128, 0), emptyMetadata)
- check("-32768", Array(primitiveHeader(INT2), 0, -128), emptyMetadata)
- check("-32769", Array(primitiveHeader(INT4), -1, 127, -1, -1),
emptyMetadata)
- check("2147483647", Array(primitiveHeader(INT4), -1, -1, -1, 127),
emptyMetadata)
- check("2147483648", Array(primitiveHeader(INT8), 0, 0, 0, -128, 0, 0, 0,
0), emptyMetadata)
- check("9223372036854775807",
- Array(primitiveHeader(INT8), -1, -1, -1, -1, -1, -1, -1, 127),
emptyMetadata)
- check("-9223372036854775808",
- Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0, -128), emptyMetadata)
- check("9223372036854775808",
- Array(primitiveHeader(DECIMAL16), 0, 0, 0, 0, 0, 0, 0, 0, -128, 0, 0, 0,
0, 0, 0, 0, 0),
- emptyMetadata)
- check("1.0", Array(primitiveHeader(DECIMAL4), 1, 10, 0, 0, 0),
emptyMetadata)
- check("1.01", Array(primitiveHeader(DECIMAL4), 2, 101, 0, 0, 0),
emptyMetadata)
- check("99999.9999", Array(primitiveHeader(DECIMAL4), 4, -1, -55, -102,
59), emptyMetadata)
- check("99999.99999",
- Array(primitiveHeader(DECIMAL8), 5, -1, -29, 11, 84, 2, 0, 0, 0),
emptyMetadata)
- check("0.000000001", Array(primitiveHeader(DECIMAL4), 9, 1, 0, 0, 0),
emptyMetadata)
- check("0.0000000001",
- Array(primitiveHeader(DECIMAL8), 10, 1, 0, 0, 0, 0, 0, 0, 0),
emptyMetadata)
- check("9" * 38,
- Array[Byte](primitiveHeader(DECIMAL16), 0) ++ BigInt("9" *
38).toByteArray.reverse,
- emptyMetadata)
- check("1" + "0" * 38,
- Array(primitiveHeader(DOUBLE)) ++
- BigInt(java.lang.Double.doubleToLongBits(1E38)).toByteArray.reverse,
- emptyMetadata)
- check("\"\"", Array(shortStrHeader(0)), emptyMetadata)
- check("\"abcd\"", Array(shortStrHeader(4), 'a', 'b', 'c', 'd'),
emptyMetadata)
- check("\"" + ("x" * 63) + "\"",
- Array(shortStrHeader(63)) ++ Array.fill(63)('x'.toByte), emptyMetadata)
- check("\"" + ("y" * 64) + "\"",
- Array[Byte](primitiveHeader(LONG_STR), 64, 0, 0, 0) ++
Array.fill(64)('y'.toByte),
- emptyMetadata)
- check("{}", Array(objectHeader(false, 1, 1),
- /* size */ 0,
- /* offset list */ 0), emptyMetadata)
- check("[]", Array(arrayHeader(false, 1),
- /* size */ 0,
- /* offset list */ 0), emptyMetadata)
- check("""{"a": 1, "b": 2, "c": "3"}""", Array(objectHeader(false, 1, 1),
- /* size */ 3,
- /* id list */ 0, 1, 2,
- /* offset list */ 0, 2, 4, 6,
- /* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2,
shortStrHeader(1), '3'),
- Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c'))
- check("""{"z": 1, "y": 2, "x": "3"}""", Array(objectHeader(false, 1, 1),
- /* size */ 3,
- /* id list */ 2, 1, 0,
- /* offset list */ 4, 2, 0, 6,
- /* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2,
shortStrHeader(1), '3'),
- Array(VERSION, 3, 0, 1, 2, 3, 'z', 'y', 'x'))
- check("""[null, true, {"false" : 0}]""", Array(arrayHeader(false, 1),
- /* size */ 3,
- /* offset list */ 0, 1, 2, 9,
- /* element data */ primitiveHeader(NULL), primitiveHeader(TRUE),
objectHeader(false, 1, 1),
- /* size */ 1,
- /* id list */ 0,
- /* offset list */ 0, 2,
- /* field data */ primitiveHeader(INT1), 0),
- Array(VERSION, 1, 0, 5, 'f', 'a', 'l', 's', 'e'))
- }
-
- test("parse_json negative") {
- for (json <- Seq("", "[", "+1", "1a", """{"a": 1, "b": 2, "a": "3"}""")) {
- checkExceptionInExpression[SparkException](ParseJson(Literal(json)),
- "Malformed records are detected in record parsing")
- }
- for (json <- Seq("\"" + "a" * (16 * 1024 * 1024) + "\"",
- (0 to 4 * 1024 * 1024).mkString("[", ",", "]"))) {
-
checkExceptionInExpression[SparkRuntimeException](ParseJson(Literal(json)),
- "Cannot build variant bigger than 16.0 MiB")
- }
- }
-
- test("round-trip") {
- def check(input: String, output: String = null): Unit = {
- checkEvaluation(
- StructsToJson(Map.empty, ParseJson(Literal(input))),
- if (output != null) output else input
- )
- }
-
- check("null")
- check("true")
- check("false")
- check("-1")
- check("1.0E10")
- check("\"\"")
- check("\"" + ("a" * 63) + "\"")
- check("\"" + ("b" * 64) + "\"")
- // scalastyle:off nonascii
- check("\"" + ("你好,世界" * 20) + "\"")
- // scalastyle:on nonascii
- check("[]")
- check("{}")
- // scalastyle:off nonascii
- check(
- "[null, true, false,-1, 1e10, \"\\uD83D\\uDE05\", [ ], { } ]",
- "[null,true,false,-1,1.0E10,\"😅\",[],{}]"
- )
- // scalastyle:on nonascii
- check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
- }
-
- test("to_json with nested variant") {
- checkEvaluation(
- StructsToJson(Map.empty, CreateArray(Seq(ParseJson(Literal("{}")),
- ParseJson(Literal("\"\"")),
- ParseJson(Literal("[1, 2, 3]"))))),
- "[{},\"\",[1,2,3]]"
- )
- checkEvaluation(
- StructsToJson(Map.empty, CreateNamedStruct(Seq(
- Literal("a"), ParseJson(Literal("""{ "x": 1, "y": null, "z": "str"
}""")),
- Literal("b"), ParseJson(Literal("[[]]")),
- Literal("c"), ParseJson(Literal("false"))))),
- """{"a":{"x":1,"y":null,"z":"str"},"b":[[]],"c":false}"""
- )
- }
-
test("to_json malformed") {
def check(value: Array[Byte], metadata: Array[Byte],
errorClass: String = "MALFORMED_VARIANT"): Unit = {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
new file mode 100644
index 000000000000..cf12001fa71b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{CreateArray,
CreateNamedStruct, Literal, StructsToJson}
+import org.apache.spark.sql.catalyst.expressions.variant.ParseJson
+import org.apache.spark.sql.execution.WholeStageCodegenExec
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.types.variant.VariantBuilder
+import org.apache.spark.unsafe.types.VariantVal
+
+class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ test("parse_json/to_json round-trip") {
+ def check(input: String, output: String = null): Unit = {
+ val df = Seq(input).toDF("v")
+ val variantDF = df.select(Column(StructsToJson(Map.empty,
ParseJson(Column("v").expr))))
+ val expected = if (output != null) output else input
+ checkAnswer(variantDF, Seq(Row(expected)))
+ }
+
+ check("null")
+ check("true")
+ check("false")
+ check("-1")
+ check("1.0E10")
+ check("\"\"")
+ check("\"" + ("a" * 63) + "\"")
+ check("\"" + ("b" * 64) + "\"")
+ // scalastyle:off nonascii
+ check("\"" + ("你好,世界" * 20) + "\"")
+ // scalastyle:on nonascii
+ check("[]")
+ check("{}")
+ // scalastyle:off nonascii
+ check(
+ "[null, true, false,-1, 1e10, \"\\uD83D\\uDE05\", [ ], { } ]",
+ "[null,true,false,-1,1.0E10,\"😅\",[],{}]"
+ )
+ // scalastyle:on nonascii
+ check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
+ }
+
+ test("to_json with nested variant") {
+ val df = Seq(1).toDF("v")
+ val variantDF1 = df.select(
+ Column(StructsToJson(Map.empty, CreateArray(Seq(
+ ParseJson(Literal("{}")), ParseJson(Literal("\"\"")),
ParseJson(Literal("[1, 2, 3]")))))))
+ checkAnswer(variantDF1, Seq(Row("[{},\"\",[1,2,3]]")))
+
+ val variantDF2 = df.select(
+ Column(StructsToJson(Map.empty, CreateNamedStruct(Seq(
+ Literal("a"), ParseJson(Literal("""{ "x": 1, "y": null, "z": "str"
}""")),
+ Literal("b"), ParseJson(Literal("[[]]")),
+ Literal("c"), ParseJson(Literal("false")))))))
+ checkAnswer(variantDF2,
Seq(Row("""{"a":{"x":1,"y":null,"z":"str"},"b":[[]],"c":false}""")))
+ }
+
+ test("parse_json - Codegen Support") {
+ val df = Seq(("1", """{"a": 1}""")).toDF("key", "v").toDF()
+ val variantDF = df.select(Column(ParseJson(Column("v").expr)))
+ val plan = variantDF.queryExecution.executedPlan
+ assert(plan.isInstanceOf[WholeStageCodegenExec])
+ val v = VariantBuilder.parseJson("""{"a":1}""")
+ val expected = new VariantVal(v.getValue, v.getMetadata)
+ checkAnswer(variantDF, Seq(Row(expected)))
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]