This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dcf3d582293 [SPARK-44791][CONNECT] Make ArrowDeserializer work with
REPL generated classes
dcf3d582293 is described below
commit dcf3d582293c3dbb3820d12fa15b41e8bd5fe6ad
Author: Herman van Hovell <[email protected]>
AuthorDate: Mon Aug 14 02:38:54 2023 +0200
[SPARK-44791][CONNECT] Make ArrowDeserializer work with REPL generated
classes
### What changes were proposed in this pull request?
Connects arrow deserialization currently does not work with REPL generated
classes. For example the following code would fail:
```scala
case class MyTestClass(value: Int) {
override def toString: String = value.toString
}
spark.range(10).map(i => MyTestClass(i.toInt)).collect()
```
The problem is that for instantiation of the `MyTestClass` class we need
the instance of the class that it was defined in (its outerscope). In Spark we
have a mechanism called `OuterScopes` to register these instances in. The
`ArrowDeserializer` was not resolving this outer instance. This PR fixes this.
We have a similar issue on the executor/driver side. This will be fixed in
a different PR.
### Why are the changes needed?
It is a bug.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
I have added tests to `ReplE2Esuite` and to the `ArrowEncoderSuite`.
Closes #42473 from hvanhovell/SPARK-44791.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../org/apache/spark/util/SparkClassUtils.scala | 28 +++++++
.../connect/client/arrow/ArrowDeserializer.scala | 14 +++-
.../spark/sql/application/ReplE2ESuite.scala | 33 ++++-----
.../connect/client/arrow/ArrowEncoderSuite.scala | 12 ++-
.../main/scala/org/apache/spark/util/Utils.scala | 28 -------
.../spark/sql/catalyst/encoders/OuterScopes.scala | 85 +++++++++++++++++-----
.../apache/spark/sql/errors/ExecutionErrors.scala | 7 ++
.../spark/sql/errors/QueryExecutionErrors.scala | 7 --
8 files changed, 138 insertions(+), 76 deletions(-)
diff --git
a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala
b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala
index a237869aef3..679d546d04c 100644
--- a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala
+++ b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala
@@ -50,6 +50,34 @@ trait SparkClassUtils {
def classIsLoadable(clazz: String): Boolean = {
Try { classForName(clazz, initialize = false) }.isSuccess
}
+
+ /**
+ * Returns true if and only if the underlying class is a member class.
+ *
+ * Note: jdk8u throws a "Malformed class name" error if a given class is a
deeply-nested
+ * inner class (See SPARK-34607 for details). This issue has already been
fixed in jdk9+, so
+ * we can remove this helper method safely if we drop the support of jdk8u.
+ */
+ def isMemberClass(cls: Class[_]): Boolean = {
+ try {
+ cls.isMemberClass
+ } catch {
+ case _: InternalError =>
+ // We emulate jdk8u `Class.isMemberClass` below:
+ // public boolean isMemberClass() {
+ // return getSimpleBinaryName() != null &&
!isLocalOrAnonymousClass();
+ // }
+ // `getSimpleBinaryName()` returns null if a given class is a
top-level class,
+ // so we replace it with `cls.getEnclosingClass != null`. The second
condition checks
+ // if a given class is not a local or an anonymous class, so we
replace it with
+ // `cls.getEnclosingMethod == null` because `cls.getEnclosingMethod()`
return a value
+ // only in either case (JVM Spec 4.8.6).
+ //
+ // Note: The newer jdk evaluates `!isLocalOrAnonymousClass()` first,
+ // we reorder the conditions to follow it.
+ cls.getEnclosingMethod == null && cls.getEnclosingClass != null
+ }
+ }
}
object SparkClassUtils extends SparkClassUtils
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 55dd640f1b6..82086b9d47a 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -34,7 +34,7 @@ import org.apache.arrow.vector.ipc.ArrowReader
import org.apache.arrow.vector.util.Text
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.connect.client.CloseableIterator
@@ -290,15 +290,23 @@ object ArrowDeserializers {
case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) =>
// We should try to make this work with MethodHandles.
+ val outer =
Option(OuterScopes.getOuterScope(tag.runtimeClass)).map(_()).toSeq
val Some(constructor) =
- ScalaReflection.findConstructor(tag.runtimeClass,
fields.map(_.enc.clsTag.runtimeClass))
+ ScalaReflection.findConstructor(
+ tag.runtimeClass,
+ outer.map(_.getClass) ++ fields.map(_.enc.clsTag.runtimeClass))
val deserializers = if (isTuple(tag.runtimeClass)) {
fields.zip(vectors).map { case (field, vector) =>
deserializerFor(field.enc, vector, timeZoneId)
}
} else {
+ val outerDeserializer = outer.map { value =>
+ new Deserializer[Any] {
+ override def get(i: Int): Any = value
+ }
+ }
val lookup = createFieldLookup(vectors)
- fields.map { field =>
+ outerDeserializer ++ fields.map { field =>
deserializerFor(field.enc, lookup(field.name), timeZoneId)
}
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 0c19b8b7df1..0e69b5afa45 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -134,20 +134,6 @@ class ReplE2ESuite extends RemoteSparkSession with
BeforeAndAfterEach {
assertContains("Array[Int] = Array(19, 24, 29, 34, 39)", output)
}
- // SPARK-43198: Switching REPL to CodeClass generation mode causes UDFs
defined through lambda
- // expressions to hit deserialization issues.
- // TODO(SPARK-43227): Enable test after fixing deserialization issue.
- ignore("UDF containing lambda expression") {
- val input = """
- |class A(x: Int) { def get = x * 20 + 5 }
- |val dummyUdf = (x: Int) => new A(x).get
- |val myUdf = udf(dummyUdf)
- |spark.range(5).select(myUdf(col("id"))).as[Int].collect()
- """.stripMargin
- val output = runCommandsInShell(input)
- assertContains("Array[Int] = Array(5, 25, 45, 65, 85)", output)
- }
-
test("UDF containing in-place lambda") {
val input = """
|class A(x: Int) { def get = x * 42 + 5 }
@@ -238,9 +224,8 @@ class ReplE2ESuite extends RemoteSparkSession with
BeforeAndAfterEach {
}
test("UDF Registration") {
- // TODO SPARK-44449 make this long again when upcasting is in.
val input = """
- |class A(x: Int) { def get: Long = x * 100 }
+ |class A(x: Int) { def get = x * 100 }
|val myUdf = udf((x: Int) => new A(x).get)
|spark.udf.register("dummyUdf", myUdf)
|spark.sql("select dummyUdf(id) from range(5)").as[Long].collect()
@@ -250,9 +235,8 @@ class ReplE2ESuite extends RemoteSparkSession with
BeforeAndAfterEach {
}
test("UDF closure registration") {
- // TODO SPARK-44449 make this int again when upcasting is in.
val input = """
- |class A(x: Int) { def get: Long = x * 15 }
+ |class A(x: Int) { def get = x * 15 }
|spark.udf.register("directUdf", (x: Int) => new A(x).get)
|spark.sql("select directUdf(id) from range(5)").as[Long].collect()
""".stripMargin
@@ -279,4 +263,17 @@ class ReplE2ESuite extends RemoteSparkSession with
BeforeAndAfterEach {
val output = runCommandsInShell(input)
assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16],
[id3,25])", output)
}
+
+ test("Collect REPL generated class") {
+ val input = """
+ |case class MyTestClass(value: Int)
+ |spark.range(4).
+ | filter($"id" % 2 === 1).
+ | select($"id".cast("int").as("value")).
+ | as[MyTestClass].
+ | collect()
+ """.stripMargin
+ val output = runCommandsInShell(input)
+ assertContains("Array[MyTestClass] = Array(MyTestClass(1),
MyTestClass(3))", output)
+ }
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index 7a8e8465a70..2a499cc548f 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.{DefinedByConstructorParams,
JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder,
DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder,
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder,
NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder,
PrimitiveDoubleEncoder, PrimitiveFloatEncoder, Primi [...]
import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor =>
toRowEncoder}
import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils,
TimestampFormatter}
@@ -759,6 +759,16 @@ class ArrowEncoderSuite extends ConnectFunSuite with
BeforeAndAfterAll {
}
}
+ case class MyTestClass(value: Int)
+ OuterScopes.addOuterScope(this)
+
+ test("REPL generated classes") {
+ val encoder = ScalaReflection.encoderFor[MyTestClass]
+ roundTripAndCheckIdentical(encoder) { () =>
+ Iterator.tabulate(10)(MyTestClass)
+ }
+ }
+
/* ******************************************************************** *
* Arrow deserialization upcasting
* ******************************************************************** */
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index a35fb3c0078..35e99785f74 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -2885,34 +2885,6 @@ private[spark] object Utils
Hex.encodeHexString(secretBytes)
}
- /**
- * Returns true if and only if the underlying class is a member class.
- *
- * Note: jdk8u throws a "Malformed class name" error if a given class is a
deeply-nested
- * inner class (See SPARK-34607 for details). This issue has already been
fixed in jdk9+, so
- * we can remove this helper method safely if we drop the support of jdk8u.
- */
- def isMemberClass(cls: Class[_]): Boolean = {
- try {
- cls.isMemberClass
- } catch {
- case _: InternalError =>
- // We emulate jdk8u `Class.isMemberClass` below:
- // public boolean isMemberClass() {
- // return getSimpleBinaryName() != null &&
!isLocalOrAnonymousClass();
- // }
- // `getSimpleBinaryName()` returns null if a given class is a
top-level class,
- // so we replace it with `cls.getEnclosingClass != null`. The second
condition checks
- // if a given class is not a local or an anonymous class, so we
replace it with
- // `cls.getEnclosingMethod == null` because `cls.getEnclosingMethod()`
return a value
- // only in either case (JVM Spec 4.8.6).
- //
- // Note: The newer jdk evaluates `!isLocalOrAnonymousClass()` first,
- // we reorder the conditions to follow it.
- cls.getEnclosingMethod == null && cls.getEnclosingClass != null
- }
- }
-
/**
* Safer than Class obj's getSimpleName which may throw Malformed class name
error in scala.
* This method mimics scalatest's getSimpleNameOfAnObjectsClass.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
similarity index 58%
rename from
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
rename to
sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
index 6f7150d8d33..c2ac504c846 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
@@ -17,17 +17,53 @@
package org.apache.spark.sql.catalyst.encoders
-import java.util.concurrent.ConcurrentMap
+import java.lang.ref._
+import java.util.Objects
+import java.util.concurrent.ConcurrentHashMap
-import com.google.common.collect.MapMaker
-
-import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.errors.ExecutionErrors
+import org.apache.spark.util.SparkClassUtils
object OuterScopes {
- @transient
- lazy val outerScopes: ConcurrentMap[String, AnyRef] =
- new MapMaker().weakValues().makeMap()
+ private[this] val queue = new ReferenceQueue[AnyRef]
+ private class HashableWeakReference(v: AnyRef) extends
WeakReference[AnyRef](v, queue) {
+ private[this] val hash = v.hashCode()
+ override def hashCode(): Int = hash
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case other: HashableWeakReference =>
+ // Note that referential equality is used to identify & purge
+ // references from the map whose' referent went out of scope.
+ if (this eq other) {
+ true
+ } else {
+ val referent = get()
+ val otherReferent = other.get()
+ referent != null && otherReferent != null &&
Objects.equals(referent, otherReferent)
+ }
+ case _ => false
+ }
+ }
+ }
+
+ private def classLoaderRef(c: Class[_]): HashableWeakReference = {
+ new HashableWeakReference(c.getClassLoader)
+ }
+
+ private[this] val outerScopes = {
+ new ConcurrentHashMap[HashableWeakReference, ConcurrentHashMap[String,
WeakReference[AnyRef]]]
+ }
+
+ /**
+ * Clean the outer scopes that have been garbage collected.
+ */
+ private def cleanOuterScopes(): Unit = {
+ var entry = queue.poll()
+ while (entry != null) {
+ outerScopes.remove(entry)
+ entry = queue.poll()
+ }
+ }
/**
* Adds a new outer scope to this context that can be used when
instantiating an `inner class`
@@ -40,7 +76,11 @@ object OuterScopes {
* given wrapper class.
*/
def addOuterScope(outer: AnyRef): Unit = {
- outerScopes.putIfAbsent(outer.getClass.getName, outer)
+ cleanOuterScopes()
+ val clz = outer.getClass
+ outerScopes
+ .computeIfAbsent(classLoaderRef(clz), _ => new ConcurrentHashMap)
+ .putIfAbsent(clz.getName, new WeakReference(outer))
}
/**
@@ -49,16 +89,24 @@ object OuterScopes {
* useful for inner class defined in REPL.
*/
def getOuterScope(innerCls: Class[_]): () => AnyRef = {
- assert(Utils.isMemberClass(innerCls))
- val outerClassName = innerCls.getDeclaringClass.getName
- val outer = outerScopes.get(outerClassName)
+ if (!SparkClassUtils.isMemberClass(innerCls)) {
+ return null
+ }
+ val outerClass = innerCls.getDeclaringClass
+ val outerClassName = outerClass.getName
+ val outer = Option(outerScopes.get(classLoaderRef(outerClass)))
+ .flatMap(map => Option(map.get(outerClassName)))
+ .map(_.get())
+ .orNull
if (outer == null) {
outerClassName match {
case AmmoniteREPLClass(cellClassName) =>
() => {
- val objClass = Utils.classForName(cellClassName)
+ val objClass = SparkClassUtils.classForName(cellClassName)
val objInstance = objClass.getField("MODULE$").get(null)
- objClass.getMethod("instance").invoke(objInstance)
+ val obj = objClass.getMethod("instance").invoke(objInstance)
+ addOuterScope(obj)
+ obj
}
// If the outer class is generated by REPL, users don't need to
register it as it has
// only one instance and there is a way to retrieve it: get the
`$read` object, call the
@@ -66,10 +114,10 @@ object OuterScopes {
// method multiply times to get the single instance of the inner most
`$iw` class.
case REPLClass(baseClassName) =>
() => {
- val objClass = Utils.classForName(baseClassName + "$")
+ val objClass = SparkClassUtils.classForName(baseClassName + "$")
val objInstance = objClass.getField("MODULE$").get(null)
val baseInstance =
objClass.getMethod("INSTANCE").invoke(objInstance)
- val baseClass = Utils.classForName(baseClassName)
+ val baseClass = SparkClassUtils.classForName(baseClassName)
var getter = iwGetter(baseClass)
var obj = baseInstance
@@ -79,10 +127,9 @@ object OuterScopes {
}
if (obj == null) {
- throw
QueryExecutionErrors.cannotGetOuterPointerForInnerClassError(innerCls)
+ throw
ExecutionErrors.cannotGetOuterPointerForInnerClassError(innerCls)
}
-
- outerScopes.putIfAbsent(outerClassName, obj)
+ addOuterScope(obj)
obj
}
case _ => null
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
index 1e8e0ef5f6a..c8321e81027 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
@@ -206,6 +206,13 @@ private[sql] trait ExecutionErrors extends
DataTypeErrorsBase {
errorClass = "_LEGACY_ERROR_TEMP_2021",
messageParameters = Map("cls" -> cls.toString))
}
+
+ def cannotGetOuterPointerForInnerClassError(innerCls: Class[_]):
SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "_LEGACY_ERROR_TEMP_2154",
+ messageParameters = Map(
+ "innerCls" -> innerCls.getName))
+ }
}
private[sql] object ExecutionErrors extends ExecutionErrors
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 1cc79a92c4c..953d9713c7a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1365,13 +1365,6 @@ private[sql] object QueryExecutionErrors extends
QueryErrorsBase with ExecutionE
"objSerializer" -> objSerializer.toString()))
}
- def cannotGetOuterPointerForInnerClassError(innerCls: Class[_]):
SparkRuntimeException = {
- new SparkRuntimeException(
- errorClass = "_LEGACY_ERROR_TEMP_2154",
- messageParameters = Map(
- "innerCls" -> innerCls.getName))
- }
-
def unsupportedOperandTypeForSizeFunctionError(
dataType: DataType): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]