This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 58561178215 [SPARK-40131][PYTHON] Support NumPy ndarray in built-in 
functions
58561178215 is described below

commit 5856117821532f108113c329d515b13bb7c5b8f5
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Mon Sep 12 10:07:18 2022 -0700

    [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions
    
    ### What changes were proposed in this pull request?
    Support NumPy ndarray in built-in functions(`pyspark.sql.functions`) by 
introducing Py4J input converter `NumpyArrayConverter`. The converter converts 
a ndarray to a Java array.
    
    The mapping between ndarray dtype with Java primitive type is defined as 
below:
    ```py
                np.dtype("int64"): gateway.jvm.long,
                np.dtype("int32"): gateway.jvm.int,
                np.dtype("int16"): gateway.jvm.short,
                # Mapping to gateway.jvm.byte causes
                #   TypeError: 'bytes' object does not support item assignment
                np.dtype("int8"): gateway.jvm.short,
                np.dtype("float32"): gateway.jvm.float,
                np.dtype("float64"): gateway.jvm.double,
                np.dtype("bool"): gateway.jvm.boolean,
    ```
    
    ### Why are the changes needed?
    As part of [SPARK-39405](https://issues.apache.org/jira/browse/SPARK-39405) 
for NumPy support in SQL.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. NumPy ndarray is supported in built-in functions.
    
    Take `lit` for example,
    ```py
    >>> spark.range(1).select(lit(np.array([1, 2], dtype='int16'))).dtypes
    [('ARRAY(1S, 2S)', 'array<smallint>')]
    >>> spark.range(1).select(lit(np.array([1, 2], dtype='int32'))).dtypes
    [('ARRAY(1, 2)', 'array<int>')]
    >>> spark.range(1).select(lit(np.array([1, 2], dtype='float32'))).dtypes
    [("ARRAY(CAST('1.0' AS FLOAT), CAST('2.0' AS FLOAT))", 'array<float>')]
    >>> spark.range(1).select(lit(np.array([]))).dtypes
    [('ARRAY()', 'array<double>')]
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #37635 from xinrong-meng/builtin_ndarray.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
---
 python/pyspark/sql/tests/test_functions.py | 23 ++++++++++++++
 python/pyspark/sql/types.py                | 49 +++++++++++++++++++++++++++++-
 2 files changed, 71 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index c2e6d2e07f2..94047f22e10 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1034,6 +1034,29 @@ class FunctionsTests(ReusedSQLTestCase):
             res = df.select(array_position(df.data, 
dtype(1)).alias("c")).collect()
             self.assertEqual([Row(c=1), Row(c=0)], res)
 
+    @unittest.skipIf(not have_numpy, "NumPy not installed")
+    def test_ndarray_input(self):
+        import numpy as np
+
+        arr_dtype_to_spark_dtypes = [
+            ("int8", [("b", "array<smallint>")]),
+            ("int16", [("b", "array<smallint>")]),
+            ("int32", [("b", "array<int>")]),
+            ("int64", [("b", "array<bigint>")]),
+            ("float32", [("b", "array<float>")]),
+            ("float64", [("b", "array<double>")]),
+        ]
+        for t, expected_spark_dtypes in arr_dtype_to_spark_dtypes:
+            arr = np.array([1, 2]).astype(t)
+            self.assertEqual(
+                expected_spark_dtypes, 
self.spark.range(1).select(lit(arr).alias("b")).dtypes
+            )
+        arr = np.array([1, 2]).astype(np.uint)
+        with self.assertRaisesRegex(
+            TypeError, "The type of array scalar '%s' is not supported" % 
arr.dtype
+        ):
+            self.spark.range(1).select(lit(arr).alias("b"))
+
     def test_binary_math_function(self):
         funcs, expected = zip(*[(atan2, 0.13664), (hypot, 8.07527), (pow, 
2.14359), (pmod, 1.1)])
         df = self.spark.range(1).select(*(func(1.1, 8) for func in funcs))
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index c1e6a738bc6..365c903487c 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -46,7 +46,7 @@ from typing import (
 )
 
 from py4j.protocol import register_input_converter
-from py4j.java_gateway import GatewayClient, JavaClass, JavaObject
+from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject
 
 from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.utils import has_numpy
@@ -2268,12 +2268,59 @@ class NumpyScalarConverter:
         return obj.item()
 
 
+class NumpyArrayConverter:
+    def _from_numpy_type_to_java_type(
+        self, nt: "np.dtype", gateway: JavaGateway
+    ) -> Optional[JavaClass]:
+        """Convert NumPy type to Py4J Java type."""
+        if nt in [np.dtype("int8"), np.dtype("int16")]:
+            # Mapping int8 to gateway.jvm.byte causes
+            #   TypeError: 'bytes' object does not support item assignment
+            return gateway.jvm.short
+        elif nt == np.dtype("int32"):
+            return gateway.jvm.int
+        elif nt == np.dtype("int64"):
+            return gateway.jvm.long
+        elif nt == np.dtype("float32"):
+            return gateway.jvm.float
+        elif nt == np.dtype("float64"):
+            return gateway.jvm.double
+        elif nt == np.dtype("bool"):
+            return gateway.jvm.boolean
+
+        return None
+
+    def can_convert(self, obj: Any) -> bool:
+        return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1
+
+    def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> 
JavaObject:
+        from pyspark import SparkContext
+
+        gateway = SparkContext._gateway
+        assert gateway is not None
+        plist = obj.tolist()
+
+        if len(obj) > 0 and isinstance(plist[0], str):
+            jtpe = gateway.jvm.String
+        else:
+            jtpe = self._from_numpy_type_to_java_type(obj.dtype, gateway)
+            if jtpe is None:
+                raise TypeError("The type of array scalar '%s' is not 
supported" % (obj.dtype))
+        jarr = gateway.new_array(jtpe, len(obj))
+        for i in range(len(plist)):
+            jarr[i] = plist[i]
+        return jarr
+
+
 # datetime is a subclass of date, we should register DatetimeConverter first
 register_input_converter(DatetimeNTZConverter())
 register_input_converter(DatetimeConverter())
 register_input_converter(DateConverter())
 register_input_converter(DayTimeIntervalTypeConverter())
 register_input_converter(NumpyScalarConverter())
+# NumPy array satisfies py4j.java_collections.ListConverter,
+# so prepend NumpyArrayConverter
+register_input_converter(NumpyArrayConverter(), prepend=True)
 
 
 def _test() -> None:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to