This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 37afc3fff8c6 [SPARK-47803][FOLLOWUP] Fix cast binary/decimal to variant
37afc3fff8c6 is described below

commit 37afc3fff8c65b612c1242f1fc9a66a2e04639ad
Author: Chenhao Li <chenhao...@databricks.com>
AuthorDate: Thu Apr 18 09:23:45 2024 +0800

    [SPARK-47803][FOLLOWUP] Fix cast binary/decimal to variant
    
    ### What changes were proposed in this pull request?
    
    This PR fixes issues introduced in 
https://github.com/apache/spark/pull/45989:
    - `VariantBuilder.appendBinary` incorrectly uses the type tag for the 
string type.
    - `VariantExpressionEvalUtils.buildVariant` misses the decimal types.
    
    ### Why are the changes needed?
    
    It is a bug fix and allows Spark to read a map schema with variant value 
(for example, `map<string, variant>`) correctly.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit tests. We ensure that at least all supported types are covered 
(scalar types, array, map, struct, variant).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46109 from chenhao-db/fix_cast_to_variant.
    
    Authored-by: Chenhao Li <chenhao...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/types/variant/VariantBuilder.java |  2 +-
 .../variant/VariantExpressionEvalUtils.scala           |  3 ++-
 .../expressions/variant/VariantExpressionSuite.scala   | 18 ++++++++++++++++++
 3 files changed, 21 insertions(+), 2 deletions(-)

diff --git 
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java
 
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java
index ea7a7674baf5..2afba81d192e 100644
--- 
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java
+++ 
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java
@@ -223,7 +223,7 @@ public class VariantBuilder {
 
   public void appendBinary(byte[] binary) {
     checkCapacity(1 + U32_SIZE + binary.length);
-    writeBuffer[writePos++] = primitiveHeader(LONG_STR);
+    writeBuffer[writePos++] = primitiveHeader(BINARY);
     writeLong(writeBuffer, writePos, binary.length, U32_SIZE);
     writePos += U32_SIZE;
     System.arraycopy(binary, 0, writeBuffer, writePos, binary.length);
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
index 4d1d70055f5e..ea90bb88a906 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
@@ -76,7 +76,8 @@ object VariantExpressionEvalUtils {
       case LongType => builder.appendLong(input.asInstanceOf[Long])
       case FloatType => builder.appendFloat(input.asInstanceOf[Float])
       case DoubleType => builder.appendDouble(input.asInstanceOf[Double])
-      case StringType => 
builder.appendString(input.asInstanceOf[UTF8String].toString)
+      case _: DecimalType => 
builder.appendDecimal(input.asInstanceOf[Decimal].toJavaBigDecimal)
+      case _: StringType => 
builder.appendString(input.asInstanceOf[UTF8String].toString)
       case BinaryType => builder.appendBinary(input.asInstanceOf[Array[Byte]])
       case DateType => builder.appendDate(input.asInstanceOf[Int])
       case TimestampType => builder.appendTimestamp(input.asInstanceOf[Long])
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index 9aa1dcd2ef95..1f9eec862bbe 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -807,9 +807,27 @@ class VariantExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     }
 
     check(null.asInstanceOf[String], null)
+    // The following tests cover all allowed scalar types.
     for (input <- Seq[Any](false, true, 0.toByte, 1.toShort, 2, 3L, 4.0F, 
5.0D)) {
       check(input, input.toString)
     }
+    for (precision <- Seq(9, 18, 38)) {
+      val input = BigDecimal("9" * precision)
+      check(Literal.create(input, DecimalType(precision, 0)), input.toString)
+    }
+    check("", "\"\"")
+    check("x" * 128, "\"" + ("x" * 128) + "\"")
+    check(Array[Byte](1, 2, 3), "\"AQID\"")
+    check(Literal(0, DateType), "\"1970-01-01\"")
+    withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+      check(Literal(0L, TimestampType), "\"1970-01-01 00:00:00+00:00\"")
+      check(Literal(0L, TimestampNTZType), "\"1970-01-01 00:00:00\"")
+    }
+    withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") {
+      check(Literal(0L, TimestampType), "\"1969-12-31 16:00:00-08:00\"")
+      check(Literal(0L, TimestampNTZType), "\"1970-01-01 00:00:00\"")
+    }
+
     check(Array(null, "a", "b", "c"), """[null,"a","b","c"]""")
     check(Map("z" -> 1, "y" -> 2, "x" -> 3), """{"x":3,"y":2,"z":1}""")
     check(Array(parseJson("""{"a": 1,"b": [1, 2, 3]}"""),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to