This is an automated email from the ASF dual-hosted git repository.
xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 58561178215 [SPARK-40131][PYTHON] Support NumPy ndarray in built-in
functions
58561178215 is described below
commit 5856117821532f108113c329d515b13bb7c5b8f5
Author: Xinrong Meng <[email protected]>
AuthorDate: Mon Sep 12 10:07:18 2022 -0700
[SPARK-40131][PYTHON] Support NumPy ndarray in built-in functions
### What changes were proposed in this pull request?
Support NumPy ndarray in built-in functions(`pyspark.sql.functions`) by
introducing Py4J input converter `NumpyArrayConverter`. The converter converts
a ndarray to a Java array.
The mapping between ndarray dtype with Java primitive type is defined as
below:
```py
np.dtype("int64"): gateway.jvm.long,
np.dtype("int32"): gateway.jvm.int,
np.dtype("int16"): gateway.jvm.short,
# Mapping to gateway.jvm.byte causes
# TypeError: 'bytes' object does not support item assignment
np.dtype("int8"): gateway.jvm.short,
np.dtype("float32"): gateway.jvm.float,
np.dtype("float64"): gateway.jvm.double,
np.dtype("bool"): gateway.jvm.boolean,
```
### Why are the changes needed?
As part of [SPARK-39405](https://issues.apache.org/jira/browse/SPARK-39405)
for NumPy support in SQL.
### Does this PR introduce _any_ user-facing change?
Yes. NumPy ndarray is supported in built-in functions.
Take `lit` for example,
```py
>>> spark.range(1).select(lit(np.array([1, 2], dtype='int16'))).dtypes
[('ARRAY(1S, 2S)', 'array<smallint>')]
>>> spark.range(1).select(lit(np.array([1, 2], dtype='int32'))).dtypes
[('ARRAY(1, 2)', 'array<int>')]
>>> spark.range(1).select(lit(np.array([1, 2], dtype='float32'))).dtypes
[("ARRAY(CAST('1.0' AS FLOAT), CAST('2.0' AS FLOAT))", 'array<float>')]
>>> spark.range(1).select(lit(np.array([]))).dtypes
[('ARRAY()', 'array<double>')]
```
### How was this patch tested?
Unit tests.
Closes #37635 from xinrong-meng/builtin_ndarray.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Xinrong Meng <[email protected]>
---
python/pyspark/sql/tests/test_functions.py | 23 ++++++++++++++
python/pyspark/sql/types.py | 49 +++++++++++++++++++++++++++++-
2 files changed, 71 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index c2e6d2e07f2..94047f22e10 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1034,6 +1034,29 @@ class FunctionsTests(ReusedSQLTestCase):
res = df.select(array_position(df.data,
dtype(1)).alias("c")).collect()
self.assertEqual([Row(c=1), Row(c=0)], res)
+ @unittest.skipIf(not have_numpy, "NumPy not installed")
+ def test_ndarray_input(self):
+ import numpy as np
+
+ arr_dtype_to_spark_dtypes = [
+ ("int8", [("b", "array<smallint>")]),
+ ("int16", [("b", "array<smallint>")]),
+ ("int32", [("b", "array<int>")]),
+ ("int64", [("b", "array<bigint>")]),
+ ("float32", [("b", "array<float>")]),
+ ("float64", [("b", "array<double>")]),
+ ]
+ for t, expected_spark_dtypes in arr_dtype_to_spark_dtypes:
+ arr = np.array([1, 2]).astype(t)
+ self.assertEqual(
+ expected_spark_dtypes,
self.spark.range(1).select(lit(arr).alias("b")).dtypes
+ )
+ arr = np.array([1, 2]).astype(np.uint)
+ with self.assertRaisesRegex(
+ TypeError, "The type of array scalar '%s' is not supported" %
arr.dtype
+ ):
+ self.spark.range(1).select(lit(arr).alias("b"))
+
def test_binary_math_function(self):
funcs, expected = zip(*[(atan2, 0.13664), (hypot, 8.07527), (pow,
2.14359), (pmod, 1.1)])
df = self.spark.range(1).select(*(func(1.1, 8) for func in funcs))
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index c1e6a738bc6..365c903487c 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -46,7 +46,7 @@ from typing import (
)
from py4j.protocol import register_input_converter
-from py4j.java_gateway import GatewayClient, JavaClass, JavaObject
+from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.utils import has_numpy
@@ -2268,12 +2268,59 @@ class NumpyScalarConverter:
return obj.item()
+class NumpyArrayConverter:
+ def _from_numpy_type_to_java_type(
+ self, nt: "np.dtype", gateway: JavaGateway
+ ) -> Optional[JavaClass]:
+ """Convert NumPy type to Py4J Java type."""
+ if nt in [np.dtype("int8"), np.dtype("int16")]:
+ # Mapping int8 to gateway.jvm.byte causes
+ # TypeError: 'bytes' object does not support item assignment
+ return gateway.jvm.short
+ elif nt == np.dtype("int32"):
+ return gateway.jvm.int
+ elif nt == np.dtype("int64"):
+ return gateway.jvm.long
+ elif nt == np.dtype("float32"):
+ return gateway.jvm.float
+ elif nt == np.dtype("float64"):
+ return gateway.jvm.double
+ elif nt == np.dtype("bool"):
+ return gateway.jvm.boolean
+
+ return None
+
+ def can_convert(self, obj: Any) -> bool:
+ return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1
+
+ def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) ->
JavaObject:
+ from pyspark import SparkContext
+
+ gateway = SparkContext._gateway
+ assert gateway is not None
+ plist = obj.tolist()
+
+ if len(obj) > 0 and isinstance(plist[0], str):
+ jtpe = gateway.jvm.String
+ else:
+ jtpe = self._from_numpy_type_to_java_type(obj.dtype, gateway)
+ if jtpe is None:
+ raise TypeError("The type of array scalar '%s' is not
supported" % (obj.dtype))
+ jarr = gateway.new_array(jtpe, len(obj))
+ for i in range(len(plist)):
+ jarr[i] = plist[i]
+ return jarr
+
+
# datetime is a subclass of date, we should register DatetimeConverter first
register_input_converter(DatetimeNTZConverter())
register_input_converter(DatetimeConverter())
register_input_converter(DateConverter())
register_input_converter(DayTimeIntervalTypeConverter())
register_input_converter(NumpyScalarConverter())
+# NumPy array satisfies py4j.java_collections.ListConverter,
+# so prepend NumpyArrayConverter
+register_input_converter(NumpyArrayConverter(), prepend=True)
def _test() -> None:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]