This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f6d5ad3ec75b [SPARK-47366][SQL][PYTHON] Add VariantVal for PySpark
f6d5ad3ec75b is described below
commit f6d5ad3ec75be63472c6b21dda959972f5360ef2
Author: Gene Pang <[email protected]>
AuthorDate: Thu Apr 11 09:16:10 2024 +0900
[SPARK-47366][SQL][PYTHON] Add VariantVal for PySpark
### What changes were proposed in this pull request?
Add a `VariantVal` implementation for PySpark. It includes convenience
methods to convert the Variant to a string, or to a Python object.
### Why are the changes needed?
Allows users to work with Variant data more conveniently.
### Does this PR introduce _any_ user-facing change?
This is new PySpark functionality to allow users to work with Variant data.
### How was this patch tested?
Added unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #45826 from gene-db/variant-pyspark.
Lead-authored-by: Gene Pang <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../source/reference/pyspark.sql/core_classes.rst | 1 +
python/docs/source/reference/pyspark.sql/index.rst | 1 +
.../pyspark.sql/{index.rst => variant_val.rst} | 32 +-
python/pyspark/sql/__init__.py | 3 +-
python/pyspark/sql/connect/conversion.py | 40 +++
python/pyspark/sql/pandas/types.py | 22 ++
python/pyspark/sql/tests/test_types.py | 64 ++++
python/pyspark/sql/types.py | 59 +++-
python/pyspark/sql/variant_utils.py | 388 +++++++++++++++++++++
.../org/apache/spark/sql/util/ArrowUtils.scala | 10 +
.../spark/sql/execution/arrow/ArrowWriter.scala | 27 +-
11 files changed, 614 insertions(+), 33 deletions(-)
diff --git a/python/docs/source/reference/pyspark.sql/core_classes.rst
b/python/docs/source/reference/pyspark.sql/core_classes.rst
index 65096da21de5..d3dbbc129cb7 100644
--- a/python/docs/source/reference/pyspark.sql/core_classes.rst
+++ b/python/docs/source/reference/pyspark.sql/core_classes.rst
@@ -49,3 +49,4 @@ Core Classes
datasource.DataSourceRegistration
datasource.InputPartition
datasource.WriterCommitMessage
+ VariantVal
diff --git a/python/docs/source/reference/pyspark.sql/index.rst
b/python/docs/source/reference/pyspark.sql/index.rst
index 9322a91fba25..93901ab7ce12 100644
--- a/python/docs/source/reference/pyspark.sql/index.rst
+++ b/python/docs/source/reference/pyspark.sql/index.rst
@@ -41,5 +41,6 @@ This page gives an overview of all public Spark SQL API.
observation
udf
udtf
+ variant_val
protobuf
datasource
diff --git a/python/docs/source/reference/pyspark.sql/index.rst
b/python/docs/source/reference/pyspark.sql/variant_val.rst
similarity index 70%
copy from python/docs/source/reference/pyspark.sql/index.rst
copy to python/docs/source/reference/pyspark.sql/variant_val.rst
index 9322a91fba25..a7f592c18e3a 100644
--- a/python/docs/source/reference/pyspark.sql/index.rst
+++ b/python/docs/source/reference/pyspark.sql/variant_val.rst
@@ -16,30 +16,12 @@
under the License.
-=========
-Spark SQL
-=========
+==========
+VariantVal
+==========
+.. currentmodule:: pyspark.sql
-This page gives an overview of all public Spark SQL API.
+.. autosummary::
+ :toctree: api/
-.. toctree::
- :maxdepth: 2
-
- core_classes
- spark_session
- configuration
- io
- dataframe
- column
- data_types
- row
- functions
- window
- grouping
- catalog
- avro
- observation
- udf
- udtf
- protobuf
- datasource
+ VariantVal.toPython
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index dd82b037a6b9..bc046da81d27 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -39,7 +39,7 @@ Important classes of Spark SQL and DataFrames:
- :class:`pyspark.sql.Window`
For working with window functions.
"""
-from pyspark.sql.types import Row
+from pyspark.sql.types import Row, VariantVal
from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration,
UDTFRegistration
from pyspark.sql.session import SparkSession
from pyspark.sql.column import Column
@@ -67,6 +67,7 @@ __all__ = [
"Row",
"DataFrameNaFunctions",
"DataFrameStatFunctions",
+ "VariantVal",
"Window",
"WindowSpec",
"DataFrameReader",
diff --git a/python/pyspark/sql/connect/conversion.py
b/python/pyspark/sql/connect/conversion.py
index c86ee9c75fec..9b1007c41f9c 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -40,6 +40,8 @@ from pyspark.sql.types import (
DecimalType,
StringType,
UserDefinedType,
+ VariantType,
+ VariantVal,
)
from pyspark.storagelevel import StorageLevel
@@ -95,6 +97,8 @@ class LocalDataToArrowConversion:
return True
elif isinstance(dataType, UserDefinedType):
return True
+ elif isinstance(dataType, VariantType):
+ return True
else:
return False
@@ -290,6 +294,24 @@ class LocalDataToArrowConversion:
return convert_udt
+ elif isinstance(dataType, VariantType):
+
+ def convert_variant(value: Any) -> Any:
+ if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
+ return None
+ elif (
+ isinstance(value, dict)
+ and all(key in value for key in ["value", "metadata"])
+ and all(isinstance(value[key], bytes) for key in ["value",
"metadata"])
+ ):
+ return VariantVal(value["value"], value["metadata"])
+ else:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+ return convert_variant
+
elif not nullable:
def convert_other(value: Any) -> Any:
@@ -381,6 +403,8 @@ class ArrowTableToRowsConversion:
return True
elif isinstance(dataType, UserDefinedType):
return True
+ elif isinstance(dataType, VariantType):
+ return True
else:
return False
@@ -488,6 +512,22 @@ class ArrowTableToRowsConversion:
return convert_udt
+ elif isinstance(dataType, VariantType):
+
+ def convert_variant(value: Any) -> Any:
+ if value is None:
+ return None
+ elif (
+ isinstance(value, dict)
+ and all(key in value for key in ["value", "metadata"])
+ and all(isinstance(value[key], bytes) for key in ["value",
"metadata"])
+ ):
+ return VariantVal(value["value"], value["metadata"])
+ else:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+ return convert_variant
+
else:
return lambda value: value
diff --git a/python/pyspark/sql/pandas/types.py
b/python/pyspark/sql/pandas/types.py
index 3b48f8d8c319..559512bd00c1 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -47,6 +47,8 @@ from pyspark.sql.types import (
NullType,
DataType,
UserDefinedType,
+ VariantType,
+ VariantVal,
_create_row,
)
from pyspark.errors import PySparkTypeError, UnsupportedOperationException,
PySparkValueError
@@ -108,6 +110,12 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
arrow_type = pa.null()
elif isinstance(dt, UserDefinedType):
arrow_type = to_arrow_type(dt.sqlType())
+ elif type(dt) == VariantType:
+ fields = [
+ pa.field("value", pa.binary(), nullable=False),
+ pa.field("metadata", pa.binary(), nullable=False),
+ ]
+ arrow_type = pa.struct(fields)
else:
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
@@ -763,6 +771,20 @@ def _create_converter_to_pandas(
return convert_udt
+ elif isinstance(dt, VariantType):
+
+ def convert_variant(value: Any) -> Any:
+ if (
+ isinstance(value, dict)
+ and all(key in value for key in ["value", "metadata"])
+ and all(isinstance(value[key], bytes) for key in ["value",
"metadata"])
+ ):
+ return VariantVal(value["value"], value["metadata"])
+ else:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+ return convert_variant
+
else:
return None
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index bb854641906a..af13adbc21bb 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -58,6 +58,7 @@ from pyspark.sql.types import (
BooleanType,
NullType,
VariantType,
+ VariantVal,
)
from pyspark.sql.types import (
_array_signed_int_typecode_ctype_mappings,
@@ -1406,6 +1407,69 @@ class TypesTestsMixin:
schema1 = self.spark.range(1).select(F.make_interval(F.lit(1))).schema
self.assertEqual(schema1.fields[0].dataType, CalendarIntervalType())
+ def test_variant_type(self):
+ from decimal import Decimal
+
+ self.assertEqual(VariantType().simpleString(), "variant")
+
+ # Holds a tuple of (key, json string value, python value)
+ expected_values = [
+ ("str", '"%s"' % ("0123456789" * 10), "0123456789" * 10),
+ ("short_str", '"abc"', "abc"),
+ ("null", "null", None),
+ ("true", "true", True),
+ ("false", "false", False),
+ ("int1", "1", 1),
+ ("-int1", "-5", -5),
+ ("int2", "257", 257),
+ ("-int2", "-124", -124),
+ ("int4", "65793", 65793),
+ ("-int4", "-69633", -69633),
+ ("int8", "4295033089", 4295033089),
+ ("-int8", "-4294967297", -4294967297),
+ ("float4", "1.23456789e-30", 1.23456789e-30),
+ ("-float4", "-4.56789e+29", -4.56789e29),
+ ("dec4", "123.456", Decimal("123.456")),
+ ("-dec4", "-321.654", Decimal("-321.654")),
+ ("dec8", "429.4967297", Decimal("429.4967297")),
+ ("-dec8", "-5.678373902", Decimal("-5.678373902")),
+ ("dec16", "467440737095.51617", Decimal("467440737095.51617")),
+ ("-dec16", "-67.849438003827263", Decimal("-67.849438003827263")),
+ ("arr", '[1.1,"2",[3],{"4":5}]', [Decimal("1.1"), "2", [3], {"4":
5}]),
+ ("obj", '{"a":["123",{"b":2}],"c":3}', {"a": ["123", {"b": 2}],
"c": 3}),
+ ]
+ json_str = "{%s}" % ",".join(['"%s": %s' % (t[0], t[1]) for t in
expected_values])
+
+ df = self.spark.createDataFrame([({"json": json_str})])
+ row = df.select(
+ F.parse_json(df.json).alias("v"),
+ F.array([F.parse_json(F.lit('{"a": 1}'))]).alias("a"),
+ F.struct([F.parse_json(F.lit('{"b": "2"}'))]).alias("s"),
+ F.create_map([F.lit("k"), F.parse_json(F.lit('{"c":
true}'))]).alias("m"),
+ ).collect()[0]
+ variants = [row["v"], row["a"][0], row["s"]["col1"], row["m"]["k"]]
+ for v in variants:
+ self.assertEqual(type(v), VariantVal)
+
+ # check str
+ as_string = str(variants[0])
+ for key, expected, _ in expected_values:
+ self.assertTrue('"%s":%s' % (key, expected) in as_string)
+ self.assertEqual(str(variants[1]), '{"a":1}')
+ self.assertEqual(str(variants[2]), '{"b":"2"}')
+ self.assertEqual(str(variants[3]), '{"c":true}')
+
+ # check toPython
+ as_python = variants[0].toPython()
+ for key, _, obj in expected_values:
+ self.assertEqual(as_python[key], obj)
+ self.assertEqual(variants[1].toPython(), {"a": 1})
+ self.assertEqual(variants[2].toPython(), {"b": "2"})
+ self.assertEqual(variants[3].toPython(), {"c": True})
+
+ # check repr
+ self.assertEqual(str(variants[0]), str(eval(repr(variants[0]))))
+
def test_from_ddl(self):
self.assertEqual(DataType.fromDDL("long"), LongType())
self.assertEqual(
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 9b1bab4c23fa..3546fd822814 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -48,6 +48,7 @@ from typing import (
from pyspark.util import is_remote_only
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.utils import has_numpy, get_active_spark_context
+from pyspark.sql.variant_utils import VariantUtils
from pyspark.errors import (
PySparkNotImplementedError,
PySparkTypeError,
@@ -95,6 +96,7 @@ __all__ = [
"StructField",
"StructType",
"VariantType",
+ "VariantVal",
]
@@ -1341,7 +1343,13 @@ class VariantType(AtomicType):
.. versionadded:: 4.0.0
"""
- pass
+ def needConversion(self) -> bool:
+ return True
+
+ def fromInternal(self, obj: Dict) -> Optional["VariantVal"]:
+ if obj is None or not all(key in obj for key in ["value", "metadata"]):
+ return None
+ return VariantVal(obj["value"], obj["metadata"])
class UserDefinedType(DataType):
@@ -1465,6 +1473,55 @@ class UserDefinedType(DataType):
return type(self) == type(other)
+class VariantVal:
+ """
+ A class to represent a Variant value in Python.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ value : bytes
+ The bytes representing the value component of the Variant.
+ metadata : bytes
+ The bytes representing the metadata component of the Variant.
+
+ Methods
+ -------
+ toPython()
+ Convert the VariantVal to a Python data structure.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import *
+ >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
+ >>> v = df.select(parse_json(df.json).alias("var")).collect()[0].var
+ >>> v.toPython()
+ {'a': 1}
+ """
+
+ def __init__(self, value: bytes, metadata: bytes):
+ self.value = value
+ self.metadata = metadata
+
+ def __str__(self) -> str:
+ return VariantUtils.to_json(self.value, self.metadata)
+
+ def __repr__(self) -> str:
+ return "VariantVal(%r, %r)" % (self.value, self.metadata)
+
+ def toPython(self) -> Any:
+ """
+ Convert the VariantVal to a Python data structure.
+
+ Returns
+ -------
+ Any
+ A Python object that represents the Variant.
+ """
+ return VariantUtils.to_python(self.value, self.metadata)
+
+
_atomic_types: List[Type[DataType]] = [
StringType,
CharType,
diff --git a/python/pyspark/sql/variant_utils.py
b/python/pyspark/sql/variant_utils.py
new file mode 100644
index 000000000000..9ca70365316d
--- /dev/null
+++ b/python/pyspark/sql/variant_utils.py
@@ -0,0 +1,388 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import decimal
+import json
+import struct
+from array import array
+from typing import Any, Callable, Dict, List, Tuple
+from pyspark.errors import PySparkValueError
+
+
+class VariantUtils:
+ """
+ A utility class for VariantVal.
+
+ Adapted from library at: org.apache.spark.types.variant.VariantUtil
+ """
+
+ BASIC_TYPE_BITS = 2
+ BASIC_TYPE_MASK = 0x3
+ TYPE_INFO_MASK = 0x3F
+ # The inclusive maximum value of the type info value. It is the size limit
of `SHORT_STR`.
+ MAX_SHORT_STR_SIZE = 0x3F
+
+ # Below is all possible basic type values.
+ # Primitive value. The type info value must be one of the values in the
below section.
+ PRIMITIVE = 0
+ # Short string value. The type info value is the string size, which must
be in `[0,
+ # MAX_SHORT_STR_SIZE]`.
+ # The string content bytes directly follow the header byte.
+ SHORT_STR = 1
+ # Object value. The content contains a size, a list of field ids, a list
of field offsets, and
+ # the actual field data. The length of the id list is `size`, while the
length of the offset
+ # list is `size + 1`, where the last offset represent the total size of
the field data. The
+ # fields in an object must be sorted by the field name in alphabetical
order. Duplicate field
+ # names in one object are not allowed.
+ # We use 5 bits in the type info to specify the integer type of the object
header: it should
+ # be 0_b4_b3b2_b1b0 (MSB is 0), where:
+ # - b4 specifies the type of size. When it is 0/1, `size` is a
little-endian 1/4-byte
+ # unsigned integer.
+ # - b3b2/b1b0 specifies the integer type of id and offset. When the 2 bits
are 0/1/2, the
+ # list contains 1/2/3-byte little-endian unsigned integers.
+ OBJECT = 2
+ # Array value. The content contains a size, a list of field offsets, and
the actual element
+ # data. It is similar to an object without the id list. The length of the
offset list
+ # is `size + 1`, where the last offset represent the total size of the
element data.
+ # Its type info should be: 000_b2_b1b0:
+ # - b2 specifies the type of size.
+ # - b1b0 specifies the integer type of offset.
+ ARRAY = 3
+
+ # Below is all possible type info values for `PRIMITIVE`.
+ # JSON Null value. Empty content.
+ NULL = 0
+ # True value. Empty content.
+ TRUE = 1
+ # False value. Empty content.
+ FALSE = 2
+ # 1-byte little-endian signed integer.
+ INT1 = 3
+ # 2-byte little-endian signed integer.
+ INT2 = 4
+ # 4-byte little-endian signed integer.
+ INT4 = 5
+ # 4-byte little-endian signed integer.
+ INT8 = 6
+ # 8-byte IEEE double.
+ DOUBLE = 7
+ # 4-byte decimal. Content is 1-byte scale + 4-byte little-endian signed
integer.
+ DECIMAL4 = 8
+ # 8-byte decimal. Content is 1-byte scale + 8-byte little-endian signed
integer.
+ DECIMAL8 = 9
+ # 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed
integer.
+ DECIMAL16 = 10
+ # Long string value. The content is (4-byte little-endian unsigned integer
representing the
+ # string size) + (size bytes of string content).
+ LONG_STR = 16
+
+ U32_SIZE = 4
+
+ @classmethod
+ def to_json(cls, value: bytes, metadata: bytes) -> str:
+ """
+ Convert the VariantVal to a JSON string.
+ :return: JSON string
+ """
+ return cls._to_json(value, metadata, 0)
+
+ @classmethod
+ def to_python(cls, value: bytes, metadata: bytes) -> str:
+ """
+ Convert the VariantVal to a nested Python object of Python data types.
+ :return: Python representation of the Variant nested structure
+ """
+ return cls._to_python(value, metadata, 0)
+
+ @classmethod
+ def _read_long(cls, data: bytes, pos: int, num_bytes: int, signed: bool)
-> int:
+ cls._check_index(pos, len(data))
+ cls._check_index(pos + num_bytes - 1, len(data))
+ return int.from_bytes(data[pos : pos + num_bytes], byteorder="little",
signed=signed)
+
+ @classmethod
+ def _check_index(cls, pos: int, length: int) -> None:
+ if pos < 0 or pos >= length:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+ @classmethod
+ def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]:
+ """
+ Returns the (basic_type, type_info) pair from the given position in
the value.
+ """
+ basic_type = value[pos] & VariantUtils.BASIC_TYPE_MASK
+ type_info = (value[pos] >> VariantUtils.BASIC_TYPE_BITS) &
VariantUtils.TYPE_INFO_MASK
+ return (basic_type, type_info)
+
+ @classmethod
+ def _get_metadata_key(cls, metadata: bytes, id: int) -> str:
+ """
+ Returns the key string from the dictionary in the metadata,
corresponding to `id`.
+ """
+ cls._check_index(0, len(metadata))
+ offset_size = ((metadata[0] >> 6) & 0x3) + 1
+ dict_size = cls._read_long(metadata, 1, offset_size, signed=False)
+ if id >= dict_size:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+ string_start = 1 + (dict_size + 2) * offset_size
+ offset = cls._read_long(metadata, 1 + (id + 1) * offset_size,
offset_size, signed=False)
+ next_offset = cls._read_long(
+ metadata, 1 + (id + 2) * offset_size, offset_size, signed=False
+ )
+ if offset > next_offset:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+ cls._check_index(string_start + next_offset - 1, len(metadata))
+ return metadata[string_start + offset : (string_start +
next_offset)].decode("utf-8")
+
+ @classmethod
+ def _get_boolean(cls, value: bytes, pos: int) -> bool:
+ cls._check_index(pos, len(value))
+ basic_type, type_info = cls._get_type_info(value, pos)
+ if basic_type != VariantUtils.PRIMITIVE or (
+ type_info != VariantUtils.TRUE and type_info != VariantUtils.FALSE
+ ):
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+ return type_info == VariantUtils.TRUE
+
+ @classmethod
+ def _get_long(cls, value: bytes, pos: int) -> int:
+ cls._check_index(pos, len(value))
+ basic_type, type_info = cls._get_type_info(value, pos)
+ if basic_type != VariantUtils.PRIMITIVE:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+ if type_info == VariantUtils.INT1:
+ return cls._read_long(value, pos + 1, 1, signed=True)
+ elif type_info == VariantUtils.INT2:
+ return cls._read_long(value, pos + 1, 2, signed=True)
+ elif type_info == VariantUtils.INT4:
+ return cls._read_long(value, pos + 1, 4, signed=True)
+ elif type_info == VariantUtils.INT8:
+ return cls._read_long(value, pos + 1, 8, signed=True)
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+ @classmethod
+ def _get_string(cls, value: bytes, pos: int) -> str:
+ cls._check_index(pos, len(value))
+ basic_type, type_info = cls._get_type_info(value, pos)
+ if basic_type == VariantUtils.SHORT_STR or (
+ basic_type == VariantUtils.PRIMITIVE and type_info ==
VariantUtils.LONG_STR
+ ):
+ start = 0
+ length = 0
+ if basic_type == VariantUtils.SHORT_STR:
+ start = pos + 1
+ length = type_info
+ else:
+ start = pos + 1 + VariantUtils.U32_SIZE
+ length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE,
signed=False)
+ cls._check_index(start + length - 1, len(value))
+ return value[start : start + length].decode("utf-8")
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+ @classmethod
+ def _get_double(cls, value: bytes, pos: int) -> float:
+ cls._check_index(pos, len(value))
+ basic_type, type_info = cls._get_type_info(value, pos)
+ if basic_type != VariantUtils.PRIMITIVE or type_info !=
VariantUtils.DOUBLE:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+ return struct.unpack("d", value[pos + 1 : pos + 9])[0]
+
+ @classmethod
+ def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal:
+ cls._check_index(pos, len(value))
+ basic_type, type_info = cls._get_type_info(value, pos)
+ if basic_type != VariantUtils.PRIMITIVE:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+ scale = value[pos + 1]
+ unscaled = 0
+ if type_info == VariantUtils.DECIMAL4:
+ unscaled = cls._read_long(value, pos + 2, 4, signed=True)
+ elif type_info == VariantUtils.DECIMAL8:
+ unscaled = cls._read_long(value, pos + 2, 8, signed=True)
+ elif type_info == VariantUtils.DECIMAL16:
+ cls._check_index(pos + 17, len(value))
+ unscaled = int.from_bytes(value[pos + 2 : pos + 18],
byteorder="little", signed=True)
+ else:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+ return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale))
+
+ @classmethod
+ def _get_type(cls, value: bytes, pos: int) -> Any:
+ """
+ Returns the Python type of the Variant at the given position.
+ """
+ cls._check_index(pos, len(value))
+ basic_type, type_info = cls._get_type_info(value, pos)
+ if basic_type == VariantUtils.SHORT_STR:
+ return str
+ elif basic_type == VariantUtils.OBJECT:
+ return dict
+ elif basic_type == VariantUtils.ARRAY:
+ return array
+ elif type_info == VariantUtils.NULL:
+ return type(None)
+ elif type_info == VariantUtils.TRUE or type_info == VariantUtils.FALSE:
+ return bool
+ elif (
+ type_info == VariantUtils.INT1
+ or type_info == VariantUtils.INT2
+ or type_info == VariantUtils.INT4
+ or type_info == VariantUtils.INT8
+ ):
+ return int
+ elif type_info == VariantUtils.DOUBLE:
+ return float
+ elif (
+ type_info == VariantUtils.DECIMAL4
+ or type_info == VariantUtils.DECIMAL8
+ or type_info == VariantUtils.DECIMAL16
+ ):
+ return decimal.Decimal
+ elif type_info == VariantUtils.LONG_STR:
+ return str
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+ @classmethod
+ def _to_json(cls, value: bytes, metadata: bytes, pos: int) -> Any:
+ variant_type = cls._get_type(value, pos)
+ if variant_type == dict:
+
+ def handle_object(key_value_pos_list: list[Tuple[str, int]]) ->
str:
+ key_value_list = [
+ json.dumps(key) + ":" + cls._to_json(value, metadata,
value_pos)
+ for (key, value_pos) in key_value_pos_list
+ ]
+ return "{" + ",".join(key_value_list) + "}"
+
+ return cls._handle_object(value, metadata, pos, handle_object)
+ elif variant_type == array:
+
+ def handle_array(value_pos_list: list[int]) -> str:
+ value_list = [
+ cls._to_json(value, metadata, value_pos) for value_pos in
value_pos_list
+ ]
+ return "[" + ",".join(value_list) + "]"
+
+ return cls._handle_array(value, pos, handle_array)
+ else:
+ value = cls._get_scalar(variant_type, value, metadata, pos)
+ if value is None:
+ return "null"
+ if type(value) == bool:
+ return "true" if value else "false"
+ if type(value) == str:
+ return json.dumps(value)
+ return str(value)
+
+ @classmethod
+ def _to_python(cls, value: bytes, metadata: bytes, pos: int) -> Any:
+ variant_type = cls._get_type(value, pos)
+ if variant_type == dict:
+
+ def handle_object(key_value_pos_list: list[Tuple[str, int]]) ->
Dict[str, Any]:
+ key_value_list = [
+ (key, cls._to_python(value, metadata, value_pos))
+ for (key, value_pos) in key_value_pos_list
+ ]
+ return dict(key_value_list)
+
+ return cls._handle_object(value, metadata, pos, handle_object)
+ elif variant_type == array:
+
+ def handle_array(value_pos_list: list[int]) -> List[Any]:
+ value_list = [
+ cls._to_python(value, metadata, value_pos) for value_pos
in value_pos_list
+ ]
+ return value_list
+
+ return cls._handle_array(value, pos, handle_array)
+ else:
+ return cls._get_scalar(variant_type, value, metadata, pos)
+
+ @classmethod
+ def _get_scalar(cls, variant_type: Any, value: bytes, metadata: bytes,
pos: int) -> Any:
+ if isinstance(None, variant_type):
+ return None
+ elif variant_type == bool:
+ return cls._get_boolean(value, pos)
+ elif variant_type == int:
+ return cls._get_long(value, pos)
+ elif variant_type == str:
+ return cls._get_string(value, pos)
+ elif variant_type == float:
+ return cls._get_double(value, pos)
+ elif variant_type == decimal.Decimal:
+ return cls._get_decimal(value, pos)
+ else:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+ @classmethod
+ def _handle_object(
+ cls, value: bytes, metadata: bytes, pos: int, func:
Callable[[list[Tuple[str, int]]], Any]
+ ) -> Any:
+ """
+ Parses the variant object at position `pos`.
+ Calls `func` with a list of (key, value position) pairs of the object.
+ """
+ cls._check_index(pos, len(value))
+ basic_type, type_info = cls._get_type_info(value, pos)
+ if basic_type != VariantUtils.OBJECT:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+ large_size = ((type_info >> 4) & 0x1) != 0
+ size_bytes = VariantUtils.U32_SIZE if large_size else 1
+ num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
+ id_size = ((type_info >> 2) & 0x3) + 1
+ offset_size = ((type_info) & 0x3) + 1
+ id_start = pos + 1 + size_bytes
+ offset_start = id_start + num_fields * id_size
+ data_start = offset_start + (num_fields + 1) * offset_size
+
+ key_value_pos_list = []
+ for i in range(num_fields):
+ id = cls._read_long(value, id_start + id_size * i, id_size,
signed=False)
+ offset = cls._read_long(
+ value, offset_start + offset_size * i, offset_size,
signed=False
+ )
+ value_pos = data_start + offset
+ key_value_pos_list.append((cls._get_metadata_key(metadata, id),
value_pos))
+ return func(key_value_pos_list)
+
+ @classmethod
+ def _handle_array(cls, value: bytes, pos: int, func: Callable[[list[int]],
Any]) -> Any:
+ """
+ Parses the variant array at position `pos`.
+ Calls `func` with a list of element positions of the array.
+ """
+ cls._check_index(pos, len(value))
+ basic_type, type_info = cls._get_type_info(value, pos)
+ if basic_type != VariantUtils.ARRAY:
+ raise PySparkValueError(error_class="MALFORMED_VARIANT")
+ large_size = ((type_info >> 2) & 0x1) != 0
+ size_bytes = VariantUtils.U32_SIZE if large_size else 1
+ num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
+ offset_size = (type_info & 0x3) + 1
+ offset_start = pos + 1 + size_bytes
+ data_start = offset_start + (num_fields + 1) * offset_size
+
+ value_pos_list = []
+ for i in range(num_fields):
+ offset = cls._read_long(
+ value, offset_start + offset_size * i, offset_size,
signed=False
+ )
+ element_pos = data_start + offset
+ value_pos_list.append(element_pos)
+ return func(value_pos_list)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 92a4c687362d..d9bd3b0e612b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -125,6 +125,12 @@ private[sql] object ArrowUtils {
largeVarTypes)).asJava)
case udt: UserDefinedType[_] =>
toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
+ case _: VariantType =>
+ val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE,
null,
+ Map("variant" -> "true").asJava)
+ new Field(name, fieldType,
+ Seq(toArrowField("value", BinaryType, false, timeZoneId,
largeVarTypes),
+ toArrowField("metadata", BinaryType, false, timeZoneId,
largeVarTypes)).asJava)
case dataType =>
val fieldType = new FieldType(nullable, toArrowType(dataType,
timeZoneId,
largeVarTypes), null)
@@ -143,6 +149,10 @@ private[sql] object ArrowUtils {
val elementField = field.getChildren().get(0)
val elementType = fromArrowField(elementField)
ArrayType(elementType, containsNull = elementField.isNullable)
+ case ArrowType.Struct.INSTANCE if
field.getMetadata.getOrDefault("variant", "") == "true"
+ && field.getChildren.asScala.map(_.getName).asJava
+ .containsAll(Seq("value", "metadata").asJava) =>
+ VariantType
case ArrowType.Struct.INSTANCE =>
val fields = field.getChildren().asScala.map { child =>
val dt = fromArrowField(child)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 6680a5320fe3..ca7703bef48b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -84,6 +84,11 @@ object ArrowWriter {
case (_: DayTimeIntervalType, vector: DurationVector) => new
DurationWriter(vector)
case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) =>
new IntervalMonthDayNanoWriter(vector)
+ case (VariantType, vector: StructVector) =>
+ val children = (0 until vector.size()).map { ordinal =>
+ createFieldWriter(vector.getChildByOrdinal(ordinal))
+ }
+ new StructWriter(vector, children.toArray)
case (dt, _) =>
throw ExecutionErrors.unsupportedDataTypeError(dt)
}
@@ -368,6 +373,8 @@ private[arrow] class StructWriter(
val valueVector: StructVector,
children: Array[ArrowFieldWriter]) extends ArrowFieldWriter {
+ lazy val isVariant = valueVector.getField.getMetadata.get("variant") ==
"true"
+
override def setNull(): Unit = {
var i = 0
while (i < children.length) {
@@ -379,12 +386,20 @@ private[arrow] class StructWriter(
}
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
- val struct = input.getStruct(ordinal, children.length)
- var i = 0
- valueVector.setIndexDefined(count)
- while (i < struct.numFields) {
- children(i).write(struct, i)
- i += 1
+ if (isVariant) {
+ valueVector.setIndexDefined(count)
+ val v = input.getVariant(ordinal)
+ val row = InternalRow(v.getValue, v.getMetadata)
+ children(0).write(row, 0)
+ children(1).write(row, 1)
+ } else {
+ val struct = input.getStruct(ordinal, children.length)
+ var i = 0
+ valueVector.setIndexDefined(count)
+ while (i < struct.numFields) {
+ children(i).write(struct, i)
+ i += 1
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]