This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d817c9a60f51 [SPARK-47775][SQL] Support remaining scalar types in the 
variant spec
d817c9a60f51 is described below

commit d817c9a60f51ef8035c8d2b37a995976ae54aa47
Author: Chenhao Li <chenhao...@databricks.com>
AuthorDate: Wed Apr 10 22:51:17 2024 +0800

    [SPARK-47775][SQL] Support remaining scalar types in the variant spec
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for the remaining scalar types defined in the variant 
spec (DATE, TIMESTAMP, TIMESTAMP_NTZ, FLOAT, BINARY). The current `parse_json` 
expression doesn't produce these types, but we need them when we support 
casting a corresponding Spark type into the variant type.
    
    ### Why are the changes needed?
    
    This PR can be considered as a preparation for the cast-to-variant feature 
and will make the latter PR smaller.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.  Existing variant expressions can decode more variant scalar types.
    
    ### How was this patch tested?
    
    Unit tests. We manually construct variant values with these new scalar 
types and test the existing variant expressions on them.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45945 from chenhao-db/support_atomic_types.
    
    Authored-by: Chenhao Li <chenhao...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/unsafe/types/VariantVal.java  |   8 +-
 .../org/apache/spark/types/variant/Variant.java    |  69 ++++++++++++-
 .../apache/spark/types/variant/VariantUtil.java    |  71 ++++++++++++-
 .../spark/sql/catalyst/expressions/Cast.scala      |   7 +-
 .../expressions/variant/variantExpressions.scala   |  84 +++++++++++-----
 .../spark/sql/catalyst/json/JacksonGenerator.scala |   2 +-
 .../variant/VariantExpressionSuite.scala           | 112 +++++++++++++++++++++
 7 files changed, 314 insertions(+), 39 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
index 652c05daf344..a441bab4ac41 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
@@ -21,6 +21,8 @@ import org.apache.spark.unsafe.Platform;
 import org.apache.spark.types.variant.Variant;
 
 import java.io.Serializable;
