This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 8537fa634cd [SPARK-29497][CONNECT] Throw error when UDF is not deserializable 8537fa634cd is described below commit 8537fa634cd02f46e7b42afd6b35f877f3a2c161 Author: Herman van Hovell <hvanhov...@databricks.com> AuthorDate: Tue Aug 1 14:53:54 2023 -0400 [SPARK-29497][CONNECT] Throw error when UDF is not deserializable ### What changes were proposed in this pull request? This PR adds a better error message when a JVM UDF cannot be deserialized. ### Why are the changes needed? In some cases a UDF cannot be deserialized. The happens when a lambda references itself (typically through the capturing class). Java cannot deserialize such an object graph because SerializedLambda's are serialization proxies which need the full graph to be deserialized before they can be transformed into the actual lambda. This is not possible if there is such a cycle. This PR adds a more readable and understandable error when this happens, the original java one is a `ClassCastExcep [...] ### Does this PR introduce _any_ user-facing change? Yes. It will throw an error on the client when a UDF is not deserializable. The error is better and more actionable then what we got before. ### How was this patch tested? Added tests. Closes #42245 from hvanhovell/SPARK-29497. Lead-authored-by: Herman van Hovell <hvanhov...@databricks.com> Co-authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit f54b402021785e0b0ec976ec889de67d3b2fdc6e) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../org/apache/spark/util/SparkSerDeUtils.scala | 21 ++++++++++- .../sql/expressions/UserDefinedFunction.scala | 24 +++++++++++- .../spark/sql/UserDefinedFunctionSuite.scala | 44 ++++++++++++++++++++-- .../main/scala/org/apache/spark/util/Utils.scala | 23 +---------- 4 files changed, 85 insertions(+), 27 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala index 3069e4c36a7..9b6174c47bd 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala @@ -16,9 +16,9 @@ */ package org.apache.spark.util -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream, ObjectStreamClass} -object SparkSerDeUtils { +trait SparkSerDeUtils { /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -34,4 +34,21 @@ object SparkSerDeUtils { val ois = new ObjectInputStream(bis) ois.readObject.asInstanceOf[T] } + + /** + * Deserialize an object using Java serialization and the given ClassLoader + */ + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) { + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname + Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } + } + ois.readObject.asInstanceOf[T] + } } + +object SparkSerDeUtils extends SparkSerDeUtils diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 3a38029c265..e060dba0b7e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -18,16 +18,18 @@ package org.apache.spark.sql.expressions import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import com.google.protobuf.ByteString +import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.SparkSerDeUtils +import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils} /** * A user-defined function. To create one, use the `udf` functions in `functions`. @@ -144,6 +146,25 @@ case class ScalarUserDefinedFunction private[sql] ( } object ScalarUserDefinedFunction { + private val LAMBDA_DESERIALIZATION_ERR_MSG: String = + "cannot assign instance of java.lang.invoke.SerializedLambda to field" + + private def checkDeserializable(bytes: Array[Byte]): Unit = { + try { + SparkSerDeUtils.deserialize(bytes, SparkClassUtils.getContextOrSparkClassLoader) + } catch { + case e: ClassCastException if e.getMessage.contains(LAMBDA_DESERIALIZATION_ERR_MSG) => + throw new SparkException( + "UDF cannot be executed on a Spark cluster: it cannot be deserialized. " + + "This is very likely to be caused by the lambda function (the UDF) having a " + + "self-reference. This is not supported by java serialization.") + case NonFatal(e) => + throw new SparkException( + "UDF cannot be executed on a Spark cluster: it cannot be deserialized.", + e) + } + } + private[sql] def apply( function: AnyRef, returnType: TypeTag[_], @@ -164,6 +185,7 @@ object ScalarUserDefinedFunction { outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = { val udfPacketBytes = SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder)) + checkDeserializable(udfPacketBytes) ScalarUserDefinedFunction( serializedUdfPacket = udfPacketBytes, inputTypes = inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType), diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala index 684f5671e48..76608559866 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql import scala.reflect.runtime.universe.typeTag -import org.scalatest.BeforeAndAfterEach - +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.connect.common.UdfPacket import org.apache.spark.sql.functions.udf import org.apache.spark.util.SparkSerDeUtils -class UserDefinedFunctionSuite extends ConnectFunSuite with BeforeAndAfterEach { +class UserDefinedFunctionSuite extends ConnectFunSuite { test("udf and encoder serialization") { def func(x: Int): Int = x + 1 @@ -48,4 +47,43 @@ class UserDefinedFunctionSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(deSer.outputEncoder == ScalaReflection.encoderFor(typeTag[Int])) assert(deSer.inputEncoders == Seq(ScalaReflection.encoderFor(typeTag[Int]))) } + + private def testNonDeserializable(f: Int => Int): Unit = { + val e = intercept[SparkException](udf(f)) + assert( + e.getMessage.contains( + "UDF cannot be executed on a Spark cluster: it cannot be deserialized.")) + assert(e.getMessage.contains("This is not supported by java serialization.")) + } + + test("non deserializable UDFs") { + testNonDeserializable(Command2(Command1()).indirect) + testNonDeserializable(MultipleLambdas().indirect) + testNonDeserializable(SelfRef(22).method) + } + + test("serializable UDFs") { + val direct = (i: Int) => i + 1 + val indirect = (i: Int) => direct(i) + udf(indirect) + udf(Command1().direct) + udf(MultipleLambdas().direct) + } +} + +case class Command1() extends Serializable { + val direct: Int => Int = (i: Int) => i + 1 +} + +case class Command2(prev: Command1) extends Serializable { + val indirect: Int => Int = (i: Int) => prev.direct(i) +} + +case class SelfRef(start: Int) extends Serializable { + val method: Int => Int = (i: Int) => i + start +} + +case class MultipleLambdas() extends Serializable { + val direct: Int => Int = (i: Int) => i + 1 + val indirect: Int => Int = (i: Int) => direct(i) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a3002eb40f4..a556f03dc09 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -95,7 +95,8 @@ private[spark] object Utils extends Logging with SparkClassUtils with SparkErrorUtils - with SparkFileUtils { + with SparkFileUtils + with SparkSerDeUtils { private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler @volatile private var cachedLocalDir: String = "" @@ -121,26 +122,6 @@ private[spark] object Utils private val copyBuffer = ThreadLocal.withInitial[Array[Byte]](() => { new Array[Byte](COPY_BUFFER_LEN) }) - - /** Serialize an object using Java serialization */ - def serialize[T](o: T): Array[Byte] = SparkSerDeUtils.serialize(o) - - /** Deserialize an object using Java serialization */ - def deserialize[T](bytes: Array[Byte]): T = SparkSerDeUtils.deserialize(bytes) - - /** Deserialize an object using Java serialization and the given ClassLoader */ - def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { - val bis = new ByteArrayInputStream(bytes) - val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = { - // scalastyle:off classforname - Class.forName(desc.getName, false, loader) - // scalastyle:on classforname - } - } - ois.readObject.asInstanceOf[T] - } - /** Deserialize a Long value (used for [[org.apache.spark.api.python.PythonPartitioner]]) */ def deserializeLongValue(bytes: Array[Byte]) : Long = { // Note: we assume that we are given a Long value encoded in network (big-endian) byte order --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org