This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 4e5b7a2e76 [GLUTEN-8761][VL] Fix mode in
UnsafeColumnarBuildSideRelation not get properly serialize (#8762)
4e5b7a2e76 is described below
commit 4e5b7a2e76ff929761be388db83e153e76fad434
Author: Terry Wang <[email protected]>
AuthorDate: Thu Feb 20 10:53:36 2025 +0800
[GLUTEN-8761][VL] Fix mode in UnsafeColumnarBuildSideRelation not get
properly serialize (#8762)
Closes #8761
---
.../unsafe/UnsafeColumnarBuildSideRelation.scala | 6 +-
.../UnsafeColumnarBuildSideRelationTest.scala | 101 +++++++++++++++++++++
2 files changed, 104 insertions(+), 3 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index ac9aef3bdd..47f659d189 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -104,7 +104,7 @@ case class UnsafeColumnarBuildSideRelation(
override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
kryo.writeObject(out, output.toList)
- kryo.writeObject(out, mode)
+ kryo.writeClassAndObject(out, mode)
out.writeInt(batches.arraySize)
kryo.writeObject(out, batches.bytesBufferLengths)
out.writeLong(batches.totalBytes)
@@ -136,14 +136,14 @@ case class UnsafeColumnarBuildSideRelation(
for (i <- 0 until totalArraySize) {
val length = bytesBufferLengths(i)
val tmpBuffer = new Array[Byte](length)
- in.read(tmpBuffer)
+ in.readFully(tmpBuffer)
batches.putBytesBuffer(i, tmpBuffer)
}
}
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]]
- mode = kryo.readObject(in, classOf[BroadcastMode])
+ mode = kryo.readClassAndObject(in).asInstanceOf[BroadcastMode]
val totalArraySize = in.readInt()
val bytesBufferLengths = kryo.readObject(in, classOf[Array[Int]])
val totalBytes = in.readLong()
diff --git
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
new file mode 100644
index 0000000000..f47c8bd562
--- /dev/null
+++
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.unsafe;
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
+import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StringType;
+
+class UnsafeColumnarBuildSideRelationTest extends SharedSparkSession {
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf
+ .set("spark.memory.offHeap.size", "200M")
+ .set("spark.memory.offHeap.enabled", "true")
+ }
+
+ var unsafeRelWithIdentityMode: UnsafeColumnarBuildSideRelation = null
+ var unsafeRelWithHashMode: UnsafeColumnarBuildSideRelation = null
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val taskMemoryManager = new TaskMemoryManager(
+ new UnifiedMemoryManager(SparkEnv.get.conf, Long.MaxValue, Long.MaxValue
/ 2, 1),
+ 0)
+ val a = AttributeReference("a", StringType, nullable = false, null)()
+ val output = Seq(a)
+ val totalArraySize = 1
+ val perArraySize = new Array[Int](totalArraySize)
+ perArraySize(0) = 10
+ val bytesArray = UnsafeBytesBufferArray(
+ 1,
+ perArraySize,
+ 10,
+ taskMemoryManager
+ )
+ bytesArray.putBytesBuffer(0, "1234567890".getBytes())
+ unsafeRelWithIdentityMode = UnsafeColumnarBuildSideRelation(
+ output,
+ bytesArray,
+ IdentityBroadcastMode
+ )
+ unsafeRelWithHashMode = UnsafeColumnarBuildSideRelation(
+ output,
+ bytesArray,
+ HashedRelationBroadcastMode(output, isNullAware = false)
+ )
+ }
+
+ test("Java default serialization") {
+ val javaSerialization = new JavaSerializer(SparkEnv.get.conf)
+ val serializerInstance = javaSerialization.newInstance()
+
+ // test unsafeRelWithIdentityMode
+ val buffer = serializerInstance.serialize(unsafeRelWithIdentityMode)
+ val obj =
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer)
+ assert(obj != null)
+ assert(obj.mode == IdentityBroadcastMode)
+
+ // test unsafeRelWithHashMode
+ val buffer2 = serializerInstance.serialize(unsafeRelWithHashMode)
+ val obj2 =
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer2)
+ assert(obj2 != null)
+ assert(obj2.mode.isInstanceOf[HashedRelationBroadcastMode])
+ }
+
+ test("Kryo serialization") {
+ val kryo = new KryoSerializer(SparkEnv.get.conf)
+ val serializerInstance = kryo.newInstance()
+
+ // test unsafeRelWithIdentityMode
+ val buffer = serializerInstance.serialize(unsafeRelWithIdentityMode)
+ val obj =
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer)
+ assert(obj != null)
+ assert(obj.mode == IdentityBroadcastMode)
+
+ // test unsafeRelWithHashMode
+ val buffer2 = serializerInstance.serialize(unsafeRelWithHashMode)
+ val obj2 =
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer2)
+ assert(obj2 != null)
+ assert(obj2.mode.isInstanceOf[HashedRelationBroadcastMode])
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]