This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new f0bb1391fe4 [SPARK-44791][CONNECT] Make ArrowDeserializer work with REPL generated classes f0bb1391fe4 is described below commit f0bb1391fe460fee886bce9151a47e89e75de671 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Mon Aug 14 02:38:54 2023 +0200 [SPARK-44791][CONNECT] Make ArrowDeserializer work with REPL generated classes ### What changes were proposed in this pull request? Connects arrow deserialization currently does not work with REPL generated classes. For example the following code would fail: ```scala case class MyTestClass(value: Int) { override def toString: String = value.toString } spark.range(10).map(i => MyTestClass(i.toInt)).collect() ``` The problem is that for instantiation of the `MyTestClass` class we need the instance of the class that it was defined in (its outerscope). In Spark we have a mechanism called `OuterScopes` to register these instances in. The `ArrowDeserializer` was not resolving this outer instance. This PR fixes this. We have a similar issue on the executor/driver side. This will be fixed in a different PR. ### Why are the changes needed? It is a bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? I have added tests to `ReplE2Esuite` and to the `ArrowEncoderSuite`. Closes #42473 from hvanhovell/SPARK-44791. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit dcf3d582293c3dbb3820d12fa15b41e8bd5fe6ad) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../org/apache/spark/util/SparkClassUtils.scala | 28 +++++++ .../connect/client/arrow/ArrowDeserializer.scala | 14 +++- .../spark/sql/application/ReplE2ESuite.scala | 33 ++++----- .../connect/client/arrow/ArrowEncoderSuite.scala | 12 ++- .../main/scala/org/apache/spark/util/Utils.scala | 28 ------- .../spark/sql/catalyst/encoders/OuterScopes.scala | 85 +++++++++++++++++----- .../apache/spark/sql/errors/ExecutionErrors.scala | 7 ++ .../spark/sql/errors/QueryExecutionErrors.scala | 7 -- 8 files changed, 138 insertions(+), 76 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala index a237869aef3..679d546d04c 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala @@ -50,6 +50,34 @@ trait SparkClassUtils { def classIsLoadable(clazz: String): Boolean = { Try { classForName(clazz, initialize = false) }.isSuccess } + + /** + * Returns true if and only if the underlying class is a member class. + * + * Note: jdk8u throws a "Malformed class name" error if a given class is a deeply-nested + * inner class (See SPARK-34607 for details). This issue has already been fixed in jdk9+, so + * we can remove this helper method safely if we drop the support of jdk8u. + */ + def isMemberClass(cls: Class[_]): Boolean = { + try { + cls.isMemberClass + } catch { + case _: InternalError => + // We emulate jdk8u `Class.isMemberClass` below: + // public boolean isMemberClass() { + // return getSimpleBinaryName() != null && !isLocalOrAnonymousClass(); + // } + // `getSimpleBinaryName()` returns null if a given class is a top-level class, + // so we replace it with `cls.getEnclosingClass != null`. The second condition checks + // if a given class is not a local or an anonymous class, so we replace it with + // `cls.getEnclosingMethod == null` because `cls.getEnclosingMethod()` return a value + // only in either case (JVM Spec 4.8.6). + // + // Note: The newer jdk evaluates `!isLocalOrAnonymousClass()` first, + // we reorder the conditions to follow it. + cls.getEnclosingMethod == null && cls.getEnclosingClass != null + } + } } object SparkClassUtils extends SparkClassUtils diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 55dd640f1b6..82086b9d47a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -34,7 +34,7 @@ import org.apache.arrow.vector.ipc.ArrowReader import org.apache.arrow.vector.util.Text import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.connect.client.CloseableIterator @@ -290,15 +290,23 @@ object ArrowDeserializers { case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) => // We should try to make this work with MethodHandles. + val outer = Option(OuterScopes.getOuterScope(tag.runtimeClass)).map(_()).toSeq val Some(constructor) = - ScalaReflection.findConstructor(tag.runtimeClass, fields.map(_.enc.clsTag.runtimeClass)) + ScalaReflection.findConstructor( + tag.runtimeClass, + outer.map(_.getClass) ++ fields.map(_.enc.clsTag.runtimeClass)) val deserializers = if (isTuple(tag.runtimeClass)) { fields.zip(vectors).map { case (field, vector) => deserializerFor(field.enc, vector, timeZoneId) } } else { + val outerDeserializer = outer.map { value => + new Deserializer[Any] { + override def get(i: Int): Any = value + } + } val lookup = createFieldLookup(vectors) - fields.map { field => + outerDeserializer ++ fields.map { field => deserializerFor(field.enc, lookup(field.name), timeZoneId) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 0c19b8b7df1..0e69b5afa45 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -134,20 +134,6 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { assertContains("Array[Int] = Array(19, 24, 29, 34, 39)", output) } - // SPARK-43198: Switching REPL to CodeClass generation mode causes UDFs defined through lambda - // expressions to hit deserialization issues. - // TODO(SPARK-43227): Enable test after fixing deserialization issue. - ignore("UDF containing lambda expression") { - val input = """ - |class A(x: Int) { def get = x * 20 + 5 } - |val dummyUdf = (x: Int) => new A(x).get - |val myUdf = udf(dummyUdf) - |spark.range(5).select(myUdf(col("id"))).as[Int].collect() - """.stripMargin - val output = runCommandsInShell(input) - assertContains("Array[Int] = Array(5, 25, 45, 65, 85)", output) - } - test("UDF containing in-place lambda") { val input = """ |class A(x: Int) { def get = x * 42 + 5 } @@ -238,9 +224,8 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { } test("UDF Registration") { - // TODO SPARK-44449 make this long again when upcasting is in. val input = """ - |class A(x: Int) { def get: Long = x * 100 } + |class A(x: Int) { def get = x * 100 } |val myUdf = udf((x: Int) => new A(x).get) |spark.udf.register("dummyUdf", myUdf) |spark.sql("select dummyUdf(id) from range(5)").as[Long].collect() @@ -250,9 +235,8 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { } test("UDF closure registration") { - // TODO SPARK-44449 make this int again when upcasting is in. val input = """ - |class A(x: Int) { def get: Long = x * 15 } + |class A(x: Int) { def get = x * 15 } |spark.udf.register("directUdf", (x: Int) => new A(x).get) |spark.sql("select directUdf(id) from range(5)").as[Long].collect() """.stripMargin @@ -279,4 +263,17 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { val output = runCommandsInShell(input) assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output) } + + test("Collect REPL generated class") { + val input = """ + |case class MyTestClass(value: Int) + |spark.range(4). + | filter($"id" % 2 === 1). + | select($"id".cast("int").as("value")). + | as[MyTestClass]. + | collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[MyTestClass] = Array(MyTestClass(1), MyTestClass(3))", output) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 7a8e8465a70..2a499cc548f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, Primi [...] import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils, TimestampFormatter} @@ -759,6 +759,16 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } } + case class MyTestClass(value: Int) + OuterScopes.addOuterScope(this) + + test("REPL generated classes") { + val encoder = ScalaReflection.encoderFor[MyTestClass] + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(MyTestClass) + } + } + /* ******************************************************************** * * Arrow deserialization upcasting * ******************************************************************** */ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a556f03dc09..85a8ffc6d2f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2866,34 +2866,6 @@ private[spark] object Utils Hex.encodeHexString(secretBytes) } - /** - * Returns true if and only if the underlying class is a member class. - * - * Note: jdk8u throws a "Malformed class name" error if a given class is a deeply-nested - * inner class (See SPARK-34607 for details). This issue has already been fixed in jdk9+, so - * we can remove this helper method safely if we drop the support of jdk8u. - */ - def isMemberClass(cls: Class[_]): Boolean = { - try { - cls.isMemberClass - } catch { - case _: InternalError => - // We emulate jdk8u `Class.isMemberClass` below: - // public boolean isMemberClass() { - // return getSimpleBinaryName() != null && !isLocalOrAnonymousClass(); - // } - // `getSimpleBinaryName()` returns null if a given class is a top-level class, - // so we replace it with `cls.getEnclosingClass != null`. The second condition checks - // if a given class is not a local or an anonymous class, so we replace it with - // `cls.getEnclosingMethod == null` because `cls.getEnclosingMethod()` return a value - // only in either case (JVM Spec 4.8.6). - // - // Note: The newer jdk evaluates `!isLocalOrAnonymousClass()` first, - // we reorder the conditions to follow it. - cls.getEnclosingMethod == null && cls.getEnclosingClass != null - } - } - /** * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. * This method mimics scalatest's getSimpleNameOfAnObjectsClass. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala similarity index 58% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala index 6f7150d8d33..c2ac504c846 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -17,17 +17,53 @@ package org.apache.spark.sql.catalyst.encoders -import java.util.concurrent.ConcurrentMap +import java.lang.ref._ +import java.util.Objects +import java.util.concurrent.ConcurrentHashMap -import com.google.common.collect.MapMaker - -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.util.Utils +import org.apache.spark.sql.errors.ExecutionErrors +import org.apache.spark.util.SparkClassUtils object OuterScopes { - @transient - lazy val outerScopes: ConcurrentMap[String, AnyRef] = - new MapMaker().weakValues().makeMap() + private[this] val queue = new ReferenceQueue[AnyRef] + private class HashableWeakReference(v: AnyRef) extends WeakReference[AnyRef](v, queue) { + private[this] val hash = v.hashCode() + override def hashCode(): Int = hash + override def equals(obj: Any): Boolean = { + obj match { + case other: HashableWeakReference => + // Note that referential equality is used to identify & purge + // references from the map whose' referent went out of scope. + if (this eq other) { + true + } else { + val referent = get() + val otherReferent = other.get() + referent != null && otherReferent != null && Objects.equals(referent, otherReferent) + } + case _ => false + } + } + } + + private def classLoaderRef(c: Class[_]): HashableWeakReference = { + new HashableWeakReference(c.getClassLoader) + } + + private[this] val outerScopes = { + new ConcurrentHashMap[HashableWeakReference, ConcurrentHashMap[String, WeakReference[AnyRef]]] + } + + /** + * Clean the outer scopes that have been garbage collected. + */ + private def cleanOuterScopes(): Unit = { + var entry = queue.poll() + while (entry != null) { + outerScopes.remove(entry) + entry = queue.poll() + } + } /** * Adds a new outer scope to this context that can be used when instantiating an `inner class` @@ -40,7 +76,11 @@ object OuterScopes { * given wrapper class. */ def addOuterScope(outer: AnyRef): Unit = { - outerScopes.putIfAbsent(outer.getClass.getName, outer) + cleanOuterScopes() + val clz = outer.getClass + outerScopes + .computeIfAbsent(classLoaderRef(clz), _ => new ConcurrentHashMap) + .putIfAbsent(clz.getName, new WeakReference(outer)) } /** @@ -49,16 +89,24 @@ object OuterScopes { * useful for inner class defined in REPL. */ def getOuterScope(innerCls: Class[_]): () => AnyRef = { - assert(Utils.isMemberClass(innerCls)) - val outerClassName = innerCls.getDeclaringClass.getName - val outer = outerScopes.get(outerClassName) + if (!SparkClassUtils.isMemberClass(innerCls)) { + return null + } + val outerClass = innerCls.getDeclaringClass + val outerClassName = outerClass.getName + val outer = Option(outerScopes.get(classLoaderRef(outerClass))) + .flatMap(map => Option(map.get(outerClassName))) + .map(_.get()) + .orNull if (outer == null) { outerClassName match { case AmmoniteREPLClass(cellClassName) => () => { - val objClass = Utils.classForName(cellClassName) + val objClass = SparkClassUtils.classForName(cellClassName) val objInstance = objClass.getField("MODULE$").get(null) - objClass.getMethod("instance").invoke(objInstance) + val obj = objClass.getMethod("instance").invoke(objInstance) + addOuterScope(obj) + obj } // If the outer class is generated by REPL, users don't need to register it as it has // only one instance and there is a way to retrieve it: get the `$read` object, call the @@ -66,10 +114,10 @@ object OuterScopes { // method multiply times to get the single instance of the inner most `$iw` class. case REPLClass(baseClassName) => () => { - val objClass = Utils.classForName(baseClassName + "$") + val objClass = SparkClassUtils.classForName(baseClassName + "$") val objInstance = objClass.getField("MODULE$").get(null) val baseInstance = objClass.getMethod("INSTANCE").invoke(objInstance) - val baseClass = Utils.classForName(baseClassName) + val baseClass = SparkClassUtils.classForName(baseClassName) var getter = iwGetter(baseClass) var obj = baseInstance @@ -79,10 +127,9 @@ object OuterScopes { } if (obj == null) { - throw QueryExecutionErrors.cannotGetOuterPointerForInnerClassError(innerCls) + throw ExecutionErrors.cannotGetOuterPointerForInnerClassError(innerCls) } - - outerScopes.putIfAbsent(outerClassName, obj) + addOuterScope(obj) obj } case _ => null diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index 1e8e0ef5f6a..c8321e81027 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -206,6 +206,13 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { errorClass = "_LEGACY_ERROR_TEMP_2021", messageParameters = Map("cls" -> cls.toString)) } + + def cannotGetOuterPointerForInnerClassError(innerCls: Class[_]): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2154", + messageParameters = Map( + "innerCls" -> innerCls.getName)) + } } private[sql] object ExecutionErrors extends ExecutionErrors diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 7685e0f907c..45986e42348 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1422,13 +1422,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "objSerializer" -> objSerializer.toString())) } - def cannotGetOuterPointerForInnerClassError(innerCls: Class[_]): SparkRuntimeException = { - new SparkRuntimeException( - errorClass = "_LEGACY_ERROR_TEMP_2154", - messageParameters = Map( - "innerCls" -> innerCls.getName)) - } - def unsupportedOperandTypeForSizeFunctionError( dataType: DataType): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org