This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 19fa431ef611 [SPARK-46300][PYTHON][CONNECT] Match minor behaviour
matching in Column with full test coverage
19fa431ef611 is described below
commit 19fa431ef61181bd9bfe96a74f6d977b720d281e
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Dec 7 15:50:11 2023 +0900
[SPARK-46300][PYTHON][CONNECT] Match minor behaviour matching in Column
with full test coverage
### What changes were proposed in this pull request?
This PR matches the corner case behaviours in `Column` between Spark
Connect and non-Spark Connect with adding unittests with the full test coverage
within `pyspark.sql.column`.
### Why are the changes needed?
- For feature parity.
- To improve the test coverage.
See
https://app.codecov.io/gh/apache/spark/commit/1a651753f4e760643d719add3b16acd311454c76/blob/python/pyspark/sql/column.py
This is not being tested.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Manually ran the new unittest.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44228 from HyukjinKwon/SPARK-46300.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/column.py | 16 +++++++++--
python/pyspark/sql/connect/column.py | 2 +-
python/pyspark/sql/connect/expressions.py | 5 ++++
.../sql/tests/connect/test_connect_column.py | 2 +-
python/pyspark/sql/tests/test_column.py | 32 +++++++++++++++++++++-
python/pyspark/sql/tests/test_functions.py | 14 +++++++++-
python/pyspark/sql/tests/test_types.py | 12 ++++++++
7 files changed, 76 insertions(+), 7 deletions(-)
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 9357b4842bbd..198dd9ff3e40 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -75,7 +75,7 @@ def _to_java_expr(col: "ColumnOrName") -> JavaObject:
@overload
def _to_seq(sc: SparkContext, cols: Iterable[JavaObject]) -> JavaObject:
- pass
+ ...
@overload
@@ -84,7 +84,7 @@ def _to_seq(
cols: Iterable["ColumnOrName"],
converter: Optional[Callable[["ColumnOrName"], JavaObject]],
) -> JavaObject:
- pass
+ ...
def _to_seq(
@@ -924,10 +924,20 @@ class Column:
Examples
--------
+
+ Example 1. Using integers for the input arguments.
+
>>> df = spark.createDataFrame(
... [(2, "Alice"), (5, "Bob")], ["age", "name"])
>>> df.select(df.name.substr(1, 3).alias("col")).collect()
[Row(col='Ali'), Row(col='Bob')]
+
+ Example 2. Using columns for the input arguments.
+
+ >>> df = spark.createDataFrame(
+ ... [(3, 4, "Alice"), (2, 3, "Bob")], ["sidx", "eidx", "name"])
+ >>> df.select(df.name.substr(df.sidx, df.eidx).alias("col")).collect()
+ [Row(col='ice'), Row(col='ob')]
"""
if type(startPos) != type(length):
raise PySparkTypeError(
@@ -1199,7 +1209,7 @@ class Column:
else:
return Column(getattr(self._jc, "as")(alias[0]))
else:
- if metadata:
+ if metadata is not None:
raise PySparkValueError(
error_class="ONLY_ALLOWED_FOR_SINGLE_COLUMN",
message_parameters={"arg_name": "metadata"},
diff --git a/python/pyspark/sql/connect/column.py
b/python/pyspark/sql/connect/column.py
index a6d9ca8a2ff4..13b00fd83d8b 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -256,7 +256,7 @@ class Column:
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_INT",
- message_parameters={"arg_name": "length", "arg_type":
type(length).__name__},
+ message_parameters={"arg_name": "startPos", "arg_type":
type(length).__name__},
)
return Column(UnresolvedFunction("substr", [self._expr, start_expr,
length_expr]))
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index 88c4f4d267b3..384422eed7d1 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -97,6 +97,11 @@ class Expression:
def alias(self, *alias: str, **kwargs: Any) -> "ColumnAlias":
metadata = kwargs.pop("metadata", None)
+ if len(alias) > 1 and metadata is not None:
+ raise PySparkValueError(
+ error_class="ONLY_ALLOWED_FOR_SINGLE_COLUMN",
+ message_parameters={"arg_name": "metadata"},
+ )
assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs
return ColumnAlias(self, list(alias), metadata)
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py
b/python/pyspark/sql/tests/connect/test_connect_column.py
index f9a9fa95a373..be351e133841 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -155,7 +155,7 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
exception=pe.exception,
error_class="NOT_COLUMN_OR_INT",
message_parameters={
- "arg_name": "length",
+ "arg_name": "startPos",
"arg_type": "float",
},
)
diff --git a/python/pyspark/sql/tests/test_column.py
b/python/pyspark/sql/tests/test_column.py
index 622c1f7b2104..e51ae69814bd 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -20,7 +20,7 @@ from itertools import chain
from pyspark.sql import Column, Row
from pyspark.sql import functions as sf
from pyspark.sql.types import StructType, StructField, LongType
-from pyspark.errors import AnalysisException, PySparkTypeError
+from pyspark.errors import AnalysisException, PySparkTypeError,
PySparkValueError
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -218,6 +218,36 @@ class ColumnTestsMixin:
).withColumn("square_value", mapping_expr[sf.col("key")])
self.assertEqual(df.count(), 3)
+ def test_alias_negative(self):
+ with self.assertRaises(PySparkValueError) as pe:
+ self.spark.range(1).id.alias("a", "b", metadata={})
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="ONLY_ALLOWED_FOR_SINGLE_COLUMN",
+ message_parameters={"arg_name": "metadata"},
+ )
+
+ def test_cast_negative(self):
+ with self.assertRaises(PySparkTypeError) as pe:
+ self.spark.range(1).id.cast(123)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_DATATYPE_OR_STR",
+ message_parameters={"arg_name": "dataType", "arg_type": "int"},
+ )
+
+ def test_over_negative(self):
+ with self.assertRaises(PySparkTypeError) as pe:
+ self.spark.range(1).id.over(123)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_WINDOWSPEC",
+ message_parameters={"arg_name": "window", "arg_type": "int"},
+ )
+
class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index 2bdcfa6085fd..2ac7ddbcba59 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -346,7 +346,7 @@ class FunctionsTestsMixin:
df = self.spark.createDataFrame([["nick"]], schema=["name"])
with self.assertRaises(PySparkTypeError) as pe:
- df.select(F.col("name").substr(0, F.lit(1)))
+ F.col("name").substr(0, F.lit(1))
self.check_error(
exception=pe.exception,
@@ -359,6 +359,18 @@ class FunctionsTestsMixin:
},
)
+ with self.assertRaises(PySparkTypeError) as pe:
+ F.col("name").substr("", "")
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_COLUMN_OR_INT",
+ message_parameters={
+ "arg_name": "startPos",
+ "arg_type": "str",
+ },
+ )
+
for name in string_functions:
self.assertEqual(
df.select(getattr(F, name)("name")).first()[0],
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index 06064e58c794..992abc8e82d9 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -883,6 +883,18 @@ class TypesTestsMixin:
self.assertEqual("v", df.select(df.d["k"]).first()[0])
self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
+ # Deprecated behaviors
+ map_col = F.create_map(F.lit(0), F.lit(100), F.lit(1), F.lit(200))
+ self.assertEqual(
+ self.spark.range(1).withColumn("mapped",
map_col.getItem(F.col("id"))).first()[1], 100
+ )
+
+ struct_col = F.struct(F.lit(0), F.lit(100), F.lit(1), F.lit(200))
+ self.assertEqual(
+ self.spark.range(1).withColumn("struct",
struct_col.getField(F.lit("col1"))).first()[1],
+ 0,
+ )
+
def test_infer_long_type(self):
longrow = [Row(f1="a", f2=100000000000000)]
df = self.sc.parallelize(longrow).toDF()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]