This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 58561178215 [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions 58561178215 is described below commit 5856117821532f108113c329d515b13bb7c5b8f5 Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Mon Sep 12 10:07:18 2022 -0700 [SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions ### What changes were proposed in this pull request? Support NumPy ndarray in built-in functions(`pyspark.sql.functions`) by introducing Py4J input converter `NumpyArrayConverter`. The converter converts a ndarray to a Java array. The mapping between ndarray dtype with Java primitive type is defined as below: ```py np.dtype("int64"): gateway.jvm.long, np.dtype("int32"): gateway.jvm.int, np.dtype("int16"): gateway.jvm.short, # Mapping to gateway.jvm.byte causes # TypeError: 'bytes' object does not support item assignment np.dtype("int8"): gateway.jvm.short, np.dtype("float32"): gateway.jvm.float, np.dtype("float64"): gateway.jvm.double, np.dtype("bool"): gateway.jvm.boolean, ``` ### Why are the changes needed? As part of [SPARK-39405](https://issues.apache.org/jira/browse/SPARK-39405) for NumPy support in SQL. ### Does this PR introduce _any_ user-facing change? Yes. NumPy ndarray is supported in built-in functions. Take `lit` for example, ```py >>> spark.range(1).select(lit(np.array([1, 2], dtype='int16'))).dtypes [('ARRAY(1S, 2S)', 'array<smallint>')] >>> spark.range(1).select(lit(np.array([1, 2], dtype='int32'))).dtypes [('ARRAY(1, 2)', 'array<int>')] >>> spark.range(1).select(lit(np.array([1, 2], dtype='float32'))).dtypes [("ARRAY(CAST('1.0' AS FLOAT), CAST('2.0' AS FLOAT))", 'array<float>')] >>> spark.range(1).select(lit(np.array([]))).dtypes [('ARRAY()', 'array<double>')] ``` ### How was this patch tested? Unit tests. Closes #37635 from xinrong-meng/builtin_ndarray. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Xinrong Meng <xinr...@apache.org> --- python/pyspark/sql/tests/test_functions.py | 23 ++++++++++++++ python/pyspark/sql/types.py | 49 +++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index c2e6d2e07f2..94047f22e10 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1034,6 +1034,29 @@ class FunctionsTests(ReusedSQLTestCase): res = df.select(array_position(df.data, dtype(1)).alias("c")).collect() self.assertEqual([Row(c=1), Row(c=0)], res) + @unittest.skipIf(not have_numpy, "NumPy not installed") + def test_ndarray_input(self): + import numpy as np + + arr_dtype_to_spark_dtypes = [ + ("int8", [("b", "array<smallint>")]), + ("int16", [("b", "array<smallint>")]), + ("int32", [("b", "array<int>")]), + ("int64", [("b", "array<bigint>")]), + ("float32", [("b", "array<float>")]), + ("float64", [("b", "array<double>")]), + ] + for t, expected_spark_dtypes in arr_dtype_to_spark_dtypes: + arr = np.array([1, 2]).astype(t) + self.assertEqual( + expected_spark_dtypes, self.spark.range(1).select(lit(arr).alias("b")).dtypes + ) + arr = np.array([1, 2]).astype(np.uint) + with self.assertRaisesRegex( + TypeError, "The type of array scalar '%s' is not supported" % arr.dtype + ): + self.spark.range(1).select(lit(arr).alias("b")) + def test_binary_math_function(self): funcs, expected = zip(*[(atan2, 0.13664), (hypot, 8.07527), (pow, 2.14359), (pmod, 1.1)]) df = self.spark.range(1).select(*(func(1.1, 8) for func in funcs)) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c1e6a738bc6..365c903487c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -46,7 +46,7 @@ from typing import ( ) from py4j.protocol import register_input_converter -from py4j.java_gateway import GatewayClient, JavaClass, JavaObject +from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject from pyspark.serializers import CloudPickleSerializer from pyspark.sql.utils import has_numpy @@ -2268,12 +2268,59 @@ class NumpyScalarConverter: return obj.item() +class NumpyArrayConverter: + def _from_numpy_type_to_java_type( + self, nt: "np.dtype", gateway: JavaGateway + ) -> Optional[JavaClass]: + """Convert NumPy type to Py4J Java type.""" + if nt in [np.dtype("int8"), np.dtype("int16")]: + # Mapping int8 to gateway.jvm.byte causes + # TypeError: 'bytes' object does not support item assignment + return gateway.jvm.short + elif nt == np.dtype("int32"): + return gateway.jvm.int + elif nt == np.dtype("int64"): + return gateway.jvm.long + elif nt == np.dtype("float32"): + return gateway.jvm.float + elif nt == np.dtype("float64"): + return gateway.jvm.double + elif nt == np.dtype("bool"): + return gateway.jvm.boolean + + return None + + def can_convert(self, obj: Any) -> bool: + return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1 + + def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject: + from pyspark import SparkContext + + gateway = SparkContext._gateway + assert gateway is not None + plist = obj.tolist() + + if len(obj) > 0 and isinstance(plist[0], str): + jtpe = gateway.jvm.String + else: + jtpe = self._from_numpy_type_to_java_type(obj.dtype, gateway) + if jtpe is None: + raise TypeError("The type of array scalar '%s' is not supported" % (obj.dtype)) + jarr = gateway.new_array(jtpe, len(obj)) + for i in range(len(plist)): + jarr[i] = plist[i] + return jarr + + # datetime is a subclass of date, we should register DatetimeConverter first register_input_converter(DatetimeNTZConverter()) register_input_converter(DatetimeConverter()) register_input_converter(DateConverter()) register_input_converter(DayTimeIntervalTypeConverter()) register_input_converter(NumpyScalarConverter()) +# NumPy array satisfies py4j.java_collections.ListConverter, +# so prepend NumpyArrayConverter +register_input_converter(NumpyArrayConverter(), prepend=True) def _test() -> None: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org