This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c4d6764674f [SPARK-47559][SQL] Codegen Support for variant 
`parse_json`
8c4d6764674f is described below

commit 8c4d6764674fcddf30245dfc25ef825eabba0ace
Author: panbingkun <panbing...@baidu.com>
AuthorDate: Thu Mar 28 20:34:42 2024 +0800

    [SPARK-47559][SQL] Codegen Support for variant `parse_json`
    
    ### What changes were proposed in this pull request?
    The PR adds Codegen Support for `parse_json`.
    
    ### Why are the changes needed?
    Improve codegen coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    - Add new UT.
    - Pass GA.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45714 from panbingkun/ParseJson_CodeGenerator.
    
    Authored-by: panbingkun <panbing...@baidu.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 ...ions.scala => VariantExpressionEvalUtils.scala} |  35 ++----
 .../expressions/variant/variantExpressions.scala   |  34 ++---
 .../variant/VariantExpressionEvalUtilsSuite.scala  | 125 +++++++++++++++++++
 .../variant/VariantExpressionSuite.scala           | 137 +--------------------
 .../apache/spark/sql/VariantEndToEndSuite.scala    |  84 +++++++++++++
 5 files changed, 229 insertions(+), 186 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
similarity index 56%
copy from 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
copy to 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
index cab61d2b12c2..74fae91f98a6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
@@ -19,35 +19,17 @@ package org.apache.spark.sql.catalyst.expressions.variant
 
 import scala.util.control.NonFatal
 
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.util.BadRecordException
 import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.types._
 import org.apache.spark.types.variant.{VariantBuilder, 
VariantSizeLimitException, VariantUtil}
-import org.apache.spark.unsafe.types._
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
 
