This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ffa4d198cec6 [SPARK-48067][SQL] Fix variant default columns ffa4d198cec6 is described below commit ffa4d198cec6620f0385a0e428b023d2ac4e3d5c Author: Richard Chen <r.c...@databricks.com> AuthorDate: Thu May 2 12:22:02 2024 -0700 [SPARK-48067][SQL] Fix variant default columns ### What changes were proposed in this pull request? Changes the literal `sql` representation of a variant value to `parse_json(variant.toJson)`. This is because there is no other representation of a literal variant. This allows variant default columns to work because default columns store a literal string representation in the schema struct fields metadata as the default value. ### Why are the changes needed? previously we could not set a variant default column like ``` create table t( v6 variant default parse_json('{\"k\": \"v\"}') ) ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added UT ### Was this patch authored or co-authored using generative AI tooling? no Closes #46312 from richardc-db/fix_variant_default_cols. Authored-by: Richard Chen <r.c...@databricks.com> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../spark/sql/catalyst/expressions/literals.scala | 4 + .../scala/org/apache/spark/sql/VariantSuite.scala | 145 ++++++++++++++++++++- 2 files changed, 146 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0fad3eff2da5..4cffc7f0b53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -42,6 +42,7 @@ import org.json4s.JsonAST._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.trees.TreePattern.{LITERAL, NULL_LITERAL, TRUE_OR_FALSE_LITERAL} import org.apache.spark.sql.catalyst.types._ @@ -204,6 +205,8 @@ object Literal { create(new GenericInternalRow( struct.fields.map(f => default(f.dataType).value)), struct) case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt) + case VariantType => + create(VariantExpressionEvalUtils.castToVariant(0, IntegerType), VariantType) case other => throw QueryExecutionErrors.noDefaultForDataTypeError(dataType) } @@ -549,6 +552,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { s"${Literal(kv._1, mapType.keyType).sql}, ${Literal(kv._2, mapType.valueType).sql}" } s"MAP(${keysAndValues.mkString(", ")})" + case (v: VariantVal, variantType: VariantType) => s"PARSE_JSON('${v.toJson(timeZoneId)}')" case _ => value.toString } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index 19e5f9ba63e6..caab98b6239a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -26,15 +26,17 @@ import scala.jdk.CollectionConverters._ import scala.util.Random import org.apache.spark.SparkRuntimeException -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, ExpressionEvalHelper, Literal} +import org.apache.spark.sql.catalyst.expressions.variant.{VariantExpressionEvalUtils, VariantGet} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.VariantVal +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} import org.apache.spark.util.ArrayImplicits._ -class VariantSuite extends QueryTest with SharedSparkSession { +class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEvalHelper { import testImplicits._ test("basic tests") { @@ -445,4 +447,141 @@ class VariantSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-48067: default variant columns works") { + withTable("t") { + sql("""create table t( + v1 variant default null, + v2 variant default parse_json(null), + v3 variant default cast(null as variant), + v4 variant default parse_json('1'), + v5 variant default parse_json('1'), + v6 variant default parse_json('{\"k\": \"v\"}'), + v7 variant default cast(5 as int), + v8 variant default cast('hello' as string), + v9 variant default parse_json(to_json(parse_json('{\"k\": \"v\"}'))) + ) using parquet""") + sql("""insert into t values(DEFAULT, DEFAULT, DEFAULT, DEFAULT, DEFAULT, DEFAULT, DEFAULT, + DEFAULT, DEFAULT)""") + + val expected = sql("""select + cast(null as variant) as v1, + parse_json(null) as v2, + cast(null as variant) as v3, + parse_json('1') as v4, + parse_json('1') as v5, + parse_json('{\"k\": \"v\"}') as v6, + cast(cast(5 as int) as variant) as v7, + cast('hello' as variant) as v8, + parse_json(to_json(parse_json('{\"k\": \"v\"}'))) as v9 + """) + val actual = sql("select * from t") + checkAnswer(actual, expected.collect()) + } + } + + Seq( + ( + "basic int parse json", + VariantExpressionEvalUtils.parseJson(UTF8String.fromString("1")), + VariantType + ), + ( + "basic json parse json", + VariantExpressionEvalUtils.parseJson(UTF8String.fromString("{\"k\": \"v\"}")), + VariantType + ), + ( + "basic null parse json", + VariantExpressionEvalUtils.parseJson(UTF8String.fromString("null")), + VariantType + ), + ( + "basic null", + null, + VariantType + ), + ( + "basic array", + new GenericArrayData(Array[Int](1, 2, 3, 4, 5)), + new ArrayType(IntegerType, false) + ), + ( + "basic string", + UTF8String.fromString("literal string"), + StringType + ), + ( + "basic timestamp", + 0L, + TimestampType + ), + ( + "basic int", + 0, + IntegerType + ), + ( + "basic struct", + Literal.default(new StructType().add("col0", StringType)).eval(), + new StructType().add("col0", StringType) + ), + ( + "complex struct with child variant", + Literal.default(new StructType() + .add("col0", StringType) + .add("col1", new StructType().add("col0", VariantType)) + .add("col2", VariantType) + .add("col3", new ArrayType(VariantType, false)) + ).eval(), + new StructType() + .add("col0", StringType) + .add("col1", new StructType().add("col0", VariantType)) + .add("col2", VariantType) + .add("col3", new ArrayType(VariantType, false)) + ), + ( + "basic array with null", + new GenericArrayData(Array[Any](1, 2, null)), + new ArrayType(IntegerType, true) + ), + ( + "basic map with null", + new ArrayBasedMapData( + new GenericArrayData(Array[Any](UTF8String.fromString("k1"), UTF8String.fromString("k2"))), + new GenericArrayData(Array[Any](1, null)) + ), + new MapType(StringType, IntegerType, true) + ) + ).foreach { case (testName, value, dt) => + test(s"SPARK-48067: Variant literal `sql` correctly recreates the variant - $testName") { + val l = Literal.create( + VariantExpressionEvalUtils.castToVariant(value, dt.asInstanceOf[DataType]), VariantType) + val jsonString = l.eval().asInstanceOf[VariantVal] + .toJson(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) + val expectedSql = s"PARSE_JSON('$jsonString')" + assert(l.sql == expectedSql) + val valueFromLiteralSql = + spark.sql(s"select ${l.sql}").collect()(0).getAs[VariantVal](0) + + // Cast the variants to their specified type to compare for logical equality. + // Currently, variant equality naively compares its value and metadata binaries. However, + // variant equality is more complex than this. + val castVariantExpr = VariantGet( + l, + Literal.create(UTF8String.fromString("$"), StringType), + dt, + true, + Some(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone).toString()) + ) + val sqlVariantExpr = VariantGet( + Literal.create(valueFromLiteralSql, VariantType), + Literal.create(UTF8String.fromString("$"), StringType), + dt, + true, + Some(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone).toString()) + ) + checkEvaluation(castVariantExpr, sqlVariantExpr.eval()) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org