This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fde833c5326 [SPARK-40130][PYTHON] Support NumPy scalars in built-in functions fde833c5326 is described below commit fde833c532630092204dc54299702676e1de8b74 Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Thu Aug 25 11:30:18 2022 -0700 [SPARK-40130][PYTHON] Support NumPy scalars in built-in functions ### What changes were proposed in this pull request? Support NumPy scalars in built-in functions by introducing Py4J input converter `NumpyScalarConverter`. Specifically, - `np.int8, np.int16, np.int32, np.int64` are mapped to Spark `int/bigint`. - `np.float32, np.float64` are mapped to Spark `double`. Note that 2147483648 is the boundary between Spark `int` and `bigint`: ```py >>> df.select(lit(np.int64(max_int + 1))).dtypes [('2147483648', 'bigint')] >>> df.select(lit(np.int64(max_int))).dtypes [('2147483647', 'int')] ``` ### Why are the changes needed? As part of [SPARK-39405](https://issues.apache.org/jira/browse/SPARK-39405) for NumPy support in SQL. ### Does this PR introduce _any_ user-facing change? Yes. NumPy scalars are supported in built-in functions when input parameter accepts scalars; Influenced functions include `lit`, `when`, `array_contains`, `array_position`, `element_at`, `array_remove`. Take `lit` for example, ```py >>> df.select(lit(np.int8(1))).dtypes [('1', 'int')] >>> df.select(lit(np.float32(1))).dtypes [('1.0', 'double')] ``` ### How was this patch tested? Unit tests. Closes #37560 from xinrong-meng/builtin_np. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Xinrong Meng <xinr...@apache.org> --- python/pyspark/sql/tests/test_functions.py | 29 +++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 13 +++++++++++++ python/pyspark/sql/utils.py | 9 +++++++++ 3 files changed, 51 insertions(+) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 71c6bc33dbb..102ebef8317 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -20,6 +20,7 @@ from inspect import getmembers, isfunction from itertools import chain import re import math +import unittest from py4j.protocol import Py4JJavaError from pyspark.sql import Row, Window, types @@ -55,6 +56,7 @@ from pyspark.sql.functions import ( ) from pyspark.sql import functions from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils +from pyspark.testing.utils import have_numpy class FunctionsTests(ReusedSQLTestCase): @@ -974,6 +976,33 @@ class FunctionsTests(ReusedSQLTestCase): ) ) + @unittest.skipIf(not have_numpy, "NumPy not installed") + def test_np_scalar_input(self): + import numpy as np + from pyspark.sql.functions import array_contains, array_position + + df = self.spark.createDataFrame([([1, 2, 3],), ([],)], ["data"]) + for dtype in [np.int8, np.int16, np.int32, np.int64]: + self.assertEqual(df.select(lit(dtype(1))).dtypes, [("1", "int")]) + res = df.select(array_contains(df.data, dtype(1)).alias("b")).collect() + self.assertEqual([Row(b=True), Row(b=False)], res) + res = df.select(array_position(df.data, dtype(1)).alias("c")).collect() + self.assertEqual([Row(c=1), Row(c=0)], res) + + # java.lang.Integer max: 2147483647 + max_int = 2147483647 + # Convert int to bigint automatically + self.assertEqual(df.select(lit(np.int32(max_int))).dtypes, [("2147483647", "int")]) + self.assertEqual(df.select(lit(np.int64(max_int + 1))).dtypes, [("2147483648", "bigint")]) + + df = self.spark.createDataFrame([([1.0, 2.0, 3.0],), ([],)], ["data"]) + for dtype in [np.float32, np.float64]: + self.assertEqual(df.select(lit(dtype(1))).dtypes, [("1.0", "double")]) + res = df.select(array_contains(df.data, dtype(1)).alias("b")).collect() + self.assertEqual([Row(b=True), Row(b=False)], res) + res = df.select(array_position(df.data, dtype(1)).alias("c")).collect() + self.assertEqual([Row(c=1), Row(c=0)], res) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 8255aca8f52..c1e6a738bc6 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -49,6 +49,10 @@ from py4j.protocol import register_input_converter from py4j.java_gateway import GatewayClient, JavaClass, JavaObject from pyspark.serializers import CloudPickleSerializer +from pyspark.sql.utils import has_numpy + +if has_numpy: + import numpy as np T = TypeVar("T") U = TypeVar("U") @@ -2256,11 +2260,20 @@ class DayTimeIntervalTypeConverter: ) +class NumpyScalarConverter: + def can_convert(self, obj: Any) -> bool: + return has_numpy and isinstance(obj, np.generic) + + def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any: + return obj.item() + + # datetime is a subclass of date, we should register DatetimeConverter first register_input_converter(DatetimeNTZConverter()) register_input_converter(DatetimeConverter()) register_input_converter(DateConverter()) register_input_converter(DayTimeIntervalTypeConverter()) +register_input_converter(NumpyScalarConverter()) def _test() -> None: diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index e4a0299164e..2ff13cd2bba 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -29,6 +29,15 @@ from py4j.protocol import Py4JJavaError from pyspark import SparkContext from pyspark.find_spark_home import _find_spark_home +has_numpy = False +try: + import numpy as np # noqa: F401 + + has_numpy = True +except ImportError: + pass + + if TYPE_CHECKING: from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org