This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0bbd049a9ade [SPARK-48591][PYTHON] Simplify the if-else branches with 
`F.lit`
0bbd049a9ade is described below

commit 0bbd049a9adebd71f4262cb661a15cb01697acf5
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jun 12 16:49:29 2024 -0700

    [SPARK-48591][PYTHON] Simplify the if-else branches with `F.lit`
    
    ### What changes were proposed in this pull request?
    Simplify the if-else branches with `F.lit` which accept both Column and 
non-Column input
    
    ### Why are the changes needed?
    code clean up
    
    ### Does this PR introduce _any_ user-facing change?
    No, internal minor refactor
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46946 from zhengruifeng/column_simplify.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/column.py | 45 ++++++++++++++++--------------------
 1 file changed, 20 insertions(+), 25 deletions(-)

diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index c38717afccda..7a619100e03a 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -27,6 +27,7 @@ from typing import (
     Any,
     Union,
     Optional,
+    cast,
 )
 
 from pyspark.sql.column import Column as ParentColumn
@@ -34,6 +35,7 @@ from pyspark.errors import PySparkTypeError, 
PySparkAttributeError, PySparkValue
 from pyspark.sql.types import DataType
 
 import pyspark.sql.connect.proto as proto
+from pyspark.sql.connect.functions import builtin as F
 from pyspark.sql.connect.expressions import (
     Expression,
     UnresolvedFunction,
@@ -308,14 +310,12 @@ class Column(ParentColumn):
                 message_parameters={},
             )
 
-        if isinstance(value, Column):
-            _value = value._expr
-        else:
-            _value = LiteralExpression._from_value(value)
-
-        _branches = self._expr._branches + [(condition._expr, _value)]
-
-        return Column(CaseWhen(branches=_branches, else_value=None))
+        return Column(
+            CaseWhen(
+                branches=self._expr._branches + [(condition._expr, 
F.lit(value)._expr)],
+                else_value=None,
+            )
+        )
 
     def otherwise(self, value: Any) -> ParentColumn:
         if not isinstance(self._expr, CaseWhen):
@@ -328,12 +328,12 @@ class Column(ParentColumn):
                 "otherwise() can only be applied once on a Column previously 
generated by when()"
             )
 
-        if isinstance(value, Column):
-            _value = value._expr
-        else:
-            _value = LiteralExpression._from_value(value)
-
-        return Column(CaseWhen(branches=self._expr._branches, 
else_value=_value))
+        return Column(
+            CaseWhen(
+                branches=self._expr._branches,
+                else_value=cast(Expression, F.lit(value)._expr),
+            )
+        )
 
     def like(self: ParentColumn, other: str) -> ParentColumn:
         return _bin_op("like", self, other)
@@ -457,14 +457,11 @@ class Column(ParentColumn):
         else:
             _cols = list(cols)
 
-        _exprs = [self._expr]
-        for c in _cols:
-            if isinstance(c, Column):
-                _exprs.append(c._expr)
-            else:
-                _exprs.append(LiteralExpression._from_value(c))
-
-        return Column(UnresolvedFunction("in", _exprs))
+        return Column(
+            UnresolvedFunction(
+                "in", [self._expr] + [cast(Expression, F.lit(c)._expr) for c 
in _cols]
+            )
+        )
 
     def between(
         self,
@@ -554,10 +551,8 @@ class Column(ParentColumn):
                     message_parameters={},
                 )
             return self.substr(k.start, k.stop)
-        elif isinstance(k, Column):
-            return Column(UnresolvedExtractValue(self._expr, k._expr))
         else:
-            return Column(UnresolvedExtractValue(self._expr, 
LiteralExpression._from_value(k)))
+            return Column(UnresolvedExtractValue(self._expr, cast(Expression, 
F.lit(k)._expr)))
 
     def __iter__(self) -> None:
         raise PySparkTypeError(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to