This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new eff142c5e8 [GLUTEN-9752][Flink] Support array/map in row/vector
conversion (#9757)
eff142c5e8 is described below
commit eff142c5e8bacac476f43669dc27a4ea565467ba
Author: lgbo <[email protected]>
AuthorDate: Fri May 30 08:32:49 2025 +0800
[GLUTEN-9752][Flink] Support array/map in row/vector conversion (#9757)
Co-authored-by: PHILO-HE <[email protected]>
---
.../apache/gluten/util/LogicalTypeConverter.java | 99 +++--
.../gluten/vectorized/ArrowVectorAccessor.java | 96 ++++-
.../gluten/vectorized/ArrowVectorWriter.java | 440 ++++++++++++++++-----
.../table/runtime/stream/custom/ScanTest.java | 61 +++
4 files changed, 547 insertions(+), 149 deletions(-)
diff --git
a/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java
b/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java
index b39138ceab..30c3299990 100644
---
a/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java
+++
b/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java
@@ -16,9 +16,9 @@
*/
package org.apache.gluten.util;
-import io.github.zhztheplayer.velox4j.type.IntegerType;
import io.github.zhztheplayer.velox4j.type.Type;
+import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.BooleanType;
import org.apache.flink.table.types.logical.DayTimeIntervalType;
@@ -26,47 +26,84 @@ import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.table.types.logical.DoubleType;
import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.MapType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.TimestampType;
import org.apache.flink.table.types.logical.VarCharType;
import java.util.List;
+import java.util.Map;
import java.util.stream.Collectors;
-/** Convertor to convert Flink LogicalType to velox data Type */
+// Convertor to convert Flink LogicalType to velox data Type
public class LogicalTypeConverter {
+ private interface VLTypeConverter {
+ Type build(LogicalType logicalType);
+ }
+
+ // Exact class matches
+ private static Map<Class<?>, VLTypeConverter> converters =
+ Map.ofEntries(
+ Map.entry(
+ BooleanType.class,
+ logicalType -> new
io.github.zhztheplayer.velox4j.type.BooleanType()),
+ Map.entry(
+ IntType.class, logicalType -> new
io.github.zhztheplayer.velox4j.type.IntegerType()),
+ Map.entry(
+ BigIntType.class,
+ logicalType -> new
io.github.zhztheplayer.velox4j.type.BigIntType()),
+ Map.entry(
+ DoubleType.class,
+ logicalType -> new
io.github.zhztheplayer.velox4j.type.DoubleType()),
+ Map.entry(
+ VarCharType.class,
+ logicalType -> new
io.github.zhztheplayer.velox4j.type.VarCharType()),
+ // TODO: may need precision
+ Map.entry(
+ TimestampType.class,
+ logicalType -> new
io.github.zhztheplayer.velox4j.type.TimestampType()),
+ Map.entry(
+ DecimalType.class,
+ logicalType -> {
+ DecimalType decimalType = (DecimalType) logicalType;
+ return new io.github.zhztheplayer.velox4j.type.DecimalType(
+ decimalType.getPrecision(), decimalType.getScale());
+ }),
+ Map.entry(
+ DayTimeIntervalType.class,
+ logicalType -> new
io.github.zhztheplayer.velox4j.type.BigIntType()),
+ Map.entry(
+ RowType.class,
+ logicalType -> {
+ RowType flinkRowType = (RowType) logicalType;
+ List<Type> fieldTypes =
+ flinkRowType.getChildren().stream()
+ .map(LogicalTypeConverter::toVLType)
+ .collect(Collectors.toList());
+ return new io.github.zhztheplayer.velox4j.type.RowType(
+ flinkRowType.getFieldNames(), fieldTypes);
+ }),
+ Map.entry(
+ ArrayType.class,
+ logicalType -> {
+ ArrayType arrayType = (ArrayType) logicalType;
+ Type elementType = toVLType(arrayType.getElementType());
+ return
io.github.zhztheplayer.velox4j.type.ArrayType.create(elementType);
+ }),
+ Map.entry(
+ MapType.class,
+ logicalType -> {
+ MapType mapType = (MapType) logicalType;
+ Type keyType = toVLType(mapType.getKeyType());
+ Type valueType = toVLType(mapType.getValueType());
+ return
io.github.zhztheplayer.velox4j.type.MapType.create(keyType, valueType);
+ }));
public static Type toVLType(LogicalType logicalType) {
- if (logicalType instanceof RowType) {
- RowType flinkRowType = (RowType) logicalType;
- List<Type> fieldTypes =
- flinkRowType.getChildren().stream()
- .map(LogicalTypeConverter::toVLType)
- .collect(Collectors.toList());
- return new io.github.zhztheplayer.velox4j.type.RowType(
- flinkRowType.getFieldNames(), fieldTypes);
- } else if (logicalType instanceof BooleanType) {
- return new io.github.zhztheplayer.velox4j.type.BooleanType();
- } else if (logicalType instanceof IntType) {
- return new IntegerType();
- } else if (logicalType instanceof BigIntType) {
- return new io.github.zhztheplayer.velox4j.type.BigIntType();
- } else if (logicalType instanceof DoubleType) {
- return new io.github.zhztheplayer.velox4j.type.DoubleType();
- } else if (logicalType instanceof VarCharType) {
- return new io.github.zhztheplayer.velox4j.type.VarCharType();
- } else if (logicalType instanceof TimestampType) {
- // TODO: may need precision
- return new io.github.zhztheplayer.velox4j.type.TimestampType();
- } else if (logicalType instanceof DecimalType) {
- DecimalType decimalType = (DecimalType) logicalType;
- return new io.github.zhztheplayer.velox4j.type.DecimalType(
- decimalType.getPrecision(), decimalType.getScale());
- } else if (logicalType instanceof DayTimeIntervalType) {
- // TODO: it seems interval now can be used as bigint for nexmark.
- return new io.github.zhztheplayer.velox4j.type.BigIntType();
- } else {
+ VLTypeConverter converter = converters.get(logicalType.getClass());
+ if (converter == null) {
throw new RuntimeException("Unsupported logical type: " + logicalType);
}
+ return converter.build(logicalType);
}
}
diff --git
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
index 5138172a7d..5c4975b302 100644
---
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
+++
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
@@ -18,6 +18,8 @@ package org.apache.gluten.vectorized;
import io.github.zhztheplayer.velox4j.type.*;
+import org.apache.flink.table.data.GenericArrayData;
+import org.apache.flink.table.data.GenericMapData;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.binary.BinaryStringData;
@@ -27,33 +29,47 @@ import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import java.util.ArrayList;
+import java.util.LinkedHashMap;
import java.util.List;
+import java.util.Map;
/*
* This module is used to convert column vector to flink generic rows.
* BinaryRowData is not supported here.
*/
public abstract class ArrowVectorAccessor {
+ private interface AccessorBuilder {
+ ArrowVectorAccessor build(FieldVector vector);
+ };
+
+ // Exact class matches
+ private static final Map<Class<? extends FieldVector>, AccessorBuilder>
accessorBuilders =
+ Map.ofEntries(
+ Map.entry(BitVector.class, vector -> new
BooleanVectorAccessor(vector)),
+ Map.entry(IntVector.class, vector -> new IntVectorAccessor(vector)),
+ Map.entry(BigIntVector.class, vector -> new
BigIntVectorAccessor(vector)),
+ Map.entry(Float8Vector.class, vector -> new
DoubleVectorAccessor(vector)),
+ Map.entry(VarCharVector.class, vector -> new
VarCharVectorAccessor(vector)),
+ Map.entry(StructVector.class, vector -> new
StructVectorAccessor(vector)),
+ Map.entry(ListVector.class, vector -> new
ListVectorAccessor(vector)),
+ Map.entry(MapVector.class, vector -> new MapVectorAccessor(vector)));
+
public static ArrowVectorAccessor create(FieldVector vector) {
- if (vector instanceof BitVector) {
- return new BooleanVectorAccessor(vector);
- } else if (vector instanceof IntVector) {
- return new IntVectorAccessor(vector);
- } else if (vector instanceof BigIntVector) {
- return new BigIntVectorAccessor(vector);
- } else if (vector instanceof Float8Vector) {
- return new DoubleVectorAccessor(vector);
- } else if (vector instanceof VarCharVector) {
- return new VarCharVectorAccessor(vector);
- } else if (vector instanceof StructVector) {
- return new StructVectorAccessor(vector);
- } else {
+ if (vector == null) {
+ throw new IllegalArgumentException(
+ "ArrowVectorAccessor. Cannot create accessor for null vector.");
+ }
+ AccessorBuilder builder = accessorBuilders.get(vector.getClass());
+ if (builder == null) {
throw new UnsupportedOperationException(
- "ArrowVectorAccessor. Unsupported type: " +
vector.getClass().getName());
+ "ArrowVectorAccessor. Unsupported vector type: " +
vector.getClass().getName());
}
+ return builder.build(vector);
}
// A general method to extract values from the vector.
@@ -153,3 +169,55 @@ class StructVectorAccessor extends ArrowVectorAccessor {
return GenericRowData.of(fieldValues);
}
}
+
+class ListVectorAccessor extends ArrowVectorAccessor {
+ private ListVector vector;
+ private ArrowVectorAccessor elementAccessor;
+
+ public ListVectorAccessor(FieldVector vector) {
+ this.vector = (ListVector) vector;
+ FieldVector elementVector = this.vector.getDataVector();
+ this.elementAccessor = ArrowVectorAccessor.create(elementVector);
+ }
+
+ @Override
+ public Object get(int rowIndex) {
+ int startIndex = vector.getElementStartIndex(rowIndex);
+ int endIndex = vector.getElementEndIndex(rowIndex);
+ Object[] elements = new Object[endIndex - startIndex];
+ for (int i = startIndex; i < endIndex; i++) {
+ elements[i - startIndex] = elementAccessor.get(i);
+ }
+ return new GenericArrayData(elements);
+ }
+}
+
+// In Arrow, the internal implementation of a map vector is an array vector.
+class MapVectorAccessor extends ArrowVectorAccessor {
+ private final MapVector vector;
+ private StructVector entriesVector;
+ private ArrowVectorAccessor keyAccessor;
+ private ArrowVectorAccessor valueAccessor;
+
+ public MapVectorAccessor(FieldVector vector) {
+ this.vector = (MapVector) vector;
+ this.entriesVector = (StructVector) this.vector.getDataVector();
+ FieldVector keyVector = this.entriesVector.getChild(MapVector.KEY_NAME);
+ FieldVector valueVector =
this.entriesVector.getChild(MapVector.VALUE_NAME);
+ this.keyAccessor = ArrowVectorAccessor.create(keyVector);
+ this.valueAccessor = ArrowVectorAccessor.create(valueVector);
+ }
+
+ @Override
+ public Object get(int rowIndex) {
+ int startIndex = vector.getElementStartIndex(rowIndex);
+ int endIndex = vector.getElementEndIndex(rowIndex);
+ Map<Object, Object> mapEntries = new LinkedHashMap<>();
+ for (int i = startIndex; i < endIndex; i++) {
+ Object key = keyAccessor.get(i);
+ Object value = valueAccessor.get(i);
+ mapEntries.put(key, value);
+ }
+ return new GenericMapData(mapEntries);
+ }
+}
diff --git
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
index 9c6c845697..52c700c255 100644
---
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
+++
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
@@ -18,19 +18,67 @@ package org.apache.gluten.vectorized;
import io.github.zhztheplayer.velox4j.type.*;
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.MapData;
import org.apache.flink.table.data.RowData;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.*;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.*;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
+import java.util.Map;
public abstract class ArrowVectorWriter {
+ private interface WriterBuilder {
+ ArrowVectorWriter build(Type fieldType, BufferAllocator allocator,
FieldVector vector);
+ };
+
+ // Exact class matches
+ private static Map<Class<? extends Type>, WriterBuilder> writerBuilders =
+ Map.ofEntries(
+ Map.entry(
+ IntegerType.class,
+ (fieldType, allocator, vector) -> new IntVectorWriter(fieldType,
allocator, vector)),
+ Map.entry(
+ BooleanType.class,
+ (fieldType, allocator, vector) ->
+ new BooleanVectorWriter(fieldType, allocator, vector)),
+ Map.entry(
+ BigIntType.class,
+ (fieldType, allocator, vector) ->
+ new BigIntVectorWriter(fieldType, allocator, vector)),
+ Map.entry(
+ DoubleType.class,
+ (fieldType, allocator, vector) ->
+ new Float8VectorWriter(fieldType, allocator, vector)),
+ Map.entry(
+ VarCharType.class,
+ (fieldType, allocator, vector) ->
+ new VarCharVectorWriter(fieldType, allocator, vector)),
+ Map.entry(
+ TimestampType.class,
+ (fieldType, allocator, vector) ->
+ new TimestampVectorWriter(fieldType, allocator, vector)),
+ Map.entry(
+ RowType.class,
+ (fieldType, allocator, vector) ->
+ new StructVectorWriter(fieldType, allocator, vector)),
+ Map.entry(
+ ArrayType.class,
+ (fieldType, allocator, vector) ->
+ new ArrayVectorWriter(fieldType, allocator, vector)),
+ Map.entry(
+ MapType.class,
+ (fieldType, allocator, vector) -> new MapVectorWriter(fieldType,
allocator, vector)));
+
public static ArrowVectorWriter create(
String fieldName, Type fieldType, BufferAllocator allocator) {
return create(fieldName, fieldType, allocator, null);
@@ -42,40 +90,31 @@ public abstract class ArrowVectorWriter {
// Build an empty vector
vector = FieldVectorCreator.create(fieldName, fieldType, false,
allocator, null);
}
- if (fieldType instanceof IntegerType) {
- return new IntVectorWriter(fieldType, allocator, vector);
- } else if (fieldType instanceof BooleanType) {
- return new BooleanVectorWriter(fieldType, allocator, vector);
- } else if (fieldType instanceof BigIntType) {
- return new BigIntVectorWriter(fieldType, allocator, vector);
- } else if (fieldType instanceof DoubleType) {
- return new Float8VectorWriter(fieldType, allocator, vector);
- } else if (fieldType instanceof VarCharType) {
- return new VarCharVectorWriter(fieldType, allocator, vector);
- } else if (fieldType instanceof TimestampType) {
- return new TimestampVectorWriter(fieldType, allocator, vector);
- } else if (fieldType instanceof RowType) {
- return new StructVectorWriter(fieldType, allocator, vector);
- } else {
- throw new UnsupportedOperationException("ArrowVectorWriter. Unsupported
type: " + fieldType);
+ WriterBuilder builder = writerBuilders.get(fieldType.getClass());
+ if (builder == null) {
+ throw new UnsupportedOperationException(
+ "ArrowVectorWriter. Unsupported type: " +
fieldType.getClass().getName());
}
+ return builder.build(fieldType, allocator, vector);
+ }
+
+ protected FieldVector vector = null;
+ protected int valueCount = 0;
+
+ ArrowVectorWriter(FieldVector vector) {
+ this.vector = vector;
}
public void write(int fieldIndex, RowData rowData) {
throw new UnsupportedOperationException("assign is not supported");
}
- public void write(int fieldIndex, List<RowData> rowData) {
- for (RowData row : rowData) {
- write(fieldIndex, row);
- }
+ public void writeArray(ArrayData arrayData) {
+ throw new UnsupportedOperationException("writeArray is not supported");
}
- protected FieldVector vector = null;
- protected int valueCount = 0;
-
- ArrowVectorWriter(FieldVector vector) {
- this.vector = vector;
+ int getValueCount() {
+ return valueCount;
}
FieldVector getVector() {
@@ -86,6 +125,7 @@ public abstract class ArrowVectorWriter {
vector.setValueCount(valueCount);
}
}
+
// Build FieldVector from Type.
class FieldVectorCreator {
public static FieldVector create(
@@ -94,34 +134,60 @@ class FieldVectorCreator {
return field.createVector(allocator);
}
+ private interface ArrowTypeConverter {
+ ArrowType convert(Type dataType, String timeZoneId);
+ }
+
+ // Exact class matches
+ private static Map<Class<? extends Type>, ArrowTypeConverter>
arrowTypeConverters =
+ Map.ofEntries(
+ Map.entry(BooleanType.class, (dataType, timeZoneId) ->
ArrowType.Bool.INSTANCE),
+ Map.entry(IntegerType.class, (dataType, timeZoneId) -> new
ArrowType.Int(8 * 4, true)),
+ Map.entry(BigIntType.class, (dataType, timeZoneId) -> new
ArrowType.Int(8 * 8, true)),
+ Map.entry(
+ DoubleType.class,
+ (dataType, timeZoneId) -> new
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Map.entry(VarCharType.class, (dataType, timeZoneId) ->
ArrowType.Utf8.INSTANCE),
+ Map.entry(
+ TimestampType.class,
+ (dataType, timeZoneId) ->
+ new ArrowType.Timestamp(
+ TimeUnit.MILLISECOND, timeZoneId == null ? "UTC" :
timeZoneId)));
+
private static ArrowType toArrowType(Type dataType, String timeZoneId) {
- if (dataType instanceof BooleanType) {
- return ArrowType.Bool.INSTANCE;
- } else if (dataType instanceof IntegerType) {
- return new ArrowType.Int(8 * 4, true);
- } else if (dataType instanceof BigIntType) {
- return new ArrowType.Int(8 * 8, true);
- } else if (dataType instanceof DoubleType) {
- return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
- } else if (dataType instanceof VarCharType) {
- return ArrowType.Utf8.INSTANCE;
- } else if (dataType instanceof TimestampType) {
- if (timeZoneId == null) {
- return new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC");
- } else {
- return new ArrowType.Timestamp(TimeUnit.MILLISECOND, timeZoneId);
- }
- } else {
- throw new UnsupportedOperationException("Unsupported type: " + dataType);
+ ArrowTypeConverter converter =
arrowTypeConverters.get(dataType.getClass());
+ if (converter == null) {
+ throw new UnsupportedOperationException("Unsupported type: " +
dataType.getClass().getName());
}
+ return converter.convert(dataType, timeZoneId);
}
private static Field toArrowField(
String name, Type dataType, boolean nullable, String timeZoneId) {
if (dataType instanceof ArrayType) {
- throw new UnsupportedOperationException("ArrayType is not supported");
+ List<Type> elementTypes = ((ArrayType) dataType).getChildren();
+ if (elementTypes.size() != 1) {
+ throw new UnsupportedOperationException("ArrayType should have exactly
one element type");
+ }
+
+ FieldType fieldType = new FieldType(nullable, ArrowType.List.INSTANCE,
null);
+ List<Field> elementFields = new ArrayList<>();
+ elementFields.add(toArrowField("element", elementTypes.get(0), nullable,
timeZoneId));
+
+ return new Field(name, fieldType, elementFields);
+
} else if (dataType instanceof MapType) {
- throw new UnsupportedOperationException("MapType is not supported");
+ MapType mapType = (MapType) dataType;
+ FieldType mapFieldType = new FieldType(nullable, new
ArrowType.Map(false), null);
+
+ List<String> fieldNames = Arrays.asList(MapVector.KEY_NAME,
MapVector.VALUE_NAME);
+ List<Type> fieldTypes = mapType.getChildren();
+ RowType structType = new RowType(fieldNames, fieldTypes);
+ Field structField =
+ toArrowField(MapVector.DATA_VECTOR_NAME, structType, nullable,
timeZoneId);
+
+ return new Field(name, mapFieldType, Arrays.asList(structField));
+
} else if (dataType instanceof RowType) {
RowType structType = (RowType) dataType;
List<String> fieldNames = structType.getNames();
@@ -141,136 +207,302 @@ class FieldVectorCreator {
}
}
-class IntVectorWriter extends ArrowVectorWriter {
- private final IntVector intVector;
+abstract class BaseVectorWriter<T extends FieldVector, V> extends
ArrowVectorWriter {
+ protected final T typedVector;
- public IntVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ protected BaseVectorWriter(FieldVector vector) {
super(vector);
- this.intVector = (IntVector) vector;
+ this.typedVector = (T) vector;
}
+ protected abstract V getValue(RowData rowData, int fieldIndex);
+
+ protected abstract V getValue(ArrayData arrayData, int index);
+
+ protected abstract void setValue(int index, V value);
+
@Override
public void write(int fieldIndex, RowData rowData) {
- intVector.setSafe(valueCount, rowData.getInt(fieldIndex));
+ setValue(valueCount, getValue(rowData, fieldIndex));
valueCount++;
}
+
+ @Override
+ public void writeArray(ArrayData arrayData) {
+ for (int i = 0; i < arrayData.size(); i++) {
+ setValue(valueCount, getValue(arrayData, i));
+ valueCount++;
+ }
+ }
}
-class BooleanVectorWriter extends ArrowVectorWriter {
- private final BitVector bitVector;
+class IntVectorWriter extends BaseVectorWriter<IntVector, Integer> {
+ public IntVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ super(vector);
+ }
+
+ @Override
+ protected Integer getValue(RowData rowData, int fieldIndex) {
+ return rowData.getInt(fieldIndex);
+ }
+
+ @Override
+ protected Integer getValue(ArrayData arrayData, int index) {
+ return arrayData.getInt(index);
+ }
+
+ @Override
+ protected void setValue(int index, Integer value) {
+ this.typedVector.setSafe(index, value);
+ }
+}
+class BooleanVectorWriter extends BaseVectorWriter<BitVector, Boolean> {
public BooleanVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
super(vector);
- this.bitVector = (BitVector) vector;
}
@Override
- public void write(int fieldIndex, RowData rowData) {
- bitVector.setSafe(valueCount, rowData.getBoolean(fieldIndex) ? 1 : 0);
- valueCount++;
+ protected Boolean getValue(RowData rowData, int fieldIndex) {
+ return rowData.getBoolean(fieldIndex);
+ }
+
+ @Override
+ protected Boolean getValue(ArrayData arrayData, int index) {
+ return arrayData.getBoolean(index);
+ }
+
+ @Override
+ protected void setValue(int index, Boolean value) {
+ this.typedVector.setSafe(index, value ? 1 : 0);
}
}
-class BigIntVectorWriter extends ArrowVectorWriter {
- private final BigIntVector bigIntvector;
+class BigIntVectorWriter extends BaseVectorWriter<BigIntVector, Long> {
public BigIntVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
super(vector);
- this.bigIntvector = (BigIntVector) vector;
}
@Override
- public void write(int fieldIndex, RowData rowData) {
- bigIntvector.setSafe(valueCount, rowData.getLong(fieldIndex));
- valueCount++;
+ protected Long getValue(RowData rowData, int fieldIndex) {
+ return rowData.getLong(fieldIndex);
+ }
+
+ @Override
+ protected Long getValue(ArrayData arrayData, int index) {
+ return arrayData.getLong(index);
+ }
+
+ @Override
+ protected void setValue(int index, Long value) {
+ this.typedVector.setSafe(index, value);
}
}
-class Float8VectorWriter extends ArrowVectorWriter {
- private final Float8Vector float8Vector;
+class Float8VectorWriter extends BaseVectorWriter<Float8Vector, Double> {
public Float8VectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
super(vector);
- this.float8Vector = (Float8Vector) vector;
}
@Override
- public void write(int fieldIndex, RowData rowData) {
- float8Vector.setSafe(valueCount, rowData.getDouble(fieldIndex));
- valueCount++;
+ protected Double getValue(RowData rowData, int fieldIndex) {
+ return rowData.getDouble(fieldIndex);
+ }
+
+ @Override
+ protected Double getValue(ArrayData arrayData, int index) {
+ return arrayData.getDouble(index);
+ }
+
+ @Override
+ protected void setValue(int index, Double value) {
+ this.typedVector.setSafe(index, value);
}
}
-class VarCharVectorWriter extends ArrowVectorWriter {
- private final VarCharVector varCharVector;
+class VarCharVectorWriter extends BaseVectorWriter<VarCharVector, byte[]> {
public VarCharVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
super(vector);
- this.varCharVector = (VarCharVector) vector;
}
@Override
- public void write(int fieldIndex, RowData rowData) {
- varCharVector.setSafe(valueCount, rowData.getString(fieldIndex).toBytes());
- valueCount++;
+ protected byte[] getValue(RowData rowData, int fieldIndex) {
+ return rowData.getString(fieldIndex).toBytes();
+ }
+
+ @Override
+ protected byte[] getValue(ArrayData arrayData, int index) {
+ return arrayData.getString(index).toBytes();
+ }
+
+ @Override
+ protected void setValue(int index, byte[] value) {
+ this.typedVector.setSafe(index, value);
}
}
-class TimestampVectorWriter extends ArrowVectorWriter {
- private final TimeStampMilliVector tsVector;
+class TimestampVectorWriter extends BaseVectorWriter<TimeStampMilliVector,
Long> {
+ private final int precision = 3; // Millisecond precision
public TimestampVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
super(vector);
- this.tsVector = (TimeStampMilliVector) vector;
}
@Override
- public void write(int fieldIndex, RowData rowData) {
- // TODO: support precision
- tsVector.setSafe(valueCount, rowData.getTimestamp(fieldIndex,
3).getMillisecond());
- valueCount++;
+ protected Long getValue(RowData rowData, int fieldIndex) {
+ return rowData.getTimestamp(fieldIndex, precision).getMillisecond();
+ }
+
+ @Override
+ protected Long getValue(ArrayData arrayData, int index) {
+ return arrayData.getTimestamp(index, precision).getMillisecond();
+ }
+
+ @Override
+ protected void setValue(int index, Long value) {
+ this.typedVector.setSafe(index, value);
}
}
-class StructVectorWriter extends ArrowVectorWriter {
- private int fieldCounts = 0;
- BufferAllocator allocator;
- private List<ArrowVectorWriter> subFieldWriters;
- private StructVector strctVector;
+class StructVectorWriter extends BaseVectorWriter<StructVector, RowData> {
+ private final int fieldCount;
+ private BufferAllocator allocator;
+ private final List<ArrowVectorWriter> fieldWriters;
public StructVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
super(vector);
- this.strctVector = (StructVector) vector;
RowType rowType = (RowType) fieldType;
- List<String> subFieldNames = rowType.getNames();
- subFieldWriters = new ArrayList<>();
- for (int i = 0; i < subFieldNames.size(); ++i) {
- subFieldWriters.add(
+ List<String> fieldNames = rowType.getNames();
+ fieldCount = fieldNames.size();
+ fieldWriters = new ArrayList<>();
+ for (int i = 0; i < fieldCount; ++i) {
+ fieldWriters.add(
ArrowVectorWriter.create(
- subFieldNames.get(i),
+ fieldNames.get(i),
rowType.getChildren().get(i),
allocator,
- (FieldVector) (this.strctVector.getChildByOrdinal(i))));
+ (FieldVector) (this.typedVector.getChildByOrdinal(i))));
}
- fieldCounts = subFieldNames.size();
}
@Override
- public void write(int fieldIndex, RowData rowData) {
- // TODO: support nullable
- RowData subRowData = rowData.getRow(fieldIndex, fieldCounts);
- strctVector.setIndexDefined(valueCount);
- for (int i = 0; i < fieldCounts; i++) {
- subFieldWriters.get(i).write(i, subRowData);
+ protected RowData getValue(RowData rowData, int fieldIndex) {
+ return rowData.getRow(fieldIndex, fieldCount);
+ }
+
+ @Override
+ protected RowData getValue(ArrayData arrayData, int index) {
+ return arrayData.getRow(index, fieldCount);
+ }
+
+ @Override
+ protected void setValue(int index, RowData value) {
+ this.typedVector.setIndexDefined(index);
+ for (int i = 0; i < fieldCount; ++i) {
+ fieldWriters.get(i).write(i, value);
}
- valueCount++;
}
@Override
public void finish() {
- strctVector.setValueCount(valueCount);
- for (int i = 0; i < fieldCounts; i++) {
- subFieldWriters.get(i).finish();
+ this.typedVector.setValueCount(valueCount);
+ for (int i = 0; i < fieldCount; ++i) {
+ fieldWriters.get(i).finish();
}
}
}
+
+class ArrayVectorWriter extends BaseVectorWriter<ListVector, ArrayData> {
+ private final ArrowVectorWriter elementWriter;
+
+ public ArrayVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ super(vector);
+
+ FieldVector elementVector = (FieldVector) this.typedVector.getDataVector();
+ List<Type> elementTypes = ((ArrayType) fieldType).getChildren();
+ if (elementTypes.size() != 1) {
+ throw new UnsupportedOperationException("ArrayType should have exactly
one element type");
+ }
+ Type elementType = elementTypes.get(0);
+ this.elementWriter = ArrowVectorWriter.create("element", elementType,
allocator, elementVector);
+ }
+
+ @Override
+ protected ArrayData getValue(RowData rowData, int fieldIndex) {
+ return rowData.getArray(fieldIndex);
+ }
+
+ @Override
+ protected ArrayData getValue(ArrayData arrayData, int index) {
+ return arrayData.getArray(index);
+ }
+
+ @Override
+ protected void setValue(int index, ArrayData value) {
+ this.typedVector.startNewValue(valueCount);
+ elementWriter.writeArray(value);
+ this.typedVector.endValue(valueCount, value.size());
+ }
+
+ @Override
+ public void finish() {
+ this.typedVector.setValueCount(valueCount);
+ elementWriter.finish();
+ }
+}
+
+class MapVectorWriter extends BaseVectorWriter<MapVector, MapData> {
+ private final ArrowVectorWriter keyWriter;
+ private final ArrowVectorWriter valueWriter;
+ private final StructVector entriesVector;
+
+ public MapVectorWriter(Type fieldType, BufferAllocator allocator,
FieldVector vector) {
+ super(vector);
+
+ entriesVector = (StructVector) this.typedVector.getDataVector();
+
+ FieldVector keyVector = (FieldVector)
entriesVector.getChild(MapVector.KEY_NAME);
+ FieldVector valueVector = (FieldVector)
entriesVector.getChild(MapVector.VALUE_NAME);
+
+ MapType mapType = (MapType) fieldType;
+ this.keyWriter =
+ ArrowVectorWriter.create(
+ MapVector.KEY_NAME, mapType.getChildren().get(0), allocator,
keyVector);
+ this.valueWriter =
+ ArrowVectorWriter.create(
+ MapVector.VALUE_NAME, mapType.getChildren().get(1), allocator,
valueVector);
+ }
+
+ @Override
+ protected MapData getValue(RowData rowData, int fieldIndex) {
+ return rowData.getMap(fieldIndex);
+ }
+
+ @Override
+ protected MapData getValue(ArrayData arrayData, int index) {
+ return arrayData.getMap(index);
+ }
+
+ @Override
+ protected void setValue(int index, MapData value) {
+ this.typedVector.startNewValue(valueCount);
+ int arrayValueCount = keyWriter.getValueCount();
+ for (int i = 0; i < value.size(); i++) {
+ entriesVector.setIndexDefined(arrayValueCount + i);
+ }
+ keyWriter.writeArray(value.keyArray());
+ valueWriter.writeArray(value.valueArray());
+ this.typedVector.endValue(valueCount, value.size());
+ }
+
+ @Override
+ public void finish() {
+ this.typedVector.setValueCount(valueCount);
+ entriesVector.setValueCount(keyWriter.getValueCount());
+ keyWriter.finish();
+ valueWriter.finish();
+ }
+}
diff --git
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
index 4a713ef1be..47eefd5fae 100644
---
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
+++
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
@@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory;
import java.util.Arrays;
import java.util.List;
+import java.util.Map;
class ScanTest extends GlutenStreamingTestBase {
private static final Logger LOG = LoggerFactory.getLogger(ScanTest.class);
@@ -69,4 +70,64 @@ class ScanTest extends GlutenStreamingTestBase {
String query = "select a, b from floatTbl where a > 0";
runAndCheck(query, Arrays.asList("+I[1, 1.0]", "+I[2, 2.0]"));
}
+
+ @Test
+ void testArrayScan() {
+ List<Row> rows =
+ Arrays.asList(
+ Row.of(1, new Integer[] {1, 2, 3}),
+ Row.of(2, new Integer[] {4, 5, 6}),
+ Row.of(3, new Integer[] {7, 8, 9}));
+ createSimpleBoundedValuesTable("arrayTbl1", "a int, b array<int>", rows);
+ String query = "select a, b from arrayTbl1 where a > 0";
+ runAndCheck(query, Arrays.asList("+I[1, [1, 2, 3]]", "+I[2, [4, 5, 6]]",
"+I[3, [7, 8, 9]]"));
+
+ rows =
+ Arrays.asList(
+ Row.of(1, new String[] {"a", "b", "c"}),
+ Row.of(2, new String[] {"d", "e", "f"}),
+ Row.of(3, new String[] {"g", "h", "i"}));
+ createSimpleBoundedValuesTable("arrayTbl2", "a int, b array<string>",
rows);
+ query = "select a, b from arrayTbl2 where a > 0";
+ runAndCheck(query, Arrays.asList("+I[1, [a, b, c]]", "+I[2, [d, e, f]]",
"+I[3, [g, h, i]]"));
+
+ rows =
+ Arrays.asList(
+ Row.of(1, new Row[] {Row.of(1, 2), Row.of(3, 4)}), Row.of(3, new
Row[] {Row.of(5, 6)}));
+ createSimpleBoundedValuesTable("arrayTbl3", "a int, b array<ROW<x int, y
int>>", rows);
+ query = "select a, b from arrayTbl3 where a > 0";
+ runAndCheck(query, Arrays.asList("+I[1, [+I[1, 2], +I[3, 4]]]", "+I[3,
[+I[5, 6]]]"));
+
+ rows =
+ Arrays.asList(
+ Row.of(1, new Integer[][] {new Integer[] {1, 3}}),
+ Row.of(3, new Integer[][] {new Integer[] {4, 5}}));
+ createSimpleBoundedValuesTable("arrayTbl4", "a int, b array<array<int>>",
rows);
+ query = "select a, b from arrayTbl4 where a > 0";
+ runAndCheck(query, Arrays.asList("+I[1, [[1, 3]]]", "+I[3, [[4, 5]]]"));
+ }
+
+ @Test
+ void testMapScan() {
+ List<Row> rows =
+ Arrays.asList(
+ Row.of(1, Map.of(1, "a")),
+ Row.of(2, Map.of(2, "b", 3, "c")),
+ Row.of(3, Map.of(4, "d", 5, "e", 6, "f")));
+ createSimpleBoundedValuesTable("mapTbl1", "a int, b map<int, string>",
rows);
+ String query = "select a, b from mapTbl1 where a > 0";
+ runAndCheck(
+ query, Arrays.asList("+I[1, {1=a}]", "+I[2, {2=b, 3=c}]", "+I[3, {4=d,
5=e, 6=f}]"));
+
+ rows =
+ Arrays.asList(
+ Row.of(1, new Map[] {Map.of("a", 1), Map.of("b", 2)}),
+ Row.of(2, new Map[] {Map.of("b", 2, "c", 3)}),
+ Row.of(3, new Map[] {Map.of("d", 4, "e", 5, "f", 6)}));
+ createSimpleBoundedValuesTable("mapTbl2", "a int, b array<map<string,
int>>", rows);
+ query = "select a, b from mapTbl2 where a > 0";
+ runAndCheck(
+ query,
+ Arrays.asList("+I[1, [{a=1}, {b=2}]]", "+I[2, [{b=2, c=3}]]", "+I[3,
[{d=4, e=5, f=6}]]"));
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]