grundprinzip commented on code in PR #39068:
URL: https://github.com/apache/spark/pull/39068#discussion_r1051232035


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -565,6 +565,47 @@ class SparkConnectPlanner(session: SparkSession) {
         val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
         Some(In(children.head, children.tail))
 
+      case "___lambda_function___" =>
+        // UnresolvedFunction[___lambda_function___, ["x, y -> x < y", "x", 
"y"]]

Review Comment:
   This branch deserves it's own function please.



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -565,6 +565,47 @@ class SparkConnectPlanner(session: SparkSession) {
         val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
         Some(In(children.head, children.tail))
 
+      case "___lambda_function___" =>
+        // UnresolvedFunction[___lambda_function___, ["x, y -> x < y", "x", 
"y"]]
+
+        if (fun.getArgumentsCount < 2) {
+          throw InvalidPlanInput(
+            "LambdaFunction requires at least 2 child expressions: 
LamdaFunction, Arguments")
+        }
+
+        val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
+
+        val function = children.head
+
+        val variableNames = children.tail.map {
+          case variable: UnresolvedAttribute if variable.nameParts.length == 1 
=>
+            variable.nameParts.head
+          case other =>
+            throw InvalidPlanInput(
+              "LambdaFunction requires all arguments to be UnresolvedAttribute 
with " +
+                s"single name part, but got $other")

Review Comment:
   There are two interesting issues here:
   
    1. When someone submits to the API an expression that does not transform 
into `UnresolvedExpression` this would throw a weird error message about the 
name parts but actually the type does not match.
    2. Why the restriction to single part names? Is this a Spark limitation?



##########
python/pyspark/sql/connect/column.py:
##########
@@ -543,6 +543,39 @@ def __repr__(self) -> str:
         return f"({self._col} ({self._data_type}))"
 
 
+class LambdaFunction(Expression):
+    def __init__(
+        self,
+        function: Expression,
+        arguments: Sequence[Expression],
+    ) -> None:
+        super().__init__()
+
+        assert isinstance(function, Expression)
+
+        assert (
+            isinstance(arguments, list)
+            and len(arguments) > 0
+            and all(isinstance(arg, ColumnReference) for arg in arguments)
+        )

Review Comment:
   Adding these assertions here is helpful in the Python client but the server 
side does not do the same assertion. What happens if we drop the assertion on 
`ColumnReference` what would happen on the server? 
   
   Is the analysis exception not better than the Python assertion>



##########
python/pyspark/sql/connect/functions.py:
##########
@@ -80,6 +84,80 @@ def _invoke_binary_math_function(name: str, col1: Any, col2: 
Any) -> Column:
     return _invoke_function(name, *_cols)
 
 
+def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]:
+    signature = inspect.signature(f)
+    parameters = signature.parameters.values()
+
+    # We should exclude functions that use
+    # variable args and keyword argnames
+    # as well as keyword only args
+    supported_parameter_types = {
+        inspect.Parameter.POSITIONAL_OR_KEYWORD,
+        inspect.Parameter.POSITIONAL_ONLY,
+    }
+
+    # Validate that
+    # function arity is between 1 and 3

Review Comment:
   ```suggestion
       # Validate that the function arity is between 1 and 3.
   ```



##########
python/pyspark/sql/connect/functions.py:
##########
@@ -80,6 +84,80 @@ def _invoke_binary_math_function(name: str, col1: Any, col2: 
Any) -> Column:
     return _invoke_function(name, *_cols)
 
 
+def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]:
+    signature = inspect.signature(f)
+    parameters = signature.parameters.values()
+
+    # We should exclude functions that use
+    # variable args and keyword argnames
+    # as well as keyword only args
+    supported_parameter_types = {
+        inspect.Parameter.POSITIONAL_OR_KEYWORD,
+        inspect.Parameter.POSITIONAL_ONLY,
+    }
+
+    # Validate that
+    # function arity is between 1 and 3

Review Comment:
   one line?



##########
python/pyspark/sql/connect/functions.py:
##########
@@ -80,6 +84,80 @@ def _invoke_binary_math_function(name: str, col1: Any, col2: 
Any) -> Column:
     return _invoke_function(name, *_cols)
 
 
+def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]:
+    signature = inspect.signature(f)
+    parameters = signature.parameters.values()
+
+    # We should exclude functions that use
+    # variable args and keyword argnames
+    # as well as keyword only args
+    supported_parameter_types = {
+        inspect.Parameter.POSITIONAL_OR_KEYWORD,
+        inspect.Parameter.POSITIONAL_ONLY,
+    }
+
+    # Validate that
+    # function arity is between 1 and 3
+    if not (1 <= len(parameters) <= 3):
+        raise ValueError(
+            "f should take between 1 and 3 arguments, but provided function 
takes {}".format(
+                len(parameters)
+            )
+        )
+
+    # and all arguments can be used as positional

Review Comment:
   ```suggestion
       # Verify that all arguments can be used as positional arguments.
   ```



##########
python/pyspark/sql/connect/functions.py:
##########
@@ -80,6 +84,80 @@ def _invoke_binary_math_function(name: str, col1: Any, col2: 
Any) -> Column:
     return _invoke_function(name, *_cols)
 
 
+def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]:
+    signature = inspect.signature(f)
+    parameters = signature.parameters.values()
+
+    # We should exclude functions that use
+    # variable args and keyword argnames
+    # as well as keyword only args
+    supported_parameter_types = {
+        inspect.Parameter.POSITIONAL_OR_KEYWORD,
+        inspect.Parameter.POSITIONAL_ONLY,
+    }
+
+    # Validate that
+    # function arity is between 1 and 3
+    if not (1 <= len(parameters) <= 3):
+        raise ValueError(
+            "f should take between 1 and 3 arguments, but provided function 
takes {}".format(
+                len(parameters)
+            )
+        )
+
+    # and all arguments can be used as positional
+    if not all(p.kind in supported_parameter_types for p in parameters):
+        raise ValueError("f should use only POSITIONAL or POSITIONAL OR 
KEYWORD arguments")

Review Comment:
   This error message is really hard to parse. I know what you want to say (and 
I know it's probably copy + paste from PySpark plain but still. 
   
   Maybe better: "All arguments of f must be usable as POSITIONAL arguments."



##########
python/pyspark/sql/connect/functions.py:
##########
@@ -80,6 +84,80 @@ def _invoke_binary_math_function(name: str, col1: Any, col2: 
Any) -> Column:
     return _invoke_function(name, *_cols)
 
 
+def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]:
+    signature = inspect.signature(f)
+    parameters = signature.parameters.values()
+
+    # We should exclude functions that use
+    # variable args and keyword argnames
+    # as well as keyword only args

Review Comment:
   ```suggestion
       # We should exclude functions that use, variable args and keyword 
argument
       # names, as well as keyword only args.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to