This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new e2b02d5944 [GLUTEN-10050][FLINK] Support decimal type (#10049)
e2b02d5944 is described below

commit e2b02d59443fcadb422dfd11649e71d91522c29e
Author: lgbo <[email protected]>
AuthorDate: Mon Jul 7 08:43:50 2025 +0800

    [GLUTEN-10050][FLINK] Support decimal type (#10049)
---
 .../gluten/vectorized/ArrowVectorAccessor.java     | 22 ++++++++++++
 .../gluten/vectorized/ArrowVectorWriter.java       | 40 ++++++++++++++++++++--
 .../table/runtime/stream/custom/ScanTest.java      | 11 ++++++
 3 files changed, 71 insertions(+), 2 deletions(-)

diff --git 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
index 6a9f2f5d05..c373c77faa 100644
--- 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
+++ 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
@@ -18,6 +18,7 @@ package org.apache.gluten.vectorized;
 
 import io.github.zhztheplayer.velox4j.type.*;
 
+import org.apache.flink.table.data.DecimalData;
 import org.apache.flink.table.data.GenericArrayData;
 import org.apache.flink.table.data.GenericMapData;
 import org.apache.flink.table.data.GenericRowData;
@@ -27,6 +28,7 @@ import org.apache.flink.table.data.binary.BinaryStringData;
 import org.apache.arrow.vector.BigIntVector;
 import org.apache.arrow.vector.BitVector;
 import org.apache.arrow.vector.DateDayVector;
+import org.apache.arrow.vector.DecimalVector;
 import org.apache.arrow.vector.FieldVector;
 import org.apache.arrow.vector.Float8Vector;
 import org.apache.arrow.vector.IntVector;
@@ -36,6 +38,7 @@ import org.apache.arrow.vector.complex.ListVector;
 import org.apache.arrow.vector.complex.MapVector;
 import org.apache.arrow.vector.complex.StructVector;
 
+import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -57,6 +60,7 @@ public abstract class ArrowVectorAccessor {
           Map.entry(IntVector.class, vector -> new IntVectorAccessor(vector)),
           Map.entry(BigIntVector.class, vector -> new 
BigIntVectorAccessor(vector)),
           Map.entry(Float8Vector.class, vector -> new 
DoubleVectorAccessor(vector)),
+          Map.entry(DecimalVector.class, vector -> new 
DecimalVectorAccessor(vector)),
           Map.entry(VarCharVector.class, vector -> new 
VarCharVectorAccessor(vector)),
           Map.entry(StructVector.class, vector -> new 
StructVectorAccessor(vector)),
           Map.entry(ListVector.class, vector -> new 
ListVectorAccessor(vector)),
@@ -150,6 +154,24 @@ class DoubleVectorAccessor extends 
BaseArrowVectorAccessor<Float8Vector> {
   }
 }
 
+class DecimalVectorAccessor extends BaseArrowVectorAccessor<DecimalVector> {
+
+  private int precision = 0;
+  private int scale = 0;
+
+  public DecimalVectorAccessor(FieldVector vector) {
+    super(vector);
+    this.precision = typedVector.getPrecision();
+    this.scale = typedVector.getScale();
+  }
+
+  @Override
+  protected Object getImpl(int rowIndex) {
+    BigDecimal decimalData = (BigDecimal) typedVector.getObject(rowIndex);
+    return DecimalData.fromBigDecimal(decimalData, precision, scale);
+  }
+}
+
 class DateDayVectorAccessor extends BaseArrowVectorAccessor<DateDayVector> {
 
   public DateDayVectorAccessor(FieldVector vector) {
diff --git 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
index c20f5bbe9c..d8a50081f8 100644
--- 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
+++ 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
@@ -19,6 +19,7 @@ package org.apache.gluten.vectorized;
 import io.github.zhztheplayer.velox4j.type.*;
 
 import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.DecimalData;
 import org.apache.flink.table.data.MapData;
 import org.apache.flink.table.data.RowData;
 
@@ -83,7 +84,11 @@ public abstract class ArrowVectorWriter {
                   new ArrayVectorWriter(fieldType, allocator, vector)),
           Map.entry(
               MapType.class,
-              (fieldType, allocator, vector) -> new MapVectorWriter(fieldType, 
allocator, vector)));
+              (fieldType, allocator, vector) -> new MapVectorWriter(fieldType, 
allocator, vector)),
+          Map.entry(
+              DecimalType.class,
+              (fieldType, allocator, vector) ->
+                  new DecimalVectorWriter(fieldType, allocator, vector)));
 
   public static ArrowVectorWriter create(
       String fieldName, Type fieldType, BufferAllocator allocator) {
@@ -159,7 +164,14 @@ class FieldVectorCreator {
               (dataType, timeZoneId) ->
                   new ArrowType.Timestamp(
                       TimeUnit.MILLISECOND, timeZoneId == null ? "UTC" : 
timeZoneId)),
-          Map.entry(DateType.class, (dataType, timeZoneId) -> new 
ArrowType.Date(DateUnit.DAY)));
+          Map.entry(DateType.class, (dataType, timeZoneId) -> new 
ArrowType.Date(DateUnit.DAY)),
+          Map.entry(
+              DecimalType.class,
+              (dataType, timeZoneId) -> {
+                DecimalType decimalType = (DecimalType) dataType;
+                return new ArrowType.Decimal(
+                    decimalType.getPrecision(), decimalType.getScale(), 128);
+              }));
 
   private static ArrowType toArrowType(Type dataType, String timeZoneId) {
     ArrowTypeConverter converter = 
arrowTypeConverters.get(dataType.getClass());
@@ -337,6 +349,30 @@ class Float8VectorWriter extends 
BaseVectorWriter<Float8Vector, Double> {
   }
 }
 
