This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 42b2bef4c4 [Flink] Refactor convertion of row to vector (#9701)
42b2bef4c4 is described below
commit 42b2bef4c44eaca986580a4302e1b69ce469ddd2
Author: lgbo <[email protected]>
AuthorDate: Fri May 23 15:01:47 2025 +0800
[Flink] Refactor convertion of row to vector (#9701)
---
.../gluten/vectorized/ArrowVectorWriter.java | 262 +++++++++++++++++++++
.../vectorized/FlinkRowToVLVectorConvertor.java | 81 +------
.../table/runtime/stream/custom/ScanTest.java | 16 +-
3 files changed, 290 insertions(+), 69 deletions(-)
diff --git
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
new file mode 100644
index 0000000000..9f3f41369b
--- /dev/null
+++
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.gluten.vectorized;
+
+import io.github.zhztheplayer.velox4j.type.*;
+
+import org.apache.arrow.flatbuf.Int;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.types.TimeUnit;
+import org.apache.arrow.vector.types.pojo.*;
+import org.apache.arrow.vector.*;
+import org.apache.flink.table.data.RowData;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public abstract class ArrowVectorWriter {
+ public static ArrowVectorWriter create(
+ String fieldName, Type fieldType, BufferAllocator allocator) {
+ return create(fieldName, fieldType, allocator, null);
+ }
+
+ protected static ArrowVectorWriter create(
+ String fieldName, Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ if (vector == null) {
+ // Build an empty vector
+ vector = FieldVectorCreator.create(
+ fieldName, fieldType, false, allocator, null);
+ }
+ if (fieldType instanceof IntegerType) {
+ return new IntVectorWriter(fieldType, allocator, vector);
+ } else if (fieldType instanceof BooleanType) {
+ return new BooleanVectorWriter(fieldType, allocator, vector);
+ } else if (fieldType instanceof BigIntType) {
+ return new BigIntVectorWriter(fieldType, allocator, vector);
+ } else if (fieldType instanceof VarCharType) {
+ return new VarCharVectorWriter(fieldType, allocator, vector);
+ } else if (fieldType instanceof TimestampType) {
+ return new TimestampVectorWriter(fieldType, allocator, vector);
+ } else if (fieldType instanceof RowType) {
+ return new StructVectorWriter(fieldType, allocator, vector);
+ } else {
+ throw new UnsupportedOperationException("Unsupported type: " +
fieldType);
+ }
+ }
+ public void write(int fieldIndex, RowData rowData) {
+ throw new UnsupportedOperationException("assign is not supported");
+ }
+
+ public void write(int fieldIndex, List<RowData> rowData) {
+ for (RowData row : rowData) {
+ write(fieldIndex, row);
+ }
+ }
+
+ protected FieldVector vector = null;
+ protected int valueCount = 0;
+ ArrowVectorWriter(FieldVector vector) {
+ this.vector = vector;
+ }
+
+ FieldVector getVector() {
+ return vector;
+ }
+
+ void finish() {
+ vector.setValueCount(valueCount);
+ }
+}
+// Build FieldVector from Type.
+class FieldVectorCreator {
+ public static FieldVector create(
+ String name, Type dataType, boolean nullable, BufferAllocator
allocator, String timeZoneId) {
+ Field field = toArrowField(name, dataType, nullable, timeZoneId);
+ return field.createVector(allocator);
+ }
+
+ private static ArrowType toArroyType(Type dataType, String timeZoneId) {
+ if (dataType instanceof BooleanType) {
+ return ArrowType.Bool.INSTANCE;
+ } else if (dataType instanceof IntegerType) {
+ return new ArrowType.Int(8 * 4, true);
+ } else if (dataType instanceof BigIntType) {
+ return new ArrowType.Int(8 * 8, true);
+ } else if (dataType instanceof VarCharType) {
+ return ArrowType.Utf8.INSTANCE;
+ } else if (dataType instanceof TimestampType) {
+ if (timeZoneId == null) {
+ return new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC");
+ } else {
+ return new ArrowType.Timestamp(TimeUnit.MILLISECOND,
timeZoneId);
+ }
+ }
+ else {
+ throw new UnsupportedOperationException("Unsupported type: " +
dataType);
+ }
+ }
+
+ private static Field toArrowField(
+ String name, Type dataType, boolean nullable, String timeZoneId) {
+ if (dataType instanceof ArrayType) {
+ throw new UnsupportedOperationException("ArrayType is not
supported");
+ } else if (dataType instanceof MapType) {
+ throw new UnsupportedOperationException("MapType is not
supported");
+ } else if (dataType instanceof RowType) {
+ RowType structType = (RowType) dataType;
+ List<String> fieldNames = structType.getNames();
+ List<Type> fieldTypes = structType.getChildren();
+ List<Field> subFields = new ArrayList<>();
+ for (int i = 0; i < structType.getChildren().size(); ++i) {
+ subFields.add(
+ toArrowField(fieldNames.get(i), fieldTypes.get(i),
nullable, timeZoneId));
+ }
+ FieldType strcutType =
+ new FieldType(nullable, ArrowType.Struct.INSTANCE, null);
+ return new Field(name, strcutType, subFields);
+ } else {
+ // TODO: support nullable
+ ArrowType arrowType = toArroyType(dataType, timeZoneId);
+ FieldType fieldType = new FieldType(nullable, arrowType, null);
+ return new Field(name, fieldType, new ArrayList<>());
+ }
+ }
+}
+
+class IntVectorWriter extends ArrowVectorWriter {
+ private final IntVector intVector;
+
+ public IntVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ super(vector);
+ this.intVector = (IntVector) vector;
+ }
+
+ @Override
+ public void write(int fieldIndex, RowData rowData) {
+ intVector.setSafe(valueCount, rowData.getInt(fieldIndex));
+ valueCount++;
+ }
+}
+
+class BooleanVectorWriter extends ArrowVectorWriter {
+ private final BitVector bitVector;
+
+ public BooleanVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ super(vector);
+ this.bitVector = (BitVector) vector;
+ }
+
+ @Override
+ public void write(int fieldIndex, RowData rowData) {
+ bitVector.setSafe(valueCount, rowData.getBoolean(fieldIndex) ? 1 : 0);
+ valueCount++;
+ }
+}
+
+
+class BigIntVectorWriter extends ArrowVectorWriter {
+ private final BigIntVector bigIntvector;
+
+ public BigIntVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ super(vector);
+ this.bigIntvector = (BigIntVector) vector;
+ }
+
+ @Override
+ public void write(int fieldIndex, RowData rowData) {
+ bigIntvector.setSafe(valueCount, rowData.getLong(fieldIndex));
+ valueCount++;
+ }
+}
+
+class VarCharVectorWriter extends ArrowVectorWriter {
+ private final VarCharVector varCharVector;
+
+ public VarCharVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ super(vector);
+ this.varCharVector = (VarCharVector) vector;
+ }
+
+ @Override
+ public void write(int fieldIndex, RowData rowData) {
+ varCharVector.setSafe(valueCount,
rowData.getString(fieldIndex).toBytes());
+ valueCount++;
+ }
+}
+
+class TimestampVectorWriter extends ArrowVectorWriter {
+ private final TimeStampMilliVector tsVector;
+
+ public TimestampVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ super(vector);
+ this.tsVector = (TimeStampMilliVector) vector;
+ }
+
+ @Override
+ public void write(int fieldIndex, RowData rowData) {
+ // TODO: support precision
+ tsVector.setSafe(valueCount, rowData.getTimestamp(fieldIndex,
3).getMillisecond());
+ valueCount++;
+ }
+
+}
+
+class StructVectorWriter extends ArrowVectorWriter {
+ private int fieldCounts = 0;
+ BufferAllocator allocator;
+ private List<ArrowVectorWriter> subFieldWriters;
+ private StructVector strctVector;
+
+ public StructVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ super(vector);
+ this.strctVector = (StructVector) vector;
+ RowType rowType = (RowType) fieldType;
+ List<String> subFieldNames = rowType.getNames();
+ subFieldWriters = new ArrayList<>();
+ for (int i = 0; i < subFieldNames.size(); ++i) {
+ subFieldWriters.add(
+ ArrowVectorWriter.create(
+ subFieldNames.get(i),
+ rowType.getChildren().get(i),
+ allocator,
+ (FieldVector)(this.strctVector.getChildByOrdinal(i))
+ ));
+ }
+ fieldCounts = subFieldNames.size();
+ }
+
+ @Override
+ public void write(int fieldIndex, RowData rowData) {
+ // TODO: support nullable
+ RowData subRowData = rowData.getRow(fieldIndex, fieldCounts);
+ strctVector.setIndexDefined(valueCount);
+ for (int i = 0; i < fieldCounts; i++) {
+ subFieldWriters.get(i).write(i, subRowData);
+ }
+ valueCount++;
+ }
+
+ @Override
+ public void finish() {
+ strctVector.setValueCount(valueCount);
+ for (int i = 0; i < fieldCounts; i++) {
+ subFieldWriters.get(i).finish();
+ }
+ }
+}
diff --git
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/FlinkRowToVLVectorConvertor.java
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/FlinkRowToVLVectorConvertor.java
index 264e33ec67..631681eda0 100644
---
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/FlinkRowToVLVectorConvertor.java
+++
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/FlinkRowToVLVectorConvertor.java
@@ -38,6 +38,7 @@ import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.table.Table;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.gluten.vectorized.ArrowVectorWriter;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryStringData;
@@ -45,6 +46,7 @@ import org.apache.flink.table.data.binary.BinaryStringData;
import java.lang.reflect.Field;
import java.util.ArrayList;
+import java.util.stream.Collectors;
import java.util.List;
/** Converter between velox RowVector and Flink RowData. */
@@ -54,78 +56,21 @@ public class FlinkRowToVLVectorConvertor {
BufferAllocator allocator,
Session session,
RowType rowType) {
- // TODO: support more types
List<FieldVector> arrowVectors = new ArrayList<>(rowType.size());
+ List<Type> fieldTypes = rowType.getChildren();
+ List<String> fieldNames = rowType.getNames();
for (int i = 0; i < rowType.size(); i++) {
Type fieldType = rowType.getChildren().get(i);
- if (fieldType instanceof IntegerType) {
- IntVector intVector = new IntVector(rowType.getNames().get(i),
allocator);
- intVector.setSafe(0, row.getInt(i));
- intVector.setValueCount(1);
- arrowVectors.add(i, intVector);
- } else if (fieldType instanceof BigIntType) {
- BigIntVector bigIntVector = new
BigIntVector(rowType.getNames().get(i), allocator);
- bigIntVector.setSafe(0, row.getLong(i));
- bigIntVector.setValueCount(1);
- arrowVectors.add(i, bigIntVector);
- } else if (fieldType instanceof VarCharType) {
- VarCharVector stringVector = new
VarCharVector(rowType.getNames().get(i), allocator);
- stringVector.setSafe(0, row.getString(i).toBytes());
- stringVector.setValueCount(1);
- arrowVectors.add(i, stringVector);
- } else if (fieldType instanceof RowType) {
- // TODO: refine this
- StructVector structVector =
- StructVector.empty(
- rowType.getNames().get(i),
- allocator);
- RowType subRowType = (RowType) fieldType;
- RowData subRow = row.getRow(i, subRowType.size());
- if (subRow != null) {
- for (int j = 0; j < subRowType.size(); j++) {
- Type subFieldType = subRowType.getChildren().get(j);
- if (subFieldType instanceof IntegerType) {
- IntVector intVector = structVector.addOrGet(
- subRowType.getNames().get(j),
-
FieldType.nullable(MinorType.INT.getType()),
- IntVector.class);
- intVector.setSafe(0, subRow.getInt(j));
- intVector.setValueCount(1);
- } else if (subFieldType instanceof BigIntType) {
- BigIntVector bigIntVector = structVector.addOrGet(
- subRowType.getNames().get(j),
-
FieldType.nullable(MinorType.BIGINT.getType()),
- BigIntVector.class);
- bigIntVector.setSafe(0, subRow.getLong(j));
- bigIntVector.setValueCount(1);
- } else if (subFieldType instanceof VarCharType) {
- VarCharVector stringVector = structVector.addOrGet(
- subRowType.getNames().get(j),
-
FieldType.nullable(MinorType.VARCHAR.getType()),
- VarCharVector.class);
- stringVector.setSafe(0,
subRow.getString(j).toBytes());
- stringVector.setValueCount(1);
- } else if (subFieldType instanceof TimestampType) {
- // TODO: support precision
- TimeStampMilliVector timestampVector =
structVector.addOrGet(
- subRowType.getNames().get(j),
-
FieldType.nullable(MinorType.TIMESTAMPMILLI.getType()),
- TimeStampMilliVector.class);
- timestampVector.setSafe(
- 0,
- subRow.getTimestamp(j,
3).getMillisecond());
- timestampVector.setValueCount(1);
- } else {
- throw new RuntimeException("Unsupported field
type: " + subFieldType);
- }
- }
- structVector.setValueCount(1);
- }
- arrowVectors.add(i, structVector);
- } else {
- throw new RuntimeException("Unsupported field type: " +
fieldType);
- }
+ ArrowVectorWriter writer =
+ ArrowVectorWriter.create(
+ fieldNames.get(i),
+ fieldTypes.get(i),
+ allocator);
+ writer.write(i, row);
+ writer.finish();
+ arrowVectors.add(i, writer.getVector());
}
+
return session.arrowOps().fromArrowTable(allocator, new
Table(arrowVectors));
}
diff --git
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
index 64ed5ba60d..b0dd716e77 100644
---
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
+++
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
@@ -39,8 +39,16 @@ class ScanTest extends GlutenStreamingTestBase {
public void before() throws Exception {
super.before();
List<Row> rows =
- Arrays.asList(Row.of(1, 1L, "1"), Row.of(2, 2L, "2"),
Row.of(3, 3L, "3"));
+ Arrays.asList(Row.of(1, 1L, "1"),
+ Row.of(2, 2L, "2"),
+ Row.of(3, 3L, "3"));
createSimpleBoundedValuesTable("MyTable", "a int, b bigint, c string",
rows);
+
+ List<Row> structRows =
+ Arrays.asList(Row.of(1, Row.of(2L, "abc")),
+ Row.of(2, Row.of(6L, "def")),
+ Row.of(3, Row.of(8L, "ghi")));
+ createSimpleBoundedValuesTable("MyTable2", "a int, b ROW<d bigint, e
string>", structRows);
}
@Test
@@ -49,5 +57,11 @@ class ScanTest extends GlutenStreamingTestBase {
LOG.info("execution plan: {}", explainExecutionPlan(query));
runAndCheck(query, Arrays.asList("+I[1, 1, 1, false]", "+I[2, 2, 2,
false]", "+I[3, 3, 3, true]"));
}
+
+ @Test
+ void testFilterWithStruct() {
+ String query = "select a, b.d, b.e from MyTable2 where a > 0";
+ runAndCheck(query, Arrays.asList("+I[1, 2, abc]", "+I[2, 6, def]",
"+I[3, 8, ghi]"));
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]