This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e5b7a2e76 [GLUTEN-8761][VL] Fix mode in 
UnsafeColumnarBuildSideRelation not get properly serialize (#8762)
4e5b7a2e76 is described below

commit 4e5b7a2e76ff929761be388db83e153e76fad434
Author: Terry Wang <[email protected]>
AuthorDate: Thu Feb 20 10:53:36 2025 +0800

    [GLUTEN-8761][VL] Fix mode in UnsafeColumnarBuildSideRelation not get 
properly serialize (#8762)
    
    Closes #8761
---
 .../unsafe/UnsafeColumnarBuildSideRelation.scala   |   6 +-
 .../UnsafeColumnarBuildSideRelationTest.scala      | 101 +++++++++++++++++++++
 2 files changed, 104 insertions(+), 3 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
index ac9aef3bdd..47f659d189 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala
@@ -104,7 +104,7 @@ case class UnsafeColumnarBuildSideRelation(
 
   override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException {
     kryo.writeObject(out, output.toList)
-    kryo.writeObject(out, mode)
+    kryo.writeClassAndObject(out, mode)
     out.writeInt(batches.arraySize)
     kryo.writeObject(out, batches.bytesBufferLengths)
     out.writeLong(batches.totalBytes)
@@ -136,14 +136,14 @@ case class UnsafeColumnarBuildSideRelation(
     for (i <- 0 until totalArraySize) {
       val length = bytesBufferLengths(i)
       val tmpBuffer = new Array[Byte](length)
-      in.read(tmpBuffer)
+      in.readFully(tmpBuffer)
       batches.putBytesBuffer(i, tmpBuffer)
     }
   }
 
   override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
     output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]]
-    mode = kryo.readObject(in, classOf[BroadcastMode])
+    mode = kryo.readClassAndObject(in).asInstanceOf[BroadcastMode]
     val totalArraySize = in.readInt()
     val bytesBufferLengths = kryo.readObject(in, classOf[Array[Int]])
     val totalBytes = in.readLong()
diff --git 
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
 
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
new file mode 100644
index 0000000000..f47c8bd562
--- /dev/null
+++ 
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelationTest.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.unsafe;
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
+import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StringType;
+
+class UnsafeColumnarBuildSideRelationTest extends SharedSparkSession {
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      .set("spark.memory.offHeap.size", "200M")
+      .set("spark.memory.offHeap.enabled", "true")
+  }
+
+  var unsafeRelWithIdentityMode: UnsafeColumnarBuildSideRelation = null
+  var unsafeRelWithHashMode: UnsafeColumnarBuildSideRelation = null
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val taskMemoryManager = new TaskMemoryManager(
+      new UnifiedMemoryManager(SparkEnv.get.conf, Long.MaxValue, Long.MaxValue 
/ 2, 1),
+      0)
+    val a = AttributeReference("a", StringType, nullable = false, null)()
+    val output = Seq(a)
+    val totalArraySize = 1
+    val perArraySize = new Array[Int](totalArraySize)
+    perArraySize(0) = 10
+    val bytesArray = UnsafeBytesBufferArray(
+      1,
+      perArraySize,
+      10,
+      taskMemoryManager
+    )
+    bytesArray.putBytesBuffer(0, "1234567890".getBytes())
+    unsafeRelWithIdentityMode = UnsafeColumnarBuildSideRelation(
+      output,
+      bytesArray,
+      IdentityBroadcastMode
+    )
+    unsafeRelWithHashMode = UnsafeColumnarBuildSideRelation(
+      output,
+      bytesArray,
+      HashedRelationBroadcastMode(output, isNullAware = false)
+    )
+  }
+
+  test("Java default serialization") {
+    val javaSerialization = new JavaSerializer(SparkEnv.get.conf)
+    val serializerInstance = javaSerialization.newInstance()
+
+    // test unsafeRelWithIdentityMode
+    val buffer = serializerInstance.serialize(unsafeRelWithIdentityMode)
+    val obj = 
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer)
+    assert(obj != null)
+    assert(obj.mode == IdentityBroadcastMode)
+
+    // test unsafeRelWithHashMode
+    val buffer2 = serializerInstance.serialize(unsafeRelWithHashMode)
+    val obj2 = 
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer2)
+    assert(obj2 != null)
+    assert(obj2.mode.isInstanceOf[HashedRelationBroadcastMode])
+  }
+
+  test("Kryo serialization") {
+    val kryo = new KryoSerializer(SparkEnv.get.conf)
+    val serializerInstance = kryo.newInstance()
+
+    // test unsafeRelWithIdentityMode
+    val buffer = serializerInstance.serialize(unsafeRelWithIdentityMode)
+    val obj = 
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer)
+    assert(obj != null)
+    assert(obj.mode == IdentityBroadcastMode)
+
+    // test unsafeRelWithHashMode
+    val buffer2 = serializerInstance.serialize(unsafeRelWithHashMode)
+    val obj2 = 
serializerInstance.deserialize[UnsafeColumnarBuildSideRelation](buffer2)
+    assert(obj2 != null)
+    assert(obj2.mode.isInstanceOf[HashedRelationBroadcastMode])
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to