This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 42b2bef4c4 [Flink] Refactor convertion of row to vector (#9701)
42b2bef4c4 is described below

commit 42b2bef4c44eaca986580a4302e1b69ce469ddd2
Author: lgbo <[email protected]>
AuthorDate: Fri May 23 15:01:47 2025 +0800

    [Flink] Refactor convertion of row to vector (#9701)
---
 .../gluten/vectorized/ArrowVectorWriter.java       | 262 +++++++++++++++++++++
 .../vectorized/FlinkRowToVLVectorConvertor.java    |  81 +------
 .../table/runtime/stream/custom/ScanTest.java      |  16 +-
 3 files changed, 290 insertions(+), 69 deletions(-)

diff --git 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
new file mode 100644
index 0000000000..9f3f41369b
--- /dev/null
+++ 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.gluten.vectorized;
+
+import io.github.zhztheplayer.velox4j.type.*;
+
+import org.apache.arrow.flatbuf.Int;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.types.TimeUnit;
+import org.apache.arrow.vector.types.pojo.*;
+import org.apache.arrow.vector.*;
+import org.apache.flink.table.data.RowData;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public abstract class ArrowVectorWriter {
+    public static ArrowVectorWriter create(
+        String fieldName, Type fieldType, BufferAllocator allocator) {
+        return create(fieldName, fieldType, allocator, null);
+    }
+
+    protected static ArrowVectorWriter create(
+        String fieldName, Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+        if (vector == null) {
+            // Build an empty vector
+            vector = FieldVectorCreator.create(
+                fieldName, fieldType, false, allocator, null);
+        }
+        if (fieldType instanceof IntegerType) {
+            return new IntVectorWriter(fieldType, allocator, vector);
+        } else if (fieldType instanceof BooleanType) {
+            return new BooleanVectorWriter(fieldType, allocator, vector);
+        } else if (fieldType instanceof BigIntType) {
+            return new BigIntVectorWriter(fieldType, allocator, vector);
+        } else if (fieldType instanceof VarCharType) {
+            return new VarCharVectorWriter(fieldType, allocator, vector);
+        } else if (fieldType instanceof TimestampType) {
+            return new TimestampVectorWriter(fieldType, allocator, vector);
+        } else if (fieldType instanceof RowType) {
+            return new StructVectorWriter(fieldType, allocator, vector);
+        } else {
+            throw new UnsupportedOperationException("Unsupported type: " + 
fieldType);
+        }
+    }
+    public void write(int fieldIndex, RowData rowData) {
+        throw new UnsupportedOperationException("assign is not supported");
+    }
+
+    public void write(int fieldIndex, List<RowData> rowData) {
+        for (RowData row : rowData) {
+            write(fieldIndex, row);
+        }
+    }
+
+    protected FieldVector vector = null;
+    protected int valueCount = 0;
+    ArrowVectorWriter(FieldVector vector) {
+        this.vector = vector;
+    }
+
+    FieldVector getVector() {
+        return vector;
+    }
+
+    void finish() {
+        vector.setValueCount(valueCount);
+    }
+}
+// Build FieldVector from Type.
+class FieldVectorCreator {
+    public static FieldVector create(
+        String name, Type dataType, boolean nullable, BufferAllocator 
allocator, String timeZoneId) {
+        Field field = toArrowField(name, dataType, nullable, timeZoneId);
+        return field.createVector(allocator);
+    }
+
+    private static ArrowType toArroyType(Type dataType, String timeZoneId) {
+        if (dataType instanceof BooleanType) {
+            return ArrowType.Bool.INSTANCE;
+        } else if (dataType instanceof IntegerType) {
+            return new ArrowType.Int(8 * 4, true);
+        } else if (dataType instanceof BigIntType) {
+            return new ArrowType.Int(8 * 8, true);
+        } else if (dataType instanceof VarCharType) {
+            return ArrowType.Utf8.INSTANCE;
+        } else if (dataType instanceof TimestampType) {
+            if (timeZoneId == null) {
+                return new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC");
+            } else {
+                return new ArrowType.Timestamp(TimeUnit.MILLISECOND, 
timeZoneId);
+            }
+        }
+        else {
+            throw new UnsupportedOperationException("Unsupported type: " + 
dataType);
+        }
+    }
+
+    private static Field toArrowField(
+        String name, Type dataType, boolean nullable, String timeZoneId) {
+        if (dataType instanceof ArrayType) {
+            throw new UnsupportedOperationException("ArrayType is not 
supported");
+        } else if (dataType instanceof MapType) {
+            throw new UnsupportedOperationException("MapType is not 
supported");
+        } else if (dataType instanceof RowType) {
+            RowType structType = (RowType) dataType;
+            List<String> fieldNames = structType.getNames();
+            List<Type> fieldTypes = structType.getChildren();
+            List<Field> subFields = new ArrayList<>();
+            for (int i = 0; i < structType.getChildren().size(); ++i) {
+                subFields.add(
+                    toArrowField(fieldNames.get(i), fieldTypes.get(i), 
nullable, timeZoneId));
+            }
+            FieldType strcutType =
+                new FieldType(nullable, ArrowType.Struct.INSTANCE, null);
+            return new Field(name, strcutType, subFields);
+        } else {
+            // TODO: support nullable
+            ArrowType arrowType = toArroyType(dataType, timeZoneId);
+            FieldType fieldType = new FieldType(nullable, arrowType, null);
+            return new Field(name, fieldType, new ArrayList<>());
+        }
+    }
+}
+
+class IntVectorWriter extends ArrowVectorWriter {
+    private final IntVector intVector;
+
+    public IntVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+        super(vector);
+        this.intVector = (IntVector) vector;
+    }
+
+    @Override
+    public void write(int fieldIndex, RowData rowData) {
+        intVector.setSafe(valueCount, rowData.getInt(fieldIndex));
+        valueCount++;
+    }
+}
+
+class BooleanVectorWriter extends ArrowVectorWriter {
+    private final BitVector bitVector;
+
+    public BooleanVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+        super(vector);
+        this.bitVector = (BitVector) vector;
+    }
+
+    @Override
+    public void write(int fieldIndex, RowData rowData) {
+        bitVector.setSafe(valueCount, rowData.getBoolean(fieldIndex) ? 1 : 0);
+        valueCount++;
+    }
+}
+
+
+class BigIntVectorWriter extends ArrowVectorWriter {
+    private final BigIntVector bigIntvector;
+
+    public BigIntVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+        super(vector);
+        this.bigIntvector = (BigIntVector) vector;
+    }
+
+    @Override
+    public void write(int fieldIndex, RowData rowData) {
+        bigIntvector.setSafe(valueCount, rowData.getLong(fieldIndex));
+        valueCount++;
+    }
+}
+
+class VarCharVectorWriter extends ArrowVectorWriter {
+    private final VarCharVector varCharVector;
+
+    public VarCharVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+        super(vector);
+        this.varCharVector = (VarCharVector) vector;
+    }
+
+    @Override
+    public void write(int fieldIndex, RowData rowData) {
+        varCharVector.setSafe(valueCount, 
rowData.getString(fieldIndex).toBytes());
+        valueCount++;
+    }
+}
+
+class TimestampVectorWriter extends ArrowVectorWriter {
+    private final TimeStampMilliVector tsVector;
+
+    public TimestampVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+        super(vector);
+        this.tsVector = (TimeStampMilliVector) vector;
+    }
+
+    @Override
+    public void write(int fieldIndex, RowData rowData) {
+        // TODO: support precision
+        tsVector.setSafe(valueCount, rowData.getTimestamp(fieldIndex, 
3).getMillisecond());
+        valueCount++;
+    }
+
+}
+
+class StructVectorWriter extends ArrowVectorWriter {
+    private int fieldCounts = 0;
+    BufferAllocator allocator;
+    private List<ArrowVectorWriter> subFieldWriters;
+    private StructVector strctVector;
+
+    public StructVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+        super(vector);
+        this.strctVector = (StructVector) vector;
+        RowType rowType = (RowType) fieldType;
+        List<String> subFieldNames = rowType.getNames();
+        subFieldWriters = new ArrayList<>();
+        for (int i = 0; i < subFieldNames.size(); ++i) {
+            subFieldWriters.add(
+                ArrowVectorWriter.create(
+                    subFieldNames.get(i),
+                    rowType.getChildren().get(i),
+                    allocator,
+                    (FieldVector)(this.strctVector.getChildByOrdinal(i))
+                    ));
+        }
+        fieldCounts = subFieldNames.size();
+    }
+
+    @Override
+    public void write(int fieldIndex, RowData rowData) {
+        // TODO: support nullable
+        RowData subRowData = rowData.getRow(fieldIndex, fieldCounts);
+        strctVector.setIndexDefined(valueCount);
+        for (int i = 0; i < fieldCounts; i++) {
+            subFieldWriters.get(i).write(i, subRowData);
+        }
+        valueCount++;
+    }
+
+    @Override
+    public void finish() {
+        strctVector.setValueCount(valueCount);
+        for (int i = 0; i < fieldCounts; i++) {
+            subFieldWriters.get(i).finish();
+        }
+    }
+}
diff --git 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/FlinkRowToVLVectorConvertor.java
 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/FlinkRowToVLVectorConvertor.java
