This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 47068db3568 [SPARK-41903][CONNECT][PYTHON] Literal` should support
1-dim ndarray
47068db3568 is described below
commit 47068db3568212618548898f8958ba5c09f07ffe
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 16 09:50:21 2023 +0900
[SPARK-41903][CONNECT][PYTHON] Literal` should support 1-dim ndarray
### What changes were proposed in this pull request?
`Literal` should support 1-dim ndarray
### Why are the changes needed?
parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
enabled UT
Closes #39570 from zhengruifeng/connect_lit_ndaray.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/functions.py | 15 ++++++++++++++-
python/pyspark/sql/tests/connect/test_parity_functions.py | 5 -----
2 files changed, 14 insertions(+), 6 deletions(-)
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index e1286f7d66e..045b1366fc5 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -31,6 +31,8 @@ from typing import (
cast,
)
+import numpy as np
+
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import (
CaseWhen,
@@ -42,7 +44,7 @@ from pyspark.sql.connect.expressions import (
LambdaFunction,
)
from pyspark.sql import functions as pysparkfuncs
-from pyspark.sql.types import DataType, StructType, ArrayType
+from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType
if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
@@ -192,6 +194,17 @@ def lit(col: Any) -> Column:
if isinstance(col, Column):
return col
elif isinstance(col, list):
+ return array(*[lit(c) for c in col])
+ elif isinstance(col, np.ndarray) and col.ndim == 1:
+ if _from_numpy_type(col.dtype) is None:
+ raise TypeError("The type of array scalar '%s' is not supported" %
(col.dtype))
+
+ # NumpyArrayConverter for Py4J can not support ndarray with int8
values.
+ # Actually this is not a problem for Connect, but here still convert it
+ # to int16 for compatibility.
+ if col.dtype == np.int8:
+ col = col.astype(np.int16)
+
return array(*[lit(c) for c in col])
else:
return Column(LiteralExpression._from_value(col))
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index 51422d555d2..d88bf67614b 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -64,11 +64,6 @@ class FunctionsParityTests(FunctionsTestsMixin,
ReusedConnectTestCase):
def test_map_functions(self):
super().test_map_functions()
- # TODO(SPARK-41903): Support data type ndarray
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_ndarray_input(self):
- super().test_ndarray_input()
-
# TODO(SPARK-41902): Parity in String representation of
higher_order_function's output
@unittest.skip("Fails in Spark Connect, should enable.")
def test_nested_higher_order_function(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]