grundprinzip commented on PR #39585:
URL: https://github.com/apache/spark/pull/39585#issuecomment-1387304156

   I did a quick test of the PR to see how far away we are from supporting 
pandas UDF with the same code. I did a quick change to the pandas UDF code and 
was able to execute them as well.
   
   This is my full patch:
   
   ```
   diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
   index a8d51aed30..46ca300d86 100644
   --- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
   +++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
   @@ -19,13 +19,15 @@ package org.apache.spark.sql.connect.planner
    
    import scala.collection.JavaConverters._
    import scala.collection.mutable
   +
    import com.google.common.collect.{Lists, Maps}
    import com.google.protobuf.{Any => ProtoAny}
   +
    import org.apache.spark.TaskContext
   -import org.apache.spark.api.python.{PythonEvalType, PythonFunction, 
SimplePythonFunction}
   +import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
    import org.apache.spark.connect.proto
    import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
   -import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, 
expressions}
   +import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier}
    import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, 
LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar}
    import org.apache.spark.sql.catalyst.expressions._
    import org.apache.spark.sql.catalyst.optimizer.CombineUnions
   @@ -819,8 +821,9 @@ class SparkConnectPlanner(session: SparkSession) {
      private def transformScalarInlineUserDefinedFunction(
          fun: proto.Expression.ScalarInlineUserDefinedFunction): Expression = {
        fun.getFunctionCase match {
   -      case 
proto.Expression.ScalarInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
   +          case 
proto.Expression.ScalarInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
            transformPythonUDF(fun)
   +          case _ => throw InvalidPlanInput("Arg")
        }
      }
    
   @@ -834,7 +837,7 @@ class SparkConnectPlanner(session: SparkSession) {
       */
      private def transformPythonUDF(
          fun: proto.Expression.ScalarInlineUserDefinedFunction): PythonUDF = {
   -    val udf = fun.getPythonUDF
   +    val udf = fun.getPythonUdf
        PythonUDF(
          fun.getFunctionName,
          transformPythonFunction(udf),
   diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
   index a65b8109a4..fc6c398de1 100644
   --- a/python/pyspark/sql/connect/udf.py
   +++ b/python/pyspark/sql/connect/udf.py
   @@ -108,7 +108,7 @@ class UserDefinedFunction:
            py_udf = PythonUDF(
                output_type=data_type_str,
                eval_type=self.evalType,
   -            command=cloudpickle.dumps(self.func),
   +            command=cloudpickle.dumps((self.func, self._returnType)),
            )
            return Column(
                ScalarInlineUserDefinedFunction(
   diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
   index d0f81e2f63..c053a1d783 100644
   --- a/python/pyspark/sql/pandas/functions.py
   +++ b/python/pyspark/sql/pandas/functions.py
   @@ -25,6 +25,7 @@ from pyspark.sql.pandas.typehints import infer_eval_type
    from pyspark.sql.pandas.utils import require_minimum_pandas_version, 
require_minimum_pyarrow_version
    from pyspark.sql.types import DataType
    from pyspark.sql.udf import _create_udf
   +from pyspark.sql.utils import is_remote
    
    
    class PandasUDFType:
   @@ -449,4 +450,8 @@ def _create_pandas_udf(f, returnType, evalType):
                "or three arguments (key, left, right)."
            )
    
   -    return _create_udf(f, returnType, evalType)
   +    if is_remote():
   +        from pyspark.sql.connect.udf import _create_udf as 
_create_connect_udf
   +        return _create_connect_udf(f, returnType, evalType)
   +    else:
   +        return _create_udf(f, returnType, evalType)
   
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to