durgaprasadml commented on code in PR #38706: URL: https://github.com/apache/beam/pull/38706#discussion_r3334647575
########## sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/SerializableRow.java: ########## @@ -0,0 +1,534 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.delta; + +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.Row; +import io.delta.kernel.types.ArrayType; +import io.delta.kernel.types.BinaryType; +import io.delta.kernel.types.BooleanType; +import io.delta.kernel.types.ByteType; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.DecimalType; +import io.delta.kernel.types.DoubleType; +import io.delta.kernel.types.FloatType; +import io.delta.kernel.types.IntegerType; +import io.delta.kernel.types.LongType; +import io.delta.kernel.types.MapType; +import io.delta.kernel.types.ShortType; +import io.delta.kernel.types.StringType; +import io.delta.kernel.types.StructType; +import java.io.Serializable; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A serializable wrapper for Delta {@link Row} that implements the {@link Row} interface itself, + * allowing worker nodes to access serialized Row objects using standard Delta Kernel APIs. + */ +public class SerializableRow implements Row, Serializable { + private static final long serialVersionUID = 1L; + + private final SerializableStructType schema; + private final @Nullable Object[] values; + + public SerializableRow(Row row) { + this.schema = new SerializableStructType(row.getSchema()); + StructType structType = row.getSchema(); + int numFields = structType.fields().size(); + this.values = new Object[numFields]; + for (int i = 0; i < numFields; i++) { + DataType type = structType.fields().get(i).getDataType(); + this.values[i] = getValue(row, i, type); + } + } + + @Override + public StructType getSchema() { + return schema.get(); + } + + @Override + public boolean isNullAt(int ord) { + return values == null || values[ord] == null; + } + + @Override + public boolean getBoolean(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Boolean) val : false; + } + + @Override + public byte getByte(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Byte) val : 0; + } + + @Override + public short getShort(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Short) val : 0; + } + + @Override + public int getInt(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Integer) val : 0; + } + + @Override + public long getLong(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Long) val : 0L; + } + + @Override + public float getFloat(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Float) val : 0.0f; + } + + @Override + public double getDouble(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Double) val : 0.0d; + } + + @Override + @SuppressWarnings("nullness") + public String getString(int ord) { + return (String) Objects.requireNonNull(values)[ord]; + } + + @Override + @SuppressWarnings("nullness") + public byte[] getBinary(int ord) { + return (byte[]) Objects.requireNonNull(values)[ord]; + } + + @Override + @SuppressWarnings("nullness") + public BigDecimal getDecimal(int ord) { + return (BigDecimal) Objects.requireNonNull(values)[ord]; + } + + @Override + @SuppressWarnings("nullness") + public Row getStruct(int ord) { + return (Row) Objects.requireNonNull(values)[ord]; + } + + @Override + @SuppressWarnings({"unchecked", "nullness"}) + public ArrayValue getArray(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + if (val == null) { + return null; + } + DataType elementType = + ((ArrayType) getSchema().fields().get(ord).getDataType()).getElementType(); + return new SerializableArrayValue((List<@Nullable Object>) val, elementType); + } + + @Override + @SuppressWarnings({"unchecked", "nullness"}) + public MapValue getMap(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + if (val == null) { + return null; + } + MapType mapType = (MapType) getSchema().fields().get(ord).getDataType(); + return new SerializableMapValue( + (Map<Object, @Nullable Object>) val, mapType.getKeyType(), mapType.getValueType()); + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SerializableRow)) { + return false; + } + SerializableRow that = (SerializableRow) o; + return Objects.equals(schema, that.schema) && java.util.Arrays.deepEquals(values, that.values); + } + + @Override + public int hashCode() { + return Objects.hash(schema, java.util.Arrays.deepHashCode(values)); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("SerializableRow{schema=").append(schema).append(", values=["); + if (values != null) { + for (int i = 0; i < values.length; i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(values[i]); + } + } + sb.append("]}"); + return sb.toString(); + } + + private static @Nullable Object getValue(Row row, int index, DataType type) { + if (row.isNullAt(index)) { + return null; + } + if (type instanceof BooleanType) { + return row.getBoolean(index); + } else if (type instanceof ByteType) { + return row.getByte(index); + } else if (type instanceof ShortType) { + return row.getShort(index); + } else if (type instanceof IntegerType) { + return row.getInt(index); + } else if (type instanceof LongType) { + return row.getLong(index); + } else if (type instanceof FloatType) { + return row.getFloat(index); + } else if (type instanceof DoubleType) { + return row.getDouble(index); + } else if (type instanceof StringType) { + return row.getString(index); + } else if (type instanceof BinaryType) { + return row.getBinary(index); + } else if (type instanceof DecimalType) { + return row.getDecimal(index); + } else if (type instanceof StructType) { + return new SerializableRow(row.getStruct(index)); + } else if (type instanceof ArrayType) { + ArrayValue arr = row.getArray(index); + return convertArray(arr, (ArrayType) type); + } else if (type instanceof MapType) { + MapValue map = row.getMap(index); + return convertMap(map, (MapType) type); + } + throw new IllegalArgumentException("Unsupported type: " + type); + } + + private static @Nullable Object getVectorValue(ColumnVector vector, int index, DataType type) { + if (vector.isNullAt(index)) { + return null; + } + if (type instanceof BooleanType) { + return vector.getBoolean(index); + } else if (type instanceof ByteType) { + return vector.getByte(index); + } else if (type instanceof ShortType) { + return vector.getShort(index); + } else if (type instanceof IntegerType) { + return vector.getInt(index); + } else if (type instanceof LongType) { + return vector.getLong(index); + } else if (type instanceof FloatType) { + return vector.getFloat(index); + } else if (type instanceof DoubleType) { + return vector.getDouble(index); + } else if (type instanceof StringType) { + return vector.getString(index); + } else if (type instanceof BinaryType) { + return vector.getBinary(index); + } else if (type instanceof DecimalType) { + return vector.getDecimal(index); + } else if (type instanceof StructType) { + StructType structType = (StructType) type; + int numFields = structType.fields().size(); + ColumnVector[] childFields = new ColumnVector[numFields]; + for (int j = 0; j < numFields; j++) { + childFields[j] = vector.getChild(j); + } + return new SerializableRow(new VectorRow(structType, childFields, index)); + } else if (type instanceof ArrayType) { + ArrayValue arr = vector.getArray(index); + return convertArray(arr, (ArrayType) type); + } else if (type instanceof MapType) { + MapValue map = vector.getMap(index); + return convertMap(map, (MapType) type); + } + throw new IllegalArgumentException("Unsupported vector type: " + type); Review Comment: It may also be useful to support DateType / TimestampType in the vector conversion path here, since nested arrays/maps/partition metadata can contain these logical types and currently fall through to the unsupported-type exception. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
