This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 70e3b2444 [CELEBORN-1495] CelebornColumnDictionary supports dictionary 
of float and double column type
70e3b2444 is described below

commit 70e3b24448b739d697a725305db9ce5bb874ad6f
Author: SteNicholas <[email protected]>
AuthorDate: Thu Jul 11 10:53:00 2024 +0800

    [CELEBORN-1495] CelebornColumnDictionary supports dictionary of float and 
double column type
    
    ### What changes were proposed in this pull request?
    
    `CelebornColumnDictionary` supports dictionary of float and double column 
type.
    
    ### Why are the changes needed?
    
    `CelebornColumnDictionary` only supports dictionary of int, long and string 
column type at present. It's recommended to support dictionary of float and 
double column type for columnar shuffle.
    
    Backport https://github.com/apache/spark/pull/42850.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    No.
    
    Closes #2607 from SteNicholas/CELEBORN-1495.
    
    Authored-by: SteNicholas <[email protected]>
    Signed-off-by: mingji <[email protected]>
---
 .../columnar/CelebornColumnDictionary.java         | 14 ++++-
 .../columnar/CelebornCompressionSchemes.scala      | 71 +++++++++++-----------
 2 files changed, 46 insertions(+), 39 deletions(-)

diff --git 
a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/sql/execution/columnar/CelebornColumnDictionary.java
 
