This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 541f0f9c374 [SPARK-42047][SPARK-41900][CONNECT][PYTHON] Literal should
support Numpy datatypes
541f0f9c374 is described below
commit 541f0f9c3747fa592120c7e8c52b957d179f136b
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Jan 14 11:12:37 2023 +0800
[SPARK-42047][SPARK-41900][CONNECT][PYTHON] Literal should support Numpy
datatypes
### What changes were proposed in this pull request?
Make `Literal` support numpy datatypes
### Why are the changes needed?
for parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
enabled tests
Closes #39551 from zhengruifeng/connect_fix_41900.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/expressions.py | 32 +++++++++++++++++-----
.../sql/tests/connect/test_parity_functions.py | 7 +----
2 files changed, 26 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index 27397fc0c13..5620b4e2cc6 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -29,7 +29,10 @@ import decimal
import datetime
import warnings
+import numpy as np
+
from pyspark.sql.types import (
+ _from_numpy_type,
DateType,
NullType,
BooleanType,
@@ -192,19 +195,30 @@ class LiteralExpression(Expression):
if isinstance(dataType, BinaryType):
assert isinstance(value, (bytes, bytearray))
elif isinstance(dataType, BooleanType):
- assert isinstance(value, bool)
+ assert isinstance(value, (bool, np.bool_))
+ value = bool(value)
elif isinstance(dataType, ByteType):
- assert isinstance(value, int) and JVM_BYTE_MIN <= int(value)
<= JVM_BYTE_MAX
+ assert isinstance(value, (int, np.int8))
+ assert JVM_BYTE_MIN <= int(value) <= JVM_BYTE_MAX
+ value = int(value)
elif isinstance(dataType, ShortType):
- assert isinstance(value, int) and JVM_SHORT_MIN <= int(value)
<= JVM_SHORT_MAX
+ assert isinstance(value, (int, np.int8, np.int16))
+ assert JVM_SHORT_MIN <= int(value) <= JVM_SHORT_MAX
+ value = int(value)
elif isinstance(dataType, IntegerType):
- assert isinstance(value, int) and JVM_INT_MIN <= int(value) <=
JVM_INT_MAX
+ assert isinstance(value, (int, np.int8, np.int16, np.int32))
+ assert JVM_INT_MIN <= int(value) <= JVM_INT_MAX
+ value = int(value)
elif isinstance(dataType, LongType):
- assert isinstance(value, int) and JVM_LONG_MIN <= int(value)
<= JVM_LONG_MAX
+ assert isinstance(value, (int, np.int8, np.int16, np.int32,
np.int64))
+ assert JVM_LONG_MIN <= int(value) <= JVM_LONG_MAX
+ value = int(value)
elif isinstance(dataType, FloatType):
- assert isinstance(value, float)
+ assert isinstance(value, (float, np.float32))
+ value = float(value)
elif isinstance(dataType, DoubleType):
- assert isinstance(value, float)
+ assert isinstance(value, (float, np.float32, np.float64))
+ value = float(value)
elif isinstance(dataType, DecimalType):
assert isinstance(value, decimal.Decimal)
elif isinstance(dataType, StringType):
@@ -259,6 +273,10 @@ class LiteralExpression(Expression):
elif isinstance(value, datetime.timedelta):
return DayTimeIntervalType()
else:
+ if isinstance(value, np.generic):
+ dt = _from_numpy_type(value.dtype)
+ if dt is not None:
+ return dt
raise ValueError(f"Unsupported Data Type {type(value).__name__}")
@classmethod
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index dd7229d158f..e763352e936 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -59,7 +59,7 @@ class FunctionsParityTests(FunctionsTestsMixin,
ReusedConnectTestCase):
def test_lit_list(self):
super().test_lit_list()
- # TODO(SPARK-41900): support Data Type int8
+ # TODO(SPARK-41283): Different column names of `lit(np.int8(1))`
@unittest.skip("Fails in Spark Connect, should enable.")
def test_lit_np_scalar(self):
super().test_lit_np_scalar()
@@ -79,11 +79,6 @@ class FunctionsParityTests(FunctionsTestsMixin,
ReusedConnectTestCase):
def test_nested_higher_order_function(self):
super().test_nested_higher_order_function()
- # TODO(SPARK-41900): support Data Type int8
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_np_scalar_input(self):
- super().test_np_scalar_input()
-
# TODO(SPARK-41901): Parity in String representation of Column
@unittest.skip("Fails in Spark Connect, should enable.")
def test_overlay(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]