This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f6d5ad3ec75b [SPARK-47366][SQL][PYTHON] Add VariantVal for PySpark
f6d5ad3ec75b is described below

commit f6d5ad3ec75be63472c6b21dda959972f5360ef2
Author: Gene Pang <gene.p...@databricks.com>
AuthorDate: Thu Apr 11 09:16:10 2024 +0900

    [SPARK-47366][SQL][PYTHON] Add VariantVal for PySpark
    
    ### What changes were proposed in this pull request?
    
    Add a `VariantVal` implementation for PySpark. It includes convenience 
methods to convert the Variant to a string, or to a Python object.
    
    ### Why are the changes needed?
    
    Allows users to work with Variant data more conveniently.
    
    ### Does this PR introduce _any_ user-facing change?
    
    This is new PySpark functionality to allow users to work with Variant data.
    
    ### How was this patch tested?
    
    Added unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45826 from gene-db/variant-pyspark.
    
    Lead-authored-by: Gene Pang <gene.p...@databricks.com>
    Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../source/reference/pyspark.sql/core_classes.rst  |   1 +
 python/docs/source/reference/pyspark.sql/index.rst |   1 +
 .../pyspark.sql/{index.rst => variant_val.rst}     |  32 +-
 python/pyspark/sql/__init__.py                     |   3 +-
 python/pyspark/sql/connect/conversion.py           |  40 +++
 python/pyspark/sql/pandas/types.py                 |  22 ++
 python/pyspark/sql/tests/test_types.py             |  64 ++++
 python/pyspark/sql/types.py                        |  59 +++-
 python/pyspark/sql/variant_utils.py                | 388 +++++++++++++++++++++
 .../org/apache/spark/sql/util/ArrowUtils.scala     |  10 +
 .../spark/sql/execution/arrow/ArrowWriter.scala    |  27 +-
 11 files changed, 614 insertions(+), 33 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql/core_classes.rst 
b/python/docs/source/reference/pyspark.sql/core_classes.rst
index 65096da21de5..d3dbbc129cb7 100644
--- a/python/docs/source/reference/pyspark.sql/core_classes.rst
+++ b/python/docs/source/reference/pyspark.sql/core_classes.rst
@@ -49,3 +49,4 @@ Core Classes
     datasource.DataSourceRegistration
     datasource.InputPartition
     datasource.WriterCommitMessage
+    VariantVal
diff --git a/python/docs/source/reference/pyspark.sql/index.rst 
b/python/docs/source/reference/pyspark.sql/index.rst
index 9322a91fba25..93901ab7ce12 100644
--- a/python/docs/source/reference/pyspark.sql/index.rst
+++ b/python/docs/source/reference/pyspark.sql/index.rst
@@ -41,5 +41,6 @@ This page gives an overview of all public Spark SQL API.
     observation
     udf
     udtf
+    variant_val
     protobuf
     datasource
diff --git a/python/docs/source/reference/pyspark.sql/index.rst 
b/python/docs/source/reference/pyspark.sql/variant_val.rst
similarity index 70%
copy from python/docs/source/reference/pyspark.sql/index.rst
copy to python/docs/source/reference/pyspark.sql/variant_val.rst
index 9322a91fba25..a7f592c18e3a 100644
--- a/python/docs/source/reference/pyspark.sql/index.rst
+++ b/python/docs/source/reference/pyspark.sql/variant_val.rst
@@ -16,30 +16,12 @@
     under the License.
 
 
-=========
-Spark SQL
-=========
+==========
+VariantVal
+==========
+.. currentmodule:: pyspark.sql
 
-This page gives an overview of all public Spark SQL API.
+.. autosummary::
+    :toctree: api/
 
