This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new db9dbd90d8e [SPARK-42048][PYTHON][CONNECT] Fix the alias name for
numpy literals
db9dbd90d8e is described below
commit db9dbd90d8edff222636bebf25df2fb96adef534
Author: Takuya UESHIN <[email protected]>
AuthorDate: Mon Feb 20 09:08:13 2023 +0900
[SPARK-42048][PYTHON][CONNECT] Fix the alias name for numpy literals
### What changes were proposed in this pull request?
Fixes the alias name for numpy literals.
Also fixes `F.lit` in Spark Connect to support `np.bool_` objects.
### Why are the changes needed?
Currently the alias name for literals created from numpy scalars contains
something like `CAST(` ... `AS <type>)`, but it should be removed and return
only the value string as same as literals from Python numbers.
### Does this PR introduce _any_ user-facing change?
The alias name will be changed.
### How was this patch tested?
Modifed/enabled related tests.
Closes #40076 from ueshin/issues/SPARK-42048/lit.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/expressions.py | 2 ++
python/pyspark/sql/functions.py | 2 +-
python/pyspark/sql/tests/connect/test_parity_functions.py | 5 -----
python/pyspark/sql/tests/test_functions.py | 15 ++++++++-------
4 files changed, 11 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index 876748d06d8..76e4252dce7 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -281,6 +281,8 @@ class LiteralExpression(Expression):
dt = _from_numpy_type(value.dtype)
if dt is not None:
return dt
+ elif isinstance(value, np.bool_):
+ return BooleanType()
raise TypeError(f"Unsupported Data Type {type(value).__name__}")
@classmethod
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b103af72e36..d296075fb0b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -181,7 +181,7 @@ def lit(col: Any) -> Column:
if has_numpy and isinstance(col, np.generic):
dt = _from_numpy_type(col.dtype)
if dt is not None:
- return _invoke_function("lit", col).astype(dt)
+ return _invoke_function("lit", col).astype(dt).alias(str(col))
return _invoke_function("lit", col)
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index 1ea33d2e370..a69e47effe4 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -48,11 +48,6 @@ class FunctionsParityTests(FunctionsTestsMixin,
ReusedConnectTestCase):
def test_lit_list(self):
super().test_lit_list()
- # TODO(SPARK-41283): Different column names of `lit(np.int8(1))`
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_lit_np_scalar(self):
- super().test_lit_np_scalar()
-
def test_raise_error(self):
self.check_raise_error(SparkConnectException)
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index d8343b4fb47..8bc2b96cc51 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1184,16 +1184,17 @@ class FunctionsTestsMixin:
from pyspark.sql.functions import lit
dtype_to_spark_dtypes = [
- (np.int8, [("CAST(1 AS TINYINT)", "tinyint")]),
- (np.int16, [("CAST(1 AS SMALLINT)", "smallint")]),
- (np.int32, [("CAST(1 AS INT)", "int")]),
- (np.int64, [("CAST(1 AS BIGINT)", "bigint")]),
- (np.float32, [("CAST(1.0 AS FLOAT)", "float")]),
- (np.float64, [("CAST(1.0 AS DOUBLE)", "double")]),
+ (np.int8, [("1", "tinyint")]),
+ (np.int16, [("1", "smallint")]),
+ (np.int32, [("1", "int")]),
+ (np.int64, [("1", "bigint")]),
+ (np.float32, [("1.0", "float")]),
+ (np.float64, [("1.0", "double")]),
(np.bool_, [("true", "boolean")]),
]
for dtype, spark_dtypes in dtype_to_spark_dtypes:
- self.assertEqual(self.spark.range(1).select(lit(dtype(1))).dtypes,
spark_dtypes)
+ with self.subTest(dtype):
+
self.assertEqual(self.spark.range(1).select(lit(dtype(1))).dtypes, spark_dtypes)
@unittest.skipIf(not have_numpy, "NumPy not installed")
def test_np_scalar_input(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]