This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 08808fb5079 [SPARK-39760][PYTHON] Support Varchar in PySpark
08808fb5079 is described below

commit 08808fb507947b51ea7656496612a81e11fe66bd
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Jul 18 15:55:55 2022 +0800

    [SPARK-39760][PYTHON] Support Varchar in PySpark
    
    ### What changes were proposed in this pull request?
    Support Varchar in PySpark
    
    ### Why are the changes needed?
    function parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new datatype supported
    
    ### How was this patch tested?
    1, added UT;
    2, manually check against the scala side:
    
    ```python
    
    In [1]: from pyspark.sql.types import *
       ...: from pyspark.sql.functions import *
       ...:
       ...: df = spark.createDataFrame([(1,), (11,)], ["value"])
       ...: ret = df.select(col("value").cast(VarcharType(10))).collect()
       ...:
    22/07/13 17:17:07 WARN CharVarcharUtils: The Spark cast operator does not 
support char/varchar type and simply treats them as string type. Please use 
string type directly to avoid confusion. Otherwise, you can set 
spark.sql.legacy.charVarcharAsString to true, so that Spark treat them as 
string type as same as Spark 3.0 and earlier
    
    In [2]:
    
    In [2]: schema = StructType([StructField("a", IntegerType(), True), 
(StructField("v", VarcharType(10), True))])
       ...: description = "this a table created via Catalog.createTable()"
       ...: table = spark.catalog.createTable("tab3_via_catalog", 
schema=schema, description=description)
       ...: table.schema
       ...:
    Out[2]: StructType([StructField('a', IntegerType(), True), StructField('v', 
StringType(), True)])
    ```
    
    ```scala
    scala> import org.apache.spark.sql.types._
    import org.apache.spark.sql.types._
    
    scala> import org.apache.spark.sql.functions._
    import org.apache.spark.sql.functions._
    
    scala> val df = spark.range(0, 10).selectExpr(" id AS value")
    df: org.apache.spark.sql.DataFrame = [value: bigint]
    
    scala> val ret = df.select(col("value").cast(VarcharType(10))).collect()
    22/07/13 17:28:56 WARN CharVarcharUtils: The Spark cast operator does not 
support char/varchar type and simply treats them as string type. Please use 
string type directly to avoid confusion. Otherwise, you can set 
spark.sql.legacy.charVarcharAsString to true, so that Spark treat them as 
string type as same as Spark 3.0 and earlier
    ret: Array[org.apache.spark.sql.Row] = Array([0], [1], [2], [3], [4], [5], 
[6], [7], [8], [9])
    
    scala>
    
    scala> val schema = StructType(StructField("a", IntegerType, true) :: 
(StructField("v", VarcharType(10), true) :: Nil))
    schema: org.apache.spark.sql.types.StructType = 
StructType(StructField(a,IntegerType,true),StructField(v,VarcharType(10),true))
    
    scala> val description = "this a table created via Catalog.createTable()"
    description: String = this a table created via Catalog.createTable()
    
    scala> val table = spark.catalog.createTable("tab3_via_catalog", 
source="json", schema=schema, description=description, 
options=Map.empty[String, String])
    table: org.apache.spark.sql.DataFrame = [a: int, v: string]
    
    scala> table.schema
    res0: org.apache.spark.sql.types.StructType = 
StructType(StructField(a,IntegerType,true),StructField(v,StringType,true))
    ```
    
    Closes #37173 from zhengruifeng/py_add_varchar.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../source/reference/pyspark.sql/data_types.rst    |  1 +
 python/pyspark/sql/tests/test_types.py             | 26 +++++++++++-
 python/pyspark/sql/types.py                        | 46 ++++++++++++++++++++--
 3 files changed, 68 insertions(+), 5 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql/data_types.rst 
b/python/docs/source/reference/pyspark.sql/data_types.rst
index d146c640477..775f0bf394a 100644
--- a/python/docs/source/reference/pyspark.sql/data_types.rst
+++ b/python/docs/source/reference/pyspark.sql/data_types.rst
@@ -40,6 +40,7 @@ Data Types
     NullType
     ShortType
     StringType
