This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 40e6592e02c [SPARK-41328][CONNECT][PYTHON] Add logical and string API to Column 40e6592e02c is described below commit 40e6592e02cbe679daec9e302e1027ffc64e7323 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Wed Nov 30 13:13:57 2022 +0900 [SPARK-41328][CONNECT][PYTHON] Add logical and string API to Column ### What changes were proposed in this pull request? 1. Upgrade `_typing.py` to use `Column`. 2. Add logical operators (and, or, etc.) and strings (like, substr, etc.) to `Column`. 3. Add basic tests for new API. ### Why are the changes needed? Improve API coverage ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT Closes #38844 from amaliujia/refactor_column_back_up_2. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/_typing.py | 4 +- python/pyspark/sql/connect/column.py | 339 ++++++++++++++++++++- python/pyspark/sql/connect/function_builder.py | 6 +- .../sql/tests/connect/test_connect_basic.py | 31 +- 4 files changed, 367 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index 8629d1c23cc..e5ade4cfcbe 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -26,7 +26,7 @@ from typing import Union, Optional import datetime import decimal -from pyspark.sql.connect.column import ScalarFunctionExpression, Column +from pyspark.sql.connect.column import Column ColumnOrName = Union[Column, str] @@ -42,7 +42,7 @@ DateTimeLiteral = Union[datetime.datetime, datetime.date] class FunctionBuilderCallable(Protocol): - def __call__(self, *_: ColumnOrName) -> ScalarFunctionExpression: + def __call__(self, *_: ColumnOrName) -> Column: ... diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 83e8b28da0f..c53d2c90bf6 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import get_args, TYPE_CHECKING, Callable, Any, Union +from typing import get_args, TYPE_CHECKING, Callable, Any, Union, overload import json import decimal @@ -29,6 +29,17 @@ if TYPE_CHECKING: from pyspark.sql.connect.client import SparkConnectClient import pyspark.sql.connect.proto as proto +# TODO(SPARK-41329): solve the circular import between _typing and this class +# if we want to reuse _type.PrimitiveType +PrimitiveType = Union[bool, float, int, str] + + +def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]: + def _(self: "Column") -> "Column": + return scalar_function(name, self) + + return _ + def _bin_op( name: str, doc: str = "binary function", reverse: bool = False @@ -219,6 +230,8 @@ class LiteralExpression(Expression): else: pair.value.CopyFrom(lit(value).to_plan(session).literal) expr.literal.map.pairs.append(pair) + elif isinstance(self._value, Column): + expr.CopyFrom(self._value.to_plan(session)) else: raise ValueError(f"Could not convert literal for type {type(self._value)}") @@ -352,17 +365,326 @@ class Column(object): __rpow__ = _bin_op("pow", reverse=True) __ge__ = _bin_op(">=") __le__ = _bin_op("<=") - # __eq__ = _bin_op("==") # ignore [assignment] + + _eqNullSafe_doc = """ + Equality test that is safe for null values. + + Parameters + ---------- + other + a value or :class:`Column` + + Examples + -------- + >>> from pyspark.sql import Row + >>> df1 = spark.createDataFrame([ + ... Row(id=1, value='foo'), + ... Row(id=2, value=None) + ... ]) + >>> df1.select( + ... df1['value'] == 'foo', + ... df1['value'].eqNullSafe('foo'), + ... df1['value'].eqNullSafe(None) + ... ).show() + +-------------+---------------+----------------+ + |(value = foo)|(value <=> foo)|(value <=> NULL)| + +-------------+---------------+----------------+ + | true| true| false| + | null| false| true| + +-------------+---------------+----------------+ + >>> df2 = spark.createDataFrame([ + ... Row(value = 'bar'), + ... Row(value = None) + ... ]) + >>> df1.join(df2, df1["value"] == df2["value"]).count() + 0 + >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count() + 1 + >>> df2 = spark.createDataFrame([ + ... Row(id=1, value=float('NaN')), + ... Row(id=2, value=42.0), + ... Row(id=3, value=None) + ... ]) + >>> df2.select( + ... df2['value'].eqNullSafe(None), + ... df2['value'].eqNullSafe(float('NaN')), + ... df2['value'].eqNullSafe(42.0) + ... ).show() + +----------------+---------------+----------------+ + |(value <=> NULL)|(value <=> NaN)|(value <=> 42.0)| + +----------------+---------------+----------------+ + | false| true| false| + | false| false| true| + | true| false| false| + +----------------+---------------+----------------+ + Notes + ----- + Unlike Pandas, PySpark doesn't consider NaN values to be NULL. See the + `NaN Semantics <https://spark.apache.org/docs/latest/sql-ref-datatypes.html#nan-semantics>`_ + for details. + """ + eqNullSafe = _bin_op("eqNullSafe", _eqNullSafe_doc) + + __neg__ = _func_op("negate") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op("and") + __or__ = _bin_op("or") + __invert__ = _func_op("not") + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # bitwise operators + _bitwiseOR_doc = """ + Compute bitwise OR of this expression with another expression. + + Parameters + ---------- + other + a value or :class:`Column` to calculate bitwise or(|) with + this :class:`Column`. + + Examples + -------- + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseOR(df.b)).collect() + [Row((a | b)=235)] + """ + _bitwiseAND_doc = """ + Compute bitwise AND of this expression with another expression. + + Parameters + ---------- + other + a value or :class:`Column` to calculate bitwise and(&) with + this :class:`Column`. + + Examples + -------- + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseAND(df.b)).collect() + [Row((a & b)=10)] + """ + _bitwiseXOR_doc = """ + Compute bitwise XOR of this expression with another expression. + + Parameters + ---------- + other + a value or :class:`Column` to calculate bitwise xor(^) with + this :class:`Column`. + + Examples + -------- + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(a=170, b=75)]) + >>> df.select(df.a.bitwiseXOR(df.b)).collect() + [Row((a ^ b)=225)] + """ + + bitwiseOR = _bin_op("bitwiseOR", _bitwiseOR_doc) + bitwiseAND = _bin_op("bitwiseAND", _bitwiseAND_doc) + bitwiseXOR = _bin_op("bitwiseXOR", _bitwiseXOR_doc) + + # string methods + def contains(self, other: Union[PrimitiveType, "Column"]) -> "Column": + """ + Contains the other element. Returns a boolean :class:`Column` based on a string match. + + Parameters + ---------- + other + string in line. A value as a literal or a :class:`Column`. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.filter(df.name.contains('o')).collect() + [Row(age=5, name='Bob')] + """ + return _bin_op("contains")(self, other) + + def startswith(self, other: Union[PrimitiveType, "Column"]) -> "Column": + """ + String starts with. Returns a boolean :class:`Column` based on a string match. + + Parameters + ---------- + other : :class:`Column` or str + string at start of line (do not use a regex `^`) + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.filter(df.name.startswith('Al')).collect() + [Row(age=2, name='Alice')] + >>> df.filter(df.name.startswith('^Al')).collect() + [] + """ + return _bin_op("startsWith")(self, other) + + def endswith(self, other: Union[PrimitiveType, "Column"]) -> "Column": + """ + String ends with. Returns a boolean :class:`Column` based on a string match. + + Parameters + ---------- + other : :class:`Column` or str + string at end of line (do not use a regex `$`) + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.filter(df.name.endswith('ice')).collect() + [Row(age=2, name='Alice')] + >>> df.filter(df.name.endswith('ice$')).collect() + [] + """ + return _bin_op("endsWith")(self, other) + + def like(self: "Column", other: str) -> "Column": + """ + SQL like expression. Returns a boolean :class:`Column` based on a SQL LIKE match. + + Parameters + ---------- + other : str + a SQL LIKE pattern + See Also + -------- + pyspark.sql.Column.rlike + Returns + ------- + :class:`Column` + Column of booleans showing whether each element + in the Column is matched by SQL LIKE pattern. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.filter(df.name.like('Al%')).collect() + [Row(age=2, name='Alice')] + """ + return _bin_op("like")(self, other) + + def rlike(self: "Column", other: str) -> "Column": + """ + SQL RLIKE expression (LIKE with Regex). Returns a boolean :class:`Column` based on a regex + match. + + Parameters + ---------- + other : str + an extended regex expression + Returns + ------- + :class:`Column` + Column of booleans showing whether each element + in the Column is matched by extended regex expression. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.filter(df.name.rlike('ice$')).collect() + [Row(age=2, name='Alice')] + """ + return _bin_op("like")(self, other) + + def ilike(self: "Column", other: str) -> "Column": + """ + SQL ILIKE expression (case insensitive LIKE). Returns a boolean :class:`Column` + based on a case insensitive match. + + Parameters + ---------- + other : str + a SQL LIKE pattern + See Also + -------- + pyspark.sql.Column.rlike + Returns + ------- + :class:`Column` + Column of booleans showing whether each element + in the Column is matched by SQL LIKE pattern. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.filter(df.name.ilike('%Ice')).collect() + [Row(age=2, name='Alice')] + """ + return _bin_op("ilike")(self, other) + + @overload + def substr(self, startPos: int, length: int) -> "Column": + ... + + @overload + def substr(self, startPos: "Column", length: "Column") -> "Column": + ... + + def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) -> "Column": + """ + Return a :class:`Column` which is a substring of the column. + + Parameters + ---------- + startPos : :class:`Column` or int + start position + length : :class:`Column` or int + length of the substring + Returns + ------- + :class:`Column` + Column representing whether each element of Column is substr of origin Column. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.select(df.name.substr(1, 3).alias("col")).collect() + [Row(col='Ali'), Row(col='Bob')] + """ + if type(startPos) != type(length): + raise TypeError( + "startPos and length must be the same type. " + "Got {startPos_t} and {length_t}, respectively.".format( + startPos_t=type(startPos), + length_t=type(length), + ) + ) + from pyspark.sql.connect.function_builder import functions as F + + if isinstance(length, int): + length_exp = self._lit(length) + elif isinstance(length, Column): + length_exp = length + else: + raise TypeError("Unsupported type for substr().") + + if isinstance(startPos, int): + start_exp = self._lit(startPos) + else: + start_exp = startPos + + return F.substr(self, start_exp, length_exp) def __eq__(self, other: Any) -> "Column": # type: ignore[override] """Returns a binary expression with the current column as the left side and the other expression as the right side. """ - from pyspark.sql.connect._typing import PrimitiveType - from pyspark.sql.connect.functions import lit - if isinstance(other, get_args(PrimitiveType)): - other = lit(other) + other = self._lit(other) return scalar_function("==", self, other) def to_plan(self, session: "SparkConnectClient") -> proto.Expression: @@ -380,5 +702,10 @@ class Column(object): def name(self) -> str: return self._expr.name() + # TODO(SPARK-41329): solve the circular import between functions.py and + # this class if we want to reuse functions.lit + def _lit(self, x: Any) -> "Column": + return Column(LiteralExpression(x)) + def __str__(self) -> str: return self._expr.__str__() diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py index b65348c6862..1edca287367 100644 --- a/python/pyspark/sql/connect/function_builder.py +++ b/python/pyspark/sql/connect/function_builder.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from pyspark.sql.connect.client import SparkConnectClient -def _build(name: str, *args: "ColumnOrName") -> ScalarFunctionExpression: +def _build(name: str, *args: "ColumnOrName") -> Column: """ Simple wrapper function that converts the arguments into the appropriate types. Parameters @@ -46,14 +46,14 @@ def _build(name: str, *args: "ColumnOrName") -> ScalarFunctionExpression: :class:`ScalarFunctionExpression` """ cols = [x if isinstance(x, Column) else col(x) for x in args] - return ScalarFunctionExpression(name, *cols) + return Column(ScalarFunctionExpression(name, *cols)) class FunctionBuilder: """This class is used to build arbitrary functions used in expressions""" def __getattr__(self, name: str) -> "FunctionBuilderCallable": - def _(*args: "ColumnOrName") -> ScalarFunctionExpression: + def _(*args: "ColumnOrName") -> Column: return _build(name, *args) _.__doc__ = f"""Function to apply {name}""" diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 47a50a2cecb..c499e393e19 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -116,8 +116,35 @@ class SparkConnectTests(SparkConnectSQLTestCase): def test_columns(self): # SPARK-41036: test `columns` API for python client. - columns = self.connect.read.table(self.tbl_name).columns - self.assertEqual(["id", "name"], columns) + df = self.connect.read.table(self.tbl_name) + df2 = self.spark.read.table(self.tbl_name) + self.assertEqual(["id", "name"], df.columns) + + self.assert_eq( + df.filter(df.name.rlike("20")).toPandas(), df2.filter(df2.name.rlike("20")).toPandas() + ) + self.assert_eq( + df.filter(df.name.like("20")).toPandas(), df2.filter(df2.name.like("20")).toPandas() + ) + self.assert_eq( + df.filter(df.name.ilike("20")).toPandas(), df2.filter(df2.name.ilike("20")).toPandas() + ) + self.assert_eq( + df.filter(df.name.contains("20")).toPandas(), + df2.filter(df2.name.contains("20")).toPandas(), + ) + self.assert_eq( + df.filter(df.name.startswith("2")).toPandas(), + df2.filter(df2.name.startswith("2")).toPandas(), + ) + self.assert_eq( + df.filter(df.name.endswith("0")).toPandas(), + df2.filter(df2.name.endswith("0")).toPandas(), + ) + self.assert_eq( + df.select(df.name.substr(0, 1).alias("col")).toPandas(), + df2.select(df2.name.substr(0, 1).alias("col")).toPandas(), + ) def test_collect(self): df = self.connect.read.table(self.tbl_name) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org