This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dd153307cb97 [SPARK-50953][PYTHON][CONNECT] Add support for 
non-literal paths in VariantGet
dd153307cb97 is described below

commit dd153307cb9735fd05a41124eca2a136f40f3b3f
Author: Harsh Motwani <harsh.motw...@databricks.com>
AuthorDate: Mon Feb 10 21:46:18 2025 +0800

    [SPARK-50953][PYTHON][CONNECT] Add support for non-literal paths in 
VariantGet
    
    ### What changes were proposed in this pull request?
    
    This PR allows the `variant_get` expression to support non-literal path 
inputs.
    
    ### Why are the changes needed?
    
    Currently, `variant_get` only supports literal paths as the second 
argument. Users may have columns containing paths which they would want to 
extract from variants. This PR allows this functionality.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, prior to this PR, `variant_get` did not have support for non-literal 
paths.
    
    ### How was this patch tested?
    
    Unit tests to make sure that:
    1. The VariantGet/TryVariantGet expressions with non-literal paths has the 
expected behavior regardless of codegen mode.
    2. VariantGet expressions with non-literal paths do not get pushed down as 
this functionality has not been implemented.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49609 from harshmotw-db/harsh-motwani_data/variant_get_column.
    
    Authored-by: Harsh Motwani <harsh.motw...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 python/pyspark/sql/connect/functions/builtin.py    | 16 +++-
 python/pyspark/sql/functions/builtin.py            | 49 ++++++++----
 python/pyspark/sql/tests/test_functions.py         |  8 +-
 .../scala/org/apache/spark/sql/functions.scala     | 40 +++++++++-
 .../expressions/variant/variantExpressions.scala   | 66 +++++++++++-----
 .../datasources/PushVariantIntoScan.scala          |  8 +-
 .../scala/org/apache/spark/sql/VariantSuite.scala  | 88 ++++++++++++++++++++++
 .../datasources/PushVariantIntoScanSuite.scala     | 22 +++++-
 8 files changed, 253 insertions(+), 44 deletions(-)

diff --git a/python/pyspark/sql/connect/functions/builtin.py 
b/python/pyspark/sql/connect/functions/builtin.py
index f13eeab12dd3..51685def7dbc 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -2161,15 +2161,23 @@ def is_variant_null(v: "ColumnOrName") -> Column:
 is_variant_null.__doc__ = pysparkfuncs.is_variant_null.__doc__
 
 
-def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
-    return _invoke_function("variant_get", _to_col(v), lit(path), 
lit(targetType))
+def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) 
-> Column:
+    assert isinstance(path, (Column, str))
+    if isinstance(path, str):
+        return _invoke_function("variant_get", _to_col(v), lit(path), 
lit(targetType))
+    else:
+        return _invoke_function("variant_get", _to_col(v), path, 
lit(targetType))
 
 
 variant_get.__doc__ = pysparkfuncs.variant_get.__doc__
 
 
-def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
-    return _invoke_function("try_variant_get", _to_col(v), lit(path), 
lit(targetType))
+def try_variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: 
str) -> Column:
+    assert isinstance(path, (Column, str))
+    if isinstance(path, str):
+        return _invoke_function("try_variant_get", _to_col(v), lit(path), 
lit(targetType))
+    else:
+        return _invoke_function("try_variant_get", _to_col(v), path, 
lit(targetType))
 
 
 try_variant_get.__doc__ = pysparkfuncs.try_variant_get.__doc__
diff --git a/python/pyspark/sql/functions/builtin.py 
b/python/pyspark/sql/functions/builtin.py
index 4575bf730fca..2b6d8569fdf8 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -20427,7 +20427,7 @@ def is_variant_null(v: "ColumnOrName") -> Column:
 
 
 @_try_remote_functions
