This is an automated email from the ASF dual-hosted git repository.

kerwinzhang pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.3 by this push:
     new c67c6e404 [CELEBORN-1123] Support fallback to non-columnar shuffle for 
schema that cannot be obtained from shuffle dependency (#2110)
c67c6e404 is described below

commit c67c6e404a4cab5c7d1c62b4ff4c2de19dd47146
Author: Nicholas Jiang <[email protected]>
AuthorDate: Mon Nov 27 13:54:07 2023 +0800

    [CELEBORN-1123] Support fallback to non-columnar shuffle for schema that 
cannot be obtained from shuffle dependency (#2110)
---
 .../shuffle/celeborn/HashBasedShuffleWriter.java   |  6 +-
 .../apache/spark/shuffle/celeborn/SparkUtils.java  | 14 ++--
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  9 +--
 .../execution/columnar/CelebornBatchBuilder.scala  |  4 +-
 .../columnar/CelebornColumnAccessor.scala          |  9 +--
 .../execution/columnar/CelebornColumnBuilder.scala |  8 +--
 .../execution/columnar/CelebornColumnStats.scala   | 83 +++++++---------------
 .../execution/columnar/CelebornColumnType.scala    | 83 ++++++----------------
 .../columnar/CelebornColumnarBatchBuilder.scala    | 28 ++------
 .../CelebornColumnarBatchCodeGenBuild.scala        | 15 ----
 .../columnar/CelebornColumnarBatchSerializer.scala | 11 +--
 .../CelebornCompressibleColumnBuilder.scala        |  6 +-
 .../columnar/CelebornCompressionScheme.scala       |  9 +--
 .../columnar/CelebornCompressionSchemes.scala      | 16 ++---
 14 files changed, 91 insertions(+), 210 deletions(-)

diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index dee8d897a..4ddb8e98d 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -70,6 +70,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleDependency<K, V, C> dep;
   private final Partitioner partitioner;
   private final ShuffleWriteMetricsReporter writeMetrics;
+  private final int stageId;
   private final int shuffleId;
   private final int mapId;
   private final TaskContext taskContext;
@@ -132,6 +133,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       throws IOException {
     this.mapId = taskContext.partitionId();
     this.dep = handle.dependency();
+    this.stageId = taskContext.stageId();
     this.shuffleId = dep.shuffleId();
     SerializerInstance serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
@@ -185,7 +187,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       columnarShuffleDictionaryMaxFactor = 
conf.columnarShuffleDictionaryMaxFactor();
       this.schema = SparkUtils.getSchema(dep);
       this.celebornBatchBuilders = new CelebornBatchBuilder[numPartitions];
-      this.isColumnarShuffle = 
CelebornBatchBuilder.supportsColumnarType(schema);
+      this.isColumnarShuffle = schema != null && 
CelebornBatchBuilder.supportsColumnarType(schema);
     }
   }
 
@@ -194,6 +196,8 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     try {
       if (canUseFastWrite()) {
         if (isColumnarShuffle) {
+          logger.info(
+              "Fast columnar write of columnar shuffle {} for stage {}.", 
shuffleId, stageId);
           fastColumnarWrite0(records);
         } else {
           fastWrite0(records);
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 8f8600b9a..7f4a541cb 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.shuffle.celeborn;
 
-import java.io.IOException;
 import java.util.concurrent.atomic.LongAdder;
 
 import scala.Tuple2;
@@ -162,10 +161,17 @@ public class SparkUtils {
   private static final DynFields.UnboundField<StructType> SCHEMA_FIELD =
       DynFields.builder().hiddenImpl(ShuffleDependency.class, 
"schema").defaultAlwaysNull().build();
 
-  public static StructType getSchema(ShuffleDependency<?, ?, ?> dep) throws 
IOException {
-    StructType schema = SCHEMA_FIELD.bind(dep).get();
+  public static StructType getSchema(ShuffleDependency<?, ?, ?> dep) {
+    StructType schema = null;
+    try {
+      schema = SCHEMA_FIELD.bind(dep).get();
+    } catch (Exception e) {
+      LOG.error("Failed to bind shuffle dependency of shuffle {}.", 
dep.shuffleId(), e);
+    }
     if (schema == null) {
-      throw new IOException("Failed to get Schema, columnar shuffle won't work 
properly.");
+      LOG.warn(
+          "Failed to get Schema of shuffle {}, columnar shuffle won't work 
properly.",
+          dep.shuffleId());
     }
     return schema;
   }
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index ecbf51f80..e27cfbfc1 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -62,14 +62,11 @@ class CelebornShuffleReader[K, C](
     var serializerInstance = dep.serializer.newInstance()
     if (conf.columnarShuffleEnabled) {
       val schema = SparkUtils.getSchema(dep)
-      if (CelebornBatchBuilder.supportsColumnarType(
-          schema)) {
-        val dataSize = SparkUtils.getDataSize(
-          dep.serializer.asInstanceOf[UnsafeRowSerializer])
+      if (schema != null && CelebornBatchBuilder.supportsColumnarType(schema)) 
{
+        logInfo(s"Creating column batch serializer of columnar shuffle 
${dep.shuffleId}.")
+        val dataSize = 
SparkUtils.getDataSize(dep.serializer.asInstanceOf[UnsafeRowSerializer])
         serializerInstance = new CelebornColumnarBatchSerializer(
           schema,
-          conf.columnarShuffleBatchSize,
-          conf.columnarShuffleDictionaryEnabled,
           conf.columnarShuffleOffHeapEnabled,
           dataSize).newInstance()
       }
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala
index 7ae77fec0..bc93c10b5 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornBatchBuilder.scala
@@ -28,7 +28,7 @@ abstract class CelebornBatchBuilder {
 
   def writeRow(row: InternalRow): Unit
 
-  def getRowCnt(): Int
+  def getRowCnt: Int
 
   def int2ByteArray(i: Int): Array[Byte] = {
     val result = new Array[Byte](4)
@@ -46,7 +46,7 @@ object CelebornBatchBuilder {
       f.dataType match {
         case BooleanType | ByteType | ShortType | IntegerType | LongType |
             FloatType | DoubleType | StringType => true
-        case dt: DecimalType => true
+        case _: DecimalType => true
         case _ => false
       })
   }
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala
index 064bbefc6..a75c8d32a 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnAccessor.scala
@@ -61,13 +61,9 @@ abstract class CelebornBasicColumnAccessor[JvmType](
     columnType.extract(buffer, row, ordinal)
   }
 
-  protected def underlyingBuffer = buffer
+  protected def underlyingBuffer: ByteBuffer = buffer
 }
 
-class CelebornNullColumnAccessor(buffer: ByteBuffer)
-  extends CelebornBasicColumnAccessor[Any](buffer, CELEBORN_NULL)
-  with CelebornNullableColumnAccessor
-
 abstract class CelebornNativeColumnAccessor[T <: AtomicType](
     override protected val buffer: ByteBuffer,
     override protected val columnType: NativeCelebornColumnType[T])
@@ -112,7 +108,6 @@ private[sql] object CelebornColumnAccessor {
     val buf = buffer.order(ByteOrder.nativeOrder)
 
     dataType match {
-      case NullType => new CelebornNullColumnAccessor(buf)
       case BooleanType => new CelebornBooleanColumnAccessor(buf)
       case ByteType => new CelebornByteColumnAccessor(buf)
       case ShortType => new CelebornShortColumnAccessor(buf)
@@ -135,7 +130,7 @@ private[sql] object CelebornColumnAccessor {
     columnAccessor match {
       case nativeAccessor: CelebornNativeColumnAccessor[_] =>
         nativeAccessor.decompress(columnVector, numRows)
-      case d: CelebornDecimalColumnAccessor =>
+      case _: CelebornDecimalColumnAccessor =>
         (0 until 
numRows).foreach(columnAccessor.extractToColumnVector(columnVector, _))
       case _ =>
         throw new RuntimeException("Not support non-primitive type now")
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala
index 0abfdd0cd..f65a5fd86 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnBuilder.scala
@@ -88,10 +88,6 @@ class CelebornBasicColumnBuilder[JvmType](
   }
 }
 
-class CelebornNullColumnBuilder
-  extends CelebornBasicColumnBuilder[Any](new 
CelebornObjectColumnStats(NullType), CELEBORN_NULL)
-  with CelebornNullableColumnBuilder
-
 abstract class CelebornComplexColumnBuilder[JvmType](
     columnStats: CelebornColumnStats,
     columnType: CelebornColumnType[JvmType])
@@ -318,7 +314,6 @@ class CelebornDecimalCodeGenColumnBuilder(dataType: 
DecimalType)
 }
 
 object CelebornColumnBuilder {
-  val MAX_BATCH_SIZE_IN_BYTE: Long = 4 * 1024 * 1024L
 
   def ensureFreeSpace(orig: ByteBuffer, size: Int): ByteBuffer = {
     if (orig.remaining >= size) {
@@ -343,7 +338,6 @@ object CelebornColumnBuilder {
       encodingEnabled: Boolean,
       encoder: Encoder[_ <: AtomicType]): CelebornColumnBuilder = {
     val builder: CelebornColumnBuilder = dataType match {
-      case NullType => new CelebornNullColumnBuilder
       case ByteType => new CelebornByteColumnBuilder
       case BooleanType => new CelebornBooleanColumnBuilder
       case ShortType => new CelebornShortColumnBuilder
@@ -367,7 +361,7 @@ object CelebornColumnBuilder {
         new CelebornCompactDecimalColumnBuilder(dt)
       case dt: DecimalType => new CelebornDecimalColumnBuilder(dt)
       case other =>
-        throw new Exception(s"not support type: $other")
+        throw new Exception(s"Unsupported type: $other")
     }
 
     builder.initialize(rowCnt, columnName, encodingEnabled)
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala
index 6c2aa0f7b..80e883bf0 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnStats.scala
@@ -63,7 +63,7 @@ final private[columnar] class CelebornBooleanColumnStats 
extends CelebornColumnS
       val value = row.getBoolean(ordinal)
       gatherValueStats(value)
     } else {
-      gatherNullStats
+      gatherNullStats()
     }
   }
 
@@ -79,15 +79,15 @@ final private[columnar] class CelebornBooleanColumnStats 
extends CelebornColumnS
 }
 
 final private[columnar] class CelebornByteColumnStats extends 
CelebornColumnStats {
-  protected var upper = Byte.MinValue
-  protected var lower = Byte.MaxValue
+  protected var upper: Byte = Byte.MinValue
+  protected var lower: Byte = Byte.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     if (!row.isNullAt(ordinal)) {
       val value = row.getByte(ordinal)
       gatherValueStats(value)
     } else {
-      gatherNullStats
+      gatherNullStats()
     }
   }
 
@@ -103,15 +103,15 @@ final private[columnar] class CelebornByteColumnStats 
extends CelebornColumnStat
 }
 
 final private[columnar] class CelebornShortColumnStats extends 
CelebornColumnStats {
-  protected var upper = Short.MinValue
-  protected var lower = Short.MaxValue
+  protected var upper: Short = Short.MinValue
+  protected var lower: Short = Short.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     if (!row.isNullAt(ordinal)) {
       val value = row.getShort(ordinal)
       gatherValueStats(value)
     } else {
-      gatherNullStats
+      gatherNullStats()
     }
   }
 
@@ -127,15 +127,15 @@ final private[columnar] class CelebornShortColumnStats 
extends CelebornColumnSta
 }
 
 final private[columnar] class CelebornIntColumnStats extends 
CelebornColumnStats {
-  protected var upper = Int.MinValue
-  protected var lower = Int.MaxValue
+  protected var upper: Int = Int.MinValue
+  protected var lower: Int = Int.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     if (!row.isNullAt(ordinal)) {
       val value = row.getInt(ordinal)
       gatherValueStats(value)
     } else {
-      gatherNullStats
+      gatherNullStats()
     }
   }
 
@@ -151,15 +151,15 @@ final private[columnar] class CelebornIntColumnStats 
extends CelebornColumnStats
 }
 
 final private[columnar] class CelebornLongColumnStats extends 
CelebornColumnStats {
-  protected var upper = Long.MinValue
-  protected var lower = Long.MaxValue
+  protected var upper: Long = Long.MinValue
+  protected var lower: Long = Long.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     if (!row.isNullAt(ordinal)) {
       val value = row.getLong(ordinal)
       gatherValueStats(value)
     } else {
-      gatherNullStats
+      gatherNullStats()
     }
   }
 
@@ -175,15 +175,15 @@ final private[columnar] class CelebornLongColumnStats 
extends CelebornColumnStat
 }
 
 final private[columnar] class CelebornFloatColumnStats extends 
CelebornColumnStats {
-  protected var upper = Float.MinValue
-  protected var lower = Float.MaxValue
+  protected var upper: Float = Float.MinValue
+  protected var lower: Float = Float.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     if (!row.isNullAt(ordinal)) {
       val value = row.getFloat(ordinal)
       gatherValueStats(value)
     } else {
-      gatherNullStats
+      gatherNullStats()
     }
   }
 
@@ -199,15 +199,15 @@ final private[columnar] class CelebornFloatColumnStats 
extends CelebornColumnSta
 }
 
 final private[columnar] class CelebornDoubleColumnStats extends 
CelebornColumnStats {
-  protected var upper = Double.MinValue
-  protected var lower = Double.MaxValue
+  protected var upper: Double = Double.MinValue
+  protected var lower: Double = Double.MaxValue
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     if (!row.isNullAt(ordinal)) {
       val value = row.getDouble(ordinal)
       gatherValueStats(value)
     } else {
-      gatherNullStats
+      gatherNullStats()
     }
   }
 
@@ -223,8 +223,8 @@ final private[columnar] class CelebornDoubleColumnStats 
extends CelebornColumnSt
 }
 
 final private[columnar] class CelebornStringColumnStats extends 
CelebornColumnStats {
-  protected var upper: UTF8String = null
-  protected var lower: UTF8String = null
+  protected var upper: UTF8String = _
+  protected var lower: UTF8String = _
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     if (!row.isNullAt(ordinal)) {
@@ -232,7 +232,7 @@ final private[columnar] class CelebornStringColumnStats 
extends CelebornColumnSt
       val size = CELEBORN_STRING.actualSize(row, ordinal)
       gatherValueStats(value, size)
     } else {
-      gatherNullStats
+      gatherNullStats()
     }
   }
 
@@ -247,34 +247,19 @@ final private[columnar] class CelebornStringColumnStats 
extends CelebornColumnSt
     Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-final private[columnar] class CelebornBinaryColumnStats extends 
CelebornColumnStats {
-  override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    if (!row.isNullAt(ordinal)) {
-      val size = CELEBORN_BINARY.actualSize(row, ordinal)
-      sizeInBytes += size
-      count += 1
-    } else {
-      gatherNullStats
-    }
-  }
-
-  override def collectedStatistics: Array[Any] =
-    Array[Any](null, null, nullCount, count, sizeInBytes)
-}
-
 final private[columnar] class CelebornDecimalColumnStats(precision: Int, 
scale: Int)
   extends CelebornColumnStats {
   def this(dt: DecimalType) = this(dt.precision, dt.scale)
 
-  protected var upper: Decimal = null
-  protected var lower: Decimal = null
+  protected var upper: Decimal = _
+  protected var lower: Decimal = _
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     if (!row.isNullAt(ordinal)) {
       val value = row.getDecimal(ordinal, precision, scale)
       gatherValueStats(value)
     } else {
-      gatherNullStats
+      gatherNullStats()
     }
   }
 
@@ -294,21 +279,3 @@ final private[columnar] class 
CelebornDecimalColumnStats(precision: Int, scale:
   override def collectedStatistics: Array[Any] =
     Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
-
-final private[columnar] class CelebornObjectColumnStats(dataType: DataType)
-  extends CelebornColumnStats {
-  val columnType = CelebornColumnType(dataType)
-
-  override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
-    if (!row.isNullAt(ordinal)) {
-      val size = columnType.actualSize(row, ordinal)
-      sizeInBytes += size
-      count += 1
-    } else {
-      gatherNullStats
-    }
-  }
-
-  override def collectedStatistics: Array[Any] =
-    Array[Any](null, null, nullCount, count, sizeInBytes)
-}
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala
index d1d5461a4..69cf10a2e 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnType.scala
@@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.columnar
 import java.math.{BigDecimal, BigInteger}
 import java.nio.ByteBuffer
 
-import scala.reflect.runtime.universe.TypeTag
-
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
@@ -177,26 +175,10 @@ sealed abstract private[columnar] class 
CelebornColumnType[JvmType] {
   override def toString: String = getClass.getSimpleName.stripSuffix("$")
 }
 
-private[columnar] object CELEBORN_NULL extends CelebornColumnType[Any] {
-
-  override def dataType: DataType = NullType
-  override def defaultSize: Int = 0
-  override def append(v: Any, buffer: ByteBuffer): Unit = {}
-  override def extract(buffer: ByteBuffer): Any = null
-  override def setField(row: InternalRow, ordinal: Int, value: Any): Unit = 
row.setNullAt(ordinal)
-  override def getField(row: InternalRow, ordinal: Int): Any = null
-}
-
 abstract private[columnar] class NativeCelebornColumnType[T <: AtomicType](
     val dataType: T,
     val defaultSize: Int)
-  extends CelebornColumnType[T#InternalType] {
-
-  /**
-   * Scala TypeTag. Can be used to create primitive arrays and hash tables.
-   */
-  def scalaTag: TypeTag[dataType.InternalType] = dataType.tag
-}
+  extends CelebornColumnType[T#InternalType] {}
 
 private[columnar] object CELEBORN_INT extends 
NativeCelebornColumnType(IntegerType, 4) {
   override def append(v: Int, buffer: ByteBuffer): Unit = {
@@ -428,26 +410,28 @@ private[columnar] trait 
DirectCopyCelebornColumnType[JvmType] extends CelebornCo
 
   // copy the bytes from ByteBuffer to UnsafeRow
   override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): 
Unit = {
-    if (row.isInstanceOf[MutableUnsafeRow]) {
-      val numBytes = buffer.getInt
-      val cursor = buffer.position()
-      buffer.position(cursor + numBytes)
-      row.asInstanceOf[MutableUnsafeRow].writer.write(
-        ordinal,
-        buffer.array(),
-        buffer.arrayOffset() + cursor,
-        numBytes)
-    } else {
-      setField(row, ordinal, extract(buffer))
+    row match {
+      case r: MutableUnsafeRow =>
+        val numBytes = buffer.getInt
+        val cursor = buffer.position()
+        buffer.position(cursor + numBytes)
+        r.writer.write(
+          ordinal,
+          buffer.array(),
+          buffer.arrayOffset() + cursor,
+          numBytes)
+      case _ =>
+        setField(row, ordinal, extract(buffer))
     }
   }
 
   // copy the bytes from UnsafeRow to ByteBuffer
   override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): 
Unit = {
-    if (row.isInstanceOf[UnsafeRow]) {
-      row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer)
-    } else {
-      super.append(row, ordinal, buffer)
+    row match {
+      case r: UnsafeRow =>
+        r.writeFieldTo(ordinal, buffer)
+      case _ =>
+        super.append(row, ordinal, buffer)
     }
   }
 }
@@ -472,10 +456,11 @@ private[columnar] object CELEBORN_STRING
   }
 
   override def setField(row: InternalRow, ordinal: Int, value: UTF8String): 
Unit = {
-    if (row.isInstanceOf[MutableUnsafeRow]) {
-      row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value)
-    } else {
-      row.update(ordinal, value.clone())
+    row match {
+      case r: MutableUnsafeRow =>
+        r.writer.write(ordinal, value)
+      case _ =>
+        row.update(ordinal, value.clone())
     }
   }
 
@@ -617,26 +602,6 @@ sealed abstract private[columnar] class 
ByteArrayCelebornColumnType[JvmType](val
   }
 }
 
-private[columnar] object CELEBORN_BINARY extends 
ByteArrayCelebornColumnType[Array[Byte]](16) {
-
-  def dataType: DataType = BinaryType
-
-  override def setField(row: InternalRow, ordinal: Int, value: Array[Byte]): 
Unit = {
-    row.update(ordinal, value)
-  }
-
-  override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
-    row.getBinary(ordinal)
-  }
-
-  override def actualSize(row: InternalRow, ordinal: Int): Int = {
-    row.getBinary(ordinal).length + 4
-  }
-
-  def serialize(value: Array[Byte]): Array[Byte] = value
-  def deserialize(bytes: Array[Byte]): Array[Byte] = bytes
-}
-
 private[columnar] case class CELEBORN_LARGE_DECIMAL(precision: Int, scale: Int)
   extends ByteArrayCelebornColumnType[Decimal](12) {
 
@@ -673,7 +638,6 @@ private[columnar] object CELEBORN_LARGE_DECIMAL {
 private[columnar] object CelebornColumnType {
   def apply(dataType: DataType): CelebornColumnType[_] = {
     dataType match {
-      case NullType => CELEBORN_NULL
       case BooleanType => CELEBORN_BOOLEAN
       case ByteType => CELEBORN_BYTE
       case ShortType => CELEBORN_SHORT
@@ -682,7 +646,6 @@ private[columnar] object CelebornColumnType {
       case FloatType => CELEBORN_FLOAT
       case DoubleType => CELEBORN_DOUBLE
       case StringType => CELEBORN_STRING
-      case BinaryType => CELEBORN_BINARY
       case dt: DecimalType if dt.precision <= Decimal.MAX_INT_DIGITS =>
         CELEBORN_COMPACT_MINI_DECIMAL(dt)
       case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala
index 159b15e32..ab6f60072 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchBuilder.scala
@@ -30,7 +30,8 @@ class CelebornColumnarBatchBuilder(
     encodingEnabled: Boolean = false) extends CelebornBatchBuilder {
   var rowCnt = 0
 
-  val typeConversion: PartialFunction[DataType, NativeCelebornColumnType[_ <: 
AtomicType]] = {
+  private val typeConversion
+      : PartialFunction[DataType, NativeCelebornColumnType[_ <: AtomicType]] = 
{
     case IntegerType => CELEBORN_INT
     case LongType => CELEBORN_LONG
     case StringType => CELEBORN_STRING
@@ -45,7 +46,7 @@ class CelebornColumnarBatchBuilder(
     case _ => null
   }
 
-  val encodersArr: Array[Encoder[_ <: AtomicType]] = schema.map { attribute =>
+  private val encodersArr: Array[Encoder[_ <: AtomicType]] = schema.map { 
attribute =>
     val nativeColumnType = typeConversion(attribute.dataType)
     if (nativeColumnType == null) {
       null
@@ -63,14 +64,13 @@ class CelebornColumnarBatchBuilder(
   var columnBuilders: Array[CelebornColumnBuilder] = _
 
   def newBuilders(): Unit = {
-    totalSize = 0
     rowCnt = 0
     var i = -1
     columnBuilders = schema.map { attribute =>
       i += 1
       encodersArr(i) match {
         case encoder: CelebornDictionaryEncoding.CelebornEncoder[_] if 
!encoder.overflow =>
-          encoder.cleanBatch
+          encoder.cleanBatch()
         case _ =>
       }
       CelebornColumnBuilder(
@@ -100,8 +100,6 @@ class CelebornColumnarBatchBuilder(
     giantBuffer.toByteArray
   }
 
-  var totalSize = 0
-
   def writeRow(row: InternalRow): Unit = {
     var i = 0
     while (i < row.numFields) {
@@ -111,21 +109,5 @@ class CelebornColumnarBatchBuilder(
     rowCnt += 1
   }
 
-  def getTotalSize(): Int = {
-    var i = 0
-    var tempTotalSize = 0
-    while (i < schema.length) {
-      columnBuilders(i) match {
-        case builder: CelebornCompressibleColumnBuilder[_] =>
-          tempTotalSize += builder.getTotalSize.toInt
-        case builder: CelebornNullableColumnBuilder => tempTotalSize += 
builder.getTotalSize.toInt
-        case _ =>
-      }
-      i += 1
-    }
-    totalSize = tempTotalSize + 4 + 4 * schema.length
-    totalSize
-  }
-
-  def getRowCnt(): Int = rowCnt
+  def getRowCnt: Int = rowCnt
 }
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala
index 1c15d163a..e510e6452 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchCodeGenBuild.scala
@@ -102,21 +102,6 @@ class CelebornColumnarBatchCodeGenBuild {
     val writeRowCode = new mutable.StringBuilder()
     for (index <- schema.indices) {
       schema.fields(index).dataType match {
-        case NullType =>
-          initCode.append(
-            s"""
-               |  ${classOf[CelebornNullColumnBuilder].getName} b$index;
-          """.stripMargin)
-          buildCode.append(
-            s"""
-               |  b$index = new 
${classOf[CelebornNullColumnBuilder].getName}();
-               |  builder.initialize($batchSize, 
"${schema.fields(index).name}", false);
-          """.stripMargin)
-          writeCode.append(genWriteCode(index))
-          writeRowCode.append(
-            s"""
-               |  b$index.appendFrom(row, $index);
-          """.stripMargin)
         case ByteType =>
           initCode.append(
             s"""
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala
index c4be15c0e..3018c0edf 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornColumnarBatchSerializer.scala
@@ -34,15 +34,11 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, 
ColumnVector}
 
 class CelebornColumnarBatchSerializer(
     schema: StructType,
-    columnBatchSize: Int,
-    encodingEnabled: Boolean,
     offHeapColumnVectorEnabled: Boolean,
     dataSize: SQLMetric = null) extends Serializer with Serializable {
   override def newInstance(): SerializerInstance =
     new CelebornColumnarBatchSerializerInstance(
       schema,
-      columnBatchSize,
-      encodingEnabled,
       offHeapColumnVectorEnabled,
       dataSize)
   override def supportsRelocationOfSerializedObjects: Boolean = true
@@ -50,8 +46,6 @@ class CelebornColumnarBatchSerializer(
 
 private class CelebornColumnarBatchSerializerInstance(
     schema: StructType,
-    columnBatchSize: Int,
-    encodingEnabled: Boolean,
     offHeapColumnVectorEnabled: Boolean,
     dataSize: SQLMetric) extends SerializerInstance {
 
@@ -93,7 +87,8 @@ private class CelebornColumnarBatchSerializerInstance(
     }
   }
 
-  val toUnsafe: UnsafeProjection = UnsafeProjection.create(schema.fields.map(f 
=> f.dataType))
+  private val toUnsafe: UnsafeProjection =
+    UnsafeProjection.create(schema.fields.map(f => f.dataType))
 
   override def deserializeStream(in: InputStream): DeserializationStream = {
     val numFields = schema.fields.length
@@ -160,7 +155,7 @@ private class CelebornColumnarBatchSerializerInstance(
         try {
           dIn.readInt()
         } catch {
-          case e: EOFException =>
+          case _: EOFException =>
             dIn.close()
             EOF
         }
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala
index 2d87856c1..a0cc2be2a 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressibleColumnBuilder.scala
@@ -29,7 +29,7 @@ trait CelebornCompressibleColumnBuilder[T <: AtomicType]
 
   this: CelebornNativeColumnBuilder[T] with WithCelebornCompressionSchemes =>
 
-  var compressionEncoder: Encoder[T] = CelebornPassThrough.encoder(columnType)
+  private var compressionEncoder: Encoder[T] = 
CelebornPassThrough.encoder(columnType)
 
   def init(encoder: Encoder[T]): Unit = {
     compressionEncoder = encoder
@@ -46,7 +46,7 @@ trait CelebornCompressibleColumnBuilder[T <: AtomicType]
   // the row to become unaligned, thus causing crashes.  Until a way of fixing 
the compression
   // is found to also allow aligned accesses this must be disabled for SPARK.
 
-  protected def isWorthCompressing(encoder: Encoder[T]) = {
+  protected def isWorthCompressing(encoder: Encoder[T]): Boolean = {
     CelebornCompressibleColumnBuilder.unaligned && encoder.compressionRatio < 
0.8
   }
 
@@ -103,5 +103,5 @@ trait CelebornCompressibleColumnBuilder[T <: AtomicType]
 }
 
 object CelebornCompressibleColumnBuilder {
-  val unaligned = Platform.unaligned()
+  val unaligned: Boolean = Platform.unaligned()
 }
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala
index a6ba31176..1e7ebae0e 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionScheme.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution.columnar
 
-import java.nio.{ByteBuffer, ByteOrder}
+import java.nio.ByteBuffer
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.vectorized.WritableColumnVector
@@ -76,11 +76,4 @@ object CelebornCompressionScheme {
       typeId,
       throw new UnsupportedOperationException(s"Unrecognized compression 
scheme type ID: $typeId"))
   }
-
-  def columnHeaderSize(columnBuffer: ByteBuffer): Int = {
-    val header = columnBuffer.duplicate().order(ByteOrder.nativeOrder)
-    val nullCount = header.getInt()
-    // null count + null positions
-    4 + 4 * nullCount
-  }
 }
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala
index 316e213c8..c2dfb53c2 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala
@@ -33,7 +33,7 @@ case object CelebornPassThrough extends 
CelebornCompressionScheme {
   override def supports(columnType: CelebornColumnType[_]): Boolean = true
 
   override def encoder[T <: AtomicType](columnType: 
NativeCelebornColumnType[T]): Encoder[T] = {
-    new this.CelebornEncoder[T](columnType)
+    new this.CelebornEncoder[T]()
   }
 
   override def decoder[T <: AtomicType](
@@ -42,7 +42,7 @@ case object CelebornPassThrough extends 
CelebornCompressionScheme {
     new this.CelebornDecoder(buffer, columnType)
   }
 
-  class CelebornEncoder[T <: AtomicType](columnType: 
NativeCelebornColumnType[T])
+  class CelebornEncoder[T <: AtomicType]()
     extends Encoder[T] {
     override def uncompressedSize: Int = 0
 
@@ -247,7 +247,7 @@ case object CelebornDictionaryEncoding extends 
CelebornCompressionScheme {
   override val typeId = 1
 
   // 32K unique values allowed
-  var MAX_DICT_SIZE = Short.MaxValue
+  var MAX_DICT_SIZE: Short = Short.MaxValue
 
   override def decoder[T <: AtomicType](
       buffer: ByteBuffer,
@@ -277,7 +277,7 @@ case object CelebornDictionaryEncoding extends 
CelebornCompressionScheme {
     // Total number of elements.
     private var count = 0
 
-    def cleanBatch: Unit = {
+    def cleanBatch(): Unit = {
       count = 0
       _uncompressedSize = 0
     }
@@ -341,11 +341,11 @@ case object CelebornDictionaryEncoding extends 
CelebornCompressionScheme {
       buffer: ByteBuffer,
       columnType: NativeCelebornColumnType[T])
     extends Decoder[T] {
-    val elementNum = ByteBufferHelper.getInt(buffer)
+    private val elementNum: Int = ByteBufferHelper.getInt(buffer)
     private val dictionary: Array[Any] = new Array[Any](elementNum)
-    private var intDictionary: Array[Int] = null
-    private var longDictionary: Array[Long] = null
-    private var stringDictionary: Array[String] = null
+    private var intDictionary: Array[Int] = _
+    private var longDictionary: Array[Long] = _
+    private var stringDictionary: Array[String] = _
 
     columnType.dataType match {
       case _: IntegerType =>


Reply via email to