This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1af0a510202 [SPARK-41710][CONNECT][PYTHON] Implement `Column.between`
1af0a510202 is described below
commit 1af0a510202bdadfbc1ab6d04b47fe01a23f4555
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Dec 26 17:48:10 2022 +0800
[SPARK-41710][CONNECT][PYTHON] Implement `Column.between`
### What changes were proposed in this pull request?
Implement `Column.between`
### Why are the changes needed?
API coverage
### Does this PR introduce _any_ user-facing change?
yes, new API
### How was this patch tested?
added UT
Closes #39216 from zhengruifeng/connect_column_between.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/column.py | 15 +++++-
.../sql/tests/connect/test_connect_column.py | 60 +++++++++++++++++++++-
2 files changed, 72 insertions(+), 3 deletions(-)
diff --git a/python/pyspark/sql/connect/column.py
b/python/pyspark/sql/connect/column.py
index 7cef9e7cd41..918d6cd2adc 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -44,6 +44,11 @@ from pyspark.sql.connect.expressions import (
if TYPE_CHECKING:
+ from pyspark.sql.connect._typing import (
+ LiteralType,
+ DateTimeLiteral,
+ DecimalLiteral,
+ )
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.window import WindowSpec
@@ -349,8 +354,14 @@ class Column:
def getItem(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("getItem() is not yet implemented.")
- def between(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("between() is not yet implemented.")
+ def between(
+ self,
+ lowerBound: Union["Column", "LiteralType", "DateTimeLiteral",
"DecimalLiteral"],
+ upperBound: Union["Column", "LiteralType", "DateTimeLiteral",
"DecimalLiteral"],
+ ) -> "Column":
+ return (self >= lowerBound) & (self <= upperBound)
+
+ between.__doc__ = PySparkColumn.between.__doc__
def getField(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("getField() is not yet implemented.")
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py
b/python/pyspark/sql/tests/connect/test_connect_column.py
index 6c5f594b29f..e0c883a7f76 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -479,12 +479,70 @@ class SparkConnectTests(SparkConnectSQLTestCase):
sdf.select(sdf.a.isin(sdf.b, 4, 5, 6)).toPandas(),
)
+ def test_between(self):
+ query = """
+ SELECT * FROM VALUES
+ (TIMESTAMP('2022-12-22 15:50:00'), DATE('2022-12-25'), 1.1),
+ (TIMESTAMP('2022-12-22 18:50:00'), NULL, 2.2),
+ (TIMESTAMP('2022-12-23 15:50:00'), DATE('2022-12-24'), 3.3),
+ (NULL, DATE('2022-12-22'), NULL)
+ AS tab(a, b, c)
+ """
+
+ # +-------------------+----------+----+
+ # | a| b| c|
+ # +-------------------+----------+----+
+ # |2022-12-22 15:50:00|2022-12-25| 1.1|
+ # |2022-12-22 18:50:00| null| 2.2|
+ # |2022-12-23 15:50:00|2022-12-24| 3.3|
+ # | null|2022-12-22|null|
+ # +-------------------+----------+----+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ self.assert_eq(
+ cdf.select(cdf.c.between(0, 2)).toPandas(),
+ sdf.select(sdf.c.between(0, 2)).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(cdf.c.between(1.1, 2.2)).toPandas(),
+ sdf.select(sdf.c.between(1.1, 2.2)).toPandas(),
+ )
+
+ self.assert_eq(
+ cdf.select(cdf.c.between(decimal.Decimal(0),
decimal.Decimal(2))).toPandas(),
+ sdf.select(sdf.c.between(decimal.Decimal(0),
decimal.Decimal(2))).toPandas(),
+ )
+
+ self.assert_eq(
+ cdf.select(
+ cdf.a.between(
+ datetime.datetime(2022, 12, 22, 17, 0, 0),
+ datetime.datetime(2022, 12, 23, 6, 0, 0),
+ )
+ ).toPandas(),
+ sdf.select(
+ sdf.a.between(
+ datetime.datetime(2022, 12, 22, 17, 0, 0),
+ datetime.datetime(2022, 12, 23, 6, 0, 0),
+ )
+ ).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(
+ cdf.b.between(datetime.date(2022, 12, 23), datetime.date(2022,
12, 24))
+ ).toPandas(),
+ sdf.select(
+ sdf.b.between(datetime.date(2022, 12, 23), datetime.date(2022,
12, 24))
+ ).toPandas(),
+ )
+
def test_unsupported_functions(self):
# SPARK-41225: Disable unsupported functions.
c = self.connect.range(1).id
for f in (
"getItem",
- "between",
"getField",
"withField",
"dropFields",
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]