This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dd153307cb97 [SPARK-50953][PYTHON][CONNECT] Add support for non-literal paths in VariantGet dd153307cb97 is described below commit dd153307cb9735fd05a41124eca2a136f40f3b3f Author: Harsh Motwani <harsh.motw...@databricks.com> AuthorDate: Mon Feb 10 21:46:18 2025 +0800 [SPARK-50953][PYTHON][CONNECT] Add support for non-literal paths in VariantGet ### What changes were proposed in this pull request? This PR allows the `variant_get` expression to support non-literal path inputs. ### Why are the changes needed? Currently, `variant_get` only supports literal paths as the second argument. Users may have columns containing paths which they would want to extract from variants. This PR allows this functionality. ### Does this PR introduce _any_ user-facing change? Yes, prior to this PR, `variant_get` did not have support for non-literal paths. ### How was this patch tested? Unit tests to make sure that: 1. The VariantGet/TryVariantGet expressions with non-literal paths has the expected behavior regardless of codegen mode. 2. VariantGet expressions with non-literal paths do not get pushed down as this functionality has not been implemented. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49609 from harshmotw-db/harsh-motwani_data/variant_get_column. Authored-by: Harsh Motwani <harsh.motw...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- python/pyspark/sql/connect/functions/builtin.py | 16 +++- python/pyspark/sql/functions/builtin.py | 49 ++++++++---- python/pyspark/sql/tests/test_functions.py | 8 +- .../scala/org/apache/spark/sql/functions.scala | 40 +++++++++- .../expressions/variant/variantExpressions.scala | 66 +++++++++++----- .../datasources/PushVariantIntoScan.scala | 8 +- .../scala/org/apache/spark/sql/VariantSuite.scala | 88 ++++++++++++++++++++++ .../datasources/PushVariantIntoScanSuite.scala | 22 +++++- 8 files changed, 253 insertions(+), 44 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index f13eeab12dd3..51685def7dbc 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2161,15 +2161,23 @@ def is_variant_null(v: "ColumnOrName") -> Column: is_variant_null.__doc__ = pysparkfuncs.is_variant_null.__doc__ -def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column: - return _invoke_function("variant_get", _to_col(v), lit(path), lit(targetType)) +def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column: + assert isinstance(path, (Column, str)) + if isinstance(path, str): + return _invoke_function("variant_get", _to_col(v), lit(path), lit(targetType)) + else: + return _invoke_function("variant_get", _to_col(v), path, lit(targetType)) variant_get.__doc__ = pysparkfuncs.variant_get.__doc__ -def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column: - return _invoke_function("try_variant_get", _to_col(v), lit(path), lit(targetType)) +def try_variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column: + assert isinstance(path, (Column, str)) + if isinstance(path, str): + return _invoke_function("try_variant_get", _to_col(v), lit(path), lit(targetType)) + else: + return _invoke_function("try_variant_get", _to_col(v), path, lit(targetType)) try_variant_get.__doc__ = pysparkfuncs.try_variant_get.__doc__ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 4575bf730fca..2b6d8569fdf8 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -20427,7 +20427,7 @@ def is_variant_null(v: "ColumnOrName") -> Column: @_try_remote_functions -def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column: +def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column: """ Extracts a sub-variant from `v` according to `path`, and then cast the sub-variant to `targetType`. Returns null if the path does not exist. Throws an exception if the cast fails. @@ -20438,9 +20438,10 @@ def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column: ---------- v : :class:`~pyspark.sql.Column` or str a variant column or column name - path : str - the extraction path. A valid path should start with `$` and is followed by zero or more - segments like `[123]`, `.name`, `['name']`, or `["name"]`. + path : :class:`~pyspark.sql.Column` or str + a column containing the extraction path strings or a string representing the extraction + path. A valid path should start with `$` and is followed by zero or more segments like + `[123]`, `.name`, `['name']`, or `["name"]`. targetType : str the target data type to cast into, in a DDL-formatted string @@ -20451,21 +20452,29 @@ def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ]) + >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }''', 'path': '$.a'} ]) >>> df.select(variant_get(parse_json(df.json), "$.a", "int").alias("r")).collect() [Row(r=1)] >>> df.select(variant_get(parse_json(df.json), "$.b", "int").alias("r")).collect() [Row(r=None)] + >>> df.select(variant_get(parse_json(df.json), df.path, "int").alias("r")).collect() + [Row(r=1)] """ from pyspark.sql.classic.column import _to_java_column - return _invoke_function( - "variant_get", _to_java_column(v), _enum_to_value(path), _enum_to_value(targetType) - ) + assert isinstance(path, (Column, str)) + if isinstance(path, str): + return _invoke_function( + "variant_get", _to_java_column(v), _enum_to_value(path), _enum_to_value(targetType) + ) + else: + return _invoke_function( + "variant_get", _to_java_column(v), _to_java_column(path), _enum_to_value(targetType) + ) @_try_remote_functions -def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column: +def try_variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column: """ Extracts a sub-variant from `v` according to `path`, and then cast the sub-variant to `targetType`. Returns null if the path does not exist or the cast fails. @@ -20476,9 +20485,10 @@ def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column: ---------- v : :class:`~pyspark.sql.Column` or str a variant column or column name - path : str - the extraction path. A valid path should start with `$` and is followed by zero or more - segments like `[123]`, `.name`, `['name']`, or `["name"]`. + path : :class:`~pyspark.sql.Column` or str + a column containing the extraction path strings or a string representing the extraction + path. A valid path should start with `$` and is followed by zero or more segments like + `[123]`, `.name`, `['name']`, or `["name"]`. targetType : str the target data type to cast into, in a DDL-formatted string @@ -20489,19 +20499,26 @@ def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ]) + >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }''', 'path': '$.a'} ]) >>> df.select(try_variant_get(parse_json(df.json), "$.a", "int").alias("r")).collect() [Row(r=1)] >>> df.select(try_variant_get(parse_json(df.json), "$.b", "int").alias("r")).collect() [Row(r=None)] >>> df.select(try_variant_get(parse_json(df.json), "$.a", "binary").alias("r")).collect() [Row(r=None)] + >>> df.select(try_variant_get(parse_json(df.json), df.path, "int").alias("r")).collect() + [Row(r=1)] """ from pyspark.sql.classic.column import _to_java_column - return _invoke_function( - "try_variant_get", _to_java_column(v), _enum_to_value(path), _enum_to_value(targetType) - ) + if isinstance(path, str): + return _invoke_function( + "try_variant_get", _to_java_column(v), _enum_to_value(path), _enum_to_value(targetType) + ) + else: + return _invoke_function( + "try_variant_get", _to_java_column(v), _to_java_column(path), _enum_to_value(targetType) + ) @_try_remote_functions diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 39db72b235bf..b627bc793f05 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1496,7 +1496,9 @@ class FunctionsTestsMixin: self.assertEqual("""{"b":[{"c":"str2"}]}""", actual["var_lit"]) def test_variant_expressions(self): - df = self.spark.createDataFrame([Row(json="""{ "a" : 1 }"""), Row(json="""{ "b" : 2 }""")]) + df = self.spark.createDataFrame( + [Row(json="""{ "a" : 1 }""", path="$.a"), Row(json="""{ "b" : 2 }""", path="$.b")] + ) v = F.parse_json(df.json) def check(resultDf, expected): @@ -1510,6 +1512,10 @@ class FunctionsTestsMixin: check(df.select(F.variant_get(v, "$.b", "int")), [None, 2]) check(df.select(F.variant_get(v, "$.a", "double")), [1.0, None]) + # non-literal variant_get + check(df.select(F.variant_get(v, df.path, "int")), [1, 2]) + check(df.select(F.try_variant_get(v, df.path, "binary")), [None, None]) + with self.assertRaises(SparkRuntimeException) as ex: df.select(F.variant_get(v, "$.a", "binary")).collect() diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 5670e513287e..ffa3a03e4224 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -7115,7 +7115,7 @@ object functions { def is_variant_null(v: Column): Column = Column.fn("is_variant_null", v) /** - * Extracts a sub-variant from `v` according to `path`, and then cast the sub-variant to + * Extracts a sub-variant from `v` according to `path` string, and then cast the sub-variant to * `targetType`. Returns null if the path does not exist. Throws an exception if the cast fails. * * @param v @@ -7132,7 +7132,25 @@ object functions { Column.fn("variant_get", v, lit(path), lit(targetType)) /** - * Extracts a sub-variant from `v` according to `path`, and then cast the sub-variant to + * Extracts a sub-variant from `v` according to `path` column, and then cast the sub-variant to + * `targetType`. Returns null if the path does not exist. Throws an exception if the cast fails. + * + * @param v + * a variant column. + * @param path + * the column containing the extraction path strings. A valid path string should start with + * `$` and is followed by zero or more segments like `[123]`, `.name`, `['name']`, or + * `["name"]`. + * @param targetType + * the target data type to cast into, in a DDL-formatted string. + * @group variant_funcs + * @since 4.0.0 + */ + def variant_get(v: Column, path: Column, targetType: String): Column = + Column.fn("variant_get", v, path, lit(targetType)) + + /** + * Extracts a sub-variant from `v` according to `path` string, and then cast the sub-variant to * `targetType`. Returns null if the path does not exist or the cast fails.. * * @param v @@ -7148,6 +7166,24 @@ object functions { def try_variant_get(v: Column, path: String, targetType: String): Column = Column.fn("try_variant_get", v, lit(path), lit(targetType)) + /** + * Extracts a sub-variant from `v` according to `path` column, and then cast the sub-variant to + * `targetType`. Returns null if the path does not exist or the cast fails.. + * + * @param v + * a variant column. + * @param path + * the column containing the extraction path strings. A valid path string should start with + * `$` and is followed by zero or more segments like `[123]`, `.name`, `['name']`, or + * `["name"]`. + * @param targetType + * the target data type to cast into, in a DDL-formatted string. + * @group variant_funcs + * @since 4.0.0 + */ + def try_variant_get(v: Column, path: Column, targetType: String): Column = + Column.fn("try_variant_get", v, lit(path), lit(targetType)) + /** * Returns schema in the SQL format of a variant. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index f722329097bc..0a72e792a04f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -246,13 +246,6 @@ case class VariantGet( val check = super.checkInputDataTypes() if (check.isFailure) { check - } else if (!path.foldable) { - DataTypeMismatch( - errorSubClass = "NON_FOLDABLE_INPUT", - messageParameters = Map( - "inputName" -> toSQLId("path"), - "inputType" -> toSQLType(path.dataType), - "inputExpr" -> toSQLExpr(path))) } else if (!VariantGet.checkDataType(targetType)) { DataTypeMismatch( errorSubClass = "CAST_WITHOUT_SUGGESTION", @@ -265,10 +258,12 @@ case class VariantGet( override lazy val dataType: DataType = targetType.asNullable - @transient private lazy val parsedPath = { - val pathValue = path.eval().toString - VariantPathParser.parse(pathValue).getOrElse { - throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName) + @transient private lazy val parsedPath: Option[Array[VariantPathSegment]] = { + if (path.foldable) { + val pathValue = path.eval().toString + Some(VariantGet.getParsedPath(pathValue, prettyName)) + } else { + None } } @@ -287,23 +282,37 @@ case class VariantGet( timeZoneId, zoneId) - protected override def nullSafeEval(input: Any, path: Any): Any = { - VariantGet.variantGet(input.asInstanceOf[VariantVal], parsedPath, dataType, castArgs) + protected override def nullSafeEval(input: Any, path: Any): Any = parsedPath match { + case Some(pp) => + VariantGet.variantGet(input.asInstanceOf[VariantVal], pp, dataType, castArgs) + case _ => + VariantGet.variantGet(input.asInstanceOf[VariantVal], path.asInstanceOf[UTF8String], dataType, + castArgs, prettyName) } protected override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childCode = child.genCode(ctx) val tmp = ctx.freshVariable("tmp", classOf[Object]) - val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath) + val childCode = child.genCode(ctx) val dataTypeArg = ctx.addReferenceObj("dataType", dataType) val castArgsArg = ctx.addReferenceObj("castArgs", castArgs) + val (pathCode, parsedPathArg) = if (parsedPath.isEmpty) { + val pathCode = path.genCode(ctx) + (pathCode, pathCode.value) + } else { + ( + new ExprCode(EmptyBlock, FalseLiteral, TrueLiteral), + ctx.addReferenceObj("parsedPath", parsedPath.get) + ) + } val code = code""" ${childCode.code} - boolean ${ev.isNull} = ${childCode.isNull}; + ${pathCode.code} + boolean ${ev.isNull} = ${childCode.isNull} || ${pathCode.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { Object $tmp = org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet( - ${childCode.value}, $parsedPathArg, $dataTypeArg, $castArgsArg); + ${childCode.value}, $parsedPathArg, $dataTypeArg, + $castArgsArg${if (parsedPath.isEmpty) s""", "$prettyName"""" else ""}); if ($tmp == null) { ${ev.isNull} = true; } else { @@ -350,6 +359,15 @@ case object VariantGet { case _ => false } + /** + * Get parsed Array[VariantPathSegment] from string representing path + */ + def getParsedPath(pathValue: String, prettyName: String): Array[VariantPathSegment] = { + VariantPathParser.parse(pathValue).getOrElse { + throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName) + } + } + /** The actual implementation of the `VariantGet` expression. */ def variantGet( input: VariantVal, @@ -368,6 +386,20 @@ case object VariantGet { VariantGet.cast(v, dataType, castArgs) } + /** + * Implementation of the `VariantGet` expression where the path is provided as a UTF8String + */ + def variantGet( + input: VariantVal, + path: UTF8String, + dataType: DataType, + castArgs: VariantCastArgs, + prettyName: String): Any = { + val pathValue = path.toString + val parsedPath = VariantGet.getParsedPath(pathValue, prettyName) + variantGet(input, parsedPath, dataType, castArgs) + } + /** * A simple wrapper of the `cast` function that takes `Variant` rather than `VariantVal`. The * `Cast` expression uses it and makes the implementation simpler. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala index 33ba4f772a13..e9cc23c6a5ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala @@ -95,9 +95,11 @@ object RequestedVariantField { def fullVariant: RequestedVariantField = RequestedVariantField(VariantMetadata("$", failOnError = true, "UTC"), VariantType) - def apply(v: VariantGet): RequestedVariantField = + def apply(v: VariantGet): RequestedVariantField = { + assert(v.path.foldable) RequestedVariantField( VariantMetadata(v.path.eval().toString, v.failOnError, v.timeZoneId.get), v.dataType) + } def apply(c: Cast): RequestedVariantField = RequestedVariantField( @@ -212,7 +214,7 @@ class VariantInRelation { // fields, which also changes the struct type containing it, and it is difficult to reconstruct // the original struct value. This is not a big loss, because we need the full variant anyway. def collectRequestedFields(expr: Expression): Unit = expr match { - case v@VariantGet(StructPathToVariant(fields), _, _, _, _) => + case v@VariantGet(StructPathToVariant(fields), path, _, _, _) if path.foldable => addField(fields, RequestedVariantField(v)) case c@Cast(StructPathToVariant(fields), _, _, _) => addField(fields, RequestedVariantField(c)) case IsNotNull(StructPath(_, _)) | IsNull(StructPath(_, _)) => @@ -240,7 +242,7 @@ class VariantInRelation { // Rewrite patterns should be consistent with visit patterns in `collectRequestedFields`. expr.transformDown { - case g@VariantGet(v@StructPathToVariant(fields), _, _, _, _) => + case g@VariantGet(v@StructPathToVariant(fields), path, _, _, _) if path.foldable => // Rewrite the attribute in advance, rather than depending on the last branch to rewrite it. // Ww need to avoid the `v@StructPathToVariant(fields)` branch to rewrite the child again. GetStructField(rewriteAttribute(v), fields(RequestedVariantField(g))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index 09b29b668b13..b6fe4af28ab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, ExpressionEvalHelper, Literal} import org.apache.spark.sql.catalyst.expressions.variant.{VariantExpressionEvalUtils, VariantGet} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId +import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -108,6 +110,92 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval checkAnswer(df.select(try_variant_get(v, "$.a", "binary")), rows(null, null)) } + test("non-literal variant_get") { + def rows(results: Any*): Seq[Row] = results.map(Row(_)) + + Seq("CODEGEN_ONLY", "NO_CODEGEN").foreach { codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + // The first three rows have valid paths while the final row has an invalid path + val df = Seq(("""{"a" : 1}""", "$.a", 2), ("""{"b" : 2}""", "$", 1), + ("""{"c" : 3}""", null, 1), (null, null, 1), (null, "$.a", 1), + ("""{"d" : 3}""", "abc", 0)).toDF("json", "path", "valid") + val v = parse_json(col("json")) + val df1 = df.where($"valid" > 0).select(variant_get(v, col("path"), "string")) + checkAnswer(df1, rows("1", """{"b":2}""", null, null, null)) + assert(df1.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == + (codegenMode == "CODEGEN_ONLY")) + // Invalid path + val df2 = df.select(variant_get(v, col("path"), "string")) + checkError( + exception = intercept[SparkRuntimeException] { df2.collect() }, + condition = "INVALID_VARIANT_GET_PATH", + parameters = Map("path" -> "abc", "functionName" -> toSQLId("variant_get"))) + assert(df2.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == + (codegenMode == "CODEGEN_ONLY")) + // Invalid cast + val df3 = df.where($"valid" > 1).select(variant_get(v, col("path"), "binary")) + checkError( + exception = intercept[SparkRuntimeException] { df3.collect() }, + condition = "INVALID_VARIANT_CAST", + parameters = Map("value" -> "1", "dataType" -> "\"BINARY\"") + ) + assert(df3.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == + (codegenMode == "CODEGEN_ONLY")) + + // try_variant_get + val df4 = df.where($"valid" > 0).select(try_variant_get(v, col("path"), "string")) + checkAnswer(df4, rows("1", """{"b":2}""", null, null, null)) + assert(df4.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == + (codegenMode == "CODEGEN_ONLY")) + // Invalid path + val df5 = df.select(try_variant_get(v, col("path"), "string")) + checkError( + exception = intercept[SparkRuntimeException] { df5.collect() }, + condition = "INVALID_VARIANT_GET_PATH", + parameters = Map("path" -> "abc", "functionName" -> toSQLId("try_variant_get"))) + assert(df5.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == + (codegenMode == "CODEGEN_ONLY")) + // Invalid cast + val df6 = df.where($"valid" > 1).select(try_variant_get(v, col("path"), "binary")) + checkAnswer(df6, rows(null)) + assert(df6.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == + (codegenMode == "CODEGEN_ONLY")) + + // SQL API + withTable("t") { + df.withColumn("v", parse_json(col("json"))).write.saveAsTable("t") + // variant_get + checkAnswer(sql("select variant_get(v, path, 'string') from t where valid > 0"), + rows("1", """{"b":2}""", null, null, null)) + checkError( + exception = intercept[SparkRuntimeException] { + sql("select variant_get(v, path, 'string') from t").collect() + }, + condition = "INVALID_VARIANT_GET_PATH", + parameters = Map("path" -> "abc", "functionName" -> toSQLId("variant_get"))) + checkError( + exception = intercept[SparkRuntimeException] { + sql("select variant_get(v, path, 'binary') from t where valid > 1").collect() + }, + condition = "INVALID_VARIANT_CAST", + parameters = Map("value" -> "1", "dataType" -> "\"BINARY\"") + ) + // try_variant_get + checkAnswer(sql("select try_variant_get(v, path, 'string') from t where valid > 0"), + rows("1", """{"b":2}""", null, null, null)) + checkError( + exception = intercept[SparkRuntimeException] { + sql("select try_variant_get(v, path, 'string') from t").collect() + }, + condition = "INVALID_VARIANT_GET_PATH", + parameters = Map("path" -> "abc", "functionName" -> toSQLId("try_variant_get"))) + checkAnswer(sql("select try_variant_get(v, path, 'binary') from t where valid > 1"), + rows(null)) + } + } + } + } + test("round trip tests") { val rand = new Random(42) val input = Seq.fill(50) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala index 2a866dcd66f0..5515c4053bc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala @@ -59,7 +59,7 @@ class PushVariantIntoScanSuite extends SharedSparkSession { testOnFormats { format => sql("create table T (v variant, vs struct<v1 variant, v2 variant, i int>, " + - "va array<variant>, vd variant default parse_json('1')) " + + "va array<variant>, vd variant default parse_json('1'), s string) " + s"using $format") sql("select variant_get(v, '$.a', 'int') as a, v, cast(v as struct<b float>) as v from T") @@ -162,6 +162,26 @@ class PushVariantIntoScanSuite extends SharedSparkSession { assert(vd.dataType == VariantType) case _ => fail() } + + // No push down if the path in variant_get is not a literal + sql("select variant_get(v, '$.a', 'int') as a, variant_get(v, s, 'int') v2, v, " + + "cast(v as struct<b float>) as v from T") + .queryExecution.optimizedPlan match { + case Project(projectList, l: LogicalRelation) => + val output = l.output + val v = output(0) + val s = output(4) + checkAlias(projectList(0), "a", GetStructField(v, 0)) + checkAlias(projectList(1), "v2", VariantGet(GetStructField(v, 1), s, + targetType = IntegerType, failOnError = true, timeZoneId = Some(localTimeZone))) + checkAlias(projectList(2), "v", GetStructField(v, 1)) + checkAlias(projectList(3), "v", GetStructField(v, 2)) + assert(v.dataType == StructType(Array( + field(0, IntegerType, "$.a"), + field(1, VariantType, "$", timeZone = "UTC"), + field(2, StructType(Array(StructField("b", FloatType))), "$")))) + case _ => fail() + } } test("No push down for JSON") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org