-def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
+def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) 
-> Column:
     """
     Extracts a sub-variant from `v` according to `path`, and then cast the 
sub-variant to
     `targetType`. Returns null if the path does not exist. Throws an exception 
if the cast fails.
@@ -20438,9 +20438,10 @@ def variant_get(v: "ColumnOrName", path: str, 
targetType: str) -> Column:
     ----------
     v : :class:`~pyspark.sql.Column` or str
         a variant column or column name
-    path : str
-        the extraction path. A valid path should start with `$` and is 
followed by zero or more
-        segments like `[123]`, `.name`, `['name']`, or `["name"]`.
+    path : :class:`~pyspark.sql.Column` or str
+        a column containing the extraction path strings or a string 
representing the extraction
+        path. A valid path should start with `$` and is followed by zero or 
more segments like
+        `[123]`, `.name`, `['name']`, or `["name"]`.
     targetType : str
         the target data type to cast into, in a DDL-formatted string
 
@@ -20451,21 +20452,29 @@ def variant_get(v: "ColumnOrName", path: str, 
targetType: str) -> Column:
 
     Examples
     --------
-    >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
+    >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }''', 'path': 
'$.a'} ])
     >>> df.select(variant_get(parse_json(df.json), "$.a", 
"int").alias("r")).collect()
     [Row(r=1)]
     >>> df.select(variant_get(parse_json(df.json), "$.b", 
"int").alias("r")).collect()
     [Row(r=None)]
+    >>> df.select(variant_get(parse_json(df.json), df.path, 
"int").alias("r")).collect()
+    [Row(r=1)]
     """
     from pyspark.sql.classic.column import _to_java_column
 
-    return _invoke_function(
-        "variant_get", _to_java_column(v), _enum_to_value(path), 
_enum_to_value(targetType)
-    )
+    assert isinstance(path, (Column, str))
+    if isinstance(path, str):
+        return _invoke_function(
+            "variant_get", _to_java_column(v), _enum_to_value(path), 
_enum_to_value(targetType)
+        )
+    else:
+        return _invoke_function(
+            "variant_get", _to_java_column(v), _to_java_column(path), 
_enum_to_value(targetType)
+        )
 
 
 @_try_remote_functions
-def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
+def try_variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: 
str) -> Column:
     """
     Extracts a sub-variant from `v` according to `path`, and then cast the 
sub-variant to
     `targetType`. Returns null if the path does not exist or the cast fails.
@@ -20476,9 +20485,10 @@ def try_variant_get(v: "ColumnOrName", path: str, 
targetType: str) -> Column:
     ----------
     v : :class:`~pyspark.sql.Column` or str
         a variant column or column name
-    path : str
-        the extraction path. A valid path should start with `$` and is 
followed by zero or more
-        segments like `[123]`, `.name`, `['name']`, or `["name"]`.
+    path : :class:`~pyspark.sql.Column` or str
+        a column containing the extraction path strings or a string 
representing the extraction
+        path. A valid path should start with `$` and is followed by zero or 
more segments like
+        `[123]`, `.name`, `['name']`, or `["name"]`.
     targetType : str
         the target data type to cast into, in a DDL-formatted string
 
@@ -20489,19 +20499,26 @@ def try_variant_get(v: "ColumnOrName", path: str, 
targetType: str) -> Column:
 
     Examples
     --------
-    >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
+    >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }''', 'path': 
'$.a'} ])
     >>> df.select(try_variant_get(parse_json(df.json), "$.a", 
"int").alias("r")).collect()
     [Row(r=1)]
     >>> df.select(try_variant_get(parse_json(df.json), "$.b", 
"int").alias("r")).collect()
     [Row(r=None)]
     >>> df.select(try_variant_get(parse_json(df.json), "$.a", 
"binary").alias("r")).collect()
     [Row(r=None)]
+    >>> df.select(try_variant_get(parse_json(df.json), df.path, 
"int").alias("r")).collect()
+    [Row(r=1)]
     """
     from pyspark.sql.classic.column import _to_java_column
 
-    return _invoke_function(
-        "try_variant_get", _to_java_column(v), _enum_to_value(path), 
_enum_to_value(targetType)
-    )
+    if isinstance(path, str):
+        return _invoke_function(
+            "try_variant_get", _to_java_column(v), _enum_to_value(path), 
_enum_to_value(targetType)
+        )
+    else:
+        return _invoke_function(
+            "try_variant_get", _to_java_column(v), _to_java_column(path), 
_enum_to_value(targetType)
+        )
 
 
 @_try_remote_functions
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index 39db72b235bf..b627bc793f05 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1496,7 +1496,9 @@ class FunctionsTestsMixin:
         self.assertEqual("""{"b":[{"c":"str2"}]}""", actual["var_lit"])
 
     def test_variant_expressions(self):