-// scalastyle:off line.size.limit
-@ExpressionDescription(
-  usage = "_FUNC_(jsonStr) - Parse a JSON string as an Variant value. Throw an 
exception when the string is not valid JSON value.",
-  examples = """
-    Examples:
-      > SELECT _FUNC_('{"a":1,"b":0.8}');
-       {"a":1,"b":0.8}
-  """,
-  since = "4.0.0",
-  group = "variant_funcs"
-)
-// scalastyle:on line.size.limit
-case class ParseJson(child: Expression) extends UnaryExpression
-  with NullIntolerant with ExpectsInputTypes with CodegenFallback {
-  override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
-
-  override def dataType: DataType = VariantType
-
-  override def prettyName: String = "parse_json"
+/**
+ * A utility class for constructing variant expressions.
+ */
+object VariantExpressionEvalUtils {
 
-  protected override def nullSafeEval(input: Any): Any = {
+  def parseJson(input: UTF8String): VariantVal = {
     try {
       val v = VariantBuilder.parseJson(input.toString)
       new VariantVal(v.getValue, v.getMetadata)
@@ -56,10 +38,7 @@ case class ParseJson(child: Expression) extends 
UnaryExpression
         throw 
QueryExecutionErrors.variantSizeLimitError(VariantUtil.SIZE_LIMIT, "parse_json")
       case NonFatal(e) =>
         throw 
QueryExecutionErrors.malformedRecordsDetectedInRecordParsingError(
-        input.toString, BadRecordException(() => 
input.asInstanceOf[UTF8String], cause = e))
+          input.toString, BadRecordException(() => input, cause = e))
     }
   }
-
-  override protected def withNewChildInternal(newChild: Expression): ParseJson 
=
-    copy(child = newChild)
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index cab61d2b12c2..00708f863e81 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -17,15 +17,9 @@
 
 package org.apache.spark.sql.catalyst.expressions.variant
 
-import scala.util.control.NonFatal
-
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.BadRecordException
-import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.types._
-import org.apache.spark.types.variant.{VariantBuilder, 
VariantSizeLimitException, VariantUtil}
-import org.apache.spark.unsafe.types._
 
 // scalastyle:off line.size.limit
 @ExpressionDescription(
@@ -39,27 +33,23 @@ import org.apache.spark.unsafe.types._
   group = "variant_funcs"
 )
 // scalastyle:on line.size.limit
-case class ParseJson(child: Expression) extends UnaryExpression
-  with NullIntolerant with ExpectsInputTypes with CodegenFallback {
+case class ParseJson(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes with RuntimeReplaceable {
+
+  override lazy val replacement: Expression = StaticInvoke(
+    VariantExpressionEvalUtils.getClass,
+    VariantType,
+    "parseJson",
+    Seq(child),
+    inputTypes,
+    returnNullable = false)
+
   override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
 
   override def dataType: DataType = VariantType
 
   override def prettyName: String = "parse_json"
 
-  protected override def nullSafeEval(input: Any): Any = {
-    try {
-      val v = VariantBuilder.parseJson(input.toString)
-      new VariantVal(v.getValue, v.getMetadata)
-    } catch {
-      case _: VariantSizeLimitException =>
-        throw 
QueryExecutionErrors.variantSizeLimitError(VariantUtil.SIZE_LIMIT, "parse_json")
-      case NonFatal(e) =>
-        throw 
QueryExecutionErrors.malformedRecordsDetectedInRecordParsingError(
-        input.toString, BadRecordException(() => 
input.asInstanceOf[UTF8String], cause = e))
-    }
-  }
-
   override protected def withNewChildInternal(newChild: Expression): ParseJson 
=
     copy(child = newChild)
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala
new file mode 100644
index 000000000000..574d5daa361e
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtilsSuite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.variant
+
+import org.apache.spark.{SparkFunSuite, SparkThrowable}
+import org.apache.spark.types.variant.VariantUtil._
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
+
+class VariantExpressionEvalUtilsSuite extends SparkFunSuite {
+
+  test("parseJson type coercion") {
+    def check(json: String, expectedValue: Array[Byte], expectedMetadata: 
Array[Byte]): Unit = {
+      val actual = 
VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json))
+      val expected = new VariantVal(expectedValue, expectedMetadata)
+      assert(actual === expected)
+    }
+
+    // Dictionary size is `0` for value 0. An empty dictionary contains one 
offset `0` for the
+    // one-past-the-end position (i.e. the sum of all string lengths).
+    val emptyMetadata = Array[Byte](VERSION, 0, 0)
+    check("null", Array(primitiveHeader(NULL)), emptyMetadata)
+    check("true", Array(primitiveHeader(TRUE)), emptyMetadata)
+    check("false", Array(primitiveHeader(FALSE)), emptyMetadata)
+    check("1", Array(primitiveHeader(INT1), 1), emptyMetadata)
+    check("-1", Array(primitiveHeader(INT1), -1), emptyMetadata)
+    check("127", Array(primitiveHeader(INT1), 127), emptyMetadata)
+    check("128", Array(primitiveHeader(INT2), -128, 0), emptyMetadata)
+    check("-32768", Array(primitiveHeader(INT2), 0, -128), emptyMetadata)
+    check("-32769", Array(primitiveHeader(INT4), -1, 127, -1, -1), 
emptyMetadata)
+    check("2147483647", Array(primitiveHeader(INT4), -1, -1, -1, 127), 
emptyMetadata)
+    check("2147483648", Array(primitiveHeader(INT8), 0, 0, 0, -128, 0, 0, 0, 
0), emptyMetadata)
+    check("9223372036854775807",
+      Array(primitiveHeader(INT8), -1, -1, -1, -1, -1, -1, -1, 127), 
emptyMetadata)
+    check("-9223372036854775808",
+      Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0, -128), emptyMetadata)
+    check("9223372036854775808",
+      Array(primitiveHeader(DECIMAL16), 0, 0, 0, 0, 0, 0, 0, 0, -128, 0, 0, 0, 
0, 0, 0, 0, 0),
+      emptyMetadata)
+    check("1.0", Array(primitiveHeader(DECIMAL4), 1, 10, 0, 0, 0), 
emptyMetadata)
+    check("1.01", Array(primitiveHeader(DECIMAL4), 2, 101, 0, 0, 0), 
emptyMetadata)
+    check("99999.9999", Array(primitiveHeader(DECIMAL4), 4, -1, -55, -102, 
59), emptyMetadata)
+    check("99999.99999",
+      Array(primitiveHeader(DECIMAL8), 5, -1, -29, 11, 84, 2, 0, 0, 0), 
emptyMetadata)
+    check("0.000000001", Array(primitiveHeader(DECIMAL4), 9, 1, 0, 0, 0), 
emptyMetadata)
+    check("0.0000000001",
+      Array(primitiveHeader(DECIMAL8), 10, 1, 0, 0, 0, 0, 0, 0, 0), 
emptyMetadata)
+    check("9" * 38,
+      Array[Byte](primitiveHeader(DECIMAL16), 0) ++ BigInt("9" * 
38).toByteArray.reverse,
+      emptyMetadata)
+    check("1" + "0" * 38,
+      Array(primitiveHeader(DOUBLE)) ++
+        BigInt(java.lang.Double.doubleToLongBits(1E38)).toByteArray.reverse,
+      emptyMetadata)
+    check("\"\"", Array(shortStrHeader(0)), emptyMetadata)
+    check("\"abcd\"", Array(shortStrHeader(4), 'a', 'b', 'c', 'd'), 
emptyMetadata)
+    check("\"" + ("x" * 63) + "\"",
+      Array(shortStrHeader(63)) ++ Array.fill(63)('x'.toByte), emptyMetadata)
+    check("\"" + ("y" * 64) + "\"",
+      Array[Byte](primitiveHeader(LONG_STR), 64, 0, 0, 0) ++ 
Array.fill(64)('y'.toByte),
+      emptyMetadata)
+    check("{}", Array(objectHeader(false, 1, 1),
+      /* size */ 0,
+      /* offset list */ 0), emptyMetadata)
+    check("[]", Array(arrayHeader(false, 1),
+      /* size */ 0,
+      /* offset list */ 0), emptyMetadata)
+    check("""{"a": 1, "b": 2, "c": "3"}""", Array(objectHeader(false, 1, 1),
+      /* size */ 3,
+      /* id list */ 0, 1, 2,
+      /* offset list */ 0, 2, 4, 6,
+      /* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2, 
shortStrHeader(1), '3'),
+      Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c'))
+    check("""{"z": 1, "y": 2, "x": "3"}""", Array(objectHeader(false, 1, 1),
+      /* size */ 3,
+      /* id list */ 2, 1, 0,
+      /* offset list */ 4, 2, 0, 6,
+      /* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2, 
shortStrHeader(1), '3'),
+      Array(VERSION, 3, 0, 1, 2, 3, 'z', 'y', 'x'))
+    check("""[null, true, {"false" : 0}]""", Array(arrayHeader(false, 1),
+      /* size */ 3,
+      /* offset list */ 0, 1, 2, 9,
+      /* element data */ primitiveHeader(NULL), primitiveHeader(TRUE), 
objectHeader(false, 1, 1),
+      /* size */ 1,
+      /* id list */ 0,
+      /* offset list */ 0, 2,
+      /* field data */ primitiveHeader(INT1), 0),
+      Array(VERSION, 1, 0, 5, 'f', 'a', 'l', 's', 'e'))
+  }
+
+  test("parseJson negative") {
+    def checkException(json: String, errorClass: String, parameters: 
Map[String, String]): Unit = {
+      checkError(
+        exception = intercept[SparkThrowable] {
+          VariantExpressionEvalUtils.parseJson(UTF8String.fromString(json))
+        },
+        errorClass = errorClass,
+        parameters = parameters
+      )
+    }
+    for (json <- Seq("", "[", "+1", "1a", """{"a": 1, "b": 2, "a": "3"}""")) {
+      checkException(json, "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION",
+        Map("badRecord" -> json, "failFastMode" -> "FAILFAST"))
+    }
+    for (json <- Seq("\"" + "a" * (16 * 1024 * 1024) + "\"",
+      (0 to 4 * 1024 * 1024).mkString("[", ",", "]"))) {
+      checkException(json, "VARIANT_SIZE_LIMIT",
+        Map("sizeLimit" -> "16.0 MiB", "functionName" -> "`parse_json`"))
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index 2793b1c8c1fb..51c610bb4609 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions.variant
 
-import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException}
+import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
 import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.types.variant.VariantUtil._
@@ -34,141 +34,6 @@ class VariantExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     }
   }
 
-  test("parse_json") {
-    def check(json: String, expectedValue: Array[Byte], expectedMetadata: 
Array[Byte]): Unit = {
-      checkEvaluation(ParseJson(Literal(json)), new VariantVal(expectedValue, 
expectedMetadata))
-    }
-
-    // Dictionary size is `0` for value 0. An empty dictionary contains one 
offset `0` for the
-    // one-past-the-end position (i.e. the sum of all string lengths).
-    val emptyMetadata = Array[Byte](VERSION, 0, 0)
-    check("null", Array(primitiveHeader(NULL)), emptyMetadata)
-    check("true", Array(primitiveHeader(TRUE)), emptyMetadata)
-    check("false", Array(primitiveHeader(FALSE)), emptyMetadata)
-    check("1", Array(primitiveHeader(INT1), 1), emptyMetadata)
-    check("-1", Array(primitiveHeader(INT1), -1), emptyMetadata)
-    check("127", Array(primitiveHeader(INT1), 127), emptyMetadata)
-    check("128", Array(primitiveHeader(INT2), -128, 0), emptyMetadata)
-    check("-32768", Array(primitiveHeader(INT2), 0, -128), emptyMetadata)
-    check("-32769", Array(primitiveHeader(INT4), -1, 127, -1, -1), 
emptyMetadata)
-    check("2147483647", Array(primitiveHeader(INT4), -1, -1, -1, 127), 
emptyMetadata)
-    check("2147483648", Array(primitiveHeader(INT8), 0, 0, 0, -128, 0, 0, 0, 
0), emptyMetadata)
-    check("9223372036854775807",
-      Array(primitiveHeader(INT8), -1, -1, -1, -1, -1, -1, -1, 127), 
emptyMetadata)
-    check("-9223372036854775808",
-      Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0, -128), emptyMetadata)
-    check("9223372036854775808",
-      Array(primitiveHeader(DECIMAL16), 0, 0, 0, 0, 0, 0, 0, 0, -128, 0, 0, 0, 
0, 0, 0, 0, 0),
-      emptyMetadata)
-    check("1.0", Array(primitiveHeader(DECIMAL4), 1, 10, 0, 0, 0), 
emptyMetadata)
-    check("1.01", Array(primitiveHeader(DECIMAL4), 2, 101, 0, 0, 0), 
emptyMetadata)
-    check("99999.9999", Array(primitiveHeader(DECIMAL4), 4, -1, -55, -102, 
59), emptyMetadata)
-    check("99999.99999",
-      Array(primitiveHeader(DECIMAL8), 5, -1, -29, 11, 84, 2, 0, 0, 0), 
emptyMetadata)
-    check("0.000000001", Array(primitiveHeader(DECIMAL4), 9, 1, 0, 0, 0), 
emptyMetadata)
-    check("0.0000000001",
-      Array(primitiveHeader(DECIMAL8), 10, 1, 0, 0, 0, 0, 0, 0, 0), 
emptyMetadata)
-    check("9" * 38,
-      Array[Byte](primitiveHeader(DECIMAL16), 0) ++ BigInt("9" * 
38).toByteArray.reverse,
-      emptyMetadata)
-    check("1" + "0" * 38,
-      Array(primitiveHeader(DOUBLE)) ++
-        BigInt(java.lang.Double.doubleToLongBits(1E38)).toByteArray.reverse,
-      emptyMetadata)
-    check("\"\"", Array(shortStrHeader(0)), emptyMetadata)
-    check("\"abcd\"", Array(shortStrHeader(4), 'a', 'b', 'c', 'd'), 
emptyMetadata)
-    check("\"" + ("x" * 63) + "\"",
-      Array(shortStrHeader(63)) ++ Array.fill(63)('x'.toByte), emptyMetadata)
-    check("\"" + ("y" * 64) + "\"",
-      Array[Byte](primitiveHeader(LONG_STR), 64, 0, 0, 0) ++ 
Array.fill(64)('y'.toByte),
-      emptyMetadata)
-    check("{}", Array(objectHeader(false, 1, 1),
-      /* size */ 0,
-      /* offset list */ 0), emptyMetadata)
-    check("[]", Array(arrayHeader(false, 1),
-      /* size */ 0,
-      /* offset list */ 0), emptyMetadata)
-    check("""{"a": 1, "b": 2, "c": "3"}""", Array(objectHeader(false, 1, 1),
-      /* size */ 3,
-      /* id list */ 0, 1, 2,
-      /* offset list */ 0, 2, 4, 6,
-      /* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2, 
shortStrHeader(1), '3'),
-      Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c'))
-    check("""{"z": 1, "y": 2, "x": "3"}""", Array(objectHeader(false, 1, 1),
-      /* size */ 3,
-      /* id list */ 2, 1, 0,
-      /* offset list */ 4, 2, 0, 6,
-      /* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2, 
shortStrHeader(1), '3'),
-      Array(VERSION, 3, 0, 1, 2, 3, 'z', 'y', 'x'))
-    check("""[null, true, {"false" : 0}]""", Array(arrayHeader(false, 1),
-      /* size */ 3,
-      /* offset list */ 0, 1, 2, 9,
-      /* element data */ primitiveHeader(NULL), primitiveHeader(TRUE), 
objectHeader(false, 1, 1),
-      /* size */ 1,
-      /* id list */ 0,
-      /* offset list */ 0, 2,
-      /* field data */ primitiveHeader(INT1), 0),
-      Array(VERSION, 1, 0, 5, 'f', 'a', 'l', 's', 'e'))
-  }
-
-  test("parse_json negative") {
-    for (json <- Seq("", "[", "+1", "1a", """{"a": 1, "b": 2, "a": "3"}""")) {
-      checkExceptionInExpression[SparkException](ParseJson(Literal(json)),
-        "Malformed records are detected in record parsing")
-    }
-    for (json <- Seq("\"" + "a" * (16 * 1024 * 1024) + "\"",
-      (0 to 4 * 1024 * 1024).mkString("[", ",", "]"))) {
-      
checkExceptionInExpression[SparkRuntimeException](ParseJson(Literal(json)),
-        "Cannot build variant bigger than 16.0 MiB")
-    }
-  }
-
-  test("round-trip") {
-    def check(input: String, output: String = null): Unit = {
-      checkEvaluation(
-        StructsToJson(Map.empty, ParseJson(Literal(input))),
-        if (output != null) output else input
-      )
-    }
-
-    check("null")
-    check("true")
-    check("false")
-    check("-1")
-    check("1.0E10")
-    check("\"\"")
-    check("\"" + ("a" * 63) + "\"")
-    check("\"" + ("b" * 64) + "\"")
-    // scalastyle:off nonascii
-    check("\"" + ("你好,世界" * 20) + "\"")
-    // scalastyle:on nonascii
-    check("[]")
-    check("{}")
-    // scalastyle:off nonascii
-    check(
-      "[null, true,   false,-1, 1e10, \"\\uD83D\\uDE05\", [ ], { } ]",
-      "[null,true,false,-1,1.0E10,\"😅\",[],{}]"
-    )
-    // scalastyle:on nonascii
-    check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
-  }
-
-  test("to_json with nested variant") {
-    checkEvaluation(
-      StructsToJson(Map.empty, CreateArray(Seq(ParseJson(Literal("{}")),
-        ParseJson(Literal("\"\"")),
-        ParseJson(Literal("[1, 2, 3]"))))),
-      "[{},\"\",[1,2,3]]"
-    )
-    checkEvaluation(
-      StructsToJson(Map.empty, CreateNamedStruct(Seq(
-        Literal("a"), ParseJson(Literal("""{ "x": 1, "y": null, "z": "str" 
}""")),
-        Literal("b"), ParseJson(Literal("[[]]")),
-        Literal("c"), ParseJson(Literal("false"))))),
-      """{"a":{"x":1,"y":null,"z":"str"},"b":[[]],"c":false}"""
-    )
-  }
-
   test("to_json malformed") {
     def check(value: Array[Byte], metadata: Array[Byte],
               errorClass: String = "MALFORMED_VARIANT"): Unit = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
new file mode 100644
index 000000000000..cf12001fa71b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{CreateArray, 
CreateNamedStruct, Literal, StructsToJson}
+import org.apache.spark.sql.catalyst.expressions.variant.ParseJson
+import org.apache.spark.sql.execution.WholeStageCodegenExec
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.types.variant.VariantBuilder
+import org.apache.spark.unsafe.types.VariantVal
+
+class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
+  import testImplicits._
+
+  test("parse_json/to_json round-trip") {
+    def check(input: String, output: String = null): Unit = {
+      val df = Seq(input).toDF("v")
+      val variantDF = df.select(Column(StructsToJson(Map.empty, 
ParseJson(Column("v").expr))))
+      val expected = if (output != null) output else input
+      checkAnswer(variantDF, Seq(Row(expected)))
+    }
+
+    check("null")
+    check("true")
+    check("false")
+    check("-1")
+    check("1.0E10")
+    check("\"\"")
+    check("\"" + ("a" * 63) + "\"")
+    check("\"" + ("b" * 64) + "\"")
+    // scalastyle:off nonascii
+    check("\"" + ("你好,世界" * 20) + "\"")
+    // scalastyle:on nonascii
+    check("[]")
+    check("{}")
+    // scalastyle:off nonascii
+    check(
+      "[null, true,   false,-1, 1e10, \"\\uD83D\\uDE05\", [ ], { } ]",
+      "[null,true,false,-1,1.0E10,\"😅\",[],{}]"
+    )
+    // scalastyle:on nonascii
+    check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
+  }
+
+  test("to_json with nested variant") {
+    val df = Seq(1).toDF("v")
+    val variantDF1 = df.select(
+      Column(StructsToJson(Map.empty, CreateArray(Seq(
+        ParseJson(Literal("{}")), ParseJson(Literal("\"\"")), 
ParseJson(Literal("[1, 2, 3]")))))))
+    checkAnswer(variantDF1, Seq(Row("[{},\"\",[1,2,3]]")))
+
+    val variantDF2 = df.select(
+      Column(StructsToJson(Map.empty, CreateNamedStruct(Seq(
+        Literal("a"), ParseJson(Literal("""{ "x": 1, "y": null, "z": "str" 
}""")),
+        Literal("b"), ParseJson(Literal("[[]]")),
+        Literal("c"), ParseJson(Literal("false")))))))
+    checkAnswer(variantDF2, 
Seq(Row("""{"a":{"x":1,"y":null,"z":"str"},"b":[[]],"c":false}""")))
+  }
+
+  test("parse_json - Codegen Support") {
+    val df = Seq(("1", """{"a": 1}""")).toDF("key", "v").toDF()
+    val variantDF = df.select(Column(ParseJson(Column("v").expr)))
+    val plan = variantDF.queryExecution.executedPlan
+    assert(plan.isInstanceOf[WholeStageCodegenExec])
+    val v = VariantBuilder.parseJson("""{"a":1}""")
+    val expected = new VariantVal(v.getValue, v.getMetadata)
+    checkAnswer(variantDF, Seq(Row(expected)))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to