hvanhovell commented on code in PR #40581: URL: https://github.com/apache/spark/pull/40581#discussion_r1156174966
########## connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala: ########## @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import java.lang.{Long => JLong} +import java.util.{Iterator => JIterator} +import java.util.Arrays + +import scala.collection.JavaConverters._ + +import org.apache.spark.api.java.function._ +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} +import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.sql.functions.udf + +/** + * All tests in this class requires client UDF artifacts synced with the server. TODO: It means + * these tests only works with SBT for now. + */ +class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession { + test("simple udf") { + def dummyUdf(x: Int): Int = x + 5 + val myUdf = udf(dummyUdf _) + val df = spark.range(5).select(myUdf(Column("id"))) + val result = df.collect() + assert(result.length == 5) + result.zipWithIndex.foreach { case (v, idx) => + assert(v.getInt(0) == idx + 5) + } + } + + test("Dataset typed filter") { + val rows = spark.range(10).filter(n => n % 2 == 0).collectAsList() + assert(rows == Arrays.asList[Long](0, 2, 4, 6, 8)) + } + + test("Dataset typed filter - java") { + val rows = spark + .range(10) + .filter(new FilterFunction[JLong] { + override def call(value: JLong): Boolean = value % 2 == 0 + }) + .collectAsList() + assert(rows == Arrays.asList[Long](0, 2, 4, 6, 8)) + } + + test("Dataset typed map") { + val rows = spark.range(10).map(n => n / 2)(PrimitiveLongEncoder).collectAsList() Review Comment: Can you import `spark.implicits` and see if we don't have to pass the encoder? ########## connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala: ########## @@ -482,27 +482,66 @@ class SparkConnectPlanner(val session: SparkSession) { } private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = { + val baseRel = transformRelation(rel.getInput) val commonUdf = rel.getFunc - val pythonUdf = transformPythonUDF(commonUdf) - val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false - pythonUdf.evalType match { - case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => - logical.MapInPandas( - pythonUdf, - pythonUdf.dataType.asInstanceOf[StructType].toAttributes, - transformRelation(rel.getInput), - isBarrier) - case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => - logical.PythonMapInArrow( - pythonUdf, - pythonUdf.dataType.asInstanceOf[StructType].toAttributes, - transformRelation(rel.getInput), - isBarrier) + commonUdf.getFunctionCase match { + case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF => + transformTypedMapPartitions(commonUdf, baseRel) + case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => + val pythonUdf = transformPythonUDF(commonUdf) + val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false + pythonUdf.evalType match { + case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => + logical.MapInPandas( + pythonUdf, + pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + baseRel, + isBarrier) + case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => + logical.PythonMapInArrow( + pythonUdf, + pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + baseRel, + isBarrier) + case _ => + throw InvalidPlanInput( + s"Function with EvalType: ${pythonUdf.evalType} is not supported") + } case _ => - throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported") + throw InvalidPlanInput( + s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not supported") } } + private def generateObjAttr[T](enc: ExpressionEncoder[T]): Attribute = { + val dataType = enc.deserializer.dataType + val nullable = !enc.clsTag.runtimeClass.isPrimitive + AttributeReference("obj", dataType, nullable)() + } + + private def transformTypedMapPartitions( + fun: proto.CommonInlineUserDefinedFunction, + child: LogicalPlan): LogicalPlan = { + val udf = fun.getScalarScalaUdf + val udfPacket = + Utils.deserialize[UdfPacket]( + udf.getPayload.toByteArray, + SparkConnectArtifactManager.classLoaderWithArtifacts) + assert(udfPacket.inputEncoders.size == 1) + implicit val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head) + val rEnc = ExpressionEncoder(udfPacket.outputEncoder) + + val deserializer = UnresolvedDeserializer(iEnc.deserializer) + val deserialized = DeserializeToObject(deserializer, generateObjAttr(iEnc), child) + val mapped = MapPartitions( + udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]], + generateObjAttr(rEnc), + deserialized) + val serialized = SerializeFromObject(rEnc.namedExpressions, mapped) + + new Dataset(session, serialized, rEnc).logicalPlan Review Comment: Why construct a dataframe here? You should be able to return `serialized`. ########## connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala: ########## @@ -482,27 +482,66 @@ class SparkConnectPlanner(val session: SparkSession) { } private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = { + val baseRel = transformRelation(rel.getInput) val commonUdf = rel.getFunc - val pythonUdf = transformPythonUDF(commonUdf) - val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false - pythonUdf.evalType match { - case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => - logical.MapInPandas( - pythonUdf, - pythonUdf.dataType.asInstanceOf[StructType].toAttributes, - transformRelation(rel.getInput), - isBarrier) - case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => - logical.PythonMapInArrow( - pythonUdf, - pythonUdf.dataType.asInstanceOf[StructType].toAttributes, - transformRelation(rel.getInput), - isBarrier) + commonUdf.getFunctionCase match { + case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF => + transformTypedMapPartitions(commonUdf, baseRel) + case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => + val pythonUdf = transformPythonUDF(commonUdf) + val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false + pythonUdf.evalType match { + case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => + logical.MapInPandas( + pythonUdf, + pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + baseRel, + isBarrier) + case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => + logical.PythonMapInArrow( + pythonUdf, + pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + baseRel, + isBarrier) + case _ => + throw InvalidPlanInput( + s"Function with EvalType: ${pythonUdf.evalType} is not supported") + } case _ => - throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported") + throw InvalidPlanInput( + s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not supported") } } + private def generateObjAttr[T](enc: ExpressionEncoder[T]): Attribute = { + val dataType = enc.deserializer.dataType + val nullable = !enc.clsTag.runtimeClass.isPrimitive + AttributeReference("obj", dataType, nullable)() + } + + private def transformTypedMapPartitions( + fun: proto.CommonInlineUserDefinedFunction, + child: LogicalPlan): LogicalPlan = { + val udf = fun.getScalarScalaUdf + val udfPacket = Review Comment: Shall we put this in a helper function? It is also needed for UDFs. ########## connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala: ########## @@ -846,7 +885,28 @@ class SparkConnectPlanner(val session: SparkSession) { private def transformFilter(rel: proto.Filter): LogicalPlan = { assert(rel.hasInput) val baseRel = transformRelation(rel.getInput) - logical.Filter(condition = transformExpression(rel.getCondition), child = baseRel) + val cond = rel.getCondition + cond.getExprTypeCase match { + case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION + if cond.getCommonInlineUserDefinedFunction.getFunctionCase == + proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF => + transformTypedFilter(cond.getCommonInlineUserDefinedFunction, baseRel) + case _ => + logical.Filter(condition = transformExpression(cond), child = baseRel) + } + } + + private def transformTypedFilter( + fun: proto.CommonInlineUserDefinedFunction, + child: LogicalPlan): TypedFilter = { + val udf = fun.getScalarScalaUdf + val udfPacket = + Utils.deserialize[UdfPacket]( + udf.getPayload.toByteArray, + SparkConnectArtifactManager.classLoaderWithArtifacts) + assert(udfPacket.inputEncoders.size == 1) + implicit val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head) Review Comment: You can also explicitly pass the encoder. ########## connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala: ########## @@ -482,27 +482,66 @@ class SparkConnectPlanner(val session: SparkSession) { } private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = { + val baseRel = transformRelation(rel.getInput) val commonUdf = rel.getFunc - val pythonUdf = transformPythonUDF(commonUdf) - val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false - pythonUdf.evalType match { - case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => - logical.MapInPandas( - pythonUdf, - pythonUdf.dataType.asInstanceOf[StructType].toAttributes, - transformRelation(rel.getInput), - isBarrier) - case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => - logical.PythonMapInArrow( - pythonUdf, - pythonUdf.dataType.asInstanceOf[StructType].toAttributes, - transformRelation(rel.getInput), - isBarrier) + commonUdf.getFunctionCase match { + case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF => + transformTypedMapPartitions(commonUdf, baseRel) + case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => + val pythonUdf = transformPythonUDF(commonUdf) + val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false + pythonUdf.evalType match { + case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => + logical.MapInPandas( + pythonUdf, + pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + baseRel, + isBarrier) + case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => + logical.PythonMapInArrow( + pythonUdf, + pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + baseRel, + isBarrier) + case _ => + throw InvalidPlanInput( + s"Function with EvalType: ${pythonUdf.evalType} is not supported") + } case _ => - throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported") + throw InvalidPlanInput( + s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not supported") } } + private def generateObjAttr[T](enc: ExpressionEncoder[T]): Attribute = { + val dataType = enc.deserializer.dataType + val nullable = !enc.clsTag.runtimeClass.isPrimitive + AttributeReference("obj", dataType, nullable)() + } + + private def transformTypedMapPartitions( + fun: proto.CommonInlineUserDefinedFunction, + child: LogicalPlan): LogicalPlan = { + val udf = fun.getScalarScalaUdf + val udfPacket = + Utils.deserialize[UdfPacket]( + udf.getPayload.toByteArray, + SparkConnectArtifactManager.classLoaderWithArtifacts) + assert(udfPacket.inputEncoders.size == 1) + implicit val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head) Review Comment: Why implicit? ########## connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala: ########## @@ -846,7 +885,28 @@ class SparkConnectPlanner(val session: SparkSession) { private def transformFilter(rel: proto.Filter): LogicalPlan = { assert(rel.hasInput) val baseRel = transformRelation(rel.getInput) - logical.Filter(condition = transformExpression(rel.getCondition), child = baseRel) + val cond = rel.getCondition + cond.getExprTypeCase match { + case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION Review Comment: Does this mean the following would hit this code path: ```scala import spark.implicits._ val under5 = udf(i: Long => i < 5) spark.range(10).filter(under5($"id")).collect() ``` If it does then that would be wrong. We should check if it has an unresolved star as its input. ########## connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala: ########## @@ -482,27 +482,66 @@ class SparkConnectPlanner(val session: SparkSession) { } private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = { + val baseRel = transformRelation(rel.getInput) val commonUdf = rel.getFunc - val pythonUdf = transformPythonUDF(commonUdf) - val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false - pythonUdf.evalType match { - case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => - logical.MapInPandas( - pythonUdf, - pythonUdf.dataType.asInstanceOf[StructType].toAttributes, - transformRelation(rel.getInput), - isBarrier) - case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => - logical.PythonMapInArrow( - pythonUdf, - pythonUdf.dataType.asInstanceOf[StructType].toAttributes, - transformRelation(rel.getInput), - isBarrier) + commonUdf.getFunctionCase match { + case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF => + transformTypedMapPartitions(commonUdf, baseRel) + case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => + val pythonUdf = transformPythonUDF(commonUdf) + val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false + pythonUdf.evalType match { + case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => + logical.MapInPandas( + pythonUdf, + pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + baseRel, + isBarrier) + case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => + logical.PythonMapInArrow( + pythonUdf, + pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + baseRel, + isBarrier) + case _ => + throw InvalidPlanInput( + s"Function with EvalType: ${pythonUdf.evalType} is not supported") + } case _ => - throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported") + throw InvalidPlanInput( + s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not supported") } } + private def generateObjAttr[T](enc: ExpressionEncoder[T]): Attribute = { + val dataType = enc.deserializer.dataType + val nullable = !enc.clsTag.runtimeClass.isPrimitive + AttributeReference("obj", dataType, nullable)() Review Comment: You can give it any name, an empty name also works. What matter is that the attribute has a stable ID (which is does). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