b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/sql/execution/columnar/CelebornColumnDictionary.java
index c264b257b..ce5addc56 100644
--- 
a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/sql/execution/columnar/CelebornColumnDictionary.java
+++ 
b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/sql/execution/columnar/CelebornColumnDictionary.java
@@ -24,6 +24,8 @@ import org.apache.spark.sql.execution.vectorized.Dictionary;
 public class CelebornColumnDictionary implements Dictionary {
   private int[] intDictionary;
   private long[] longDictionary;
+  private float[] floatDictionary;
+  private double[] doubleDictionary;
   private String[] stringDictionary;
 
   public CelebornColumnDictionary(int[] dictionary) {
@@ -34,6 +36,14 @@ public class CelebornColumnDictionary implements Dictionary {
     this.longDictionary = dictionary;
   }
 
+  public CelebornColumnDictionary(float[] dictionary) {
+    this.floatDictionary = dictionary;
+  }
+
+  public CelebornColumnDictionary(double[] dictionary) {
+    this.doubleDictionary = dictionary;
+  }
+
   public CelebornColumnDictionary(String[] dictionary) {
     this.stringDictionary = dictionary;
   }
@@ -50,12 +60,12 @@ public class CelebornColumnDictionary implements Dictionary 
{
 
   @Override
   public float decodeToFloat(int id) {
-    throw new UnsupportedOperationException("Dictionary encoding does not 
support float");
+    return floatDictionary[id];
   }
 
   @Override
   public double decodeToDouble(int id) {
-    throw new UnsupportedOperationException("Dictionary encoding does not 
support double");
+    return doubleDictionary[id];
   }
 
   @Override
diff --git 
a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala
 
b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala
index c2dfb53c2..d1033b63b 100644
--- 
a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala
+++ 
b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/sql/execution/columnar/CelebornCompressionSchemes.scala
@@ -260,7 +260,7 @@ case object CelebornDictionaryEncoding extends 
CelebornCompressionScheme {
   }
 
   override def supports(columnType: CelebornColumnType[_]): Boolean = 
columnType match {
-    case CELEBORN_INT | CELEBORN_LONG | CELEBORN_STRING => true
+    case CELEBORN_INT | CELEBORN_LONG | CELEBORN_FLOAT | CELEBORN_DOUBLE | 
CELEBORN_STRING => true
     case _ => false
   }
 
@@ -345,6 +345,8 @@ case object CelebornDictionaryEncoding extends 
CelebornCompressionScheme {
     private val dictionary: Array[Any] = new Array[Any](elementNum)
     private var intDictionary: Array[Int] = _
     private var longDictionary: Array[Long] = _
+    private var floatDictionary: Array[Float] = _
+    private var doubleDictionary: Array[Double] = _
     private var stringDictionary: Array[String] = _
 
     columnType.dataType match {
@@ -362,6 +364,20 @@ case object CelebornDictionaryEncoding extends 
CelebornCompressionScheme {
           longDictionary(i) = v
           dictionary(i) = v
         }
+      case _: FloatType =>
+        floatDictionary = new Array[Float](elementNum)
+        for (i <- 0 until elementNum) {
+          val v = columnType.extract(buffer).asInstanceOf[Float]
+          floatDictionary(i) = v
+          dictionary(i) = v
+        }
+      case _: DoubleType =>
+        doubleDictionary = new Array[Double](elementNum)
+        for (i <- 0 until elementNum) {
+          val v = columnType.extract(buffer).asInstanceOf[Double]
+          doubleDictionary(i) = v
+          dictionary(i) = v
+        }
       case _: StringType =>
         stringDictionary = new Array[String](elementNum)
         for (i <- 0 until elementNum) {
@@ -384,51 +400,32 @@ case object CelebornDictionaryEncoding extends 
CelebornCompressionScheme {
       var nextNullIndex = if (nullCount > 0) 
ByteBufferHelper.getInt(nullsBuffer) else -1
       var pos = 0
       var seenNulls = 0
+      val dictionaryIds = columnVector.reserveDictionaryIds(capacity)
       columnType.dataType match {
         case _: IntegerType =>
-          val dictionaryIds = columnVector.reserveDictionaryIds(capacity)
           columnVector.setDictionary(new 
CelebornColumnDictionary(intDictionary))
-          while (pos < capacity) {
-            if (pos != nextNullIndex) {
-              dictionaryIds.putInt(pos, buffer.getShort())
-            } else {
-              seenNulls += 1
-              if (seenNulls < nullCount) nextNullIndex = 
ByteBufferHelper.getInt(nullsBuffer)
-              columnVector.putNull(pos)
-            }
-            pos += 1
-          }
         case _: LongType =>
-          val dictionaryIds = columnVector.reserveDictionaryIds(capacity)
           columnVector.setDictionary(new 
CelebornColumnDictionary(longDictionary))
-          while (pos < capacity) {
-            if (pos != nextNullIndex) {
-              dictionaryIds.putInt(pos, buffer.getShort())
-            } else {
-              seenNulls += 1
-              if (seenNulls < nullCount) {
-                nextNullIndex = ByteBufferHelper.getInt(nullsBuffer)
-              }
-              columnVector.putNull(pos)
-            }
-            pos += 1
-          }
+        case _: FloatType =>
+          columnVector.setDictionary(new 
CelebornColumnDictionary(floatDictionary))
+        case _: DoubleType =>
+          columnVector.setDictionary(new 
CelebornColumnDictionary(doubleDictionary))
         case _: StringType =>
-          val dictionaryIds = columnVector.reserveDictionaryIds(capacity)
           columnVector.setDictionary(new 
CelebornColumnDictionary(stringDictionary))
-          while (pos < capacity) {
-            if (pos != nextNullIndex) {
-              dictionaryIds.putInt(pos, buffer.getShort())
-            } else {
-              seenNulls += 1
-              if (seenNulls < nullCount) nextNullIndex = 
ByteBufferHelper.getInt(nullsBuffer)
-              columnVector.putNull(pos)
-            }
-            pos += 1
-          }
-
         case _ => throw new IllegalStateException("Not supported type in 
DictionaryEncoding.")
       }
+      while (pos < capacity) {
+        if (pos != nextNullIndex) {
+          dictionaryIds.putInt(pos, buffer.getShort())
+        } else {
+          seenNulls += 1
+          if (seenNulls < nullCount) {
+            nextNullIndex = ByteBufferHelper.getInt(nullsBuffer)
+          }
+          columnVector.putNull(pos)
+        }
+        pos += 1
+      }
     }
   }
 }

Reply via email to