+import java.time.ZoneId;
+import java.time.ZoneOffset;
 import java.util.Arrays;
 
 /**
@@ -99,13 +101,17 @@ public class VariantVal implements Serializable {
         '}';
   }
 
+  public String toJson(ZoneId zoneId) {
+    return new Variant(value, metadata).toJson(zoneId);
+  }
+
   /**
    * @return A human-readable representation of the Variant value. It is 
always a JSON string at
    * this moment.
    */
   @Override
   public String toString() {
-    return new Variant(value, metadata).toJson();
+    return toJson(ZoneOffset.UTC);
   }
 
   /**
diff --git 
a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java 
b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
index 8340aadd261f..4aeb2c6e1435 100644
--- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
+++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
@@ -23,7 +23,16 @@ import com.fasterxml.jackson.core.JsonGenerator;
 import java.io.CharArrayWriter;
 import java.io.IOException;
 import java.math.BigDecimal;
+import java.time.Instant;
+import java.time.LocalDate;
+import java.time.ZoneId;
+import java.time.ZoneOffset;
+import java.time.format.DateTimeFormatter;
+import java.time.format.DateTimeFormatterBuilder;
+import java.time.temporal.ChronoUnit;
 import java.util.Arrays;
+import java.util.Base64;
+import java.util.Locale;
 
 import static org.apache.spark.types.variant.VariantUtil.*;
 
@@ -89,6 +98,16 @@ public final class Variant {
     return VariantUtil.getDecimal(value, pos);
   }
 
+  // Get a float value from the variant.
+  public float getFloat() {
+    return VariantUtil.getFloat(value, pos);
+  }
+
+  // Get a binary value from the variant.
+  public byte[] getBinary() {
+    return VariantUtil.getBinary(value, pos);
+  }
+
   // Get a string value from the variant.
   public String getString() {
     return VariantUtil.getString(value, pos);
@@ -188,9 +207,9 @@ public final class Variant {
 
   // Stringify the variant in JSON format.
   // Throw `MALFORMED_VARIANT` if the variant is malformed.
-  public String toJson() {
+  public String toJson(ZoneId zoneId) {
     StringBuilder sb = new StringBuilder();
-    toJsonImpl(value, metadata, pos, sb);
+    toJsonImpl(value, metadata, pos, sb, zoneId);
     return sb.toString();
   }
 
@@ -208,7 +227,30 @@ public final class Variant {
     }
   }
 
-  static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder 
sb) {
+  // A simplified and more performant version of `sb.append(escapeJson(str))`. 
It is used when we
+  // know `str` doesn't contain any special character that needs escaping.
+  static void appendQuoted(StringBuilder sb, String str) {
+    sb.append('"');
+    sb.append(str);
+    sb.append('"');
+  }
+
+  private static final DateTimeFormatter TIMESTAMP_NTZ_FORMATTER = new 
DateTimeFormatterBuilder()
+      .append(DateTimeFormatter.ISO_LOCAL_DATE)
+      .appendLiteral(' ')
+      .append(DateTimeFormatter.ISO_LOCAL_TIME)
+      .toFormatter(Locale.US);
+
+  private static final DateTimeFormatter TIMESTAMP_FORMATTER = new 
DateTimeFormatterBuilder()
+      .append(TIMESTAMP_NTZ_FORMATTER)
+      .appendOffset("+HH:MM", "+00:00")
+      .toFormatter(Locale.US);
+
+  private static Instant microsToInstant(long timestamp) {
+    return Instant.EPOCH.plus(timestamp, ChronoUnit.MICROS);
+  }
+
+  static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder 
sb, ZoneId zoneId) {
     switch (VariantUtil.getType(value, pos)) {
       case OBJECT:
         handleObject(value, pos, (size, idSize, offsetSize, idStart, 
offsetStart, dataStart) -> {
@@ -220,7 +262,7 @@ public final class Variant {
             if (i != 0) sb.append(',');
             sb.append(escapeJson(getMetadataKey(metadata, id)));
             sb.append(':');
-            toJsonImpl(value, metadata, elementPos, sb);
+            toJsonImpl(value, metadata, elementPos, sb, zoneId);
           }
           sb.append('}');
           return null;
@@ -233,7 +275,7 @@ public final class Variant {
             int offset = readUnsigned(value, offsetStart + offsetSize * i, 
offsetSize);
             int elementPos = dataStart + offset;
             if (i != 0) sb.append(',');
-            toJsonImpl(value, metadata, elementPos, sb);
+            toJsonImpl(value, metadata, elementPos, sb, zoneId);
           }
           sb.append(']');
           return null;
@@ -257,6 +299,23 @@ public final class Variant {
       case DECIMAL:
         sb.append(VariantUtil.getDecimal(value, pos).toPlainString());
         break;
+      case DATE:
+        appendQuoted(sb, LocalDate.ofEpochDay((int) VariantUtil.getLong(value, 
pos)).toString());
+        break;
+      case TIMESTAMP:
+        appendQuoted(sb, TIMESTAMP_FORMATTER.format(
+            microsToInstant(VariantUtil.getLong(value, pos)).atZone(zoneId)));
+        break;
+      case TIMESTAMP_NTZ:
+        appendQuoted(sb, TIMESTAMP_NTZ_FORMATTER.format(
+            microsToInstant(VariantUtil.getLong(value, 
pos)).atZone(ZoneOffset.UTC)));
+        break;
+      case FLOAT:
+        sb.append(VariantUtil.getFloat(value, pos));
+        break;
+      case BINARY:
+        appendQuoted(sb, 
Base64.getEncoder().encodeToString(VariantUtil.getBinary(value, pos)));
+        break;
     }
   }
 }
diff --git 
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java 
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
index 1d579188ccdb..e4e9cc8b4cfa 100644
--- 
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
+++ 
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
@@ -23,6 +23,7 @@ import scala.collection.immutable.Map$;
 
 import java.math.BigDecimal;
 import java.math.BigInteger;
+import java.util.Arrays;
 
 /**
  * This class defines constants related to the variant format and provides 
functions for
@@ -101,6 +102,21 @@ public class VariantUtil {
   public static final int DECIMAL8 = 9;
   // 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed 
integer.
   public static final int DECIMAL16 = 10;
+  // Date value. Content is 4-byte little-endian signed integer that 
represents the number of days
+  // from the Unix epoch.
+  public static final int DATE = 11;
+  // Timestamp value. Content is 8-byte little-endian signed integer that 
represents the number of
+  // microseconds elapsed since the Unix epoch, 1970-01-01 00:00:00 UTC. It is 
displayed to users in
+  // their local time zones and may be displayed differently depending on the 
execution environment.
+  public static final int TIMESTAMP = 12;
+  // Timestamp_ntz value. It has the same content as `TIMESTAMP` but should 
always be interpreted
+  // as if the local time zone is UTC.
+  public static final int TIMESTAMP_NTZ = 13;
+  // 4-byte IEEE float.
+  public static final int FLOAT = 14;
+  // Binary value. The content is (4-byte little-endian unsigned integer 
representing the binary
+  // size) + (size bytes of binary content).
+  public static final int BINARY = 15;
   // Long string value. The content is (4-byte little-endian unsigned integer 
representing the
   // string size) + (size bytes of string content).
   public static final int LONG_STR = 16;
@@ -212,6 +228,11 @@ public class VariantUtil {
     STRING,
     DOUBLE,
     DECIMAL,
+    DATE,
+    TIMESTAMP,
+    TIMESTAMP_NTZ,
+    FLOAT,
+    BINARY,
   }
 
   // Get the value type of variant value `value[pos...]`. It is only legal to 
call `get*` if
@@ -247,6 +268,16 @@ public class VariantUtil {
           case DECIMAL8:
           case DECIMAL16:
             return Type.DECIMAL;
+          case DATE:
+            return Type.DATE;
+          case TIMESTAMP:
+            return Type.TIMESTAMP;
+          case TIMESTAMP_NTZ:
+            return Type.TIMESTAMP_NTZ;
+          case FLOAT:
+            return Type.FLOAT;
+          case BINARY:
+            return Type.BINARY;
           case LONG_STR:
             return Type.STRING;
           default:
@@ -283,9 +314,13 @@ public class VariantUtil {
           case INT2:
             return 3;
           case INT4:
+          case DATE:
+          case FLOAT:
             return 5;
           case INT8:
           case DOUBLE:
+          case TIMESTAMP:
+          case TIMESTAMP_NTZ:
             return 9;
           case DECIMAL4:
             return 6;
@@ -293,6 +328,7 @@ public class VariantUtil {
             return 10;
           case DECIMAL16:
             return 18;
+          case BINARY:
           case LONG_STR:
             return 1 + U32_SIZE + readUnsigned(value, pos + 1, U32_SIZE);
           default:
@@ -318,23 +354,31 @@ public class VariantUtil {
   }
 
   // Get a long value from variant value `value[pos...]`.
+  // It is only legal to call it if `getType` returns one of 
`Type.LONG/DATE/TIMESTAMP/
+  // TIMESTAMP_NTZ`. If the type is `DATE`, the return value is guaranteed to 
fit into an int and
+  // represents the number of days from the Unix epoch. If the type is 
`TIMESTAMP/TIMESTAMP_NTZ`,
+  // the return value represents the number of microseconds from the Unix 
epoch.
   // Throw `MALFORMED_VARIANT` if the variant is malformed.
   public static long getLong(byte[] value, int pos) {
     checkIndex(pos, value.length);
     int basicType = value[pos] & BASIC_TYPE_MASK;
     int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
-    if (basicType != PRIMITIVE) throw unexpectedType(Type.LONG);
+    String exceptionMessage = "Expect type to be 
LONG/DATE/TIMESTAMP/TIMESTAMP_NTZ";
+    if (basicType != PRIMITIVE) throw new 
IllegalStateException(exceptionMessage);
     switch (typeInfo) {
       case INT1:
         return readLong(value, pos + 1, 1);
       case INT2:
         return readLong(value, pos + 1, 2);
       case INT4:
+      case DATE:
         return readLong(value, pos + 1, 4);
       case INT8:
+      case TIMESTAMP:
+      case TIMESTAMP_NTZ:
         return readLong(value, pos + 1, 8);
       default:
-        throw unexpectedType(Type.LONG);
+        throw new IllegalStateException(exceptionMessage);
     }
   }
 
@@ -380,6 +424,29 @@ public class VariantUtil {
     return result.stripTrailingZeros();
   }
 
+  // Get a float value from variant value `value[pos...]`.
+  // Throw `MALFORMED_VARIANT` if the variant is malformed.
+  public static float getFloat(byte[] value, int pos) {
+    checkIndex(pos, value.length);
+    int basicType = value[pos] & BASIC_TYPE_MASK;
+    int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+    if (basicType != PRIMITIVE || typeInfo != FLOAT) throw 
unexpectedType(Type.FLOAT);
+    return Float.intBitsToFloat((int) readLong(value, pos + 1, 4));
+  }
+
+  // Get a binary value from variant value `value[pos...]`.
+  // Throw `MALFORMED_VARIANT` if the variant is malformed.
+  public static byte[] getBinary(byte[] value, int pos) {
+    checkIndex(pos, value.length);
+    int basicType = value[pos] & BASIC_TYPE_MASK;
+    int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+    if (basicType != PRIMITIVE || typeInfo != BINARY) throw 
unexpectedType(Type.BINARY);
+    int start = pos + 1 + U32_SIZE;
+    int length = readUnsigned(value, pos + 1, U32_SIZE);
+    checkIndex(start + length - 1, value.length);
+    return Arrays.copyOfRange(value, start, start + length);
+  }
+
   // Get a string value from variant value `value[pos...]`.
   // Throw `MALFORMED_VARIANT` if the variant is malformed.
   public static String getString(byte[] value, int pos) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8a077d9e9acb..94cf7130d485 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -1114,7 +1114,7 @@ case class Cast(
       _ => throw QueryExecutionErrors.cannotCastFromNullTypeError(to)
     } else if (from.isInstanceOf[VariantType]) {
       buildCast[VariantVal](_, v => {
-        variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId)
+        variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId, 
zoneId)
       })
     } else {
       to match {
@@ -1211,11 +1211,12 @@ case class Cast(
     case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) =>
       val tmp = ctx.freshVariable("tmp", classOf[Object])
       val dataTypeArg = ctx.addReferenceObj("dataType", to)
-      val zoneIdArg = ctx.addReferenceObj("zoneId", timeZoneId)
+      val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
+      val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, 
classOf[ZoneId].getName)
       val failOnError = evalMode != EvalMode.TRY
       val cls = classOf[variant.VariantGet].getName
       code"""
-        Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneIdArg);
+        Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg, 
$zoneIdArg);
         if ($tmp == null) {
           $evNull = true;
         } else {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index c5e316dc6c8c..8b09bf5f7de0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions.variant
 
+import java.time.ZoneId
+
 import scala.util.parsing.combinator.RegexParsers
 
 import org.apache.spark.SparkRuntimeException
@@ -170,7 +172,8 @@ case class VariantGet(
       parsedPath,
       dataType,
       failOnError,
-      timeZoneId)
+      timeZoneId,
+      zoneId)
   }
 
   protected override def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
@@ -178,14 +181,15 @@ case class VariantGet(
     val tmp = ctx.freshVariable("tmp", classOf[Object])
     val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath)
     val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
-    val zoneIdArg = ctx.addReferenceObj("zoneId", timeZoneId)
+    val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
+    val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, 
classOf[ZoneId].getName)
     val code = code"""
       ${childCode.code}
       boolean ${ev.isNull} = ${childCode.isNull};
       ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
       if (!${ev.isNull}) {
         Object $tmp = 
org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
-          ${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError, 
$zoneIdArg);
+          ${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError, 
$zoneStrArg, $zoneIdArg);
         if ($tmp == null) {
           ${ev.isNull} = true;
         } else {
@@ -228,7 +232,8 @@ case object VariantGet {
       parsedPath: Array[VariantPathParser.PathSegment],
       dataType: DataType,
       failOnError: Boolean,
-      zoneId: Option[String]): Any = {
+      zoneStr: Option[String],
+      zoneId: ZoneId): Any = {
     var v = new Variant(input.getValue, input.getMetadata)
     for (path <- parsedPath) {
       v = path match {
@@ -238,7 +243,7 @@ case object VariantGet {
       }
       if (v == null) return null
     }
-    VariantGet.cast(v, dataType, failOnError, zoneId)
+    VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId)
   }
 
   /**
@@ -249,9 +254,10 @@ case object VariantGet {
       input: VariantVal,
       dataType: DataType,
       failOnError: Boolean,
-      zoneId: Option[String]): Any = {
+      zoneStr: Option[String],
+      zoneId: ZoneId): Any = {
     val v = new Variant(input.getValue, input.getMetadata)
-    VariantGet.cast(v, dataType, failOnError, zoneId)
+    VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId)
   }
 
   /**
@@ -261,9 +267,19 @@ case object VariantGet {
    * "hello" to int). If the cast fails, throw an exception when `failOnError` 
is true, or return a
    * SQL NULL when it is false.
    */