-        df = self.spark.createDataFrame([Row(json="""{ "a" : 1 }"""), 
Row(json="""{ "b" : 2 }""")])
+        df = self.spark.createDataFrame(
+            [Row(json="""{ "a" : 1 }""", path="$.a"), Row(json="""{ "b" : 2 
}""", path="$.b")]
+        )
         v = F.parse_json(df.json)
 
         def check(resultDf, expected):
@@ -1510,6 +1512,10 @@ class FunctionsTestsMixin:
         check(df.select(F.variant_get(v, "$.b", "int")), [None, 2])
         check(df.select(F.variant_get(v, "$.a", "double")), [1.0, None])
 
+        # non-literal variant_get
+        check(df.select(F.variant_get(v, df.path, "int")), [1, 2])
+        check(df.select(F.try_variant_get(v, df.path, "binary")), [None, None])
+
         with self.assertRaises(SparkRuntimeException) as ex:
             df.select(F.variant_get(v, "$.a", "binary")).collect()
 
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
index 5670e513287e..ffa3a03e4224 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
@@ -7115,7 +7115,7 @@ object functions {
   def is_variant_null(v: Column): Column = Column.fn("is_variant_null", v)
 
   /**
-   * Extracts a sub-variant from `v` according to `path`, and then cast the 
sub-variant to
+   * Extracts a sub-variant from `v` according to `path` string, and then cast 
the sub-variant to
    * `targetType`. Returns null if the path does not exist. Throws an 
exception if the cast fails.
    *
    * @param v
@@ -7132,7 +7132,25 @@ object functions {
     Column.fn("variant_get", v, lit(path), lit(targetType))
 
   /**
-   * Extracts a sub-variant from `v` according to `path`, and then cast the 
sub-variant to
+   * Extracts a sub-variant from `v` according to `path` column, and then cast 
the sub-variant to
+   * `targetType`. Returns null if the path does not exist. Throws an 
exception if the cast fails.
+   *
+   * @param v
+   *   a variant column.
+   * @param path
+   *   the column containing the extraction path strings. A valid path string 
should start with
+   *   `$` and is followed by zero or more segments like `[123]`, `.name`, 
`['name']`, or
+   *   `["name"]`.
+   * @param targetType
+   *   the target data type to cast into, in a DDL-formatted string.
+   * @group variant_funcs
+   * @since 4.0.0
+   */
+  def variant_get(v: Column, path: Column, targetType: String): Column =
+    Column.fn("variant_get", v, path, lit(targetType))
+
+  /**
+   * Extracts a sub-variant from `v` according to `path` string, and then cast 
the sub-variant to
    * `targetType`. Returns null if the path does not exist or the cast fails..
    *
    * @param v
@@ -7148,6 +7166,24 @@ object functions {
   def try_variant_get(v: Column, path: String, targetType: String): Column =
     Column.fn("try_variant_get", v, lit(path), lit(targetType))
 
+  /**
+   * Extracts a sub-variant from `v` according to `path` column, and then cast 
the sub-variant to
+   * `targetType`. Returns null if the path does not exist or the cast fails..
+   *
+   * @param v
+   *   a variant column.
+   * @param path
+   *   the column containing the extraction path strings. A valid path string 
should start with
+   *   `$` and is followed by zero or more segments like `[123]`, `.name`, 
`['name']`, or
+   *   `["name"]`.
+   * @param targetType
+   *   the target data type to cast into, in a DDL-formatted string.
+   * @group variant_funcs
+   * @since 4.0.0
+   */
+  def try_variant_get(v: Column, path: Column, targetType: String): Column =
+    Column.fn("try_variant_get", v, lit(path), lit(targetType))
+
   /**
    * Returns schema in the SQL format of a variant.
    *
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index f722329097bc..0a72e792a04f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -246,13 +246,6 @@ case class VariantGet(
     val check = super.checkInputDataTypes()
     if (check.isFailure) {
       check
-    } else if (!path.foldable) {
-      DataTypeMismatch(
-        errorSubClass = "NON_FOLDABLE_INPUT",
-        messageParameters = Map(
-          "inputName" -> toSQLId("path"),
-          "inputType" -> toSQLType(path.dataType),
-          "inputExpr" -> toSQLExpr(path)))
     } else if (!VariantGet.checkDataType(targetType)) {
       DataTypeMismatch(
         errorSubClass = "CAST_WITHOUT_SUGGESTION",
@@ -265,10 +258,12 @@ case class VariantGet(
 
   override lazy val dataType: DataType = targetType.asNullable
 
-  @transient private lazy val parsedPath = {
-    val pathValue = path.eval().toString
-    VariantPathParser.parse(pathValue).getOrElse {
-      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+  @transient private lazy val parsedPath: Option[Array[VariantPathSegment]] = {
+    if (path.foldable) {
+      val pathValue = path.eval().toString
+      Some(VariantGet.getParsedPath(pathValue, prettyName))
+    } else {
+      None
     }
   }
 
@@ -287,23 +282,37 @@ case class VariantGet(
     timeZoneId,
     zoneId)
 
-  protected override def nullSafeEval(input: Any, path: Any): Any = {
-    VariantGet.variantGet(input.asInstanceOf[VariantVal], parsedPath, 
dataType, castArgs)
+  protected override def nullSafeEval(input: Any, path: Any): Any = parsedPath 
match {
+    case Some(pp) =>
+      VariantGet.variantGet(input.asInstanceOf[VariantVal], pp, dataType, 
castArgs)
+    case _ =>
+      VariantGet.variantGet(input.asInstanceOf[VariantVal], 
path.asInstanceOf[UTF8String], dataType,
+        castArgs, prettyName)
   }
 
   protected override def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
-    val childCode = child.genCode(ctx)
     val tmp = ctx.freshVariable("tmp", classOf[Object])
-    val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath)
+    val childCode = child.genCode(ctx)
     val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
     val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
+    val (pathCode, parsedPathArg) = if (parsedPath.isEmpty) {
+      val pathCode = path.genCode(ctx)
+      (pathCode, pathCode.value)
+    } else {
+      (
+        new ExprCode(EmptyBlock, FalseLiteral, TrueLiteral),
+        ctx.addReferenceObj("parsedPath", parsedPath.get)
+      )
+    }
     val code = code"""
       ${childCode.code}
-      boolean ${ev.isNull} = ${childCode.isNull};
+      ${pathCode.code}
+      boolean ${ev.isNull} = ${childCode.isNull} || ${pathCode.isNull};
       ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
       if (!${ev.isNull}) {
         Object $tmp = 
org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
-          ${childCode.value}, $parsedPathArg, $dataTypeArg, $castArgsArg);
+          ${childCode.value}, $parsedPathArg, $dataTypeArg,
+          $castArgsArg${if (parsedPath.isEmpty) s""", "$prettyName"""" else 
""});
         if ($tmp == null) {
           ${ev.isNull} = true;
         } else {
@@ -350,6 +359,15 @@ case object VariantGet {
     case _ => false
   }
 
+  /**
+   * Get parsed Array[VariantPathSegment] from string representing path
+   */
+  def getParsedPath(pathValue: String, prettyName: String): 
Array[VariantPathSegment] = {
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
   /** The actual implementation of the `VariantGet` expression. */
   def variantGet(
       input: VariantVal,
@@ -368,6 +386,20 @@ case object VariantGet {
     VariantGet.cast(v, dataType, castArgs)
   }
 
+  /**
+   * Implementation of the `VariantGet` expression where the path is provided 
as a UTF8String
+   */
+  def variantGet(
+      input: VariantVal,
+      path: UTF8String,
+      dataType: DataType,
+      castArgs: VariantCastArgs,
+      prettyName: String): Any = {
+    val pathValue = path.toString
+    val parsedPath = VariantGet.getParsedPath(pathValue, prettyName)
+    variantGet(input, parsedPath, dataType, castArgs)
+  }
+
   /**
    * A simple wrapper of the `cast` function that takes `Variant` rather than 
`VariantVal`. The
    * `Cast` expression uses it and makes the implementation simpler.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
index 33ba4f772a13..e9cc23c6a5ba 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
@@ -95,9 +95,11 @@ object RequestedVariantField {
   def fullVariant: RequestedVariantField =
     RequestedVariantField(VariantMetadata("$", failOnError = true, "UTC"), 
VariantType)
 
-  def apply(v: VariantGet): RequestedVariantField =
+  def apply(v: VariantGet): RequestedVariantField = {
+    assert(v.path.foldable)
     RequestedVariantField(
       VariantMetadata(v.path.eval().toString, v.failOnError, 
v.timeZoneId.get), v.dataType)
+  }
 
   def apply(c: Cast): RequestedVariantField =
     RequestedVariantField(
@@ -212,7 +214,7 @@ class VariantInRelation {
   // fields, which also changes the struct type containing it, and it is 
difficult to reconstruct
   // the original struct value. This is not a big loss, because we need the 
full variant anyway.
   def collectRequestedFields(expr: Expression): Unit = expr match {
-    case v@VariantGet(StructPathToVariant(fields), _, _, _, _) =>
+    case v@VariantGet(StructPathToVariant(fields), path, _, _, _) if 
path.foldable =>
       addField(fields, RequestedVariantField(v))
     case c@Cast(StructPathToVariant(fields), _, _, _) => addField(fields, 
RequestedVariantField(c))
     case IsNotNull(StructPath(_, _)) | IsNull(StructPath(_, _)) =>
@@ -240,7 +242,7 @@ class VariantInRelation {
 
     // Rewrite patterns should be consistent with visit patterns in 
`collectRequestedFields`.
     expr.transformDown {
-      case g@VariantGet(v@StructPathToVariant(fields), _, _, _, _) =>
+      case g@VariantGet(v@StructPathToVariant(fields), path, _, _, _) if 
path.foldable =>
         // Rewrite the attribute in advance, rather than depending on the last 
branch to rewrite it.
         // Ww need to avoid the `v@StructPathToVariant(fields)` branch to 
rewrite the child again.
         GetStructField(rewriteAttribute(v), fields(RequestedVariantField(g)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index 09b29b668b13..b6fe4af28ab0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -29,6 +29,8 @@ import org.apache.spark.SparkRuntimeException
 import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, 
ExpressionEvalHelper, Literal}
 import 
org.apache.spark.sql.catalyst.expressions.variant.{VariantExpressionEvalUtils, 
VariantGet}
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, 
GenericArrayData}
+import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId
+import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
@@ -108,6 +110,92 @@ class VariantSuite extends QueryTest with 
SharedSparkSession with ExpressionEval
     checkAnswer(df.select(try_variant_get(v, "$.a", "binary")), rows(null, 
null))
   }
 
+  test("non-literal variant_get") {
+    def rows(results: Any*): Seq[Row] = results.map(Row(_))
+
+    Seq("CODEGEN_ONLY", "NO_CODEGEN").foreach { codegenMode =>
+      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) {
+        // The first three rows have valid paths while the final row has an 
invalid path
+        val df = Seq(("""{"a" : 1}""", "$.a", 2), ("""{"b" : 2}""", "$", 1),
+          ("""{"c" : 3}""", null, 1), (null, null, 1), (null, "$.a", 1),
+          ("""{"d" : 3}""", "abc", 0)).toDF("json", "path", "valid")
+        val v = parse_json(col("json"))
+        val df1 = df.where($"valid" > 0).select(variant_get(v, col("path"), 
"string"))
+        checkAnswer(df1, rows("1", """{"b":2}""", null, null, null))
+        
assert(df1.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+          (codegenMode == "CODEGEN_ONLY"))
+        // Invalid path
+        val df2 = df.select(variant_get(v, col("path"), "string"))
+        checkError(
+          exception = intercept[SparkRuntimeException] { df2.collect() },
+          condition = "INVALID_VARIANT_GET_PATH",
+          parameters = Map("path" -> "abc", "functionName" -> 
toSQLId("variant_get")))
+        
assert(df2.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+          (codegenMode == "CODEGEN_ONLY"))
+        // Invalid cast
+        val df3 = df.where($"valid" > 1).select(variant_get(v, col("path"), 
"binary"))
+        checkError(
+          exception = intercept[SparkRuntimeException] { df3.collect() },
+          condition = "INVALID_VARIANT_CAST",
+          parameters = Map("value" -> "1", "dataType" -> "\"BINARY\"")
+        )
+        
assert(df3.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+          (codegenMode == "CODEGEN_ONLY"))
+
+        // try_variant_get
+        val df4 = df.where($"valid" > 0).select(try_variant_get(v, 
col("path"), "string"))
+        checkAnswer(df4, rows("1", """{"b":2}""", null, null, null))
+        
assert(df4.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+          (codegenMode == "CODEGEN_ONLY"))
+        // Invalid path
+        val df5 = df.select(try_variant_get(v, col("path"), "string"))
+        checkError(
+          exception = intercept[SparkRuntimeException] { df5.collect() },
+          condition = "INVALID_VARIANT_GET_PATH",
+          parameters = Map("path" -> "abc", "functionName" -> 
toSQLId("try_variant_get")))
+        
assert(df5.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+          (codegenMode == "CODEGEN_ONLY"))
+        // Invalid cast
+        val df6 = df.where($"valid" > 1).select(try_variant_get(v, 
col("path"), "binary"))
+        checkAnswer(df6, rows(null))
+        
assert(df6.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] ==
+          (codegenMode == "CODEGEN_ONLY"))
+
+        // SQL API
+        withTable("t") {
+          df.withColumn("v", parse_json(col("json"))).write.saveAsTable("t")
+          // variant_get
+          checkAnswer(sql("select variant_get(v, path, 'string') from t where 
valid > 0"),
+            rows("1", """{"b":2}""", null, null, null))
+          checkError(
+            exception = intercept[SparkRuntimeException] {
+              sql("select variant_get(v, path, 'string') from t").collect()
+            },
+            condition = "INVALID_VARIANT_GET_PATH",
+            parameters = Map("path" -> "abc", "functionName" -> 
toSQLId("variant_get")))
+          checkError(
+            exception = intercept[SparkRuntimeException] {
+              sql("select variant_get(v, path, 'binary') from t where valid > 
1").collect()
+            },
+            condition = "INVALID_VARIANT_CAST",
+            parameters = Map("value" -> "1", "dataType" -> "\"BINARY\"")
+          )
+          // try_variant_get
+          checkAnswer(sql("select try_variant_get(v, path, 'string') from t 
where valid > 0"),
+            rows("1", """{"b":2}""", null, null, null))
+          checkError(
+            exception = intercept[SparkRuntimeException] {
+              sql("select try_variant_get(v, path, 'string') from t").collect()
+            },
+            condition = "INVALID_VARIANT_GET_PATH",
+            parameters = Map("path" -> "abc", "functionName" -> 
toSQLId("try_variant_get")))
+          checkAnswer(sql("select try_variant_get(v, path, 'binary') from t 
where valid > 1"),
+            rows(null))
+        }
+      }
+    }
+  }
+
   test("round trip tests") {
     val rand = new Random(42)
     val input = Seq.fill(50) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
index 2a866dcd66f0..5515c4053bc1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
@@ -59,7 +59,7 @@ class PushVariantIntoScanSuite extends SharedSparkSession {
 
   testOnFormats { format =>
     sql("create table T (v variant, vs struct<v1 variant, v2 variant, i int>, 
" +
-      "va array<variant>, vd variant default parse_json('1')) " +
+      "va array<variant>, vd variant default parse_json('1'), s string) " +
       s"using $format")
 
     sql("select variant_get(v, '$.a', 'int') as a, v, cast(v as struct<b 
float>) as v from T")
@@ -162,6 +162,26 @@ class PushVariantIntoScanSuite extends SharedSparkSession {
         assert(vd.dataType == VariantType)
       case _ => fail()
     }
+
+    // No push down if the path in variant_get is not a literal
+    sql("select variant_get(v, '$.a', 'int') as a, variant_get(v, s, 'int') 
v2, v, " +
+      "cast(v as struct<b float>) as v from T")
+      .queryExecution.optimizedPlan match {
+      case Project(projectList, l: LogicalRelation) =>
+        val output = l.output
+        val v = output(0)
+        val s = output(4)
+        checkAlias(projectList(0), "a", GetStructField(v, 0))
+        checkAlias(projectList(1), "v2", VariantGet(GetStructField(v, 1), s,
+          targetType = IntegerType, failOnError = true, timeZoneId = 
Some(localTimeZone)))
+        checkAlias(projectList(2), "v", GetStructField(v, 1))
+        checkAlias(projectList(3), "v", GetStructField(v, 2))
+        assert(v.dataType == StructType(Array(
+          field(0, IntegerType, "$.a"),
+          field(1, VariantType, "$", timeZone = "UTC"),
+          field(2, StructType(Array(StructField("b", FloatType))), "$"))))
+      case _ => fail()
+    }
   }
 
   test("No push down for JSON") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to