+class DecimalVectorWriter extends BaseVectorWriter<DecimalVector, DecimalData> 
{
+  private final DecimalType decimalType;
+
+  public DecimalVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+    super(vector);
+    this.decimalType = (DecimalType) fieldType;
+  }
+
+  @Override
+  protected DecimalData getValue(RowData rowData, int fieldIndex) {
+    return rowData.getDecimal(fieldIndex, decimalType.getPrecision(), 
decimalType.getScale());
+  }
+
+  @Override
+  protected DecimalData getValue(ArrayData arrayData, int index) {
+    return arrayData.getDecimal(index, decimalType.getPrecision(), 
decimalType.getScale());
+  }
+
+  @Override
+  protected void setValue(int index, DecimalData value) {
+    this.typedVector.setSafe(index, value.toBigDecimal());
+  }
+}
+
 class VarCharVectorWriter extends BaseVectorWriter<VarCharVector, byte[]> {
 
   public VarCharVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
diff --git 
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
 
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
index db4a2d00ec..f9cbe2968a 100644
--- 
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
+++ 
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
@@ -26,6 +26,7 @@ import org.junit.jupiter.api.Test;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.math.BigDecimal;
 import java.time.LocalDate;
 import java.util.Arrays;
 import java.util.LinkedHashMap;
@@ -189,4 +190,14 @@ class ScanTest extends GlutenStreamingTestBase {
     String query = "select a, b from dateTbl where a > 0";
     runAndCheck(query, Arrays.asList("+I[1, 2023-01-01]", "+I[2, 
2023-01-02]"));
   }
+
+  @Test
+  void testDecimal() {
+    List<Row> rows =
+        Arrays.asList(
+            Row.of(1, new BigDecimal("1.23")), Row.of(2, null), Row.of(3, new 
BigDecimal("7.89")));
+    createSimpleBoundedValuesTable("decimalTbl", "a int, b decimal(5, 2)", 
rows);
+    String query = "select a, b from decimalTbl where a > 0";
+    runAndCheck(query, Arrays.asList("+I[1, 1.23]", "+I[2, null]", "+I[3, 
7.89]"));
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to