-.. toctree::
-    :maxdepth: 2
-
-    core_classes
-    spark_session
-    configuration
-    io
-    dataframe
-    column
-    data_types
-    row
-    functions
-    window
-    grouping
-    catalog
-    avro
-    observation
-    udf
-    udtf
-    protobuf
-    datasource
+    VariantVal.toPython
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index dd82b037a6b9..bc046da81d27 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -39,7 +39,7 @@ Important classes of Spark SQL and DataFrames:
     - :class:`pyspark.sql.Window`
       For working with window functions.
 """
-from pyspark.sql.types import Row
+from pyspark.sql.types import Row, VariantVal
 from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration, 
UDTFRegistration
 from pyspark.sql.session import SparkSession
 from pyspark.sql.column import Column
@@ -67,6 +67,7 @@ __all__ = [
     "Row",
     "DataFrameNaFunctions",
     "DataFrameStatFunctions",
+    "VariantVal",
     "Window",
     "WindowSpec",
     "DataFrameReader",
diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index c86ee9c75fec..9b1007c41f9c 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -40,6 +40,8 @@ from pyspark.sql.types import (
     DecimalType,
     StringType,
     UserDefinedType,
+    VariantType,
+    VariantVal,
 )
 
 from pyspark.storagelevel import StorageLevel
@@ -95,6 +97,8 @@ class LocalDataToArrowConversion:
             return True
         elif isinstance(dataType, UserDefinedType):
             return True
+        elif isinstance(dataType, VariantType):
+            return True
         else:
             return False
 
@@ -290,6 +294,24 @@ class LocalDataToArrowConversion:
 
             return convert_udt
 
+        elif isinstance(dataType, VariantType):
+
+            def convert_variant(value: Any) -> Any:
+                if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
+                    return None
+                elif (
+                    isinstance(value, dict)
+                    and all(key in value for key in ["value", "metadata"])
+                    and all(isinstance(value[key], bytes) for key in ["value", 
"metadata"])
+                ):
+                    return VariantVal(value["value"], value["metadata"])
+                else:
+                    raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+            return convert_variant
+
         elif not nullable:
 
             def convert_other(value: Any) -> Any:
@@ -381,6 +403,8 @@ class ArrowTableToRowsConversion:
             return True
         elif isinstance(dataType, UserDefinedType):
             return True
+        elif isinstance(dataType, VariantType):
+            return True
         else:
             return False
 
@@ -488,6 +512,22 @@ class ArrowTableToRowsConversion:
 
             return convert_udt
 
+        elif isinstance(dataType, VariantType):
+
+            def convert_variant(value: Any) -> Any:
+                if value is None:
+                    return None
+                elif (
+                    isinstance(value, dict)
+                    and all(key in value for key in ["value", "metadata"])
+                    and all(isinstance(value[key], bytes) for key in ["value", 
"metadata"])
+                ):
+                    return VariantVal(value["value"], value["metadata"])
+                else:
+                    raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+            return convert_variant
+
         else:
             return lambda value: value
 
diff --git a/python/pyspark/sql/pandas/types.py 
b/python/pyspark/sql/pandas/types.py
index 3b48f8d8c319..559512bd00c1 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -47,6 +47,8 @@ from pyspark.sql.types import (
     NullType,
     DataType,
     UserDefinedType,
+    VariantType,
+    VariantVal,
     _create_row,
 )
 from pyspark.errors import PySparkTypeError, UnsupportedOperationException, 
PySparkValueError
@@ -108,6 +110,12 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
         arrow_type = pa.null()
     elif isinstance(dt, UserDefinedType):
         arrow_type = to_arrow_type(dt.sqlType())
+    elif type(dt) == VariantType:
+        fields = [
+            pa.field("value", pa.binary(), nullable=False),
+            pa.field("metadata", pa.binary(), nullable=False),
+        ]
+        arrow_type = pa.struct(fields)
     else:
         raise PySparkTypeError(
             error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
@@ -763,6 +771,20 @@ def _create_converter_to_pandas(
 
             return convert_udt
 
+        elif isinstance(dt, VariantType):
+
+            def convert_variant(value: Any) -> Any:
+                if (
+                    isinstance(value, dict)
+                    and all(key in value for key in ["value", "metadata"])
+                    and all(isinstance(value[key], bytes) for key in ["value", 
"metadata"])
+                ):
+                    return VariantVal(value["value"], value["metadata"])
+                else:
+                    raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+            return convert_variant
+
         else:
             return None
 
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index bb854641906a..af13adbc21bb 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -58,6 +58,7 @@ from pyspark.sql.types import (
     BooleanType,
     NullType,
     VariantType,
+    VariantVal,
 )
 from pyspark.sql.types import (
     _array_signed_int_typecode_ctype_mappings,
@@ -1406,6 +1407,69 @@ class TypesTestsMixin:
         schema1 = self.spark.range(1).select(F.make_interval(F.lit(1))).schema
         self.assertEqual(schema1.fields[0].dataType, CalendarIntervalType())
 
+    def test_variant_type(self):
+        from decimal import Decimal
+
+        self.assertEqual(VariantType().simpleString(), "variant")
+
+        # Holds a tuple of (key, json string value, python value)
+        expected_values = [
+            ("str", '"%s"' % ("0123456789" * 10), "0123456789" * 10),
+            ("short_str", '"abc"', "abc"),
+            ("null", "null", None),
+            ("true", "true", True),
+            ("false", "false", False),
+            ("int1", "1", 1),
+            ("-int1", "-5", -5),
+            ("int2", "257", 257),
+            ("-int2", "-124", -124),
+            ("int4", "65793", 65793),
+            ("-int4", "-69633", -69633),
+            ("int8", "4295033089", 4295033089),
+            ("-int8", "-4294967297", -4294967297),
+            ("float4", "1.23456789e-30", 1.23456789e-30),
+            ("-float4", "-4.56789e+29", -4.56789e29),
+            ("dec4", "123.456", Decimal("123.456")),
+            ("-dec4", "-321.654", Decimal("-321.654")),
+            ("dec8", "429.4967297", Decimal("429.4967297")),
+            ("-dec8", "-5.678373902", Decimal("-5.678373902")),
+            ("dec16", "467440737095.51617", Decimal("467440737095.51617")),
+            ("-dec16", "-67.849438003827263", Decimal("-67.849438003827263")),
+            ("arr", '[1.1,"2",[3],{"4":5}]', [Decimal("1.1"), "2", [3], {"4": 
5}]),
+            ("obj", '{"a":["123",{"b":2}],"c":3}', {"a": ["123", {"b": 2}], 
"c": 3}),
+        ]
+        json_str = "{%s}" % ",".join(['"%s": %s' % (t[0], t[1]) for t in 
expected_values])
+
+        df = self.spark.createDataFrame([({"json": json_str})])
+        row = df.select(
+            F.parse_json(df.json).alias("v"),
+            F.array([F.parse_json(F.lit('{"a": 1}'))]).alias("a"),
+            F.struct([F.parse_json(F.lit('{"b": "2"}'))]).alias("s"),
+            F.create_map([F.lit("k"), F.parse_json(F.lit('{"c": 
true}'))]).alias("m"),
+        ).collect()[0]
+        variants = [row["v"], row["a"][0], row["s"]["col1"], row["m"]["k"]]
+        for v in variants:
+            self.assertEqual(type(v), VariantVal)
+
+        # check str
+        as_string = str(variants[0])
+        for key, expected, _ in expected_values:
+            self.assertTrue('"%s":%s' % (key, expected) in as_string)
+        self.assertEqual(str(variants[1]), '{"a":1}')
+        self.assertEqual(str(variants[2]), '{"b":"2"}')
+        self.assertEqual(str(variants[3]), '{"c":true}')
+
+        # check toPython
+        as_python = variants[0].toPython()
+        for key, _, obj in expected_values:
+            self.assertEqual(as_python[key], obj)
+        self.assertEqual(variants[1].toPython(), {"a": 1})
+        self.assertEqual(variants[2].toPython(), {"b": "2"})
+        self.assertEqual(variants[3].toPython(), {"c": True})
+
+        # check repr
+        self.assertEqual(str(variants[0]), str(eval(repr(variants[0]))))
+
     def test_from_ddl(self):
         self.assertEqual(DataType.fromDDL("long"), LongType())
         self.assertEqual(
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 9b1bab4c23fa..3546fd822814 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -48,6 +48,7 @@ from typing import (
 from pyspark.util import is_remote_only
 from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.utils import has_numpy, get_active_spark_context
+from pyspark.sql.variant_utils import VariantUtils
 from pyspark.errors import (
     PySparkNotImplementedError,
     PySparkTypeError,
@@ -95,6 +96,7 @@ __all__ = [
     "StructField",
     "StructType",
     "VariantType",
+    "VariantVal",
 ]
 
 
@@ -1341,7 +1343,13 @@ class VariantType(AtomicType):
     .. versionadded:: 4.0.0
     """
 
