This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fc1435d14d09 [SPARK-48415][PYTHON] Refactor `TypeName` to support
parameterized datatypes
fc1435d14d09 is described below
commit fc1435d14d090b792a0f19372d6b11c7ff026372
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue May 28 08:39:28 2024 +0800
[SPARK-48415][PYTHON] Refactor `TypeName` to support parameterized datatypes
### What changes were proposed in this pull request?
1, refactor instance method `TypeName` to support parameterized datatypes
2, remove redundant simpleString/jsonValue methods, since they are type
name by default.
### Why are the changes needed?
to be consistent with the Scala side
### Does this PR introduce _any_ user-facing change?
type names changes:
`CharType(10)`: `char` -> `char(10)`
`VarcharType(10)`: `varchar` -> `varchar(10)`
`DecimalType(10, 2)`: `decimal` -> `decimal(10,2)`
`DayTimeIntervalType(DAY, HOUR)`: `daytimeinterval` -> `interval day to
hour`
`YearMonthIntervalType(YEAR, MONTH)`: `yearmonthinterval` -> `interval year
to month`
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #46738 from zhengruifeng/py_type_name.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/tests/test_types.py | 133 +++++++++++++++++++++++++++++++++
python/pyspark/sql/types.py | 74 +++++++-----------
2 files changed, 160 insertions(+), 47 deletions(-)
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index 80f2c0fcbc03..cc482b886e3a 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -81,6 +81,139 @@ from pyspark.testing.utils import PySparkErrorTestUtils
class TypesTestsMixin:
+ def test_class_method_type_name(self):
+ for dataType, expected in [
+ (StringType, "string"),
+ (CharType, "char"),
+ (VarcharType, "varchar"),
+ (BinaryType, "binary"),
+ (BooleanType, "boolean"),
+ (DecimalType, "decimal"),
+ (FloatType, "float"),
+ (DoubleType, "double"),
+ (ByteType, "byte"),
+ (ShortType, "short"),
+ (IntegerType, "integer"),
+ (LongType, "long"),
+ (DateType, "date"),
+ (TimestampType, "timestamp"),
+ (TimestampNTZType, "timestamp_ntz"),
+ (NullType, "void"),
+ (VariantType, "variant"),
+ (YearMonthIntervalType, "yearmonthinterval"),
+ (DayTimeIntervalType, "daytimeinterval"),
+ (CalendarIntervalType, "interval"),
+ ]:
+ self.assertEqual(dataType.typeName(), expected)
+
+ def test_instance_method_type_name(self):
+ for dataType, expected in [
+ (StringType(), "string"),
+ (CharType(5), "char(5)"),
+ (VarcharType(10), "varchar(10)"),
+ (BinaryType(), "binary"),
+ (BooleanType(), "boolean"),
+ (DecimalType(), "decimal(10,0)"),
+ (DecimalType(10, 2), "decimal(10,2)"),
+ (FloatType(), "float"),
+ (DoubleType(), "double"),
+ (ByteType(), "byte"),
+ (ShortType(), "short"),
+ (IntegerType(), "integer"),
+ (LongType(), "long"),
+ (DateType(), "date"),
+ (TimestampType(), "timestamp"),
+ (TimestampNTZType(), "timestamp_ntz"),
+ (NullType(), "void"),
+ (VariantType(), "variant"),
+ (YearMonthIntervalType(), "interval year to month"),
+ (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval
year"),
+ (
+ YearMonthIntervalType(YearMonthIntervalType.YEAR,
YearMonthIntervalType.MONTH),
+ "interval year to month",
+ ),
+ (DayTimeIntervalType(), "interval day to second"),
+ (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"),
+ (
+ DayTimeIntervalType(DayTimeIntervalType.HOUR,
DayTimeIntervalType.SECOND),
+ "interval hour to second",
+ ),
+ (CalendarIntervalType(), "interval"),
+ ]:
+ self.assertEqual(dataType.typeName(), expected)
+
+ def test_simple_string(self):
+ for dataType, expected in [
+ (StringType(), "string"),
+ (CharType(5), "char(5)"),
+ (VarcharType(10), "varchar(10)"),
+ (BinaryType(), "binary"),
+ (BooleanType(), "boolean"),
+ (DecimalType(), "decimal(10,0)"),
+ (DecimalType(10, 2), "decimal(10,2)"),
+ (FloatType(), "float"),
+ (DoubleType(), "double"),
+ (ByteType(), "tinyint"),
+ (ShortType(), "smallint"),
+ (IntegerType(), "int"),
+ (LongType(), "bigint"),
+ (DateType(), "date"),
+ (TimestampType(), "timestamp"),
+ (TimestampNTZType(), "timestamp_ntz"),
+ (NullType(), "void"),
+ (VariantType(), "variant"),
+ (YearMonthIntervalType(), "interval year to month"),
+ (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval
year"),
+ (
+ YearMonthIntervalType(YearMonthIntervalType.YEAR,
YearMonthIntervalType.MONTH),
+ "interval year to month",
+ ),
+ (DayTimeIntervalType(), "interval day to second"),
+ (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"),
+ (
+ DayTimeIntervalType(DayTimeIntervalType.HOUR,
DayTimeIntervalType.SECOND),
+ "interval hour to second",
+ ),
+ (CalendarIntervalType(), "interval"),
+ ]:
+ self.assertEqual(dataType.simpleString(), expected)
+
+ def test_json_value(self):
+ for dataType, expected in [
+ (StringType(), "string"),
+ (CharType(5), "char(5)"),
+ (VarcharType(10), "varchar(10)"),
+ (BinaryType(), "binary"),
+ (BooleanType(), "boolean"),
+ (DecimalType(), "decimal(10,0)"),
+ (DecimalType(10, 2), "decimal(10,2)"),
+ (FloatType(), "float"),
+ (DoubleType(), "double"),
+ (ByteType(), "byte"),
+ (ShortType(), "short"),
+ (IntegerType(), "integer"),
+ (LongType(), "long"),
+ (DateType(), "date"),
+ (TimestampType(), "timestamp"),
+ (TimestampNTZType(), "timestamp_ntz"),
+ (NullType(), "void"),
+ (VariantType(), "variant"),
+ (YearMonthIntervalType(), "interval year to month"),
+ (YearMonthIntervalType(YearMonthIntervalType.YEAR), "interval
year"),
+ (
+ YearMonthIntervalType(YearMonthIntervalType.YEAR,
YearMonthIntervalType.MONTH),
+ "interval year to month",
+ ),
+ (DayTimeIntervalType(), "interval day to second"),
+ (DayTimeIntervalType(DayTimeIntervalType.DAY), "interval day"),
+ (
+ DayTimeIntervalType(DayTimeIntervalType.HOUR,
DayTimeIntervalType.SECOND),
+ "interval hour to second",
+ ),
+ (CalendarIntervalType(), "interval"),
+ ]:
+ self.assertEqual(dataType.jsonValue(), expected)
+
def test_apply_schema_to_row(self):
df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index b9db59e0a58a..563c63f5dfb1 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -115,7 +115,11 @@ class DataType:
return hash(str(self))
def __eq__(self, other: Any) -> bool:
- return isinstance(other, self.__class__) and self.__dict__ ==
other.__dict__
+ if isinstance(other, self.__class__):
+ self_dict = {k: v for k, v in self.__dict__.items() if k !=
"typeName"}
+ other_dict = {k: v for k, v in other.__dict__.items() if k !=
"typeName"}
+ return self_dict == other_dict
+ return False
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
@@ -124,6 +128,12 @@ class DataType:
def typeName(cls) -> str:
return cls.__name__[:-4].lower()
+ # The classmethod 'typeName' is not always consistent with the Scala side,
e.g.
+ # DecimalType(10, 2): 'decimal' vs 'decimal(10, 2)'
+ # This method is used in subclass initializer to replace 'typeName' if
they are different.
+ def _type_name(self) -> str:
+ return
self.__class__.__name__.removesuffix("Type").removesuffix("UDT").lower()
+
def simpleString(self) -> str:
return self.typeName()
@@ -215,24 +225,6 @@ class DataType:
if isinstance(dataType, (ArrayType, StructType, MapType)):
dataType._build_formatted_string(prefix, stringConcat, maxDepth -
1)
- # The method typeName() is not always the same as the Scala side.
- # Add this helper method to make TreeString() compatible with Scala side.
- @classmethod
- def _get_jvm_type_name(cls, dataType: "DataType") -> str:
- if isinstance(
- dataType,
- (
- DecimalType,
- CharType,
- VarcharType,
- DayTimeIntervalType,
- YearMonthIntervalType,
- ),
- ):
- return dataType.simpleString()
- else:
- return dataType.typeName()
-
# This singleton pattern does not work with pickle, you will get
# another object after pickle and unpickle
@@ -294,6 +286,7 @@ class StringType(AtomicType):
providers = [providerSpark, providerICU]
def __init__(self, collation: Optional[str] = None):
+ self.typeName = self._type_name # type: ignore[method-assign]
self.collationId = 0 if collation is None else
self.collationNameToId(collation)
@classmethod
@@ -315,7 +308,7 @@ class StringType(AtomicType):
return StringType.providerSpark
return StringType.providerICU
- def simpleString(self) -> str:
+ def _type_name(self) -> str:
if self.isUTF8BinaryCollation():
return "string"
@@ -348,12 +341,10 @@ class CharType(AtomicType):
"""
def __init__(self, length: int):
+ self.typeName = self._type_name # type: ignore[method-assign]
self.length = length
- def simpleString(self) -> str:
- return "char(%d)" % (self.length)
-
- def jsonValue(self) -> str:
+ def _type_name(self) -> str:
return "char(%d)" % (self.length)
def __repr__(self) -> str:
@@ -370,12 +361,10 @@ class VarcharType(AtomicType):
"""
def __init__(self, length: int):
+ self.typeName = self._type_name # type: ignore[method-assign]
self.length = length
- def simpleString(self) -> str:
- return "varchar(%d)" % (self.length)
-
- def jsonValue(self) -> str:
+ def _type_name(self) -> str:
return "varchar(%d)" % (self.length)
def __repr__(self) -> str:
@@ -474,14 +463,12 @@ class DecimalType(FractionalType):
"""
def __init__(self, precision: int = 10, scale: int = 0):
+ self.typeName = self._type_name # type: ignore[method-assign]
self.precision = precision
self.scale = scale
self.hasPrecisionInfo = True # this is a public API
- def simpleString(self) -> str:
- return "decimal(%d,%d)" % (self.precision, self.scale)
-
- def jsonValue(self) -> str:
+ def _type_name(self) -> str:
return "decimal(%d,%d)" % (self.precision, self.scale)
def __repr__(self) -> str:
@@ -556,6 +543,7 @@ class DayTimeIntervalType(AnsiIntervalType):
_inverted_fields = dict(zip(_fields.values(), _fields.keys()))
def __init__(self, startField: Optional[int] = None, endField:
Optional[int] = None):
+ self.typeName = self._type_name # type: ignore[method-assign]
if startField is None and endField is None:
# Default matched to scala side.
startField = DayTimeIntervalType.DAY
@@ -572,7 +560,7 @@ class DayTimeIntervalType(AnsiIntervalType):
self.startField = startField
self.endField = endField
- def _str_repr(self) -> str:
+ def _type_name(self) -> str:
fields = DayTimeIntervalType._fields
start_field_name = fields[self.startField]
end_field_name = fields[self.endField]
@@ -581,10 +569,6 @@ class DayTimeIntervalType(AnsiIntervalType):
else:
return "interval %s to %s" % (start_field_name, end_field_name)
- simpleString = _str_repr
-
- jsonValue = _str_repr
-
def __repr__(self) -> str:
return "%s(%d, %d)" % (type(self).__name__, self.startField,
self.endField)
@@ -614,6 +598,7 @@ class YearMonthIntervalType(AnsiIntervalType):
_inverted_fields = dict(zip(_fields.values(), _fields.keys()))
def __init__(self, startField: Optional[int] = None, endField:
Optional[int] = None):
+ self.typeName = self._type_name # type: ignore[method-assign]
if startField is None and endField is None:
# Default matched to scala side.
startField = YearMonthIntervalType.YEAR
@@ -630,7 +615,7 @@ class YearMonthIntervalType(AnsiIntervalType):
self.startField = startField
self.endField = endField
- def _str_repr(self) -> str:
+ def _type_name(self) -> str:
fields = YearMonthIntervalType._fields
start_field_name = fields[self.startField]
end_field_name = fields[self.endField]
@@ -639,10 +624,6 @@ class YearMonthIntervalType(AnsiIntervalType):
else:
return "interval %s to %s" % (start_field_name, end_field_name)
- simpleString = _str_repr
-
- jsonValue = _str_repr
-
def __repr__(self) -> str:
return "%s(%d, %d)" % (type(self).__name__, self.startField,
self.endField)
@@ -776,7 +757,7 @@ class ArrayType(DataType):
) -> None:
if maxDepth > 0:
stringConcat.append(
- f"{prefix}-- element:
{DataType._get_jvm_type_name(self.elementType)} "
+ f"{prefix}-- element: {self.elementType.typeName()} "
+ f"(containsNull = {str(self.containsNull).lower()})\n"
)
DataType._data_type_build_formatted_string(
@@ -924,12 +905,12 @@ class MapType(DataType):
maxDepth: int = JVM_INT_MAX,
) -> None:
if maxDepth > 0:
- stringConcat.append(f"{prefix}-- key:
{DataType._get_jvm_type_name(self.keyType)}\n")
+ stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n")
DataType._data_type_build_formatted_string(
self.keyType, f"{prefix} |", stringConcat, maxDepth
)
stringConcat.append(
- f"{prefix}-- value:
{DataType._get_jvm_type_name(self.valueType)} "
+ f"{prefix}-- value: {self.valueType.typeName()} "
+ f"(valueContainsNull =
{str(self.valueContainsNull).lower()})\n"
)
DataType._data_type_build_formatted_string(
@@ -1092,8 +1073,7 @@ class StructField(DataType):
) -> None:
if maxDepth > 0:
stringConcat.append(
- f"{prefix}-- {escape_meta_characters(self.name)}: "
- + f"{DataType._get_jvm_type_name(self.dataType)} "
+ f"{prefix}-- {escape_meta_characters(self.name)}:
{self.dataType.typeName()} "
+ f"(nullable = {str(self.nullable).lower()})\n"
)
DataType._data_type_build_formatted_string(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]