-  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: 
Option[String]): Any = {
-    def invalidCast(): Any =
-      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, 
dataType) else null
+  def cast(
+      v: Variant,
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneStr: Option[String],
+      zoneId: ZoneId): Any = {
+    def invalidCast(): Any = {
+      if (failOnError) {
+        throw QueryExecutionErrors.invalidVariantCast(v.toJson(zoneId), 
dataType)
+      } else {
+        null
+      }
+    }
 
     if (dataType == VariantType) return new VariantVal(v.getValue, 
v.getMetadata)
     val variantType = v.getType
@@ -273,15 +289,22 @@ case object VariantGet {
         val input = variantType match {
           case Type.OBJECT | Type.ARRAY =>
             return if (dataType.isInstanceOf[StringType]) {
-              UTF8String.fromString(v.toJson)
+              UTF8String.fromString(v.toJson(zoneId))
             } else {
               invalidCast()
             }
-          case Type.BOOLEAN => v.getBoolean
-          case Type.LONG => v.getLong
-          case Type.STRING => UTF8String.fromString(v.getString)
-          case Type.DOUBLE => v.getDouble
-          case Type.DECIMAL => Decimal(v.getDecimal)
+          case Type.BOOLEAN => Literal(v.getBoolean, BooleanType)
+          case Type.LONG => Literal(v.getLong, LongType)
+          case Type.STRING => Literal(UTF8String.fromString(v.getString), 
StringType)
+          case Type.DOUBLE => Literal(v.getDouble, DoubleType)
+          case Type.DECIMAL =>
+            val d = Decimal(v.getDecimal)
+            Literal(Decimal(v.getDecimal), DecimalType(d.precision, d.scale))
+          case Type.DATE => Literal(v.getLong.toInt, DateType)
+          case Type.TIMESTAMP => Literal(v.getLong, TimestampType)
+          case Type.TIMESTAMP_NTZ => Literal(v.getLong, TimestampNTZType)
+          case Type.FLOAT => Literal(v.getFloat, FloatType)
+          case Type.BINARY => Literal(v.getBinary, BinaryType)
           // We have handled other cases and should never reach here. This 
case is only intended
           // to by pass the compiler exhaustiveness check.
           case _ => throw QueryExecutionErrors.unreachableError()
@@ -289,15 +312,17 @@ case object VariantGet {
         // We mostly use the `Cast` expression to implement the cast. However, 
`Cast` silently
         // ignores the overflow in the long/decimal -> timestamp cast, and we 
want to enforce
         // strict overflow checks.
-        input match {
-          case l: Long if dataType == TimestampType =>
-            try Math.multiplyExact(l, MICROS_PER_SECOND)
+        input.dataType match {
+          case LongType if dataType == TimestampType =>
+            try Math.multiplyExact(input.value.asInstanceOf[Long], 
MICROS_PER_SECOND)
             catch {
               case _: ArithmeticException => invalidCast()
             }
-          case d: Decimal if dataType == TimestampType =>
+          case _: DecimalType if dataType == TimestampType =>
             try {
-              d.toJavaBigDecimal
+              input.value
+                .asInstanceOf[Decimal]
+                .toJavaBigDecimal
                 .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
                 .toBigInteger
                 .longValueExact()
@@ -305,9 +330,8 @@ case object VariantGet {
               case _: ArithmeticException => invalidCast()
             }
           case _ =>
-            val inputLiteral = Literal(input)
-            if (Cast.canAnsiCast(inputLiteral.dataType, dataType)) {
-              val result = Cast(inputLiteral, dataType, zoneId, 
EvalMode.TRY).eval()
+            if (Cast.canAnsiCast(input.dataType, dataType)) {
+              val result = Cast(input, dataType, zoneStr, EvalMode.TRY).eval()
               if (result == null) invalidCast() else result
             } else {
               invalidCast()
@@ -318,7 +342,7 @@ case object VariantGet {
           val size = v.arraySize()
           val array = new Array[Any](size)
           for (i <- 0 until size) {
-            array(i) = cast(v.getElementAtIndex(i), elementType, failOnError, 
zoneId)
+            array(i) = cast(v.getElementAtIndex(i), elementType, failOnError, 
zoneStr, zoneId)
           }
           new GenericArrayData(array)
         } else {
@@ -332,7 +356,7 @@ case object VariantGet {
           for (i <- 0 until size) {
             val field = v.getFieldAtIndex(i)
             keyArray(i) = UTF8String.fromString(field.key)
-            valueArray(i) = cast(field.value, valueType, failOnError, zoneId)
+            valueArray(i) = cast(field.value, valueType, failOnError, zoneStr, 
zoneId)
           }
           ArrayBasedMapData(keyArray, valueArray)
         } else {
@@ -345,7 +369,8 @@ case object VariantGet {
             val field = v.getFieldAtIndex(i)
             st.getFieldIndex(field.key) match {
               case Some(idx) =>
-                row.update(idx, cast(field.value, fields(idx).dataType, 
failOnError, zoneId))
+                row.update(idx,
+                  cast(field.value, fields(idx).dataType, failOnError, 
zoneStr, zoneId))
               case _ =>
             }
           }
@@ -576,6 +601,11 @@ object SchemaOfVariant {
     case Type.DECIMAL =>
       val d = v.getDecimal
       DecimalType(d.precision(), d.scale())
+    case Type.DATE => DateType
+    case Type.TIMESTAMP => TimestampType
+    case Type.TIMESTAMP_NTZ => TimestampNTZType
+    case Type.FLOAT => FloatType
+    case Type.BINARY => BinaryType
   }
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index 1964b5f24b34..80f2b2a0070c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -317,7 +317,7 @@ class JacksonGenerator(
   }
 
   def write(v: VariantVal): Unit = {
-    gen.writeRawValue(v.toString)
+    gen.writeRawValue(v.toJson(options.zoneId))
   }
 
   def writeLineEnding(): Unit = {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index a5863e80a26c..24675518646d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.DateTimeConstants._
 import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -685,4 +686,115 @@ class VariantExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       Array(null, null, 1)
     )
   }
+
+  test("atomic types that are not produced by parse_json") {
+    // Dictionary size is `0` for value 0. An empty dictionary contains one 
offset `0` for the
+    // one-past-the-end position (i.e. the sum of all string lengths).
+    val emptyMetadata = Array[Byte](VERSION, 0, 0)
+
+    def checkToJson(value: Array[Byte], expected: String): Unit = {
+      val input = Literal(new VariantVal(value, emptyMetadata))
+      checkEvaluation(StructsToJson(Map.empty, input), expected)
+    }
+
+    def checkCast(value: Array[Byte], dataType: DataType, expected: Any): Unit 
= {
+      val input = Literal(new VariantVal(value, emptyMetadata))
+      checkEvaluation(Cast(input, dataType, evalMode = EvalMode.ANSI), 
expected)
+    }
+
+    checkToJson(Array(primitiveHeader(DATE), 0, 0, 0, 0), "\"1970-01-01\"")
+    checkToJson(Array(primitiveHeader(DATE), -1, -1, -1, 127), 
"\"+5881580-07-11\"")
+    checkToJson(Array(primitiveHeader(DATE), 0, 0, 0, -128), 
"\"-5877641-06-23\"")
+    withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+      checkCast(Array(primitiveHeader(DATE), 0, 0, 0, 0), TimestampType, 0L)
+      checkCast(Array(primitiveHeader(DATE), 1, 0, 0, 0), TimestampType, 
MICROS_PER_DAY)
+    }
+    withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") {
+      checkCast(Array(primitiveHeader(DATE), 0, 0, 0, 0), TimestampType, 8 * 
MICROS_PER_HOUR)
+      checkCast(Array(primitiveHeader(DATE), 1, 0, 0, 0), TimestampType,
+        MICROS_PER_DAY + 8 * MICROS_PER_HOUR)
+    }
+
+    def littleEndianLong(value: Long): Array[Byte] =
+      BigInt(value).toByteArray.reverse.padTo(8, 0.toByte)
+
+    val time1 = littleEndianLong(0)
+    // In America/Los_Angeles timezone, timestamp value `skippedTime` is 
2011-03-13 03:00:00.
+    // The next second of 2011-03-13 01:59:59 jumps to 2011-03-13 03:00:00.
+    val skippedTime = 1300010400000000L
+    val time2 = littleEndianLong(skippedTime)
+    val time3 = littleEndianLong(skippedTime - 1)
+    val time4 = littleEndianLong(Long.MinValue)
+    val time5 = littleEndianLong(Long.MaxValue)
+    val time6 = littleEndianLong(-62198755200000000L)
+    val timestampHeader = Array(primitiveHeader(TIMESTAMP))
+    val timestampNtzHeader = Array(primitiveHeader(TIMESTAMP_NTZ))
+
+    for (timeZone <- Seq("UTC", "America/Los_Angeles")) {
+      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
+        checkToJson(timestampNtzHeader ++ time1, "\"1970-01-01 00:00:00\"")
+        checkToJson(timestampNtzHeader ++ time2, "\"2011-03-13 10:00:00\"")
+        checkToJson(timestampNtzHeader ++ time3, "\"2011-03-13 
09:59:59.999999\"")
+        checkToJson(timestampNtzHeader ++ time4, "\"-290308-12-21 
19:59:05.224192\"")
+        checkToJson(timestampNtzHeader ++ time5, "\"+294247-01-10 
04:00:54.775807\"")
+        checkToJson(timestampNtzHeader ++ time6, "\"-0001-01-01 00:00:00\"")
+
+        checkCast(timestampNtzHeader ++ time1, DateType, 0)
+        checkCast(timestampNtzHeader ++ time2, DateType, 15046)
+        checkCast(timestampNtzHeader ++ time3, DateType, 15046)
+        checkCast(timestampNtzHeader ++ time4, DateType, -106751992)
+        checkCast(timestampNtzHeader ++ time5, DateType, 106751991)
+        checkCast(timestampNtzHeader ++ time6, DateType, -719893)
+      }
+    }
+
+    withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+      checkToJson(timestampHeader ++ time1, "\"1970-01-01 00:00:00+00:00\"")
+      checkToJson(timestampHeader ++ time2, "\"2011-03-13 10:00:00+00:00\"")
+      checkToJson(timestampHeader ++ time3, "\"2011-03-13 
09:59:59.999999+00:00\"")
+      checkToJson(timestampHeader ++ time4, "\"-290308-12-21 
19:59:05.224192+00:00\"")
+      checkToJson(timestampHeader ++ time5, "\"+294247-01-10 
04:00:54.775807+00:00\"")
+      checkToJson(timestampHeader ++ time6, "\"-0001-01-01 00:00:00+00:00\"")
+
+      checkCast(timestampHeader ++ time1, DateType, 0)
+      checkCast(timestampHeader ++ time2, DateType, 15046)
+      checkCast(timestampHeader ++ time3, DateType, 15046)
+      checkCast(timestampHeader ++ time4, DateType, -106751992)
+      checkCast(timestampHeader ++ time5, DateType, 106751991)
+      checkCast(timestampHeader ++ time6, DateType, -719893)
+    }
+
+    withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") {
+      checkToJson(timestampHeader ++ time1, "\"1969-12-31 16:00:00-08:00\"")
+      checkToJson(timestampHeader ++ time2, "\"2011-03-13 03:00:00-07:00\"")
+      checkToJson(timestampHeader ++ time3, "\"2011-03-13 
01:59:59.999999-08:00\"")
+      checkToJson(timestampHeader ++ time4, "\"-290308-12-21 
12:06:07.224192-07:52\"")
+      checkToJson(timestampHeader ++ time5, "\"+294247-01-09 
20:00:54.775807-08:00\"")
+      checkToJson(timestampHeader ++ time6, "\"-0002-12-31 16:07:02-07:52\"")
+
+      checkCast(timestampHeader ++ time1, DateType, -1)
+      checkCast(timestampHeader ++ time2, DateType, 15046)
+      checkCast(timestampHeader ++ time3, DateType, 15046)
+      checkCast(timestampHeader ++ time4, DateType, -106751992)
+      checkCast(timestampHeader ++ time5, DateType, 106751990)
+      checkCast(timestampHeader ++ time6, DateType, -719894)
+    }
+
+    checkToJson(Array(primitiveHeader(FLOAT)) ++
+      BigInt(java.lang.Float.floatToIntBits(1.23F)).toByteArray.reverse, 
"1.23")
+    checkToJson(Array(primitiveHeader(FLOAT)) ++
+      BigInt(java.lang.Float.floatToIntBits(-0.0F)).toByteArray.reverse, 
"-0.0")
+    // Note: 1.23F.toDouble != 1.23.
+    checkCast(Array(primitiveHeader(FLOAT)) ++
+      BigInt(java.lang.Float.floatToIntBits(1.23F)).toByteArray.reverse, 
DoubleType, 1.23F.toDouble)
+
+    checkToJson(Array(primitiveHeader(BINARY), 0, 0, 0, 0), "\"\"")
+    checkToJson(Array(primitiveHeader(BINARY), 1, 0, 0, 0, 1), "\"AQ==\"")
+    checkToJson(Array(primitiveHeader(BINARY), 2, 0, 0, 0, 1, 2), "\"AQI=\"")
+    checkToJson(Array(primitiveHeader(BINARY), 3, 0, 0, 0, 1, 2, 3), 
"\"AQID\"")
+    checkCast(Array(primitiveHeader(BINARY), 3, 0, 0, 0, 1, 2, 3), StringType,
+      "\u0001\u0002\u0003")
+    checkCast(Array(primitiveHeader(BINARY), 5, 0, 0, 0, 72, 101, 108, 108, 
111), StringType,
+      "Hello")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to