This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 541f0f9c374 [SPARK-42047][SPARK-41900][CONNECT][PYTHON] Literal should 
support Numpy datatypes
541f0f9c374 is described below

commit 541f0f9c3747fa592120c7e8c52b957d179f136b
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Jan 14 11:12:37 2023 +0800

    [SPARK-42047][SPARK-41900][CONNECT][PYTHON] Literal should support Numpy 
datatypes
    
    ### What changes were proposed in this pull request?
    Make `Literal` support numpy datatypes
    
    ### Why are the changes needed?
    for parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    enabled tests
    
    Closes #39551 from zhengruifeng/connect_fix_41900.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/expressions.py          | 32 +++++++++++++++++-----
 .../sql/tests/connect/test_parity_functions.py     |  7 +----
 2 files changed, 26 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 27397fc0c13..5620b4e2cc6 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -29,7 +29,10 @@ import decimal
 import datetime
 import warnings
 
+import numpy as np
+
 from pyspark.sql.types import (
+    _from_numpy_type,
     DateType,
     NullType,
     BooleanType,
@@ -192,19 +195,30 @@ class LiteralExpression(Expression):
             if isinstance(dataType, BinaryType):
                 assert isinstance(value, (bytes, bytearray))
             elif isinstance(dataType, BooleanType):
-                assert isinstance(value, bool)
+                assert isinstance(value, (bool, np.bool_))
+                value = bool(value)
             elif isinstance(dataType, ByteType):
-                assert isinstance(value, int) and JVM_BYTE_MIN <= int(value) 
<= JVM_BYTE_MAX
+                assert isinstance(value, (int, np.int8))
+                assert JVM_BYTE_MIN <= int(value) <= JVM_BYTE_MAX
+                value = int(value)
             elif isinstance(dataType, ShortType):
-                assert isinstance(value, int) and JVM_SHORT_MIN <= int(value) 
<= JVM_SHORT_MAX
+                assert isinstance(value, (int, np.int8, np.int16))
+                assert JVM_SHORT_MIN <= int(value) <= JVM_SHORT_MAX
+                value = int(value)
             elif isinstance(dataType, IntegerType):
-                assert isinstance(value, int) and JVM_INT_MIN <= int(value) <= 
JVM_INT_MAX
+                assert isinstance(value, (int, np.int8, np.int16, np.int32))
+                assert JVM_INT_MIN <= int(value) <= JVM_INT_MAX
+                value = int(value)
             elif isinstance(dataType, LongType):
-                assert isinstance(value, int) and JVM_LONG_MIN <= int(value) 
<= JVM_LONG_MAX
+                assert isinstance(value, (int, np.int8, np.int16, np.int32, 
np.int64))
+                assert JVM_LONG_MIN <= int(value) <= JVM_LONG_MAX
+                value = int(value)
             elif isinstance(dataType, FloatType):
-                assert isinstance(value, float)
+                assert isinstance(value, (float, np.float32))
+                value = float(value)
             elif isinstance(dataType, DoubleType):
-                assert isinstance(value, float)
+                assert isinstance(value, (float, np.float32, np.float64))
+                value = float(value)
             elif isinstance(dataType, DecimalType):
                 assert isinstance(value, decimal.Decimal)
             elif isinstance(dataType, StringType):
@@ -259,6 +273,10 @@ class LiteralExpression(Expression):
         elif isinstance(value, datetime.timedelta):
             return DayTimeIntervalType()
         else:
+            if isinstance(value, np.generic):
+                dt = _from_numpy_type(value.dtype)
+                if dt is not None:
+                    return dt
             raise ValueError(f"Unsupported Data Type {type(value).__name__}")
 
     @classmethod
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py 
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index dd7229d158f..e763352e936 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -59,7 +59,7 @@ class FunctionsParityTests(FunctionsTestsMixin, 
ReusedConnectTestCase):
     def test_lit_list(self):
         super().test_lit_list()
 
-    # TODO(SPARK-41900): support Data Type int8
+    # TODO(SPARK-41283): Different column names of `lit(np.int8(1))`
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_lit_np_scalar(self):
         super().test_lit_np_scalar()
@@ -79,11 +79,6 @@ class FunctionsParityTests(FunctionsTestsMixin, 
ReusedConnectTestCase):
     def test_nested_higher_order_function(self):
         super().test_nested_higher_order_function()
 
-    # TODO(SPARK-41900): support Data Type int8
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_np_scalar_input(self):
-        super().test_np_scalar_input()
-
     # TODO(SPARK-41901): Parity in String representation of Column
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_overlay(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to