This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4a1f2419e243 [SPARK-47451][SQL] Support to_json(variant).
4a1f2419e243 is described below
commit 4a1f2419e243272064fde96529149316fe53bc10
Author: Chenhao Li <[email protected]>
AuthorDate: Mon Mar 25 14:58:01 2024 +0800
[SPARK-47451][SQL] Support to_json(variant).
### What changes were proposed in this pull request?
This PR adds the functionality to format a variant value as a JSON string.
It is exposed in the `to_json` expression by allowing the variant type (or a
nested type containing the variant type) as its input.
### How was this patch tested?
Unit tests that validate the `to_json` result. The input includes both
`parse_json` results and manually constructed bytes.
Negative cases with malformed inputs are also covered.
Some tests disabled in https://github.com/apache/spark/pull/45479 are
re-enabled.
Closes #45575 from chenhao-db/variant_to_json.
Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
common/unsafe/pom.xml | 6 +
.../org/apache/spark/unsafe/types/VariantVal.java | 4 +-
.../src/main/resources/error/error-classes.json | 12 +
.../org/apache/spark/types/variant/Variant.java | 91 +++++++
.../apache/spark/types/variant/VariantUtil.java | 298 +++++++++++++++++++++
docs/sql-error-conditions.md | 12 +
.../sql/catalyst/expressions/jsonExpressions.scala | 8 +-
.../spark/sql/catalyst/json/JacksonGenerator.scala | 16 +-
.../variant/VariantExpressionSuite.scala | 255 +++++++++++++++++-
.../scala/org/apache/spark/sql/VariantSuite.scala | 9 +-
.../sql/expressions/ExpressionInfoSuite.scala | 3 -
11 files changed, 696 insertions(+), 18 deletions(-)
diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml
index 13b45f55a4ad..a5ef9847859a 100644
--- a/common/unsafe/pom.xml
+++ b/common/unsafe/pom.xml
@@ -47,6 +47,12 @@
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-variant_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+
<dependency>
<groupId>org.scala-lang.modules</groupId>
<artifactId>scala-parallel-collections_${scala.binary.version}</artifactId>
diff --git
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
index e0f04d816d0d..652c05daf344 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
@@ -18,6 +18,7 @@
package org.apache.spark.unsafe.types;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.types.variant.Variant;
import java.io.Serializable;
import java.util.Arrays;
@@ -104,8 +105,7 @@ public class VariantVal implements Serializable {
*/
@Override
public String toString() {
- // NOTE: the encoding is not yet implemented, this is not the final
implementation.
- return new String(value);
+ return new Variant(value, metadata).toJson();
}
/**
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index 091f24d44f66..c219db8c6969 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -2876,6 +2876,12 @@
},
"sqlState" : "22023"
},
+ "MALFORMED_VARIANT" : {
+ "message" : [
+ "Variant binary is malformed. Please check the data source is valid."
+ ],
+ "sqlState" : "22023"
+ },
"MERGE_CARDINALITY_VIOLATION" : {
"message" : [
"The ON search condition of the MERGE statement matched a single row
from the target table with multiple rows of the source table.",
@@ -4555,6 +4561,12 @@
],
"sqlState" : "42883"
},
+ "VARIANT_CONSTRUCTOR_SIZE_LIMIT" : {
+ "message" : [
+ "Cannot construct a Variant larger than 16 MiB. The maximum allowed size
of a Variant value is 16 MiB."
+ ],
+ "sqlState" : "22023"
+ },
"VARIANT_SIZE_LIMIT" : {
"message" : [
"Cannot build variant bigger than <sizeLimit> in <functionName>.",
diff --git
a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
index e43b7ec8ac54..746b38c697d0 100644
--- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
+++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
@@ -17,6 +17,14 @@
package org.apache.spark.types.variant;
+import com.fasterxml.jackson.core.JsonFactory;
+import com.fasterxml.jackson.core.JsonGenerator;
+
+import java.io.CharArrayWriter;
+import java.io.IOException;
+
+import static org.apache.spark.types.variant.VariantUtil.*;
+
/**
* This class is structurally equivalent to {@link
org.apache.spark.unsafe.types.VariantVal}. We
* define a new class to avoid depending on or modifying Spark.
@@ -28,6 +36,15 @@ public final class Variant {
public Variant(byte[] value, byte[] metadata) {
this.value = value;
this.metadata = metadata;
+ // There is currently only one allowed version.
+ if (metadata.length < 1 || (metadata[0] & VERSION_MASK) != VERSION) {
+ throw malformedVariant();
+ }
+ // Don't attempt to use a Variant larger than 16 MiB. We'll never produce
one, and it risks
+ // memory instability.
+ if (metadata.length > SIZE_LIMIT || value.length > SIZE_LIMIT) {
+ throw variantConstructorSizeLimit();
+ }
}
public byte[] getValue() {
@@ -37,4 +54,78 @@ public final class Variant {
public byte[] getMetadata() {
return metadata;
}
+
+ // Stringify the variant in JSON format.
+ // Throw `MALFORMED_VARIANT` if the variant is malformed.
+ public String toJson() {
+ StringBuilder sb = new StringBuilder();
+ toJsonImpl(value, metadata, 0, sb);
+ return sb.toString();
+ }
+
+ // Escape a string so that it can be pasted into JSON structure.
+ // For example, if `str` only contains a new-line character, then the result
content is "\n"
+ // (4 characters).
+ static String escapeJson(String str) {
+ try (CharArrayWriter writer = new CharArrayWriter();
+ JsonGenerator gen = new JsonFactory().createGenerator(writer)) {
+ gen.writeString(str);
+ gen.flush();
+ return writer.toString();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder
sb) {
+ switch (VariantUtil.getType(value, pos)) {
+ case OBJECT:
+ handleObject(value, pos, (size, idSize, offsetSize, idStart,
offsetStart, dataStart) -> {
+ sb.append('{');
+ for (int i = 0; i < size; ++i) {
+ int id = readUnsigned(value, idStart + idSize * i, idSize);
+ int offset = readUnsigned(value, offsetStart + offsetSize * i,
offsetSize);
+ int elementPos = dataStart + offset;
+ if (i != 0) sb.append(',');
+ sb.append(escapeJson(getMetadataKey(metadata, id)));
+ sb.append(':');
+ toJsonImpl(value, metadata, elementPos, sb);
+ }
+ sb.append('}');
+ return null;
+ });
+ break;
+ case ARRAY:
+ handleArray(value, pos, (size, offsetSize, offsetStart, dataStart) -> {
+ sb.append('[');
+ for (int i = 0; i < size; ++i) {
+ int offset = readUnsigned(value, offsetStart + offsetSize * i,
offsetSize);
+ int elementPos = dataStart + offset;
+ if (i != 0) sb.append(',');
+ toJsonImpl(value, metadata, elementPos, sb);
+ }
+ sb.append(']');
+ return null;
+ });
+ break;
+ case NULL:
+ sb.append("null");
+ break;
+ case BOOLEAN:
+ sb.append(VariantUtil.getBoolean(value, pos));
+ break;
+ case LONG:
+ sb.append(VariantUtil.getLong(value, pos));
+ break;
+ case STRING:
+ sb.append(escapeJson(VariantUtil.getString(value, pos)));
+ break;
+ case DOUBLE:
+ sb.append(VariantUtil.getDouble(value, pos));
+ break;
+ case DECIMAL:
+ sb.append(VariantUtil.getDecimal(value, pos).toPlainString());
+ break;
+ }
+ }
}
diff --git
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
index d6e572f98901..b601b7c75eff 100644
---
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
+++
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
@@ -17,6 +17,13 @@
package org.apache.spark.types.variant;
+import org.apache.spark.QueryContext;
+import org.apache.spark.SparkRuntimeException;
+import scala.collection.immutable.Map$;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+
/**
* This class defines constants related to the variant format and provides
functions for
* manipulating variant binaries.
@@ -141,4 +148,295 @@ public class VariantUtil {
return (byte) (((largeSize ? 1 : 0) << (BASIC_TYPE_BITS + 2)) |
((offsetSize - 1) << BASIC_TYPE_BITS) | ARRAY);
}
+
+ // An exception indicating that the variant value or metadata doesn't
+ static SparkRuntimeException malformedVariant() {
+ return new SparkRuntimeException("MALFORMED_VARIANT",
+ Map$.MODULE$.<String, String>empty(), null, new QueryContext[]{}, "");
+ }
+
+ // An exception indicating that an external caller tried to call the Variant
constructor with
+ // value or metadata exceeding the 16MiB size limit. We will never construct
a Variant this large,
+ // so it should only be possible to encounter this exception when reading a
Variant produced by
+ // another tool.
+ static SparkRuntimeException variantConstructorSizeLimit() {
+ return new SparkRuntimeException("VARIANT_CONSTRUCTOR_SIZE_LIMIT",
+ Map$.MODULE$.<String, String>empty(), null, new QueryContext[]{}, "");
+ }
+
+ // Check the validity of an array index `pos`. Throw `MALFORMED_VARIANT` if
it is out of bound,
+ // meaning that the variant is malformed.
+ static void checkIndex(int pos, int length) {
+ if (pos < 0 || pos >= length) throw malformedVariant();
+ }
+
+ // Read a little-endian signed long value from `bytes[pos, pos + numBytes)`.
+ static long readLong(byte[] bytes, int pos, int numBytes) {
+ checkIndex(pos, bytes.length);
+ checkIndex(pos + numBytes - 1, bytes.length);
+ long result = 0;
+ // All bytes except the most significant byte should be unsign-extended
and shifted (so we need
+ // `& 0xFF`). The most significant byte should be sign-extended and is
handled after the loop.
+ for (int i = 0; i < numBytes - 1; ++i) {
+ long unsignedByteValue = bytes[pos + i] & 0xFF;
+ result |= unsignedByteValue << (8 * i);
+ }
+ long signedByteValue = bytes[pos + numBytes - 1];
+ result |= signedByteValue << (8 * (numBytes - 1));
+ return result;
+ }
+
+ // Read a little-endian unsigned int value from `bytes[pos, pos +
numBytes)`. The value must fit
+ // into a non-negative int (`[0, Integer.MAX_VALUE]`).
+ static int readUnsigned(byte[] bytes, int pos, int numBytes) {
+ checkIndex(pos, bytes.length);
+ checkIndex(pos + numBytes - 1, bytes.length);
+ int result = 0;
+ // Similar to the `readLong` loop, but all bytes should be unsign-extended.
+ for (int i = 0; i < numBytes; ++i) {
+ int unsignedByteValue = bytes[pos + i] & 0xFF;
+ result |= unsignedByteValue << (8 * i);
+ }
+ if (result < 0) throw malformedVariant();
+ return result;
+ }
+
+ // The value type of variant value. It is determined by the header byte but
not a 1:1 mapping
+ // (for example, INT1/2/4/8 all maps to `Type.LONG`).
+ public enum Type {
+ OBJECT,
+ ARRAY,
+ NULL,
+ BOOLEAN,
+ LONG,
+ STRING,
+ DOUBLE,
+ DECIMAL,
+ }
+
+ // Get the value type of variant value `value[pos...]`. It is only legal to
call `get*` if
+ // `getType` returns this type (for example, it is only legal to call
`getLong` if `getType`
+ // returns `Type.Long`).
+ // Throw `MALFORMED_VARIANT` if the variant is malformed.
+ public static Type getType(byte[] value, int pos) {
+ checkIndex(pos, value.length);
+ int basicType = value[pos] & BASIC_TYPE_MASK;
+ int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+ switch (basicType) {
+ case SHORT_STR:
+ return Type.STRING;
+ case OBJECT:
+ return Type.OBJECT;
+ case ARRAY:
+ return Type.ARRAY;
+ default:
+ switch (typeInfo) {
+ case NULL:
+ return Type.NULL;
+ case TRUE:
+ case FALSE:
+ return Type.BOOLEAN;
+ case INT1:
+ case INT2:
+ case INT4:
+ case INT8:
+ return Type.LONG;
+ case DOUBLE:
+ return Type.DOUBLE;
+ case DECIMAL4:
+ case DECIMAL8:
+ case DECIMAL16:
+ return Type.DECIMAL;
+ case LONG_STR:
+ return Type.STRING;
+ default:
+ throw malformedVariant();
+ }
+ }
+ }
+
+ static IllegalStateException unexpectedType(Type type) {
+ return new IllegalStateException("Expect type to be " + type);
+ }
+
+ // Get a boolean value from variant value `value[pos...]`.
+ // Throw `MALFORMED_VARIANT` if the variant is malformed.
+ public static boolean getBoolean(byte[] value, int pos) {
+ checkIndex(pos, value.length);
+ int basicType = value[pos] & BASIC_TYPE_MASK;
+ int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+ if (basicType != PRIMITIVE || (typeInfo != TRUE && typeInfo != FALSE)) {
+ throw unexpectedType(Type.BOOLEAN);
+ }
+ return typeInfo == TRUE;
+ }
+
+ // Get a long value from variant value `value[pos...]`.
+ // Throw `MALFORMED_VARIANT` if the variant is malformed.
+ public static long getLong(byte[] value, int pos) {
+ checkIndex(pos, value.length);
+ int basicType = value[pos] & BASIC_TYPE_MASK;
+ int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+ if (basicType != PRIMITIVE) throw unexpectedType(Type.LONG);
+ switch (typeInfo) {
+ case INT1:
+ return readLong(value, pos + 1, 1);
+ case INT2:
+ return readLong(value, pos + 1, 2);
+ case INT4:
+ return readLong(value, pos + 1, 4);
+ case INT8:
+ return readLong(value, pos + 1, 8);
+ default:
+ throw unexpectedType(Type.LONG);
+ }
+ }
+
+ // Get a double value from variant value `value[pos...]`.
+ // Throw `MALFORMED_VARIANT` if the variant is malformed.
+ public static double getDouble(byte[] value, int pos) {
+ checkIndex(pos, value.length);
+ int basicType = value[pos] & BASIC_TYPE_MASK;
+ int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+ if (basicType != PRIMITIVE || typeInfo != DOUBLE) throw
unexpectedType(Type.DOUBLE);
+ return Double.longBitsToDouble(readLong(value, pos + 1, 8));
+ }
+
+ // Get a decimal value from variant value `value[pos...]`.
+ // Throw `MALFORMED_VARIANT` if the variant is malformed.
+ public static BigDecimal getDecimal(byte[] value, int pos) {
+ checkIndex(pos, value.length);
+ int basicType = value[pos] & BASIC_TYPE_MASK;
+ int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+ if (basicType != PRIMITIVE) throw unexpectedType(Type.DECIMAL);
+ int scale = value[pos + 1];
+ BigDecimal result;
+ switch (typeInfo) {
+ case DECIMAL4:
+ result = BigDecimal.valueOf(readLong(value, pos + 2, 4), scale);
+ break;
+ case DECIMAL8:
+ result = BigDecimal.valueOf(readLong(value, pos + 2, 8), scale);
+ break;
+ case DECIMAL16:
+ checkIndex(pos + 17, value.length);
+ byte[] bytes = new byte[16];
+ // Copy the bytes reversely because the `BigInteger` constructor
expects a big-endian
+ // representation.
+ for (int i = 0; i < 16; ++i) {
+ bytes[i] = value[pos + 17 - i];
+ }
+ result = new BigDecimal(new BigInteger(bytes), scale);
+ break;
+ default:
+ throw unexpectedType(Type.DECIMAL);
+ }
+ return result.stripTrailingZeros();
+ }
+
+ // Get a string value from variant value `value[pos...]`.
+ // Throw `MALFORMED_VARIANT` if the variant is malformed.
+ public static String getString(byte[] value, int pos) {
+ checkIndex(pos, value.length);
+ int basicType = value[pos] & BASIC_TYPE_MASK;
+ int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+ if (basicType == SHORT_STR || (basicType == PRIMITIVE && typeInfo ==
LONG_STR)) {
+ int start;
+ int length;
+ if (basicType == SHORT_STR) {
+ start = pos + 1;
+ length = typeInfo;
+ } else {
+ start = pos + 1 + U32_SIZE;
+ length = readUnsigned(value, pos + 1, U32_SIZE);
+ }
+ checkIndex(start + length - 1, value.length);
+ return new String(value, start, length);
+ }
+ throw unexpectedType(Type.STRING);
+ }
+
+ public interface ObjectHandler<T> {
+ /**
+ * @param size Number of object fields.
+ * @param idSize The integer size of the field id list.
+ * @param offsetSize The integer size of the offset list.
+ * @param idStart The starting index of the field id list in the variant
value array.
+ * @param offsetStart The starting index of the offset list in the variant
value array.
+ * @param dataStart The starting index of field data in the variant value
array.
+ */
+ T apply(int size, int idSize, int offsetSize, int idStart, int
offsetStart, int dataStart);
+ }
+
+ // A helper function to access a variant object. It provides `handler` with
its required
+ // parameters and returns what it returns.
+ public static <T> T handleObject(byte[] value, int pos, ObjectHandler<T>
handler) {
+ checkIndex(pos, value.length);
+ int basicType = value[pos] & BASIC_TYPE_MASK;
+ int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+ if (basicType != OBJECT) throw unexpectedType(Type.OBJECT);
+ // Refer to the comment of the `OBJECT` constant for the details of the
object header encoding.
+ // Suppose `typeInfo` has a bit representation of 0_b4_b3b2_b1b0, the
following line extracts
+ // b4 to determine whether the object uses a 1/4-byte size.
+ boolean largeSize = ((typeInfo >> 4) & 0x1) != 0;
+ int sizeBytes = (largeSize ? U32_SIZE : 1);
+ int size = readUnsigned(value, pos + 1, sizeBytes);
+ // Extracts b3b2 to determine the integer size of the field id list.
+ int idSize = ((typeInfo >> 2) & 0x3) + 1;
+ // Extracts b1b0 to determine the integer size of the offset list.
+ int offsetSize = (typeInfo & 0x3) + 1;
+ int idStart = pos + 1 + sizeBytes;
+ int offsetStart = idStart + size * idSize;
+ int dataStart = offsetStart + (size + 1) * offsetSize;
+ return handler.apply(size, idSize, offsetSize, idStart, offsetStart,
dataStart);
+ }
+
+ public interface ArrayHandler<T> {
+ /**
+ * @param size Number of array elements.
+ * @param offsetSize The integer size of the offset list.
+ * @param offsetStart The starting index of the offset list in the variant
value array.
+ * @param dataStart The starting index of element data in the variant
value array.
+ */
+ T apply(int size, int offsetSize, int offsetStart, int dataStart);
+ }
+
+ // A helper function to access a variant array.
+ public static <T> T handleArray(byte[] value, int pos, ArrayHandler<T>
handler) {
+ checkIndex(pos, value.length);
+ int basicType = value[pos] & BASIC_TYPE_MASK;
+ int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+ if (basicType != ARRAY) throw unexpectedType(Type.ARRAY);
+ // Refer to the comment of the `ARRAY` constant for the details of the
object header encoding.
+ // Suppose `typeInfo` has a bit representation of 000_b2_b1b0, the
following line extracts
+ // b2 to determine whether the object uses a 1/4-byte size.
+ boolean largeSize = ((typeInfo >> 2) & 0x1) != 0;
+ int sizeBytes = (largeSize ? U32_SIZE : 1);
+ int size = readUnsigned(value, pos + 1, sizeBytes);
+ // Extracts b1b0 to determine the integer size of the offset list.
+ int offsetSize = (typeInfo & 0x3) + 1;
+ int offsetStart = pos + 1 + sizeBytes;
+ int dataStart = offsetStart + (size + 1) * offsetSize;
+ return handler.apply(size, offsetSize, offsetStart, dataStart);
+ }
+
+ // Get a key at `id` in the variant metadata.
+ // Throw `MALFORMED_VARIANT` if the variant is malformed. An out-of-bound
`id` is also considered
+ // a malformed variant because it is read from the corresponding variant
value.
+ public static String getMetadataKey(byte[] metadata, int id) {
+ checkIndex(0, metadata.length);
+ // Extracts the highest 2 bits in the metadata header to determine the
integer size of the
+ // offset list.
+ int offsetSize = ((metadata[0] >> 6) & 0x3) + 1;
+ int dictSize = readUnsigned(metadata, 1, offsetSize);
+ if (id >= dictSize) throw malformedVariant();
+ // There are a header byte, a `dictSize` with `offsetSize` bytes, and
`(dictSize + 1)` offsets
+ // before the string data.
+ int stringStart = 1 + (dictSize + 2) * offsetSize;
+ int offset = readUnsigned(metadata, 1 + (id + 1) * offsetSize, offsetSize);
+ int nextOffset = readUnsigned(metadata, 1 + (id + 2) * offsetSize,
offsetSize);
+ if (offset > nextOffset) throw malformedVariant();
+ checkIndex(stringStart + nextOffset - 1, metadata.length);
+ return new String(metadata, stringStart + offset, nextOffset - offset);
+ }
}
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index bab64caa3888..8b666c1ef9c8 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -1589,6 +1589,12 @@ Parse Mode: `<failFastMode>`. To process malformed
records as null result, try s
For more details see
[MALFORMED_RECORD_IN_PARSING](sql-error-conditions-malformed-record-in-parsing-error-class.html)
+### MALFORMED_VARIANT
+
+[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception)
+
+Variant binary is malformed. Please check the data source is valid.
+
### MERGE_CARDINALITY_VIOLATION
[SQLSTATE:
23K01](sql-error-conditions-sqlstates.html#class-23-integrity-constraint-violation)
@@ -2732,6 +2738,12 @@ The variable `<variableName>` cannot be found. Verify
the spelling and correctne
If you did not qualify the name with a schema and catalog, verify the
current_schema() output, or qualify the name with the correct schema and
catalog.
To tolerate the error on drop use DROP VARIABLE IF EXISTS.
+### VARIANT_CONSTRUCTOR_SIZE_LIMIT
+
+[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception)
+
+Cannot construct a Variant larger than 16 MiB. The maximum allowed size of a
Variant value is 16 MiB.
+
### VARIANT_SIZE_LIMIT
[SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index b155987242b3..f35c6da4f8af 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
import org.apache.spark.util.Utils
private[this] sealed trait PathInstruction
@@ -813,13 +813,17 @@ case class StructsToJson(
(map: Any) =>
gen.write(map.asInstanceOf[MapData])
getAndReset()
+ case _: VariantType =>
+ (v: Any) =>
+ gen.write(v.asInstanceOf[VariantVal])
+ getAndReset()
}
}
override def dataType: DataType = StringType
override def checkInputDataTypes(): TypeCheckResult = inputSchema match {
- case dt @ (_: StructType | _: MapType | _: ArrayType) =>
+ case dt @ (_: StructType | _: MapType | _: ArrayType | _: VariantType) =>
JacksonUtils.verifyType(prettyName, dt)
case _ =>
DataTypeMismatch(
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index c2c6117e1e3a..1964b5f24b34 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.VariantVal
import org.apache.spark.util.ArrayImplicits._
/**
@@ -46,11 +47,13 @@ class JacksonGenerator(
// we can directly access data in `ArrayData` without the help of
`SpecificMutableRow`.
private type ValueWriter = (SpecializedGetters, Int) => Unit
- // `JackGenerator` can only be initialized with a `StructType`, a `MapType`
or a `ArrayType`.
+ // `JackGenerator` can only be initialized with a `StructType`, a `MapType`,
a `ArrayType` or a
+ // `VariantType`.
require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType]
- || dataType.isInstanceOf[ArrayType],
+ || dataType.isInstanceOf[ArrayType] || dataType.isInstanceOf[VariantType],
s"JacksonGenerator only supports to be initialized with a
${StructType.simpleString}, " +
- s"${MapType.simpleString} or ${ArrayType.simpleString} but got
${dataType.catalogString}")
+ s"${MapType.simpleString}, ${ArrayType.simpleString} or
${VariantType.simpleString} but " +
+ s"got ${dataType.catalogString}")
// `ValueWriter`s for all fields of the schema
private lazy val rootFieldWriters: Array[ValueWriter] = dataType match {
@@ -202,6 +205,9 @@ class JacksonGenerator(
(row: SpecializedGetters, ordinal: Int) =>
writeObject(writeMapData(row.getMap(ordinal), mt, valueWriter))
+ case VariantType =>
+ (row: SpecializedGetters, ordinal: Int) => write(row.getVariant(ordinal))
+
// For UDT values, they should be in the SQL type's corresponding value
type.
// We should not see values in the user-defined class at here.
// For example, VectorUDT's SQL type is an array of double. So, we should
expect that v is
@@ -310,6 +316,10 @@ class JacksonGenerator(
mapType = dataType.asInstanceOf[MapType]))
}
+ def write(v: VariantVal): Unit = {
+ gen.writeRawValue(v.toString)
+ }
+
def writeLineEnding(): Unit = {
// Note that JSON uses writer with UTF-8 charset. This string will be
written out as UTF-8.
gen.writeRaw(lineSeparator)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index 22155c927e37..2793b1c8c1fb 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -18,11 +18,22 @@
package org.apache.spark.sql.catalyst.expressions.variant
import org.apache.spark.{SparkException, SparkFunSuite, SparkRuntimeException}
-import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper,
Literal}
+import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.types.variant.VariantUtil._
import org.apache.spark.unsafe.types.VariantVal
class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+ // Zero-extend each byte in the array with the appropriate number of bytes.
+ // Used to manually construct variant binary values with a given offset size.
+ // E.g. padded(Array(1,2,3), 3) will produce Array(1,0,0,2,0,0,3,0,0).
+ private def padded(a: Array[Byte], size: Int): Array[Byte] = {
+ a.flatMap { b =>
+ val padding = List.fill(size - 1)(0.toByte)
+ b :: padding
+ }
+ }
+
test("parse_json") {
def check(json: String, expectedValue: Array[Byte], expectedMetadata:
Array[Byte]): Unit = {
checkEvaluation(ParseJson(Literal(json)), new VariantVal(expectedValue,
expectedMetadata))
@@ -111,4 +122,246 @@ class VariantExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
"Cannot build variant bigger than 16.0 MiB")
}
}
+
+ test("round-trip") {
+ def check(input: String, output: String = null): Unit = {
+ checkEvaluation(
+ StructsToJson(Map.empty, ParseJson(Literal(input))),
+ if (output != null) output else input
+ )
+ }
+
+ check("null")
+ check("true")
+ check("false")
+ check("-1")
+ check("1.0E10")
+ check("\"\"")
+ check("\"" + ("a" * 63) + "\"")
+ check("\"" + ("b" * 64) + "\"")
+ // scalastyle:off nonascii
+ check("\"" + ("你好,世界" * 20) + "\"")
+ // scalastyle:on nonascii
+ check("[]")
+ check("{}")
+ // scalastyle:off nonascii
+ check(
+ "[null, true, false,-1, 1e10, \"\\uD83D\\uDE05\", [ ], { } ]",
+ "[null,true,false,-1,1.0E10,\"😅\",[],{}]"
+ )
+ // scalastyle:on nonascii
+ check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
+ }
+
+ test("to_json with nested variant") {
+ checkEvaluation(
+ StructsToJson(Map.empty, CreateArray(Seq(ParseJson(Literal("{}")),
+ ParseJson(Literal("\"\"")),
+ ParseJson(Literal("[1, 2, 3]"))))),
+ "[{},\"\",[1,2,3]]"
+ )
+ checkEvaluation(
+ StructsToJson(Map.empty, CreateNamedStruct(Seq(
+ Literal("a"), ParseJson(Literal("""{ "x": 1, "y": null, "z": "str"
}""")),
+ Literal("b"), ParseJson(Literal("[[]]")),
+ Literal("c"), ParseJson(Literal("false"))))),
+ """{"a":{"x":1,"y":null,"z":"str"},"b":[[]],"c":false}"""
+ )
+ }
+
+ test("to_json malformed") {
+ def check(value: Array[Byte], metadata: Array[Byte],
+ errorClass: String = "MALFORMED_VARIANT"): Unit = {
+ checkErrorInExpression[SparkRuntimeException](
+ ResolveTimeZone.resolveTimeZones(
+ StructsToJson(Map.empty, Literal(new VariantVal(value, metadata)))),
+ errorClass
+ )
+ }
+
+ val emptyMetadata = Array[Byte](VERSION, 0, 0)
+ // INT8 only has 7 byte content.
+ check(Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0), emptyMetadata)
+ // DECIMAL16 only has 15 byte content.
+ check(Array(primitiveHeader(DECIMAL16)) ++ Array.fill(16)(0.toByte),
emptyMetadata)
+ // Short string content too short.
+ check(Array(shortStrHeader(2), 'x'), emptyMetadata)
+ // Long string length too short (requires 4 bytes).
+ check(Array(primitiveHeader(LONG_STR), 0, 0, 0), emptyMetadata)
+ // Long string content too short.
+ check(Array(primitiveHeader(LONG_STR), 1, 0, 0, 0), emptyMetadata)
+ // Size is 1 but no content.
+ check(Array(arrayHeader(false, 1),
+ /* size */ 1,
+ /* offset list */ 0), emptyMetadata)
+ // Requires 4-byte size is but the actual size only has one byte.
+ check(Array(arrayHeader(true, 1),
+ /* size */ 0,
+ /* offset list */ 0), emptyMetadata)
+ // Offset out of bound.
+ check(Array(arrayHeader(false, 1),
+ /* size */ 1,
+ /* offset list */ 1, 1), emptyMetadata)
+ // Id out of bound.
+ check(Array(objectHeader(false, 1, 1),
+ /* size */ 1,
+ /* id list */ 0,
+ /* offset list */ 0, 2,
+ /* field data */ primitiveHeader(INT1), 1), emptyMetadata)
+ // Variant version is not 1.
+ check(Array(primitiveHeader(INT1), 0), Array[Byte](3, 0, 0))
+ check(Array(primitiveHeader(INT1), 0), Array[Byte](2, 0, 0))
+
+ // Construct binary values that are over 1 << 24 bytes, but otherwise
valid.
+ val bigVersion = Array[Byte]((VERSION | (3 << 6)).toByte)
+ val a = Array.fill(1 << 24)('a'.toByte)
+ val hugeMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 0, 1) ++
+ a ++ Array[Byte]('b')
+ check(Array(primitiveHeader(TRUE)), hugeMetadata,
"VARIANT_CONSTRUCTOR_SIZE_LIMIT")
+
+ // The keys are 'aaa....' and 'b'. Values are "yyy..." and 'true'.
+ val y = Array.fill(1 << 24)('y'.toByte)
+ val hugeObject = Array[Byte](objectHeader(true, 4, 4)) ++
+ /* size */ padded(Array(2), 4) ++
+ /* id list */ padded(Array(0, 1), 4) ++
+ // Second value starts at offset 5 + (1 << 24), which is `5001`
little-endian. The last value
+ // is 1 byte, so the one-past-the-end value is `6001`
+ /* offset list */ Array[Byte](0, 0, 0, 0, 5, 0, 0, 1, 6, 0, 0, 1) ++
+ /* field data */ Array[Byte](primitiveHeader(LONG_STR), 0, 0, 0, 1) ++ y
++ Array[Byte](
+ primitiveHeader(TRUE)
+ )
+
+ val smallMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 2, 0, 0, 0) ++
+ Array[Byte]('a', 'b')
+ check(hugeObject, smallMetadata, "VARIANT_CONSTRUCTOR_SIZE_LIMIT")
+ check(hugeObject, hugeMetadata, "VARIANT_CONSTRUCTOR_SIZE_LIMIT")
+ }
+
+ // Test valid forms of Variant that our writer would never produce.
+ test("to_json valid input") {
+ def check(expectedJson: String, value: Array[Byte], metadata:
Array[Byte]): Unit = {
+ checkEvaluation(
+ StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))),
+ expectedJson
+ )
+ }
+ // Some valid metadata formats. Check that they aren't rejected.
+ // Sorted string bit is set, and can be ignored.
+ val emptyMetadata2 = Array[Byte](VERSION | 1 << 4, 0, 0)
+ // Bit 5 is not defined in the spec, and can be ignored.
+ val emptyMetadata3 = Array[Byte](VERSION | 1 << 5, 0, 0)
+ // Can specify 3 bytes per size/offset, even if they aren't needed.
+ val header = (VERSION | (2 << 6)).toByte
+ val emptyMetadata4 = Array[Byte](header, 0, 0, 0, 0, 0, 0)
+ check("true", Array(primitiveHeader(TRUE)), emptyMetadata2)
+ check("true", Array(primitiveHeader(TRUE)), emptyMetadata3)
+ check("true", Array(primitiveHeader(TRUE)), emptyMetadata4)
+ }
+
+ // Test StructsToJson with manually constructed input that uses up to 4
bytes for offsets and
+ // sizes. We never produce 4-byte offsets, since they're only needed for
>16 MiB values, which we
+ // error out on, but the reader should be able to handle them if some other
writer decides to use
+ // them for smaller values.
+ test("to_json with large offsets and sizes") {
+ def check(expectedJson: String, value: Array[Byte], metadata:
Array[Byte]): Unit = {
+ checkEvaluation(
+ StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))),
+ expectedJson
+ )
+ }
+
+ for {
+ offsetSize <- 1 to 4
+ idSize <- 1 to 4
+ metadataSize <- 1 to 4
+ largeSize <- Seq(false, true)
+ } {
+ // Test array
+ val version = Array[Byte]((VERSION | ((metadataSize - 1) << 6)).toByte)
+ val emptyMetadata = version ++ padded(Array(0, 0), metadataSize)
+ // Construct a binary with the given sizes. Regardless, to_json should
produce the same
+ // result.
+ val arrayValue = Array[Byte](arrayHeader(largeSize, offsetSize)) ++
+ /* size */ padded(Array(3), if (largeSize) 4 else 1) ++
+ /* offset list */ padded(Array(0, 1, 4, 5), offsetSize) ++
+ Array[Byte](/* values */ primitiveHeader(FALSE),
+ primitiveHeader(INT2), 2, 1, primitiveHeader(NULL))
+ check("[false,258,null]", arrayValue, emptyMetadata)
+
+ // Test object
+ val metadata = version ++
+ padded(Array(3, 0, 1, 2, 3), metadataSize) ++
+ Array[Byte]('a', 'b', 'c')
+ val objectValue = Array[Byte](objectHeader(largeSize, idSize,
offsetSize)) ++
+ /* size */ padded(Array(3), if (largeSize) 4 else 1) ++
+ /* id list */ padded(Array(0, 1, 2), idSize) ++
+ /* offset list */ padded(Array(0, 2, 4, 6), offsetSize) ++
+ /* field data */ Array[Byte](primitiveHeader(INT1), 1,
+ primitiveHeader(INT1), 2, shortStrHeader(1), '3')
+
+ check("""{"a":1,"b":2,"c":"3"}""", objectValue, metadata)
+ }
+ }
+
+ test("to_json large binary") {
+ def check(expectedJson: String, value: Array[Byte], metadata:
Array[Byte]): Unit = {
+ checkEvaluation(
+ StructsToJson(Map.empty, Literal(new VariantVal(value, metadata))),
+ expectedJson
+ )
+ }
+
+ // Create a binary that uses the max 1 << 24 bytes for both metadata and
value.
+ val bigVersion = Array[Byte]((VERSION | (2 << 6)).toByte)
+ // Create a single huge value, followed by a one-byte string. We'll have 1
header byte, plus 12
+ // bytes for size and offsets, plus 1 byte for the final value, so the
large value is 1 << 24 -
+ // 14 bytes, or (-14, -1, -1) as a signed little-endian value.
+ val aSize = (1 << 24) - 14
+ val a = Array.fill(aSize)('a'.toByte)
+ val hugeMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, -14, -1,
-1, -13, -1, -1) ++
+ a ++ Array[Byte]('b')
+ // Validate metadata in isolation.
+ check("true", Array(primitiveHeader(TRUE)), hugeMetadata)
+
+ // The object will contain a large string, and the following bytes:
+ // - object header and size: 1+4 bytes
+ // - ID list: 6 bytes
+ // - offset list: 9 bytes
+ // - field headers and string length: 6 bytes
+ // In order to get the full binary to 1 << 24, the large string is (1 <<
24) - 26 bytes. As a
+ // signed little-endian value, this is (-26, -1, -1).
+ val ySize = (1 << 24) - 26
+ val y = Array.fill(ySize)('y'.toByte)
+ val hugeObject = Array[Byte](objectHeader(true, 3, 3)) ++
+ /* size */ padded(Array(2), 4) ++
+ /* id list */ padded(Array(0, 1), 3) ++
+ // Second offset is (-26,-1,-1), plus 5 bytes for string header, so
(-21,-1,-1)
+ /* offset list */ Array[Byte](0, 0, 0, -21, -1, -1, -20, -1, -1) ++
+ /* field data */ Array[Byte](primitiveHeader(LONG_STR), -26, -1, -1, 0)
++ y ++ Array[Byte](
+ primitiveHeader(TRUE)
+ )
+ // Same as hugeObject, but with a short string.
+ val smallObject = Array[Byte](objectHeader(false, 1, 1)) ++
+ /* size */ Array[Byte](2) ++
+ /* id list */ Array[Byte](0, 1) ++
+ /* offset list */ Array[Byte](0, 6, 7) ++
+ /* field data */ Array[Byte](primitiveHeader(LONG_STR), 1, 0, 0, 0, 'y',
+ primitiveHeader(TRUE))
+ val smallMetadata = bigVersion ++ Array[Byte](2, 0, 0, 0, 0, 0, 1, 0, 0,
2, 0, 0) ++
+ Array[Byte]('a', 'b')
+
+ // Check all combinations of large/small value and metadata.
+ val expectedResult1 =
+
s"""{"${a.map(_.toChar).mkString}":"${y.map(_.toChar).mkString}","b":true}"""
+ check(expectedResult1, hugeObject, hugeMetadata)
+ val expectedResult2 =
+ s"""{"${a.map(_.toChar).mkString}":"y","b":true}"""
+ check(expectedResult2, smallObject, hugeMetadata)
+ val expectedResult3 =
+ s"""{"a":"${y.map(_.toChar).mkString}","b":true}"""
+ check(expectedResult3, hugeObject, smallMetadata)
+ val expectedResult4 =
+ s"""{"a":"y","b":true}"""
+ check(expectedResult4, smallObject, smallMetadata)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index 1a2c424938a1..3991b44d0bbb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -30,10 +30,7 @@ import org.apache.spark.unsafe.types.VariantVal
import org.apache.spark.util.ArrayImplicits._
class VariantSuite extends QueryTest with SharedSparkSession {
- // TODO(SPARK-45891): We need to ignore some tests for now because the
`toString` implementation
- // doesn't match the `parse_json` implementation yet. We will shortly add a
new `toString`
- // implementation and re-enable the tests.
- ignore("basic tests") {
+ test("basic tests") {
def verifyResult(df: DataFrame): Unit = {
val result = df.collect()
.map(_.get(0).asInstanceOf[VariantVal].toString)
@@ -43,8 +40,6 @@ class VariantSuite extends QueryTest with SharedSparkSession {
assert(result == expected)
}
- // At this point, JSON parsing logic is not really implemented. We just
construct some number
- // inputs that are also valid JSON. This exercises passing VariantVal
throughout the system.
val query = spark.sql("select parse_json(repeat('1', id)) as v from
range(1, 10)")
verifyResult(query)
@@ -142,7 +137,7 @@ class VariantSuite extends QueryTest with
SharedSparkSession {
}
}
- ignore("write partitioned file") {
+ test("write partitioned file") {
def verifyResult(df: DataFrame): Unit = {
val result = df.selectExpr("v").collect()
.map(_.get(0).asInstanceOf[VariantVal].toString)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index 4c77f26949b0..19251330cffe 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -196,9 +196,6 @@ class ExpressionInfoSuite extends SparkFunSuite with
SharedSparkSession {
}
val exampleRe = """^(.+);\n(?s)(.+)$""".r
val ignoreSet = Set(
- // TODO(SPARK-45891): need to temporarily ignore it because the
`toString` implementation
- // doesn't match the `parse_json` implementation yet.
- "org.apache.spark.sql.catalyst.expressions.variant.ParseJson",
// One of examples shows getting the current timestamp
"org.apache.spark.sql.catalyst.expressions.UnixTimestamp",
"org.apache.spark.sql.catalyst.expressions.CurrentDate",
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]