-    pass
+    def needConversion(self) -> bool:
+        return True
+
+    def fromInternal(self, obj: Dict) -> Optional["VariantVal"]:
+        if obj is None or not all(key in obj for key in ["value", "metadata"]):
+            return None
+        return VariantVal(obj["value"], obj["metadata"])
 
 
 class UserDefinedType(DataType):
@@ -1465,6 +1473,55 @@ class UserDefinedType(DataType):
         return type(self) == type(other)
 
 
+class VariantVal:
+    """
+    A class to represent a Variant value in Python.
+
+    .. versionadded:: 4.0.0
+
+    Parameters
+    ----------
+    value : bytes
+        The bytes representing the value component of the Variant.
+    metadata : bytes
+        The bytes representing the metadata component of the Variant.
+
+    Methods
+    -------
+    toPython()
+        Convert the VariantVal to a Python data structure.
+
+    Examples
+    --------
+    >>> from pyspark.sql.functions import *
+    >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
+    >>> v = df.select(parse_json(df.json).alias("var")).collect()[0].var
+    >>> v.toPython()
+    {'a': 1}
+    """
+
+    def __init__(self, value: bytes, metadata: bytes):
+        self.value = value
+        self.metadata = metadata
+
+    def __str__(self) -> str:
+        return VariantUtils.to_json(self.value, self.metadata)
+
+    def __repr__(self) -> str:
+        return "VariantVal(%r, %r)" % (self.value, self.metadata)
+
+    def toPython(self) -> Any:
+        """
+        Convert the VariantVal to a Python data structure.
+
+        Returns
+        -------
+        Any
+            A Python object that represents the Variant.
+        """
+        return VariantUtils.to_python(self.value, self.metadata)
+
+
 _atomic_types: List[Type[DataType]] = [
     StringType,
     CharType,
diff --git a/python/pyspark/sql/variant_utils.py 
b/python/pyspark/sql/variant_utils.py
new file mode 100644
index 000000000000..9ca70365316d
--- /dev/null
+++ b/python/pyspark/sql/variant_utils.py
@@ -0,0 +1,388 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import decimal
+import json
+import struct
+from array import array
+from typing import Any, Callable, Dict, List, Tuple
+from pyspark.errors import PySparkValueError
+
+
+class VariantUtils:
+    """
+    A utility class for VariantVal.
+
+    Adapted from library at: org.apache.spark.types.variant.VariantUtil
+    """
+
+    BASIC_TYPE_BITS = 2
+    BASIC_TYPE_MASK = 0x3
+    TYPE_INFO_MASK = 0x3F
+    # The inclusive maximum value of the type info value. It is the size limit 
of `SHORT_STR`.
+    MAX_SHORT_STR_SIZE = 0x3F
+
+    # Below is all possible basic type values.
+    # Primitive value. The type info value must be one of the values in the 
below section.
+    PRIMITIVE = 0
+    # Short string value. The type info value is the string size, which must 
be in `[0,
+    # MAX_SHORT_STR_SIZE]`.
+    # The string content bytes directly follow the header byte.
+    SHORT_STR = 1
+    # Object value. The content contains a size, a list of field ids, a list 
of field offsets, and
+    # the actual field data. The length of the id list is `size`, while the 
length of the offset
+    # list is `size + 1`, where the last offset represent the total size of 
the field data. The
+    # fields in an object must be sorted by the field name in alphabetical 
order. Duplicate field
+    # names in one object are not allowed.
+    # We use 5 bits in the type info to specify the integer type of the object 
header: it should
+    # be 0_b4_b3b2_b1b0 (MSB is 0), where:
+    # - b4 specifies the type of size. When it is 0/1, `size` is a 
little-endian 1/4-byte
+    # unsigned integer.
+    # - b3b2/b1b0 specifies the integer type of id and offset. When the 2 bits 
are  0/1/2, the
+    # list contains 1/2/3-byte little-endian unsigned integers.
+    OBJECT = 2
+    # Array value. The content contains a size, a list of field offsets, and 
the actual element
+    # data. It is similar to an object without the id list. The length of the 
offset list
+    # is `size + 1`, where the last offset represent the total size of the 
element data.
+    # Its type info should be: 000_b2_b1b0:
+    # - b2 specifies the type of size.
+    # - b1b0 specifies the integer type of offset.
+    ARRAY = 3
+
+    # Below is all possible type info values for `PRIMITIVE`.
+    # JSON Null value. Empty content.
+    NULL = 0
+    # True value. Empty content.
+    TRUE = 1
+    # False value. Empty content.
+    FALSE = 2
+    # 1-byte little-endian signed integer.
+    INT1 = 3
+    # 2-byte little-endian signed integer.
+    INT2 = 4
+    # 4-byte little-endian signed integer.
+    INT4 = 5
+    # 4-byte little-endian signed integer.
+    INT8 = 6
+    # 8-byte IEEE double.
+    DOUBLE = 7
+    # 4-byte decimal. Content is 1-byte scale + 4-byte little-endian signed 
integer.
+    DECIMAL4 = 8
+    # 8-byte decimal. Content is 1-byte scale + 8-byte little-endian signed 
integer.
+    DECIMAL8 = 9
+    # 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed 
integer.
+    DECIMAL16 = 10
+    # Long string value. The content is (4-byte little-endian unsigned integer 
representing the
+    # string size) + (size bytes of string content).
+    LONG_STR = 16
+
+    U32_SIZE = 4
+
+    @classmethod
+    def to_json(cls, value: bytes, metadata: bytes) -> str:
+        """
+        Convert the VariantVal to a JSON string.
+        :return: JSON string
+        """
+        return cls._to_json(value, metadata, 0)
+
+    @classmethod
+    def to_python(cls, value: bytes, metadata: bytes) -> str:
+        """
+        Convert the VariantVal to a nested Python object of Python data types.
+        :return: Python representation of the Variant nested structure
+        """
+        return cls._to_python(value, metadata, 0)
+
+    @classmethod
+    def _read_long(cls, data: bytes, pos: int, num_bytes: int, signed: bool) 
-> int:
+        cls._check_index(pos, len(data))
+        cls._check_index(pos + num_bytes - 1, len(data))
+        return int.from_bytes(data[pos : pos + num_bytes], byteorder="little", 
signed=signed)
+
+    @classmethod
+    def _check_index(cls, pos: int, length: int) -> None:
+        if pos < 0 or pos >= length:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+    @classmethod
+    def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]:
+        """
+        Returns the (basic_type, type_info) pair from the given position in 
the value.
+        """
+        basic_type = value[pos] & VariantUtils.BASIC_TYPE_MASK
+        type_info = (value[pos] >> VariantUtils.BASIC_TYPE_BITS) & 
VariantUtils.TYPE_INFO_MASK
+        return (basic_type, type_info)
+
+    @classmethod
+    def _get_metadata_key(cls, metadata: bytes, id: int) -> str:
+        """
+        Returns the key string from the dictionary in the metadata, 
corresponding to `id`.
+        """
+        cls._check_index(0, len(metadata))
+        offset_size = ((metadata[0] >> 6) & 0x3) + 1
+        dict_size = cls._read_long(metadata, 1, offset_size, signed=False)
+        if id >= dict_size:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        string_start = 1 + (dict_size + 2) * offset_size
+        offset = cls._read_long(metadata, 1 + (id + 1) * offset_size, 
offset_size, signed=False)
+        next_offset = cls._read_long(
+            metadata, 1 + (id + 2) * offset_size, offset_size, signed=False
+        )
+        if offset > next_offset:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        cls._check_index(string_start + next_offset - 1, len(metadata))
+        return metadata[string_start + offset : (string_start + 
next_offset)].decode("utf-8")
+
+    @classmethod
+    def _get_boolean(cls, value: bytes, pos: int) -> bool:
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type != VariantUtils.PRIMITIVE or (
+            type_info != VariantUtils.TRUE and type_info != VariantUtils.FALSE
+        ):
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        return type_info == VariantUtils.TRUE
+
+    @classmethod
+    def _get_long(cls, value: bytes, pos: int) -> int:
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type != VariantUtils.PRIMITIVE:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        if type_info == VariantUtils.INT1:
+            return cls._read_long(value, pos + 1, 1, signed=True)
+        elif type_info == VariantUtils.INT2:
+            return cls._read_long(value, pos + 1, 2, signed=True)
+        elif type_info == VariantUtils.INT4:
+            return cls._read_long(value, pos + 1, 4, signed=True)
+        elif type_info == VariantUtils.INT8:
+            return cls._read_long(value, pos + 1, 8, signed=True)
+        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+    @classmethod
+    def _get_string(cls, value: bytes, pos: int) -> str:
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type == VariantUtils.SHORT_STR or (
+            basic_type == VariantUtils.PRIMITIVE and type_info == 
VariantUtils.LONG_STR
+        ):
+            start = 0
+            length = 0
+            if basic_type == VariantUtils.SHORT_STR:
+                start = pos + 1
+                length = type_info
+            else:
+                start = pos + 1 + VariantUtils.U32_SIZE
+                length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, 
signed=False)
+            cls._check_index(start + length - 1, len(value))
+            return value[start : start + length].decode("utf-8")
+        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+    @classmethod
+    def _get_double(cls, value: bytes, pos: int) -> float:
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type != VariantUtils.PRIMITIVE or type_info != 
VariantUtils.DOUBLE:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        return struct.unpack("d", value[pos + 1 : pos + 9])[0]
+
+    @classmethod
+    def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal:
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type != VariantUtils.PRIMITIVE:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        scale = value[pos + 1]
+        unscaled = 0
+        if type_info == VariantUtils.DECIMAL4:
+            unscaled = cls._read_long(value, pos + 2, 4, signed=True)
+        elif type_info == VariantUtils.DECIMAL8:
+            unscaled = cls._read_long(value, pos + 2, 8, signed=True)
+        elif type_info == VariantUtils.DECIMAL16:
+            cls._check_index(pos + 17, len(value))
+            unscaled = int.from_bytes(value[pos + 2 : pos + 18], 
byteorder="little", signed=True)
+        else:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale))
+
+    @classmethod
+    def _get_type(cls, value: bytes, pos: int) -> Any:
+        """
+        Returns the Python type of the Variant at the given position.
+        """
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type == VariantUtils.SHORT_STR:
+            return str
+        elif basic_type == VariantUtils.OBJECT:
+            return dict
+        elif basic_type == VariantUtils.ARRAY:
+            return array
+        elif type_info == VariantUtils.NULL:
+            return type(None)
+        elif type_info == VariantUtils.TRUE or type_info == VariantUtils.FALSE:
+            return bool
+        elif (
+            type_info == VariantUtils.INT1
+            or type_info == VariantUtils.INT2
+            or type_info == VariantUtils.INT4
+            or type_info == VariantUtils.INT8
+        ):
+            return int
+        elif type_info == VariantUtils.DOUBLE:
+            return float
+        elif (
+            type_info == VariantUtils.DECIMAL4
+            or type_info == VariantUtils.DECIMAL8
+            or type_info == VariantUtils.DECIMAL16
+        ):
+            return decimal.Decimal
+        elif type_info == VariantUtils.LONG_STR:
+            return str
+        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+    @classmethod
+    def _to_json(cls, value: bytes, metadata: bytes, pos: int) -> Any:
+        variant_type = cls._get_type(value, pos)
+        if variant_type == dict:
+
+            def handle_object(key_value_pos_list: list[Tuple[str, int]]) -> 
str:
+                key_value_list = [
+                    json.dumps(key) + ":" + cls._to_json(value, metadata, 
value_pos)
+                    for (key, value_pos) in key_value_pos_list
+                ]
+                return "{" + ",".join(key_value_list) + "}"
+
+            return cls._handle_object(value, metadata, pos, handle_object)
+        elif variant_type == array:
+
+            def handle_array(value_pos_list: list[int]) -> str:
+                value_list = [
+                    cls._to_json(value, metadata, value_pos) for value_pos in 
value_pos_list
+                ]
+                return "[" + ",".join(value_list) + "]"
+
+            return cls._handle_array(value, pos, handle_array)
+        else:
+            value = cls._get_scalar(variant_type, value, metadata, pos)
+            if value is None:
+                return "null"
+            if type(value) == bool:
+                return "true" if value else "false"
+            if type(value) == str:
+                return json.dumps(value)
+            return str(value)
+
+    @classmethod
+    def _to_python(cls, value: bytes, metadata: bytes, pos: int) -> Any:
+        variant_type = cls._get_type(value, pos)
+        if variant_type == dict:
+
+            def handle_object(key_value_pos_list: list[Tuple[str, int]]) -> 
Dict[str, Any]:
+                key_value_list = [
+                    (key, cls._to_python(value, metadata, value_pos))
+                    for (key, value_pos) in key_value_pos_list
+                ]
+                return dict(key_value_list)
+
+            return cls._handle_object(value, metadata, pos, handle_object)
+        elif variant_type == array:
+
+            def handle_array(value_pos_list: list[int]) -> List[Any]:
+                value_list = [
+                    cls._to_python(value, metadata, value_pos) for value_pos 
in value_pos_list
+                ]
+                return value_list
+
+            return cls._handle_array(value, pos, handle_array)
+        else:
+            return cls._get_scalar(variant_type, value, metadata, pos)
+
+    @classmethod
+    def _get_scalar(cls, variant_type: Any, value: bytes, metadata: bytes, 
pos: int) -> Any:
+        if isinstance(None, variant_type):
+            return None
+        elif variant_type == bool:
+            return cls._get_boolean(value, pos)
+        elif variant_type == int:
+            return cls._get_long(value, pos)
+        elif variant_type == str:
+            return cls._get_string(value, pos)
+        elif variant_type == float:
+            return cls._get_double(value, pos)
+        elif variant_type == decimal.Decimal:
+            return cls._get_decimal(value, pos)
+        else:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+
+    @classmethod
+    def _handle_object(
+        cls, value: bytes, metadata: bytes, pos: int, func: 
Callable[[list[Tuple[str, int]]], Any]
+    ) -> Any:
+        """
+        Parses the variant object at position `pos`.
+        Calls `func` with a list of (key, value position) pairs of the object.
+        """
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type != VariantUtils.OBJECT:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        large_size = ((type_info >> 4) & 0x1) != 0
+        size_bytes = VariantUtils.U32_SIZE if large_size else 1
+        num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
+        id_size = ((type_info >> 2) & 0x3) + 1
+        offset_size = ((type_info) & 0x3) + 1
+        id_start = pos + 1 + size_bytes
+        offset_start = id_start + num_fields * id_size
+        data_start = offset_start + (num_fields + 1) * offset_size
+
+        key_value_pos_list = []
+        for i in range(num_fields):
+            id = cls._read_long(value, id_start + id_size * i, id_size, 
signed=False)
+            offset = cls._read_long(
+                value, offset_start + offset_size * i, offset_size, 
signed=False
+            )
+            value_pos = data_start + offset
+            key_value_pos_list.append((cls._get_metadata_key(metadata, id), 
value_pos))
+        return func(key_value_pos_list)
+
+    @classmethod
+    def _handle_array(cls, value: bytes, pos: int, func: Callable[[list[int]], 
Any]) -> Any:
+        """
+        Parses the variant array at position `pos`.
+        Calls `func` with a list of element positions of the array.
+        """
+        cls._check_index(pos, len(value))
+        basic_type, type_info = cls._get_type_info(value, pos)
+        if basic_type != VariantUtils.ARRAY:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        large_size = ((type_info >> 2) & 0x1) != 0
+        size_bytes = VariantUtils.U32_SIZE if large_size else 1
+        num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
+        offset_size = (type_info & 0x3) + 1
+        offset_start = pos + 1 + size_bytes
+        data_start = offset_start + (num_fields + 1) * offset_size
+
+        value_pos_list = []
+        for i in range(num_fields):
+            offset = cls._read_long(
+                value, offset_start + offset_size * i, offset_size, 
signed=False
+            )
+            element_pos = data_start + offset
+            value_pos_list.append(element_pos)
+        return func(value_pos_list)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 92a4c687362d..d9bd3b0e612b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -125,6 +125,12 @@ private[sql] object ArrowUtils {
             largeVarTypes)).asJava)
       case udt: UserDefinedType[_] =>
         toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
