zhengruifeng commented on code in PR #38956:
URL: https://github.com/apache/spark/pull/38956#discussion_r1041716023


##########
python/pyspark/sql/connect/column.py:
##########
@@ -129,6 +140,53 @@ def name(self) -> str:
         ...
 
 
+class CaseWhen(Expression):
+    def __init__(
+        self, branches: Sequence[Tuple[Expression, Expression]], else_value: 
Optional[Expression]
+    ):
+
+        assert isinstance(branches, list)
+        for branch in branches:
+            assert (
+                isinstance(branch, tuple)
+                and len(branch) == 2
+                and all(isinstance(expr, Expression) for expr in branch)
+            )
+        self._branches = branches
+
+        if else_value is not None:
+            assert isinstance(else_value, Expression)
+
+        self._else_value = else_value
+
+    def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
+        # fun = proto.Expression()
+        # fun.unresolved_function.function_name = "when"
+        # for condition, value in self._branches:
+        #     
fun.unresolved_function.arguments.extend(condition.to_plan(session))
+        #     fun.unresolved_function.arguments.extend(value.to_plan(session))
+        # if self._else_value is not None:
+        #     
fun.unresolved_function.arguments.extend(self._else_value.to_plan(session))
+        # return fun
+
+        args: Sequence[Expression] = []
+        for condition, value in self._branches:
+            args.append(condition)
+            args.append(value)
+
+        if self._else_value is not None:
+            args.append(self._else_value)
+
+        unresolved_function = UnresolvedFunction(name="when", args=args)
+
+        return unresolved_function.to_plan(session)

Review Comment:
   both `cdf.select(CF.when(cdf.a == 0, 1.0))` and `cdf.select(CF.when(cdf.a == 
0, 1.0).otherwise(2.0))` fail with this error:
   
   ```
   grpc._channel._MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC 
that terminated with:
           status = StatusCode.UNKNOWN
           details = "Invalid arguments for function when."
           debug_error_string = 
"{"created":"@1670381875.743425000","description":"Error received from peer 
ipv6:[::1]:15002","file":"src/core/lib/surface/call.cc","file_line":1064,"grpc_message":"Invalid
 arguments for function when.","grpc_status":2}"
   ```
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to