index 264e33ec67..631681eda0 100644
--- 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/FlinkRowToVLVectorConvertor.java
+++ 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/FlinkRowToVLVectorConvertor.java
@@ -38,6 +38,7 @@ import org.apache.arrow.vector.complex.StructVector;
 import org.apache.arrow.vector.table.Table;
 import org.apache.arrow.vector.types.Types.MinorType;
 import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.gluten.vectorized.ArrowVectorWriter;
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.data.binary.BinaryStringData;
@@ -45,6 +46,7 @@ import org.apache.flink.table.data.binary.BinaryStringData;
 
 import java.lang.reflect.Field;
 import java.util.ArrayList;
+import java.util.stream.Collectors;
 import java.util.List;
 
 /** Converter between velox RowVector and Flink RowData. */
@@ -54,78 +56,21 @@ public class FlinkRowToVLVectorConvertor {
             BufferAllocator allocator,
             Session session,
             RowType rowType) {
-        // TODO: support more types
         List<FieldVector> arrowVectors = new ArrayList<>(rowType.size());
+        List<Type> fieldTypes = rowType.getChildren();
+        List<String> fieldNames = rowType.getNames();
         for (int i = 0; i < rowType.size(); i++) {
             Type fieldType = rowType.getChildren().get(i);
-            if (fieldType instanceof IntegerType) {
-                IntVector intVector = new IntVector(rowType.getNames().get(i), 
allocator);
-                intVector.setSafe(0, row.getInt(i));
-                intVector.setValueCount(1);
-                arrowVectors.add(i, intVector);
-            } else if (fieldType instanceof BigIntType) {
-                BigIntVector bigIntVector = new 
BigIntVector(rowType.getNames().get(i), allocator);
-                bigIntVector.setSafe(0, row.getLong(i));
-                bigIntVector.setValueCount(1);
-                arrowVectors.add(i, bigIntVector);
-            } else if (fieldType instanceof VarCharType) {
-                VarCharVector stringVector = new 
VarCharVector(rowType.getNames().get(i), allocator);
-                stringVector.setSafe(0, row.getString(i).toBytes());
-                stringVector.setValueCount(1);
-                arrowVectors.add(i, stringVector);
-            } else if (fieldType instanceof RowType) {
-                // TODO: refine this
-                StructVector structVector =
-                        StructVector.empty(
-                                rowType.getNames().get(i),
-                                allocator);
-                RowType subRowType = (RowType) fieldType;
-                RowData subRow = row.getRow(i, subRowType.size());
-                if (subRow != null) {
-                    for (int j = 0; j < subRowType.size(); j++) {
-                        Type subFieldType = subRowType.getChildren().get(j);
-                        if (subFieldType instanceof IntegerType) {
-                            IntVector intVector = structVector.addOrGet(
-                                    subRowType.getNames().get(j),
-                                    
FieldType.nullable(MinorType.INT.getType()),
-                                    IntVector.class);
-                            intVector.setSafe(0, subRow.getInt(j));
-                            intVector.setValueCount(1);
-                        } else if (subFieldType instanceof BigIntType) {
-                            BigIntVector bigIntVector = structVector.addOrGet(
-                                    subRowType.getNames().get(j),
-                                    
FieldType.nullable(MinorType.BIGINT.getType()),
-                                    BigIntVector.class);
-                            bigIntVector.setSafe(0, subRow.getLong(j));
-                            bigIntVector.setValueCount(1);
-                        } else if (subFieldType instanceof VarCharType) {
-                            VarCharVector stringVector = structVector.addOrGet(
-                                    subRowType.getNames().get(j),
-                                    
FieldType.nullable(MinorType.VARCHAR.getType()),
-                                    VarCharVector.class);
-                            stringVector.setSafe(0, 
subRow.getString(j).toBytes());
-                            stringVector.setValueCount(1);
-                        } else if (subFieldType instanceof TimestampType) {
-                            // TODO: support precision
-                            TimeStampMilliVector timestampVector = 
structVector.addOrGet(
-                                    subRowType.getNames().get(j),
-                                    
FieldType.nullable(MinorType.TIMESTAMPMILLI.getType()),
-                                    TimeStampMilliVector.class);
-                            timestampVector.setSafe(
-                                    0,
-                                    subRow.getTimestamp(j, 
3).getMillisecond());
-                            timestampVector.setValueCount(1);
-                        } else {
-                            throw new RuntimeException("Unsupported field 
type: " + subFieldType);
-                        }
-                    }
-                    structVector.setValueCount(1);
-                }
-                arrowVectors.add(i, structVector);
-            } else {
-                throw new RuntimeException("Unsupported field type: " + 
fieldType);
-            }
+            ArrowVectorWriter writer =
+            ArrowVectorWriter.create(
+                            fieldNames.get(i),
+                            fieldTypes.get(i),
+                            allocator);
+            writer.write(i, row);
+            writer.finish();
+            arrowVectors.add(i, writer.getVector());
         }
+
         return session.arrowOps().fromArrowTable(allocator, new 
Table(arrowVectors));
     }
 
diff --git 
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
 
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
index 64ed5ba60d..b0dd716e77 100644
--- 
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
+++ 
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
@@ -39,8 +39,16 @@ class ScanTest extends GlutenStreamingTestBase {
     public void before() throws Exception {
         super.before();
         List<Row> rows =
-                Arrays.asList(Row.of(1, 1L, "1"), Row.of(2, 2L, "2"), 
Row.of(3, 3L, "3"));
+                Arrays.asList(Row.of(1, 1L, "1"),
+                    Row.of(2, 2L, "2"),
+                    Row.of(3, 3L, "3"));
         createSimpleBoundedValuesTable("MyTable", "a int, b bigint, c string", 
rows);
+
+        List<Row> structRows =
+                Arrays.asList(Row.of(1, Row.of(2L, "abc")),
+                    Row.of(2,  Row.of(6L, "def")),
+                    Row.of(3,  Row.of(8L, "ghi")));
+        createSimpleBoundedValuesTable("MyTable2", "a int, b ROW<d bigint, e 
string>", structRows);
     }
 
     @Test
@@ -49,5 +57,11 @@ class ScanTest extends GlutenStreamingTestBase {
         LOG.info("execution plan: {}", explainExecutionPlan(query));
         runAndCheck(query, Arrays.asList("+I[1, 1, 1, false]", "+I[2, 2, 2, 
false]", "+I[3, 3, 3, true]"));
     }
+
+    @Test
+    void testFilterWithStruct() {
+        String query = "select a, b.d, b.e from MyTable2 where a > 0";
+        runAndCheck(query, Arrays.asList("+I[1, 2, abc]", "+I[2, 6, def]", 
"+I[3, 8, ghi]"));
+    }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to