This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 37afc3fff8c6 [SPARK-47803][FOLLOWUP] Fix cast binary/decimal to variant 37afc3fff8c6 is described below commit 37afc3fff8c65b612c1242f1fc9a66a2e04639ad Author: Chenhao Li <chenhao...@databricks.com> AuthorDate: Thu Apr 18 09:23:45 2024 +0800 [SPARK-47803][FOLLOWUP] Fix cast binary/decimal to variant ### What changes were proposed in this pull request? This PR fixes issues introduced in https://github.com/apache/spark/pull/45989: - `VariantBuilder.appendBinary` incorrectly uses the type tag for the string type. - `VariantExpressionEvalUtils.buildVariant` misses the decimal types. ### Why are the changes needed? It is a bug fix and allows Spark to read a map schema with variant value (for example, `map<string, variant>`) correctly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests. We ensure that at least all supported types are covered (scalar types, array, map, struct, variant). ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46109 from chenhao-db/fix_cast_to_variant. Authored-by: Chenhao Li <chenhao...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../org/apache/spark/types/variant/VariantBuilder.java | 2 +- .../variant/VariantExpressionEvalUtils.scala | 3 ++- .../expressions/variant/VariantExpressionSuite.scala | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java index ea7a7674baf5..2afba81d192e 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java @@ -223,7 +223,7 @@ public class VariantBuilder { public void appendBinary(byte[] binary) { checkCapacity(1 + U32_SIZE + binary.length); - writeBuffer[writePos++] = primitiveHeader(LONG_STR); + writeBuffer[writePos++] = primitiveHeader(BINARY); writeLong(writeBuffer, writePos, binary.length, U32_SIZE); writePos += U32_SIZE; System.arraycopy(binary, 0, writeBuffer, writePos, binary.length); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index 4d1d70055f5e..ea90bb88a906 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -76,7 +76,8 @@ object VariantExpressionEvalUtils { case LongType => builder.appendLong(input.asInstanceOf[Long]) case FloatType => builder.appendFloat(input.asInstanceOf[Float]) case DoubleType => builder.appendDouble(input.asInstanceOf[Double]) - case StringType => builder.appendString(input.asInstanceOf[UTF8String].toString) + case _: DecimalType => builder.appendDecimal(input.asInstanceOf[Decimal].toJavaBigDecimal) + case _: StringType => builder.appendString(input.asInstanceOf[UTF8String].toString) case BinaryType => builder.appendBinary(input.asInstanceOf[Array[Byte]]) case DateType => builder.appendDate(input.asInstanceOf[Int]) case TimestampType => builder.appendTimestamp(input.asInstanceOf[Long]) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index 9aa1dcd2ef95..1f9eec862bbe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -807,9 +807,27 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } check(null.asInstanceOf[String], null) + // The following tests cover all allowed scalar types. for (input <- Seq[Any](false, true, 0.toByte, 1.toShort, 2, 3L, 4.0F, 5.0D)) { check(input, input.toString) } + for (precision <- Seq(9, 18, 38)) { + val input = BigDecimal("9" * precision) + check(Literal.create(input, DecimalType(precision, 0)), input.toString) + } + check("", "\"\"") + check("x" * 128, "\"" + ("x" * 128) + "\"") + check(Array[Byte](1, 2, 3), "\"AQID\"") + check(Literal(0, DateType), "\"1970-01-01\"") + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + check(Literal(0L, TimestampType), "\"1970-01-01 00:00:00+00:00\"") + check(Literal(0L, TimestampNTZType), "\"1970-01-01 00:00:00\"") + } + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") { + check(Literal(0L, TimestampType), "\"1969-12-31 16:00:00-08:00\"") + check(Literal(0L, TimestampNTZType), "\"1970-01-01 00:00:00\"") + } + check(Array(null, "a", "b", "c"), """[null,"a","b","c"]""") check(Map("z" -> 1, "y" -> 2, "x" -> 3), """{"x":3,"y":2,"z":1}""") check(Array(parseJson("""{"a": 1,"b": [1, 2, 3]}"""), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org