This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d817c9a60f51 [SPARK-47775][SQL] Support remaining scalar types in the variant spec d817c9a60f51 is described below commit d817c9a60f51ef8035c8d2b37a995976ae54aa47 Author: Chenhao Li <chenhao...@databricks.com> AuthorDate: Wed Apr 10 22:51:17 2024 +0800 [SPARK-47775][SQL] Support remaining scalar types in the variant spec ### What changes were proposed in this pull request? This PR adds support for the remaining scalar types defined in the variant spec (DATE, TIMESTAMP, TIMESTAMP_NTZ, FLOAT, BINARY). The current `parse_json` expression doesn't produce these types, but we need them when we support casting a corresponding Spark type into the variant type. ### Why are the changes needed? This PR can be considered as a preparation for the cast-to-variant feature and will make the latter PR smaller. ### Does this PR introduce _any_ user-facing change? Yes. Existing variant expressions can decode more variant scalar types. ### How was this patch tested? Unit tests. We manually construct variant values with these new scalar types and test the existing variant expressions on them. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45945 from chenhao-db/support_atomic_types. Authored-by: Chenhao Li <chenhao...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../org/apache/spark/unsafe/types/VariantVal.java | 8 +- .../org/apache/spark/types/variant/Variant.java | 69 ++++++++++++- .../apache/spark/types/variant/VariantUtil.java | 71 ++++++++++++- .../spark/sql/catalyst/expressions/Cast.scala | 7 +- .../expressions/variant/variantExpressions.scala | 84 +++++++++++----- .../spark/sql/catalyst/json/JacksonGenerator.scala | 2 +- .../variant/VariantExpressionSuite.scala | 112 +++++++++++++++++++++ 7 files changed, 314 insertions(+), 39 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java index 652c05daf344..a441bab4ac41 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java @@ -21,6 +21,8 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.types.variant.Variant; import java.io.Serializable; +import java.time.ZoneId; +import java.time.ZoneOffset; import java.util.Arrays; /** @@ -99,13 +101,17 @@ public class VariantVal implements Serializable { '}'; } + public String toJson(ZoneId zoneId) { + return new Variant(value, metadata).toJson(zoneId); + } + /** * @return A human-readable representation of the Variant value. It is always a JSON string at * this moment. */ @Override public String toString() { - return new Variant(value, metadata).toJson(); + return toJson(ZoneOffset.UTC); } /** diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java index 8340aadd261f..4aeb2c6e1435 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java @@ -23,7 +23,16 @@ import com.fasterxml.jackson.core.JsonGenerator; import java.io.CharArrayWriter; import java.io.IOException; import java.math.BigDecimal; +import java.time.Instant; +import java.time.LocalDate; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeFormatterBuilder; +import java.time.temporal.ChronoUnit; import java.util.Arrays; +import java.util.Base64; +import java.util.Locale; import static org.apache.spark.types.variant.VariantUtil.*; @@ -89,6 +98,16 @@ public final class Variant { return VariantUtil.getDecimal(value, pos); } + // Get a float value from the variant. + public float getFloat() { + return VariantUtil.getFloat(value, pos); + } + + // Get a binary value from the variant. + public byte[] getBinary() { + return VariantUtil.getBinary(value, pos); + } + // Get a string value from the variant. public String getString() { return VariantUtil.getString(value, pos); @@ -188,9 +207,9 @@ public final class Variant { // Stringify the variant in JSON format. // Throw `MALFORMED_VARIANT` if the variant is malformed. - public String toJson() { + public String toJson(ZoneId zoneId) { StringBuilder sb = new StringBuilder(); - toJsonImpl(value, metadata, pos, sb); + toJsonImpl(value, metadata, pos, sb, zoneId); return sb.toString(); } @@ -208,7 +227,30 @@ public final class Variant { } } - static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb) { + // A simplified and more performant version of `sb.append(escapeJson(str))`. It is used when we + // know `str` doesn't contain any special character that needs escaping. + static void appendQuoted(StringBuilder sb, String str) { + sb.append('"'); + sb.append(str); + sb.append('"'); + } + + private static final DateTimeFormatter TIMESTAMP_NTZ_FORMATTER = new DateTimeFormatterBuilder() + .append(DateTimeFormatter.ISO_LOCAL_DATE) + .appendLiteral(' ') + .append(DateTimeFormatter.ISO_LOCAL_TIME) + .toFormatter(Locale.US); + + private static final DateTimeFormatter TIMESTAMP_FORMATTER = new DateTimeFormatterBuilder() + .append(TIMESTAMP_NTZ_FORMATTER) + .appendOffset("+HH:MM", "+00:00") + .toFormatter(Locale.US); + + private static Instant microsToInstant(long timestamp) { + return Instant.EPOCH.plus(timestamp, ChronoUnit.MICROS); + } + + static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb, ZoneId zoneId) { switch (VariantUtil.getType(value, pos)) { case OBJECT: handleObject(value, pos, (size, idSize, offsetSize, idStart, offsetStart, dataStart) -> { @@ -220,7 +262,7 @@ public final class Variant { if (i != 0) sb.append(','); sb.append(escapeJson(getMetadataKey(metadata, id))); sb.append(':'); - toJsonImpl(value, metadata, elementPos, sb); + toJsonImpl(value, metadata, elementPos, sb, zoneId); } sb.append('}'); return null; @@ -233,7 +275,7 @@ public final class Variant { int offset = readUnsigned(value, offsetStart + offsetSize * i, offsetSize); int elementPos = dataStart + offset; if (i != 0) sb.append(','); - toJsonImpl(value, metadata, elementPos, sb); + toJsonImpl(value, metadata, elementPos, sb, zoneId); } sb.append(']'); return null; @@ -257,6 +299,23 @@ public final class Variant { case DECIMAL: sb.append(VariantUtil.getDecimal(value, pos).toPlainString()); break; + case DATE: + appendQuoted(sb, LocalDate.ofEpochDay((int) VariantUtil.getLong(value, pos)).toString()); + break; + case TIMESTAMP: + appendQuoted(sb, TIMESTAMP_FORMATTER.format( + microsToInstant(VariantUtil.getLong(value, pos)).atZone(zoneId))); + break; + case TIMESTAMP_NTZ: + appendQuoted(sb, TIMESTAMP_NTZ_FORMATTER.format( + microsToInstant(VariantUtil.getLong(value, pos)).atZone(ZoneOffset.UTC))); + break; + case FLOAT: + sb.append(VariantUtil.getFloat(value, pos)); + break; + case BINARY: + appendQuoted(sb, Base64.getEncoder().encodeToString(VariantUtil.getBinary(value, pos))); + break; } } } diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java index 1d579188ccdb..e4e9cc8b4cfa 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java @@ -23,6 +23,7 @@ import scala.collection.immutable.Map$; import java.math.BigDecimal; import java.math.BigInteger; +import java.util.Arrays; /** * This class defines constants related to the variant format and provides functions for @@ -101,6 +102,21 @@ public class VariantUtil { public static final int DECIMAL8 = 9; // 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed integer. public static final int DECIMAL16 = 10; + // Date value. Content is 4-byte little-endian signed integer that represents the number of days + // from the Unix epoch. + public static final int DATE = 11; + // Timestamp value. Content is 8-byte little-endian signed integer that represents the number of + // microseconds elapsed since the Unix epoch, 1970-01-01 00:00:00 UTC. It is displayed to users in + // their local time zones and may be displayed differently depending on the execution environment. + public static final int TIMESTAMP = 12; + // Timestamp_ntz value. It has the same content as `TIMESTAMP` but should always be interpreted + // as if the local time zone is UTC. + public static final int TIMESTAMP_NTZ = 13; + // 4-byte IEEE float. + public static final int FLOAT = 14; + // Binary value. The content is (4-byte little-endian unsigned integer representing the binary + // size) + (size bytes of binary content). + public static final int BINARY = 15; // Long string value. The content is (4-byte little-endian unsigned integer representing the // string size) + (size bytes of string content). public static final int LONG_STR = 16; @@ -212,6 +228,11 @@ public class VariantUtil { STRING, DOUBLE, DECIMAL, + DATE, + TIMESTAMP, + TIMESTAMP_NTZ, + FLOAT, + BINARY, } // Get the value type of variant value `value[pos...]`. It is only legal to call `get*` if @@ -247,6 +268,16 @@ public class VariantUtil { case DECIMAL8: case DECIMAL16: return Type.DECIMAL; + case DATE: + return Type.DATE; + case TIMESTAMP: + return Type.TIMESTAMP; + case TIMESTAMP_NTZ: + return Type.TIMESTAMP_NTZ; + case FLOAT: + return Type.FLOAT; + case BINARY: + return Type.BINARY; case LONG_STR: return Type.STRING; default: @@ -283,9 +314,13 @@ public class VariantUtil { case INT2: return 3; case INT4: + case DATE: + case FLOAT: return 5; case INT8: case DOUBLE: + case TIMESTAMP: + case TIMESTAMP_NTZ: return 9; case DECIMAL4: return 6; @@ -293,6 +328,7 @@ public class VariantUtil { return 10; case DECIMAL16: return 18; + case BINARY: case LONG_STR: return 1 + U32_SIZE + readUnsigned(value, pos + 1, U32_SIZE); default: @@ -318,23 +354,31 @@ public class VariantUtil { } // Get a long value from variant value `value[pos...]`. + // It is only legal to call it if `getType` returns one of `Type.LONG/DATE/TIMESTAMP/ + // TIMESTAMP_NTZ`. If the type is `DATE`, the return value is guaranteed to fit into an int and + // represents the number of days from the Unix epoch. If the type is `TIMESTAMP/TIMESTAMP_NTZ`, + // the return value represents the number of microseconds from the Unix epoch. // Throw `MALFORMED_VARIANT` if the variant is malformed. public static long getLong(byte[] value, int pos) { checkIndex(pos, value.length); int basicType = value[pos] & BASIC_TYPE_MASK; int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; - if (basicType != PRIMITIVE) throw unexpectedType(Type.LONG); + String exceptionMessage = "Expect type to be LONG/DATE/TIMESTAMP/TIMESTAMP_NTZ"; + if (basicType != PRIMITIVE) throw new IllegalStateException(exceptionMessage); switch (typeInfo) { case INT1: return readLong(value, pos + 1, 1); case INT2: return readLong(value, pos + 1, 2); case INT4: + case DATE: return readLong(value, pos + 1, 4); case INT8: + case TIMESTAMP: + case TIMESTAMP_NTZ: return readLong(value, pos + 1, 8); default: - throw unexpectedType(Type.LONG); + throw new IllegalStateException(exceptionMessage); } } @@ -380,6 +424,29 @@ public class VariantUtil { return result.stripTrailingZeros(); } + // Get a float value from variant value `value[pos...]`. + // Throw `MALFORMED_VARIANT` if the variant is malformed. + public static float getFloat(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType != PRIMITIVE || typeInfo != FLOAT) throw unexpectedType(Type.FLOAT); + return Float.intBitsToFloat((int) readLong(value, pos + 1, 4)); + } + + // Get a binary value from variant value `value[pos...]`. + // Throw `MALFORMED_VARIANT` if the variant is malformed. + public static byte[] getBinary(byte[] value, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; + if (basicType != PRIMITIVE || typeInfo != BINARY) throw unexpectedType(Type.BINARY); + int start = pos + 1 + U32_SIZE; + int length = readUnsigned(value, pos + 1, U32_SIZE); + checkIndex(start + length - 1, value.length); + return Arrays.copyOfRange(value, start, start + length); + } + // Get a string value from variant value `value[pos...]`. // Throw `MALFORMED_VARIANT` if the variant is malformed. public static String getString(byte[] value, int pos) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8a077d9e9acb..94cf7130d485 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1114,7 +1114,7 @@ case class Cast( _ => throw QueryExecutionErrors.cannotCastFromNullTypeError(to) } else if (from.isInstanceOf[VariantType]) { buildCast[VariantVal](_, v => { - variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId) + variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId, zoneId) }) } else { to match { @@ -1211,11 +1211,12 @@ case class Cast( case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) => val tmp = ctx.freshVariable("tmp", classOf[Object]) val dataTypeArg = ctx.addReferenceObj("dataType", to) - val zoneIdArg = ctx.addReferenceObj("zoneId", timeZoneId) + val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId) + val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) val failOnError = evalMode != EvalMode.TRY val cls = classOf[variant.VariantGet].getName code""" - Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneIdArg); + Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg, $zoneIdArg); if ($tmp == null) { $evNull = true; } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index c5e316dc6c8c..8b09bf5f7de0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.variant +import java.time.ZoneId + import scala.util.parsing.combinator.RegexParsers import org.apache.spark.SparkRuntimeException @@ -170,7 +172,8 @@ case class VariantGet( parsedPath, dataType, failOnError, - timeZoneId) + timeZoneId, + zoneId) } protected override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -178,14 +181,15 @@ case class VariantGet( val tmp = ctx.freshVariable("tmp", classOf[Object]) val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath) val dataTypeArg = ctx.addReferenceObj("dataType", dataType) - val zoneIdArg = ctx.addReferenceObj("zoneId", timeZoneId) + val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId) + val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) val code = code""" ${childCode.code} boolean ${ev.isNull} = ${childCode.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { Object $tmp = org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet( - ${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError, $zoneIdArg); + ${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError, $zoneStrArg, $zoneIdArg); if ($tmp == null) { ${ev.isNull} = true; } else { @@ -228,7 +232,8 @@ case object VariantGet { parsedPath: Array[VariantPathParser.PathSegment], dataType: DataType, failOnError: Boolean, - zoneId: Option[String]): Any = { + zoneStr: Option[String], + zoneId: ZoneId): Any = { var v = new Variant(input.getValue, input.getMetadata) for (path <- parsedPath) { v = path match { @@ -238,7 +243,7 @@ case object VariantGet { } if (v == null) return null } - VariantGet.cast(v, dataType, failOnError, zoneId) + VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId) } /** @@ -249,9 +254,10 @@ case object VariantGet { input: VariantVal, dataType: DataType, failOnError: Boolean, - zoneId: Option[String]): Any = { + zoneStr: Option[String], + zoneId: ZoneId): Any = { val v = new Variant(input.getValue, input.getMetadata) - VariantGet.cast(v, dataType, failOnError, zoneId) + VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId) } /** @@ -261,9 +267,19 @@ case object VariantGet { * "hello" to int). If the cast fails, throw an exception when `failOnError` is true, or return a * SQL NULL when it is false. */ - def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: Option[String]): Any = { - def invalidCast(): Any = - if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, dataType) else null + def cast( + v: Variant, + dataType: DataType, + failOnError: Boolean, + zoneStr: Option[String], + zoneId: ZoneId): Any = { + def invalidCast(): Any = { + if (failOnError) { + throw QueryExecutionErrors.invalidVariantCast(v.toJson(zoneId), dataType) + } else { + null + } + } if (dataType == VariantType) return new VariantVal(v.getValue, v.getMetadata) val variantType = v.getType @@ -273,15 +289,22 @@ case object VariantGet { val input = variantType match { case Type.OBJECT | Type.ARRAY => return if (dataType.isInstanceOf[StringType]) { - UTF8String.fromString(v.toJson) + UTF8String.fromString(v.toJson(zoneId)) } else { invalidCast() } - case Type.BOOLEAN => v.getBoolean - case Type.LONG => v.getLong - case Type.STRING => UTF8String.fromString(v.getString) - case Type.DOUBLE => v.getDouble - case Type.DECIMAL => Decimal(v.getDecimal) + case Type.BOOLEAN => Literal(v.getBoolean, BooleanType) + case Type.LONG => Literal(v.getLong, LongType) + case Type.STRING => Literal(UTF8String.fromString(v.getString), StringType) + case Type.DOUBLE => Literal(v.getDouble, DoubleType) + case Type.DECIMAL => + val d = Decimal(v.getDecimal) + Literal(Decimal(v.getDecimal), DecimalType(d.precision, d.scale)) + case Type.DATE => Literal(v.getLong.toInt, DateType) + case Type.TIMESTAMP => Literal(v.getLong, TimestampType) + case Type.TIMESTAMP_NTZ => Literal(v.getLong, TimestampNTZType) + case Type.FLOAT => Literal(v.getFloat, FloatType) + case Type.BINARY => Literal(v.getBinary, BinaryType) // We have handled other cases and should never reach here. This case is only intended // to by pass the compiler exhaustiveness check. case _ => throw QueryExecutionErrors.unreachableError() @@ -289,15 +312,17 @@ case object VariantGet { // We mostly use the `Cast` expression to implement the cast. However, `Cast` silently // ignores the overflow in the long/decimal -> timestamp cast, and we want to enforce // strict overflow checks. - input match { - case l: Long if dataType == TimestampType => - try Math.multiplyExact(l, MICROS_PER_SECOND) + input.dataType match { + case LongType if dataType == TimestampType => + try Math.multiplyExact(input.value.asInstanceOf[Long], MICROS_PER_SECOND) catch { case _: ArithmeticException => invalidCast() } - case d: Decimal if dataType == TimestampType => + case _: DecimalType if dataType == TimestampType => try { - d.toJavaBigDecimal + input.value + .asInstanceOf[Decimal] + .toJavaBigDecimal .multiply(new java.math.BigDecimal(MICROS_PER_SECOND)) .toBigInteger .longValueExact() @@ -305,9 +330,8 @@ case object VariantGet { case _: ArithmeticException => invalidCast() } case _ => - val inputLiteral = Literal(input) - if (Cast.canAnsiCast(inputLiteral.dataType, dataType)) { - val result = Cast(inputLiteral, dataType, zoneId, EvalMode.TRY).eval() + if (Cast.canAnsiCast(input.dataType, dataType)) { + val result = Cast(input, dataType, zoneStr, EvalMode.TRY).eval() if (result == null) invalidCast() else result } else { invalidCast() @@ -318,7 +342,7 @@ case object VariantGet { val size = v.arraySize() val array = new Array[Any](size) for (i <- 0 until size) { - array(i) = cast(v.getElementAtIndex(i), elementType, failOnError, zoneId) + array(i) = cast(v.getElementAtIndex(i), elementType, failOnError, zoneStr, zoneId) } new GenericArrayData(array) } else { @@ -332,7 +356,7 @@ case object VariantGet { for (i <- 0 until size) { val field = v.getFieldAtIndex(i) keyArray(i) = UTF8String.fromString(field.key) - valueArray(i) = cast(field.value, valueType, failOnError, zoneId) + valueArray(i) = cast(field.value, valueType, failOnError, zoneStr, zoneId) } ArrayBasedMapData(keyArray, valueArray) } else { @@ -345,7 +369,8 @@ case object VariantGet { val field = v.getFieldAtIndex(i) st.getFieldIndex(field.key) match { case Some(idx) => - row.update(idx, cast(field.value, fields(idx).dataType, failOnError, zoneId)) + row.update(idx, + cast(field.value, fields(idx).dataType, failOnError, zoneStr, zoneId)) case _ => } } @@ -576,6 +601,11 @@ object SchemaOfVariant { case Type.DECIMAL => val d = v.getDecimal DecimalType(d.precision(), d.scale()) + case Type.DATE => DateType + case Type.TIMESTAMP => TimestampType + case Type.TIMESTAMP_NTZ => TimestampNTZType + case Type.FLOAT => FloatType + case Type.BINARY => BinaryType } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 1964b5f24b34..80f2b2a0070c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -317,7 +317,7 @@ class JacksonGenerator( } def write(v: VariantVal): Unit = { - gen.writeRawValue(v.toString) + gen.writeRawValue(v.toJson(options.zoneId)) } def writeLineEnding(): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index a5863e80a26c..24675518646d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -685,4 +686,115 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { Array(null, null, 1) ) } + + test("atomic types that are not produced by parse_json") { + // Dictionary size is `0` for value 0. An empty dictionary contains one offset `0` for the + // one-past-the-end position (i.e. the sum of all string lengths). + val emptyMetadata = Array[Byte](VERSION, 0, 0) + + def checkToJson(value: Array[Byte], expected: String): Unit = { + val input = Literal(new VariantVal(value, emptyMetadata)) + checkEvaluation(StructsToJson(Map.empty, input), expected) + } + + def checkCast(value: Array[Byte], dataType: DataType, expected: Any): Unit = { + val input = Literal(new VariantVal(value, emptyMetadata)) + checkEvaluation(Cast(input, dataType, evalMode = EvalMode.ANSI), expected) + } + + checkToJson(Array(primitiveHeader(DATE), 0, 0, 0, 0), "\"1970-01-01\"") + checkToJson(Array(primitiveHeader(DATE), -1, -1, -1, 127), "\"+5881580-07-11\"") + checkToJson(Array(primitiveHeader(DATE), 0, 0, 0, -128), "\"-5877641-06-23\"") + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + checkCast(Array(primitiveHeader(DATE), 0, 0, 0, 0), TimestampType, 0L) + checkCast(Array(primitiveHeader(DATE), 1, 0, 0, 0), TimestampType, MICROS_PER_DAY) + } + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") { + checkCast(Array(primitiveHeader(DATE), 0, 0, 0, 0), TimestampType, 8 * MICROS_PER_HOUR) + checkCast(Array(primitiveHeader(DATE), 1, 0, 0, 0), TimestampType, + MICROS_PER_DAY + 8 * MICROS_PER_HOUR) + } + + def littleEndianLong(value: Long): Array[Byte] = + BigInt(value).toByteArray.reverse.padTo(8, 0.toByte) + + val time1 = littleEndianLong(0) + // In America/Los_Angeles timezone, timestamp value `skippedTime` is 2011-03-13 03:00:00. + // The next second of 2011-03-13 01:59:59 jumps to 2011-03-13 03:00:00. + val skippedTime = 1300010400000000L + val time2 = littleEndianLong(skippedTime) + val time3 = littleEndianLong(skippedTime - 1) + val time4 = littleEndianLong(Long.MinValue) + val time5 = littleEndianLong(Long.MaxValue) + val time6 = littleEndianLong(-62198755200000000L) + val timestampHeader = Array(primitiveHeader(TIMESTAMP)) + val timestampNtzHeader = Array(primitiveHeader(TIMESTAMP_NTZ)) + + for (timeZone <- Seq("UTC", "America/Los_Angeles")) { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) { + checkToJson(timestampNtzHeader ++ time1, "\"1970-01-01 00:00:00\"") + checkToJson(timestampNtzHeader ++ time2, "\"2011-03-13 10:00:00\"") + checkToJson(timestampNtzHeader ++ time3, "\"2011-03-13 09:59:59.999999\"") + checkToJson(timestampNtzHeader ++ time4, "\"-290308-12-21 19:59:05.224192\"") + checkToJson(timestampNtzHeader ++ time5, "\"+294247-01-10 04:00:54.775807\"") + checkToJson(timestampNtzHeader ++ time6, "\"-0001-01-01 00:00:00\"") + + checkCast(timestampNtzHeader ++ time1, DateType, 0) + checkCast(timestampNtzHeader ++ time2, DateType, 15046) + checkCast(timestampNtzHeader ++ time3, DateType, 15046) + checkCast(timestampNtzHeader ++ time4, DateType, -106751992) + checkCast(timestampNtzHeader ++ time5, DateType, 106751991) + checkCast(timestampNtzHeader ++ time6, DateType, -719893) + } + } + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + checkToJson(timestampHeader ++ time1, "\"1970-01-01 00:00:00+00:00\"") + checkToJson(timestampHeader ++ time2, "\"2011-03-13 10:00:00+00:00\"") + checkToJson(timestampHeader ++ time3, "\"2011-03-13 09:59:59.999999+00:00\"") + checkToJson(timestampHeader ++ time4, "\"-290308-12-21 19:59:05.224192+00:00\"") + checkToJson(timestampHeader ++ time5, "\"+294247-01-10 04:00:54.775807+00:00\"") + checkToJson(timestampHeader ++ time6, "\"-0001-01-01 00:00:00+00:00\"") + + checkCast(timestampHeader ++ time1, DateType, 0) + checkCast(timestampHeader ++ time2, DateType, 15046) + checkCast(timestampHeader ++ time3, DateType, 15046) + checkCast(timestampHeader ++ time4, DateType, -106751992) + checkCast(timestampHeader ++ time5, DateType, 106751991) + checkCast(timestampHeader ++ time6, DateType, -719893) + } + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") { + checkToJson(timestampHeader ++ time1, "\"1969-12-31 16:00:00-08:00\"") + checkToJson(timestampHeader ++ time2, "\"2011-03-13 03:00:00-07:00\"") + checkToJson(timestampHeader ++ time3, "\"2011-03-13 01:59:59.999999-08:00\"") + checkToJson(timestampHeader ++ time4, "\"-290308-12-21 12:06:07.224192-07:52\"") + checkToJson(timestampHeader ++ time5, "\"+294247-01-09 20:00:54.775807-08:00\"") + checkToJson(timestampHeader ++ time6, "\"-0002-12-31 16:07:02-07:52\"") + + checkCast(timestampHeader ++ time1, DateType, -1) + checkCast(timestampHeader ++ time2, DateType, 15046) + checkCast(timestampHeader ++ time3, DateType, 15046) + checkCast(timestampHeader ++ time4, DateType, -106751992) + checkCast(timestampHeader ++ time5, DateType, 106751990) + checkCast(timestampHeader ++ time6, DateType, -719894) + } + + checkToJson(Array(primitiveHeader(FLOAT)) ++ + BigInt(java.lang.Float.floatToIntBits(1.23F)).toByteArray.reverse, "1.23") + checkToJson(Array(primitiveHeader(FLOAT)) ++ + BigInt(java.lang.Float.floatToIntBits(-0.0F)).toByteArray.reverse, "-0.0") + // Note: 1.23F.toDouble != 1.23. + checkCast(Array(primitiveHeader(FLOAT)) ++ + BigInt(java.lang.Float.floatToIntBits(1.23F)).toByteArray.reverse, DoubleType, 1.23F.toDouble) + + checkToJson(Array(primitiveHeader(BINARY), 0, 0, 0, 0), "\"\"") + checkToJson(Array(primitiveHeader(BINARY), 1, 0, 0, 0, 1), "\"AQ==\"") + checkToJson(Array(primitiveHeader(BINARY), 2, 0, 0, 0, 1, 2), "\"AQI=\"") + checkToJson(Array(primitiveHeader(BINARY), 3, 0, 0, 0, 1, 2, 3), "\"AQID\"") + checkCast(Array(primitiveHeader(BINARY), 3, 0, 0, 0, 1, 2, 3), StringType, + "\u0001\u0002\u0003") + checkCast(Array(primitiveHeader(BINARY), 5, 0, 0, 0, 72, 101, 108, 108, 111), StringType, + "Hello") + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org