This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 00d7094dc30 [SPARK-39809][PYTHON] Support CharType in PySpark 00d7094dc30 is described below commit 00d7094dc3024ae594605b311dcc55e95d277d5f Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Jul 19 10:22:04 2022 +0900 [SPARK-39809][PYTHON] Support CharType in PySpark ### What changes were proposed in this pull request? Support CharType in PySpark ### Why are the changes needed? for function parity ### Does this PR introduce _any_ user-facing change? yes, new type added ### How was this patch tested? added UT Closes #37215 from zhengruifeng/py_add_char. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/tests/test_types.py | 26 ++++++++++++++++++--- python/pyspark/sql/types.py | 42 +++++++++++++++++++++++++++++++--- 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 218cfc413db..b1609417a0c 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -38,6 +38,7 @@ from pyspark.sql.types import ( DayTimeIntervalType, MapType, StringType, + CharType, VarcharType, StructType, StructField, @@ -740,9 +741,12 @@ class TypesTests(ReusedSQLTestCase): from pyspark.sql.types import _all_atomic_types, _parse_datatype_string for k, t in _all_atomic_types.items(): - if k != "varchar": + if k != "varchar" and k != "char": self.assertEqual(t(), _parse_datatype_string(k)) self.assertEqual(IntegerType(), _parse_datatype_string("int")) + self.assertEqual(CharType(1), _parse_datatype_string("char(1)")) + self.assertEqual(CharType(10), _parse_datatype_string("char( 10 )")) + self.assertEqual(CharType(11), _parse_datatype_string("char( 11)")) self.assertEqual(VarcharType(1), _parse_datatype_string("varchar(1)")) self.assertEqual(VarcharType(10), _parse_datatype_string("varchar( 10 )")) self.assertEqual(VarcharType(11), _parse_datatype_string("varchar( 11)")) @@ -1033,6 +1037,7 @@ class TypesTests(ReusedSQLTestCase): instances = [ NullType(), StringType(), + CharType(10), VarcharType(10), BinaryType(), BooleanType(), @@ -1138,6 +1143,15 @@ class DataTypeTests(unittest.TestCase): t3 = DecimalType(8) self.assertNotEqual(t2, t3) + def test_char_type(self): + v1 = CharType(10) + v2 = CharType(20) + self.assertTrue(v2 is not v1) + self.assertNotEqual(v1, v2) + v3 = CharType(10) + self.assertEqual(v1, v3) + self.assertFalse(v1 is v3) + def test_varchar_type(self): v1 = VarcharType(10) v2 = VarcharType(20) @@ -1221,14 +1235,18 @@ class DataTypeVerificationTests(unittest.TestCase): success_spec = [ # String ("", StringType()), - ("", StringType()), (1, StringType()), (1.0, StringType()), ([], StringType()), ({}, StringType()), + # Char + ("", CharType(10)), + (1, CharType(10)), + (1.0, CharType(10)), + ([], CharType(10)), + ({}, CharType(10)), # Varchar ("", VarcharType(10)), - ("", VarcharType(10)), (1, VarcharType(10)), (1.0, VarcharType(10)), ([], VarcharType(10)), @@ -1289,6 +1307,8 @@ class DataTypeVerificationTests(unittest.TestCase): failure_spec = [ # String (match anything but None) (None, StringType(), ValueError), + # CharType (match anything but None) + (None, CharType(10), ValueError), # VarcharType (match anything but None) (None, VarcharType(10), ValueError), # UDT diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 7ab8f7c9c2d..e034ff75e10 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -56,6 +56,7 @@ U = TypeVar("U") __all__ = [ "DataType", "NullType", + "CharType", "StringType", "VarcharType", "BinaryType", @@ -182,6 +183,28 @@ class StringType(AtomicType, metaclass=DataTypeSingleton): pass +class CharType(AtomicType): + """Char data type + + Parameters + ---------- + length : int + the length limitation. + """ + + def __init__(self, length: int): + self.length = length + + def simpleString(self) -> str: + return "char(%d)" % (self.length) + + def jsonValue(self) -> str: + return "char(%d)" % (self.length) + + def __repr__(self) -> str: + return "CharType(%d)" % (self.length) + + class VarcharType(AtomicType): """Varchar data type @@ -648,6 +671,10 @@ class StructType(DataType): >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True + >>> struct1 = StructType([StructField("f1", CharType(10), True)]) + >>> struct2 = StructType([StructField("f1", CharType(10), True)]) + >>> struct1 == struct2 + True >>> struct1 = StructType([StructField("f1", VarcharType(10), True)]) >>> struct2 = StructType([StructField("f1", VarcharType(10), True)]) >>> struct1 == struct2 @@ -971,6 +998,7 @@ class UserDefinedType(DataType): _atomic_types: List[Type[DataType]] = [ StringType, + CharType, VarcharType, BinaryType, BooleanType, @@ -993,6 +1021,7 @@ _all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = dic (v.typeName(), v) for v in _complex_types ) +_LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)") _LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)") _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") @@ -1015,6 +1044,8 @@ def _parse_datatype_string(s: str) -> DataType: StructType([StructField('a', ByteType(), True), StructField('b', DecimalType(16,8), True)]) >>> _parse_datatype_string("a DOUBLE, b STRING") StructType([StructField('a', DoubleType(), True), StructField('b', StringType(), True)]) + >>> _parse_datatype_string("a DOUBLE, b CHAR( 50 )") + StructType([StructField('a', DoubleType(), True), StructField('b', CharType(50), True)]) >>> _parse_datatype_string("a DOUBLE, b VARCHAR( 50 )") StructType([StructField('a', DoubleType(), True), StructField('b', VarcharType(50), True)]) >>> _parse_datatype_string("a: array< short>") @@ -1085,7 +1116,7 @@ def _parse_datatype_json_string(json_string: str) -> DataType: ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype >>> for cls in _all_atomic_types.values(): - ... if cls is not VarcharType: + ... if cls is not VarcharType and cls is not CharType: ... check_datatype(cls()) ... else: ... check_datatype(cls(1)) @@ -1112,6 +1143,7 @@ def _parse_datatype_json_string(json_string: str) -> DataType: ... StructField("simpleMap", simple_maptype, True), ... StructField("simpleStruct", simple_structtype, True), ... StructField("boolean", BooleanType(), False), + ... StructField("chars", CharType(10), False), ... StructField("words", VarcharType(10), False), ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) >>> check_datatype(complex_structtype) @@ -1145,6 +1177,9 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType: if first_field is not None and second_field is None: return DayTimeIntervalType(first_field) return DayTimeIntervalType(first_field, second_field) + elif _LENGTH_CHAR.match(json_value): + m = _LENGTH_CHAR.match(json_value) + return CharType(int(m.group(1))) # type: ignore[union-attr] elif _LENGTH_VARCHAR.match(json_value): m = _LENGTH_VARCHAR.match(json_value) return VarcharType(int(m.group(1))) # type: ignore[union-attr] @@ -1586,6 +1621,7 @@ _acceptable_types = { DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str,), + CharType: (str,), VarcharType: (str,), BinaryType: (bytearray, bytes), DateType: (datetime.date, datetime.datetime), @@ -1697,8 +1733,8 @@ def _make_type_verifier( new_msg("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) ) - if isinstance(dataType, (StringType, VarcharType)): - # StringType and VarcharType can work with any types + if isinstance(dataType, (StringType, CharType, VarcharType)): + # StringType, CharType and VarcharType can work with any types def verify_value(obj: Any) -> None: pass --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org