+      case _: VariantType =>
+        val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, 
null,
+          Map("variant" -> "true").asJava)
+        new Field(name, fieldType,
+          Seq(toArrowField("value", BinaryType, false, timeZoneId, 
largeVarTypes),
+            toArrowField("metadata", BinaryType, false, timeZoneId, 
largeVarTypes)).asJava)
       case dataType =>
         val fieldType = new FieldType(nullable, toArrowType(dataType, 
timeZoneId,
           largeVarTypes), null)
@@ -143,6 +149,10 @@ private[sql] object ArrowUtils {
         val elementField = field.getChildren().get(0)
         val elementType = fromArrowField(elementField)
         ArrayType(elementType, containsNull = elementField.isNullable)
+      case ArrowType.Struct.INSTANCE if 
field.getMetadata.getOrDefault("variant", "") == "true"
+        && field.getChildren.asScala.map(_.getName).asJava
+          .containsAll(Seq("value", "metadata").asJava) =>
+        VariantType
       case ArrowType.Struct.INSTANCE =>
         val fields = field.getChildren().asScala.map { child =>
           val dt = fromArrowField(child)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 6680a5320fe3..ca7703bef48b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -84,6 +84,11 @@ object ArrowWriter {
       case (_: DayTimeIntervalType, vector: DurationVector) => new 
DurationWriter(vector)
       case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) =>
         new IntervalMonthDayNanoWriter(vector)
+      case (VariantType, vector: StructVector) =>
+        val children = (0 until vector.size()).map { ordinal =>
+          createFieldWriter(vector.getChildByOrdinal(ordinal))
+        }
+        new StructWriter(vector, children.toArray)
       case (dt, _) =>
         throw ExecutionErrors.unsupportedDataTypeError(dt)
     }
@@ -368,6 +373,8 @@ private[arrow] class StructWriter(
     val valueVector: StructVector,
     children: Array[ArrowFieldWriter]) extends ArrowFieldWriter {
 
+  lazy val isVariant = valueVector.getField.getMetadata.get("variant") == 
"true"
+
   override def setNull(): Unit = {
     var i = 0
     while (i < children.length) {
@@ -379,12 +386,20 @@ private[arrow] class StructWriter(
   }
 
   override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
-    val struct = input.getStruct(ordinal, children.length)
-    var i = 0
-    valueVector.setIndexDefined(count)
-    while (i < struct.numFields) {
-      children(i).write(struct, i)
-      i += 1
+    if (isVariant) {
+      valueVector.setIndexDefined(count)
+      val v = input.getVariant(ordinal)
+      val row = InternalRow(v.getValue, v.getMetadata)
+      children(0).write(row, 0)
+      children(1).write(row, 1)
+    } else {
+      val struct = input.getStruct(ordinal, children.length)
+      var i = 0
+      valueVector.setIndexDefined(count)
+      while (i < struct.numFields) {
+        children(i).write(struct, i)
+        i += 1
+      }
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to