+    VarcharType
     StructField
     StructType
     TimestampType
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index ef0ad82dbb9..218cfc413db 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -38,6 +38,7 @@ from pyspark.sql.types import (
     DayTimeIntervalType,
     MapType,
     StringType,
+    VarcharType,
     StructType,
     StructField,
     ArrayType,
@@ -739,8 +740,12 @@ class TypesTests(ReusedSQLTestCase):
         from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
 
         for k, t in _all_atomic_types.items():
-            self.assertEqual(t(), _parse_datatype_string(k))
+            if k != "varchar":
+                self.assertEqual(t(), _parse_datatype_string(k))
         self.assertEqual(IntegerType(), _parse_datatype_string("int"))
+        self.assertEqual(VarcharType(1), _parse_datatype_string("varchar(1)"))
+        self.assertEqual(VarcharType(10), _parse_datatype_string("varchar( 10  
 )"))
+        self.assertEqual(VarcharType(11), _parse_datatype_string("varchar( 
11)"))
         self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1  
,1)"))
         self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 
10,1 )"))
         self.assertEqual(DecimalType(11, 1), 
_parse_datatype_string("decimal(11,1)"))
@@ -1028,6 +1033,7 @@ class TypesTests(ReusedSQLTestCase):
         instances = [
             NullType(),
             StringType(),
+            VarcharType(10),
             BinaryType(),
             BooleanType(),
             DateType(),
@@ -1132,6 +1138,15 @@ class DataTypeTests(unittest.TestCase):
         t3 = DecimalType(8)
         self.assertNotEqual(t2, t3)
 
+    def test_varchar_type(self):
+        v1 = VarcharType(10)
+        v2 = VarcharType(20)
+        self.assertTrue(v2 is not v1)
+        self.assertNotEqual(v1, v2)
+        v3 = VarcharType(10)
+        self.assertEqual(v1, v3)
+        self.assertFalse(v1 is v3)
+
     # regression test for SPARK-10392
     def test_datetype_equal_zero(self):
         dt = DateType()
@@ -1211,6 +1226,13 @@ class DataTypeVerificationTests(unittest.TestCase):
             (1.0, StringType()),
             ([], StringType()),
             ({}, StringType()),
+            # Varchar
+            ("", VarcharType(10)),
+            ("", VarcharType(10)),
+            (1, VarcharType(10)),
+            (1.0, VarcharType(10)),
+            ([], VarcharType(10)),
+            ({}, VarcharType(10)),
             # UDT
             (ExamplePoint(1.0, 2.0), ExamplePointUDT()),
             # Boolean
@@ -1267,6 +1289,8 @@ class DataTypeVerificationTests(unittest.TestCase):
         failure_spec = [
             # String (match anything but None)
             (None, StringType(), ValueError),
+            # VarcharType (match anything but None)
+            (None, VarcharType(10), ValueError),
             # UDT
             (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
             # Boolean
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index fa3f3dd7d88..7ab8f7c9c2d 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -57,6 +57,7 @@ __all__ = [
     "DataType",
     "NullType",
     "StringType",
+    "VarcharType",
     "BinaryType",
     "BooleanType",
     "DateType",
@@ -181,6 +182,28 @@ class StringType(AtomicType, metaclass=DataTypeSingleton):
     pass
 
 
+class VarcharType(AtomicType):
+    """Varchar data type
+
+    Parameters
+    ----------
+    length : int
+        the length limitation.
+    """
+
+    def __init__(self, length: int):
+        self.length = length
+
+    def simpleString(self) -> str:
+        return "varchar(%d)" % (self.length)
+
+    def jsonValue(self) -> str:
+        return "varchar(%d)" % (self.length)
+
+    def __repr__(self) -> str:
+        return "VarcharType(%d)" % (self.length)
+
+
 class BinaryType(AtomicType, metaclass=DataTypeSingleton):
     """Binary (byte array) data type."""
 
@@ -625,6 +648,10 @@ class StructType(DataType):
     >>> struct2 = StructType([StructField("f1", StringType(), True)])
     >>> struct1 == struct2
     True
+    >>> struct1 = StructType([StructField("f1", VarcharType(10), True)])
+    >>> struct2 = StructType([StructField("f1", VarcharType(10), True)])
+    >>> struct1 == struct2
+    True
     >>> struct1 = StructType([StructField("f1", StringType(), True)])
     >>> struct2 = StructType([StructField("f1", StringType(), True),
     ...     StructField("f2", IntegerType(), False)])
@@ -944,6 +971,7 @@ class UserDefinedType(DataType):
 
 _atomic_types: List[Type[DataType]] = [
     StringType,
+    VarcharType,
     BinaryType,
     BooleanType,
     DecimalType,
@@ -965,7 +993,7 @@ _all_complex_types: Dict[str, Type[Union[ArrayType, 
MapType, StructType]]] = dic
     (v.typeName(), v) for v in _complex_types
 )
 
-
+_LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)")
 _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
 _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to 
(day|hour|minute|second))?")
 
@@ -987,6 +1015,8 @@ def _parse_datatype_string(s: str) -> DataType:
     StructType([StructField('a', ByteType(), True), StructField('b', 
DecimalType(16,8), True)])
     >>> _parse_datatype_string("a DOUBLE, b STRING")
     StructType([StructField('a', DoubleType(), True), StructField('b', 
StringType(), True)])
+    >>> _parse_datatype_string("a DOUBLE, b VARCHAR( 50 )")
+    StructType([StructField('a', DoubleType(), True), StructField('b', 
VarcharType(50), True)])
     >>> _parse_datatype_string("a: array< short>")
     StructType([StructField('a', ArrayType(ShortType(), True), True)])
     >>> _parse_datatype_string(" map<string , string > ")
@@ -1055,7 +1085,10 @@ def _parse_datatype_json_string(json_string: str) -> 
DataType:
     ...     python_datatype = 
_parse_datatype_json_string(scala_datatype.json())
     ...     assert datatype == python_datatype
     >>> for cls in _all_atomic_types.values():
-    ...     check_datatype(cls())
+    ...     if cls is not VarcharType:
+    ...         check_datatype(cls())
+    ...     else:
+    ...         check_datatype(cls(1))
 
     >>> # Simple ArrayType.
     >>> simple_arraytype = ArrayType(StringType(), True)
@@ -1079,6 +1112,7 @@ def _parse_datatype_json_string(json_string: str) -> 
DataType:
     ...     StructField("simpleMap", simple_maptype, True),
     ...     StructField("simpleStruct", simple_structtype, True),
     ...     StructField("boolean", BooleanType(), False),
+    ...     StructField("words", VarcharType(10), False),
     ...     StructField("withMeta", DoubleType(), False, {"name": "age"})])
     >>> check_datatype(complex_structtype)
 
@@ -1111,6 +1145,9 @@ def _parse_datatype_json_value(json_value: Union[dict, 
str]) -> DataType:
             if first_field is not None and second_field is None:
                 return DayTimeIntervalType(first_field)
             return DayTimeIntervalType(first_field, second_field)
+        elif _LENGTH_VARCHAR.match(json_value):
+            m = _LENGTH_VARCHAR.match(json_value)
+            return VarcharType(int(m.group(1)))  # type: ignore[union-attr]
         else:
             raise ValueError("Could not parse datatype: %s" % json_value)
     else:
@@ -1549,6 +1586,7 @@ _acceptable_types = {
     DoubleType: (float,),
     DecimalType: (decimal.Decimal,),
     StringType: (str,),
+    VarcharType: (str,),
     BinaryType: (bytearray, bytes),
     DateType: (datetime.date, datetime.datetime),
     TimestampType: (datetime.datetime,),
@@ -1659,8 +1697,8 @@ def _make_type_verifier(
                 new_msg("%s can not accept object %r in type %s" % (dataType, 
obj, type(obj)))
             )
 
-    if isinstance(dataType, StringType):
-        # StringType can work with any types
+    if isinstance(dataType, (StringType, VarcharType)):
+        # StringType and VarcharType can work with any types
         def verify_value(obj: Any) -> None:
             pass
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to