This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 47068db3568 [SPARK-41903][CONNECT][PYTHON] Literal` should support 
1-dim ndarray
47068db3568 is described below

commit 47068db3568212618548898f8958ba5c09f07ffe
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 16 09:50:21 2023 +0900

    [SPARK-41903][CONNECT][PYTHON] Literal` should support 1-dim ndarray
    
    ### What changes were proposed in this pull request?
    `Literal` should support 1-dim ndarray
    
    ### Why are the changes needed?
    parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    enabled UT
    
    Closes #39570 from zhengruifeng/connect_lit_ndaray.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/functions.py                   | 15 ++++++++++++++-
 python/pyspark/sql/tests/connect/test_parity_functions.py |  5 -----
 2 files changed, 14 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index e1286f7d66e..045b1366fc5 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -31,6 +31,8 @@ from typing import (
     cast,
 )
 
+import numpy as np
+
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import (
     CaseWhen,
@@ -42,7 +44,7 @@ from pyspark.sql.connect.expressions import (
     LambdaFunction,
 )
 from pyspark.sql import functions as pysparkfuncs
-from pyspark.sql.types import DataType, StructType, ArrayType
+from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import ColumnOrName
@@ -192,6 +194,17 @@ def lit(col: Any) -> Column:
     if isinstance(col, Column):
         return col
     elif isinstance(col, list):
+        return array(*[lit(c) for c in col])
+    elif isinstance(col, np.ndarray) and col.ndim == 1:
+        if _from_numpy_type(col.dtype) is None:
+            raise TypeError("The type of array scalar '%s' is not supported" % 
(col.dtype))
+
+        # NumpyArrayConverter for Py4J can not support ndarray with int8 
values.
+        # Actually this is not a problem for Connect, but here still convert it
+        # to int16 for compatibility.
+        if col.dtype == np.int8:
+            col = col.astype(np.int16)
+
         return array(*[lit(c) for c in col])
     else:
         return Column(LiteralExpression._from_value(col))
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py 
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index 51422d555d2..d88bf67614b 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -64,11 +64,6 @@ class FunctionsParityTests(FunctionsTestsMixin, 
ReusedConnectTestCase):
     def test_map_functions(self):
         super().test_map_functions()
 
-    # TODO(SPARK-41903): Support data type ndarray
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_ndarray_input(self):
-        super().test_ndarray_input()
-
     # TODO(SPARK-41902): Parity in String representation of 
higher_order_function's output
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_nested_higher_order_function(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to