http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index 074b0aa..a12440e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -24,6 +24,10 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ObjectArrays; + +import io.netty.buffer.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.OutOfMemoryException; import org.apache.arrow.vector.AddOrGetResult; @@ -42,16 +46,12 @@ import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.schema.ArrowFieldNode; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.util.CallBack; import org.apache.arrow.vector.util.JsonStringArrayList; import org.apache.arrow.vector.util.TransferPair; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ObjectArrays; - -import io.netty.buffer.ArrowBuf; - public class ListVector extends BaseRepeatedValueVector implements FieldVector { final UInt4Vector offsets; @@ -62,14 +62,16 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector { private UnionListWriter writer; private UnionListReader reader; private CallBack callBack; + private final DictionaryEncoding dictionary; - public ListVector(String name, BufferAllocator allocator, CallBack callBack) { + public ListVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack) { super(name, allocator); this.bits = new BitVector("$bits$", allocator); this.offsets = getOffsetVector(); this.innerVectors = Collections.unmodifiableList(Arrays.<BufferBacked>asList(bits, offsets)); this.writer = new UnionListWriter(this); this.reader = new UnionListReader(this); + this.dictionary = dictionary; this.callBack = callBack; } @@ -80,7 +82,7 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector { } Field field = children.get(0); MinorType minorType = Types.getMinorTypeForArrowType(field.getType()); - AddOrGetResult<FieldVector> addOrGetVector = addOrGetVector(minorType); + AddOrGetResult<FieldVector> addOrGetVector = addOrGetVector(minorType, field.getDictionary()); if (!addOrGetVector.isCreated()) { throw new IllegalArgumentException("Child vector already existed: " + addOrGetVector.getVector()); } @@ -151,16 +153,16 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector { TransferPair pairs[] = new TransferPair[3]; public TransferImpl(String name, BufferAllocator allocator) { - this(new ListVector(name, allocator, null)); + this(new ListVector(name, allocator, dictionary, null)); } public TransferImpl(ListVector to) { this.to = to; - to.addOrGetVector(vector.getMinorType()); + to.addOrGetVector(vector.getMinorType(), vector.getField().getDictionary()); pairs[0] = offsets.makeTransferPair(to.offsets); pairs[1] = bits.makeTransferPair(to.bits); if (to.getDataVector() instanceof ZeroVector) { - to.addOrGetVector(vector.getMinorType()); + to.addOrGetVector(vector.getMinorType(), vector.getField().getDictionary()); } pairs[2] = getDataVector().makeTransferPair(to.getDataVector()); } @@ -232,8 +234,8 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector { return success; } - public <T extends ValueVector> AddOrGetResult<T> addOrGetVector(MinorType minorType) { - AddOrGetResult<T> result = super.addOrGetVector(minorType); + public <T extends ValueVector> AddOrGetResult<T> addOrGetVector(MinorType minorType, DictionaryEncoding dictionary) { + AddOrGetResult<T> result = super.addOrGetVector(minorType, dictionary); reader = new UnionListReader(this); return result; }
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java index 31a1bb7..4d750ca 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java @@ -160,7 +160,7 @@ public class MapVector extends AbstractMapVector { // (This is similar to what happens in ScanBatch where the children cannot be added till they are // read). To take care of this, we ensure that the hashCode of the MaterializedField does not // include the hashCode of the children but is based only on MaterializedField$key. - final FieldVector newVector = to.addOrGet(child, vector.getMinorType(), vector.getClass()); + final FieldVector newVector = to.addOrGet(child, vector.getMinorType(), vector.getClass(), vector.getField().getDictionary()); if (allocate && to.size() != preSize) { newVector.allocateNew(); } @@ -314,12 +314,11 @@ public class MapVector extends AbstractMapVector { public void initializeChildrenFromFields(List<Field> children) { for (Field field : children) { MinorType minorType = Types.getMinorTypeForArrowType(field.getType()); - FieldVector vector = (FieldVector)this.add(field.getName(), minorType); + FieldVector vector = (FieldVector)this.add(field.getName(), minorType, field.getDictionary()); vector.initializeChildrenFromFields(field.getChildren()); } } - public List<FieldVector> getChildrenFromFields() { return getChildren(); } http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java index 5fa3530..bb1fdf8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java @@ -34,6 +34,7 @@ import org.apache.arrow.vector.complex.impl.NullableMapReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holders.ComplexHolder; import org.apache.arrow.vector.schema.ArrowFieldNode; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.util.CallBack; import org.apache.arrow.vector.util.TransferPair; @@ -48,14 +49,16 @@ public class NullableMapVector extends MapVector implements FieldVector { protected final BitVector bits; private final List<BufferBacked> innerVectors; + private final DictionaryEncoding dictionary; private final Accessor accessor; private final Mutator mutator; - public NullableMapVector(String name, BufferAllocator allocator, CallBack callBack) { + public NullableMapVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack) { super(name, checkNotNull(allocator), callBack); this.bits = new BitVector("$bits$", allocator); this.innerVectors = Collections.unmodifiableList(Arrays.<BufferBacked>asList(bits)); + this.dictionary = dictionary; this.accessor = new Accessor(); this.mutator = new Mutator(); } @@ -83,7 +86,7 @@ public class NullableMapVector extends MapVector implements FieldVector { @Override public TransferPair getTransferPair(BufferAllocator allocator) { - return new NullableMapTransferPair(this, new NullableMapVector(name, allocator, callBack), false); + return new NullableMapTransferPair(this, new NullableMapVector(name, allocator, dictionary, callBack), false); } @Override @@ -93,7 +96,7 @@ public class NullableMapVector extends MapVector implements FieldVector { @Override public TransferPair getTransferPair(String ref, BufferAllocator allocator) { - return new NullableMapTransferPair(this, new NullableMapVector(ref, allocator, callBack), false); + return new NullableMapTransferPair(this, new NullableMapVector(ref, allocator, dictionary, callBack), false); } protected class NullableMapTransferPair extends MapTransferPair { http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java index dbdd205..6d05316 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java @@ -149,7 +149,8 @@ public class ComplexWriterImpl extends AbstractFieldWriter implements ComplexWri switch(mode){ case INIT: - NullableMapVector map = container.addOrGet(name, MinorType.MAP, NullableMapVector.class); + // TODO allow dictionaries in complex types + NullableMapVector map = container.addOrGet(name, MinorType.MAP, NullableMapVector.class, null); mapRoot = nullableMapWriterFactory.build(map); mapRoot.setPosition(idx()); mode = Mode.MAP; @@ -180,7 +181,8 @@ public class ComplexWriterImpl extends AbstractFieldWriter implements ComplexWri case INIT: int vectorCount = container.size(); - ListVector listVector = container.addOrGet(name, MinorType.LIST, ListVector.class); + // TODO allow dictionaries in complex types + ListVector listVector = container.addOrGet(name, MinorType.LIST, ListVector.class, null); if (container.size() > vectorCount) { listVector.allocateNew(); } http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java index 1f7253b..e33319a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java @@ -125,7 +125,7 @@ public class PromotableWriter extends AbstractPromotableFieldWriter { // ??? return null; } - ValueVector v = listVector.addOrGetVector(type).getVector(); + ValueVector v = listVector.addOrGetVector(type, null).getVector(); v.allocateNew(); setWriter(v); writer.setPosition(position); @@ -150,7 +150,8 @@ public class PromotableWriter extends AbstractPromotableFieldWriter { TransferPair tp = vector.getTransferPair(vector.getMinorType().name().toLowerCase(), vector.getAllocator()); tp.transfer(); if (parentContainer != null) { - unionVector = parentContainer.addOrGet(name, MinorType.UNION, UnionVector.class); + // TODO allow dictionaries in complex types + unionVector = parentContainer.addOrGet(name, MinorType.UNION, UnionVector.class, null); unionVector.allocateNew(); } else if (listVector != null) { unionVector = listVector.promoteToUnion(); http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java new file mode 100644 index 0000000..0c1cadf --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java @@ -0,0 +1,66 @@ +/******************************************************************************* + + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.apache.arrow.vector.dictionary; + +import java.util.Objects; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; + +public class Dictionary { + + private final DictionaryEncoding encoding; + private final FieldVector dictionary; + + public Dictionary(FieldVector dictionary, DictionaryEncoding encoding) { + this.dictionary = dictionary; + this.encoding = encoding; + } + + public FieldVector getVector() { + return dictionary; + } + + public DictionaryEncoding getEncoding() { + return encoding; + } + + public ArrowType getVectorType() { + return dictionary.getField().getType(); + } + + @Override + public String toString() { + return "Dictionary " + encoding + " " + dictionary; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Dictionary that = (Dictionary) o; + return Objects.equals(encoding, that.encoding) && Objects.equals(dictionary, that.dictionary); + } + + @Override + public int hashCode() { + return Objects.hash(encoding, dictionary); + } +} http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java new file mode 100644 index 0000000..0666bc4 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java @@ -0,0 +1,144 @@ +/******************************************************************************* + + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.apache.arrow.vector.dictionary; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Map; + +import com.google.common.collect.ImmutableList; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.TransferPair; + +public class DictionaryEncoder { + + // TODO recursively examine fields? + + /** + * Dictionary encodes a vector with a provided dictionary. The dictionary must contain all values in the vector. + * + * @param vector vector to encode + * @param dictionary dictionary used for encoding + * @return dictionary encoded vector + */ + public static ValueVector encode(ValueVector vector, Dictionary dictionary) { + validateType(vector.getMinorType()); + // load dictionary values into a hashmap for lookup + ValueVector.Accessor dictionaryAccessor = dictionary.getVector().getAccessor(); + Map<Object, Integer> lookUps = new HashMap<>(dictionaryAccessor.getValueCount()); + for (int i = 0; i < dictionaryAccessor.getValueCount(); i++) { + // for primitive array types we need a wrapper that implements equals and hashcode appropriately + lookUps.put(dictionaryAccessor.getObject(i), i); + } + + Field valueField = vector.getField(); + Field indexField = new Field(valueField.getName(), valueField.isNullable(), + dictionary.getEncoding().getIndexType(), dictionary.getEncoding(), null); + + // vector to hold our indices (dictionary encoded values) + FieldVector indices = indexField.createVector(vector.getAllocator()); + ValueVector.Mutator mutator = indices.getMutator(); + + // use reflection to pull out the set method + // TODO implement a common interface for int vectors + Method setter = null; + for (Class<?> c: ImmutableList.of(int.class, long.class)) { + try { + setter = mutator.getClass().getMethod("set", int.class, c); + break; + } catch(NoSuchMethodException e) { + // ignore + } + } + if (setter == null) { + throw new IllegalArgumentException("Dictionary encoding does not have a valid int type:" + indices.getClass()); + } + + ValueVector.Accessor accessor = vector.getAccessor(); + int count = accessor.getValueCount(); + + indices.allocateNew(); + + try { + for (int i = 0; i < count; i++) { + Object value = accessor.getObject(i); + if (value != null) { // if it's null leave it null + // note: this may fail if value was not included in the dictionary + Object encoded = lookUps.get(value); + if (encoded == null) { + throw new IllegalArgumentException("Dictionary encoding not defined for value:" + value); + } + setter.invoke(mutator, i, encoded); + } + } + } catch (IllegalAccessException e) { + throw new RuntimeException("IllegalAccessException invoking vector mutator set():", e); + } catch (InvocationTargetException e) { + throw new RuntimeException("InvocationTargetException invoking vector mutator set():", e.getCause()); + } + + mutator.setValueCount(count); + + return indices; + } + + /** + * Decodes a dictionary encoded array using the provided dictionary. + * + * @param indices dictionary encoded values, must be int type + * @param dictionary dictionary used to decode the values + * @return vector with values restored from dictionary + */ + public static ValueVector decode(ValueVector indices, Dictionary dictionary) { + ValueVector.Accessor accessor = indices.getAccessor(); + int count = accessor.getValueCount(); + ValueVector dictionaryVector = dictionary.getVector(); + int dictionaryCount = dictionaryVector.getAccessor().getValueCount(); + // copy the dictionary values into the decoded vector + TransferPair transfer = dictionaryVector.getTransferPair(indices.getAllocator()); + transfer.getTo().allocateNewSafe(); + for (int i = 0; i < count; i++) { + Object index = accessor.getObject(i); + if (index != null) { + int indexAsInt = ((Number) index).intValue(); + if (indexAsInt > dictionaryCount) { + throw new IllegalArgumentException("Provided dictionary does not contain value for index " + indexAsInt); + } + transfer.copyValueSafe(indexAsInt, i); + } + } + // TODO do we need to worry about the field? + ValueVector decoded = transfer.getTo(); + decoded.getMutator().setValueCount(count); + return decoded; + } + + private static void validateType(MinorType type) { + // byte arrays don't work as keys in our dictionary map - we could wrap them with something to + // implement equals and hashcode if we want that functionality + if (type == MinorType.VARBINARY || type == MinorType.LIST || type == MinorType.MAP || type == MinorType.UNION) { + throw new IllegalArgumentException("Dictionary encoding for complex types not implemented: type " + type); + } + } +} http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java new file mode 100644 index 0000000..63fde25 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.dictionary; + +import java.util.HashMap; +import java.util.Map; + +public interface DictionaryProvider { + + public Dictionary lookup(long id); + + public static class MapDictionaryProvider implements DictionaryProvider { + + private final Map<Long, Dictionary> map; + + public MapDictionaryProvider(Dictionary... dictionaries) { + this.map = new HashMap<>(); + for (Dictionary dictionary: dictionaries) { + put(dictionary); + } + } + + public void put(Dictionary dictionary) { + map.put(dictionary.getEncoding().getId(), dictionary); + } + + @Override + public Dictionary lookup(long id) { + return map.get(id); + } + } +} http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java new file mode 100644 index 0000000..28440a1 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java @@ -0,0 +1,142 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SeekableByteChannel; +import java.util.Arrays; +import java.util.List; + +import org.apache.arrow.flatbuf.Footer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.schema.ArrowDictionaryBatch; +import org.apache.arrow.vector.schema.ArrowMessage; +import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.stream.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ArrowFileReader extends ArrowReader<SeekableReadChannel> { + + private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFileReader.class); + + private ArrowFooter footer; + private int currentDictionaryBatch = 0; + private int currentRecordBatch = 0; + + public ArrowFileReader(SeekableByteChannel in, BufferAllocator allocator) { + super(new SeekableReadChannel(in), allocator); + } + + public ArrowFileReader(SeekableReadChannel in, BufferAllocator allocator) { + super(in, allocator); + } + + @Override + protected Schema readSchema(SeekableReadChannel in) throws IOException { + if (footer == null) { + if (in.size() <= (ArrowMagic.MAGIC_LENGTH * 2 + 4)) { + throw new InvalidArrowFileException("file too small: " + in.size()); + } + ByteBuffer buffer = ByteBuffer.allocate(4 + ArrowMagic.MAGIC_LENGTH); + long footerLengthOffset = in.size() - buffer.remaining(); + in.setPosition(footerLengthOffset); + in.readFully(buffer); + buffer.flip(); + byte[] array = buffer.array(); + if (!ArrowMagic.validateMagic(Arrays.copyOfRange(array, 4, array.length))) { + throw new InvalidArrowFileException("missing Magic number " + Arrays.toString(buffer.array())); + } + int footerLength = MessageSerializer.bytesToInt(array); + if (footerLength <= 0 || footerLength + ArrowMagic.MAGIC_LENGTH * 2 + 4 > in.size()) { + throw new InvalidArrowFileException("invalid footer length: " + footerLength); + } + long footerOffset = footerLengthOffset - footerLength; + LOGGER.debug(String.format("Footer starts at %d, length: %d", footerOffset, footerLength)); + ByteBuffer footerBuffer = ByteBuffer.allocate(footerLength); + in.setPosition(footerOffset); + in.readFully(footerBuffer); + footerBuffer.flip(); + Footer footerFB = Footer.getRootAsFooter(footerBuffer); + this.footer = new ArrowFooter(footerFB); + } + return footer.getSchema(); + } + + @Override + protected ArrowMessage readMessage(SeekableReadChannel in, BufferAllocator allocator) throws IOException { + if (currentDictionaryBatch < footer.getDictionaries().size()) { + ArrowBlock block = footer.getDictionaries().get(currentDictionaryBatch++); + return readDictionaryBatch(in, block, allocator); + } else if (currentRecordBatch < footer.getRecordBatches().size()) { + ArrowBlock block = footer.getRecordBatches().get(currentRecordBatch++); + return readRecordBatch(in, block, allocator); + } else { + return null; + } + } + + public List<ArrowBlock> getDictionaryBlocks() throws IOException { + ensureInitialized(); + return footer.getDictionaries(); + } + + public List<ArrowBlock> getRecordBlocks() throws IOException { + ensureInitialized(); + return footer.getRecordBatches(); + } + + public void loadRecordBatch(ArrowBlock block) throws IOException { + ensureInitialized(); + int blockIndex = footer.getRecordBatches().indexOf(block); + if (blockIndex == -1) { + throw new IllegalArgumentException("Arrow bock does not exist in record batches: " + block); + } + currentRecordBatch = blockIndex; + loadNextBatch(); + } + + private ArrowDictionaryBatch readDictionaryBatch(SeekableReadChannel in, + ArrowBlock block, + BufferAllocator allocator) throws IOException { + LOGGER.debug(String.format("DictionaryRecordBatch at %d, metadata: %d, body: %d", + block.getOffset(), block.getMetadataLength(), block.getBodyLength())); + in.setPosition(block.getOffset()); + ArrowDictionaryBatch batch = MessageSerializer.deserializeDictionaryBatch(in, block, allocator); + if (batch == null) { + throw new IOException("Invalid file. No batch at offset: " + block.getOffset()); + } + return batch; + } + + private ArrowRecordBatch readRecordBatch(SeekableReadChannel in, + ArrowBlock block, + BufferAllocator allocator) throws IOException { + LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", + block.getOffset(), block.getMetadataLength(), + block.getBodyLength())); + in.setPosition(block.getOffset()); + ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(in, block, allocator); + if (batch == null) { + throw new IOException("Invalid file. No batch at offset: " + block.getOffset()); + } + return batch; + } +} http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java new file mode 100644 index 0000000..23d210a --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; +import java.util.List; + +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ArrowFileWriter extends ArrowWriter { + + private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFileWriter.class); + + public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + super(root, provider, out); + } + + @Override + protected void startInternal(WriteChannel out) throws IOException { + ArrowMagic.writeMagic(out); + } + + @Override + protected void endInternal(WriteChannel out, + Schema schema, + List<ArrowBlock> dictionaries, + List<ArrowBlock> records) throws IOException { + long footerStart = out.getCurrentPosition(); + out.write(new ArrowFooter(schema, dictionaries, records), false); + int footerLength = (int)(out.getCurrentPosition() - footerStart); + if (footerLength <= 0) { + throw new InvalidArrowFileException("invalid footer"); + } + out.writeIntLittleEndian(footerLength); + LOGGER.debug(String.format("Footer starts at %d, length: %d", footerStart, footerLength)); + ArrowMagic.writeMagic(out); + LOGGER.debug(String.format("magic written, now at %d", out.getCurrentPosition())); + } +} http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java index 3890306..1c0008a 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java @@ -38,7 +38,6 @@ public class ArrowFooter implements FBSerializable { private final List<ArrowBlock> recordBatches; public ArrowFooter(Schema schema, List<ArrowBlock> dictionaries, List<ArrowBlock> recordBatches) { - super(); this.schema = schema; this.dictionaries = dictionaries; this.recordBatches = recordBatches; http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java new file mode 100644 index 0000000..99ea96b --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java @@ -0,0 +1,37 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +public class ArrowMagic { + + private static final byte[] MAGIC = "ARROW1".getBytes(StandardCharsets.UTF_8); + + public static final int MAGIC_LENGTH = MAGIC.length; + + public static void writeMagic(WriteChannel out) throws IOException { + out.write(MAGIC); + } + + public static boolean validateMagic(byte[] array) { + return Arrays.equals(MAGIC, array); + } +} http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java index 8f4f497..1646fbe 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java @@ -18,90 +18,188 @@ package org.apache.arrow.vector.file; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.SeekableByteChannel; -import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.google.common.collect.ImmutableList; -import org.apache.arrow.flatbuf.Footer; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.schema.ArrowDictionaryBatch; +import org.apache.arrow.vector.schema.ArrowMessage; +import org.apache.arrow.vector.schema.ArrowMessage.ArrowMessageVisitor; import org.apache.arrow.vector.schema.ArrowRecordBatch; -import org.apache.arrow.vector.stream.MessageSerializer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class ArrowReader implements AutoCloseable { - private static final Logger LOGGER = LoggerFactory.getLogger(ArrowReader.class); - - public static final byte[] MAGIC = "ARROW1".getBytes(); +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.Int; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; - private final SeekableByteChannel in; +public abstract class ArrowReader<T extends ReadChannel> implements DictionaryProvider, AutoCloseable { + private final T in; private final BufferAllocator allocator; - private ArrowFooter footer; + private VectorLoader loader; + private VectorSchemaRoot root; + private Map<Long, Dictionary> dictionaries; - public ArrowReader(SeekableByteChannel in, BufferAllocator allocator) { - super(); + private boolean initialized = false; + + protected ArrowReader(T in, BufferAllocator allocator) { this.in = in; this.allocator = allocator; } - private int readFully(ByteBuffer buffer) throws IOException { - int total = 0; - int n; - do { - n = in.read(buffer); - total += n; - } while (n >= 0 && buffer.remaining() > 0); - buffer.flip(); - return total; + /** + * Returns the vector schema root. This will be loaded with new values on every call to loadNextBatch + * + * @return the vector schema root + * @throws IOException if reading of schema fails + */ + public VectorSchemaRoot getVectorSchemaRoot() throws IOException { + ensureInitialized(); + return root; } - public ArrowFooter readFooter() throws IOException { - if (footer == null) { - if (in.size() <= (MAGIC.length * 2 + 4)) { - throw new InvalidArrowFileException("file too small: " + in.size()); - } - ByteBuffer buffer = ByteBuffer.allocate(4 + MAGIC.length); - long footerLengthOffset = in.size() - buffer.remaining(); - in.position(footerLengthOffset); - readFully(buffer); - byte[] array = buffer.array(); - if (!Arrays.equals(MAGIC, Arrays.copyOfRange(array, 4, array.length))) { - throw new InvalidArrowFileException("missing Magic number " + Arrays.toString(buffer.array())); - } - int footerLength = MessageSerializer.bytesToInt(array); - if (footerLength <= 0 || footerLength + MAGIC.length * 2 + 4 > in.size()) { - throw new InvalidArrowFileException("invalid footer length: " + footerLength); - } - long footerOffset = footerLengthOffset - footerLength; - LOGGER.debug(String.format("Footer starts at %d, length: %d", footerOffset, footerLength)); - ByteBuffer footerBuffer = ByteBuffer.allocate(footerLength); - in.position(footerOffset); - readFully(footerBuffer); - Footer footerFB = Footer.getRootAsFooter(footerBuffer); - this.footer = new ArrowFooter(footerFB); + /** + * Returns any dictionaries + * + * @return dictionaries, if any + * @throws IOException if reading of schema fails + */ + public Map<Long, Dictionary> getDictionaryVectors() throws IOException { + ensureInitialized(); + return dictionaries; + } + + @Override + public Dictionary lookup(long id) { + if (initialized) { + return dictionaries.get(id); + } else { + return null; } - return footer; } - // TODO: read dictionaries - - public ArrowRecordBatch readRecordBatch(ArrowBlock block) throws IOException { - LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", - block.getOffset(), block.getMetadataLength(), - block.getBodyLength())); - in.position(block.getOffset()); - ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch( - new ReadChannel(in, block.getOffset()), block, allocator); - if (batch == null) { - throw new IOException("Invalid file. No batch at offset: " + block.getOffset()); + public void loadNextBatch() throws IOException { + ensureInitialized(); + // read in all dictionary batches, then stop after our first record batch + ArrowMessageVisitor<Boolean> visitor = new ArrowMessageVisitor<Boolean>() { + @Override + public Boolean visit(ArrowDictionaryBatch message) { + try { load(message); } finally { message.close(); } + return true; + } + @Override + public Boolean visit(ArrowRecordBatch message) { + try { loader.load(message); } finally { message.close(); } + return false; + } + }; + root.setRowCount(0); + ArrowMessage message = readMessage(in, allocator); + while (message != null && message.accepts(visitor)) { + message = readMessage(in, allocator); } - return batch; } + public long bytesRead() { return in.bytesRead(); } + @Override public void close() throws IOException { + if (initialized) { + root.close(); + for (Dictionary dictionary: dictionaries.values()) { + dictionary.getVector().close(); + } + } in.close(); } + + protected abstract Schema readSchema(T in) throws IOException; + + protected abstract ArrowMessage readMessage(T in, BufferAllocator allocator) throws IOException; + + protected void ensureInitialized() throws IOException { + if (!initialized) { + initialize(); + initialized = true; + } + } + + /** + * Reads the schema and initializes the vectors + */ + private void initialize() throws IOException { + Schema schema = readSchema(in); + List<Field> fields = new ArrayList<>(); + List<FieldVector> vectors = new ArrayList<>(); + Map<Long, Dictionary> dictionaries = new HashMap<>(); + + for (Field field: schema.getFields()) { + Field updated = toMemoryFormat(field, dictionaries); + fields.add(updated); + vectors.add(updated.createVector(allocator)); + } + + this.root = new VectorSchemaRoot(fields, vectors, 0); + this.loader = new VectorLoader(root); + this.dictionaries = Collections.unmodifiableMap(dictionaries); + } + + // in the message format, fields have the dictionary type + // in the memory format, they have the index type + private Field toMemoryFormat(Field field, Map<Long, Dictionary> dictionaries) { + DictionaryEncoding encoding = field.getDictionary(); + List<Field> children = field.getChildren(); + + if (encoding == null && children.isEmpty()) { + return field; + } + + List<Field> updatedChildren = new ArrayList<>(children.size()); + for (Field child: children) { + updatedChildren.add(toMemoryFormat(child, dictionaries)); + } + + ArrowType type; + if (encoding == null) { + type = field.getType(); + } else { + // re-type the field for in-memory format + type = encoding.getIndexType(); + if (type == null) { + type = new Int(32, true); + } + // get existing or create dictionary vector + if (!dictionaries.containsKey(encoding.getId())) { + // create a new dictionary vector for the values + Field dictionaryField = new Field(field.getName(), field.isNullable(), field.getType(), null, children); + FieldVector dictionaryVector = dictionaryField.createVector(allocator); + dictionaries.put(encoding.getId(), new Dictionary(dictionaryVector, encoding)); + } + } + + return new Field(field.getName(), field.isNullable(), type, encoding, updatedChildren); + } + + private void load(ArrowDictionaryBatch dictionaryBatch) { + long id = dictionaryBatch.getDictionaryId(); + Dictionary dictionary = dictionaries.get(id); + if (dictionary == null) { + throw new IllegalArgumentException("Dictionary ID " + id + " not defined in schema"); + } + FieldVector vector = dictionary.getVector(); + VectorSchemaRoot root = new VectorSchemaRoot(ImmutableList.of(vector.getField()), ImmutableList.of(vector), 0); + VectorLoader loader = new VectorLoader(root); + loader.load(dictionaryBatch.getDictionary()); + } } http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java index 24c667e..60a6afb 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -21,77 +21,172 @@ import java.io.IOException; import java.nio.channels.WritableByteChannel; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import com.google.common.collect.ImmutableList; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.schema.ArrowDictionaryBatch; import org.apache.arrow.vector.schema.ArrowRecordBatch; import org.apache.arrow.vector.stream.MessageSerializer; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class ArrowWriter implements AutoCloseable { +public abstract class ArrowWriter implements AutoCloseable { + private static final Logger LOGGER = LoggerFactory.getLogger(ArrowWriter.class); + // schema with fields in message format, not memory format + private final Schema schema; private final WriteChannel out; - private final Schema schema; + private final VectorUnloader unloader; + private final List<ArrowDictionaryBatch> dictionaries; + + private final List<ArrowBlock> dictionaryBlocks = new ArrayList<>(); + private final List<ArrowBlock> recordBlocks = new ArrayList<>(); - private final List<ArrowBlock> recordBatches = new ArrayList<>(); private boolean started = false; + private boolean ended = false; - public ArrowWriter(WritableByteChannel out, Schema schema) { + /** + * Note: fields are not closed when the writer is closed + * + * @param root + * @param provider + * @param out + */ + protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + this.unloader = new VectorUnloader(root); this.out = new WriteChannel(out); - this.schema = schema; + + List<Field> fields = new ArrayList<>(root.getSchema().getFields().size()); + Map<Long, ArrowDictionaryBatch> dictionaryBatches = new HashMap<>(); + + for (Field field: root.getSchema().getFields()) { + fields.add(toMessageFormat(field, provider, dictionaryBatches)); + } + + this.schema = new Schema(fields); + this.dictionaries = Collections.unmodifiableList(new ArrayList<>(dictionaryBatches.values())); + } + + // in the message format, fields have the dictionary type + // in the memory format, they have the index type + private Field toMessageFormat(Field field, DictionaryProvider provider, Map<Long, ArrowDictionaryBatch> batches) { + DictionaryEncoding encoding = field.getDictionary(); + List<Field> children = field.getChildren(); + + if (encoding == null && children.isEmpty()) { + return field; + } + + List<Field> updatedChildren = new ArrayList<>(children.size()); + for (Field child: children) { + updatedChildren.add(toMessageFormat(child, provider, batches)); + } + + ArrowType type; + if (encoding == null) { + type = field.getType(); + } else { + long id = encoding.getId(); + Dictionary dictionary = provider.lookup(id); + if (dictionary == null) { + throw new IllegalArgumentException("Could not find dictionary with ID " + id); + } + type = dictionary.getVectorType(); + + if (!batches.containsKey(id)) { + FieldVector vector = dictionary.getVector(); + int count = vector.getAccessor().getValueCount(); + VectorSchemaRoot root = new VectorSchemaRoot(ImmutableList.of(field), ImmutableList.of(vector), count); + VectorUnloader unloader = new VectorUnloader(root); + ArrowRecordBatch batch = unloader.getRecordBatch(); + batches.put(id, new ArrowDictionaryBatch(id, batch)); + } + } + + return new Field(field.getName(), field.isNullable(), type, encoding, updatedChildren); } - private void start() throws IOException { - writeMagic(); - MessageSerializer.serialize(out, schema); + public void start() throws IOException { + ensureStarted(); } - // TODO: write dictionaries + public void writeBatch() throws IOException { + ensureStarted(); + try (ArrowRecordBatch batch = unloader.getRecordBatch()) { + writeRecordBatch(batch); + } + } - public void writeRecordBatch(ArrowRecordBatch recordBatch) throws IOException { - checkStarted(); - ArrowBlock batchDesc = MessageSerializer.serialize(out, recordBatch); + protected void writeRecordBatch(ArrowRecordBatch batch) throws IOException { + ArrowBlock block = MessageSerializer.serialize(out, batch); LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d", - batchDesc.getOffset(), batchDesc.getMetadataLength(), batchDesc.getBodyLength())); + block.getOffset(), block.getMetadataLength(), block.getBodyLength())); + recordBlocks.add(block); + } - // add metadata to footer - recordBatches.add(batchDesc); + public void end() throws IOException { + ensureStarted(); + ensureEnded(); } - private void checkStarted() throws IOException { + public long bytesWritten() { return out.getCurrentPosition(); } + + private void ensureStarted() throws IOException { if (!started) { started = true; - start(); + startInternal(out); + // write the schema - for file formats this is duplicated in the footer, but matches + // the streaming format + MessageSerializer.serialize(out, schema); + // write out any dictionaries + for (ArrowDictionaryBatch batch : dictionaries) { + try { + ArrowBlock block = MessageSerializer.serialize(out, batch); + LOGGER.debug(String.format("DictionaryRecordBatch at %d, metadata: %d, body: %d", + block.getOffset(), block.getMetadataLength(), block.getBodyLength())); + dictionaryBlocks.add(block); + } finally { + batch.close(); + } + } } } - @Override - public void close() throws IOException { - try { - long footerStart = out.getCurrentPosition(); - writeFooter(); - int footerLength = (int)(out.getCurrentPosition() - footerStart); - if (footerLength <= 0 ) { - throw new InvalidArrowFileException("invalid footer"); - } - out.writeIntLittleEndian(footerLength); - LOGGER.debug(String.format("Footer starts at %d, length: %d", footerStart, footerLength)); - writeMagic(); - } finally { - out.close(); + private void ensureEnded() throws IOException { + if (!ended) { + ended = true; + endInternal(out, schema, dictionaryBlocks, recordBlocks); } } - private void writeMagic() throws IOException { - out.write(ArrowReader.MAGIC); - LOGGER.debug(String.format("magic written, now at %d", out.getCurrentPosition())); - } + protected abstract void startInternal(WriteChannel out) throws IOException; + + protected abstract void endInternal(WriteChannel out, + Schema schema, + List<ArrowBlock> dictionaries, + List<ArrowBlock> records) throws IOException; - private void writeFooter() throws IOException { - // TODO: dictionaries - out.write(new ArrowFooter(schema, Collections.<ArrowBlock>emptyList(), recordBatches), false); + @Override + public void close() { + try { + end(); + out.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } } } http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java index a9dc129..b062f38 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java @@ -32,16 +32,9 @@ public class ReadChannel implements AutoCloseable { private ReadableByteChannel in; private long bytesRead = 0; - // The starting byte offset into 'in'. - private final long startByteOffset; - - public ReadChannel(ReadableByteChannel in, long startByteOffset) { - this.in = in; - this.startByteOffset = startByteOffset; - } public ReadChannel(ReadableByteChannel in) { - this(in, 0); + this.in = in; } public long bytesRead() { return bytesRead; } @@ -72,8 +65,6 @@ public class ReadChannel implements AutoCloseable { return n; } - public long getCurrentPositiion() { return startByteOffset + bytesRead; } - @Override public void close() throws IOException { if (this.in != null) { http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java new file mode 100644 index 0000000..914c3cb --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.file; + +import java.io.IOException; +import java.nio.channels.SeekableByteChannel; + +public class SeekableReadChannel extends ReadChannel { + + private final SeekableByteChannel in; + + public SeekableReadChannel(SeekableByteChannel in) { + super(in); + this.in = in; + } + + public void setPosition(long position) throws IOException { + in.position(position); + } + + public long size() throws IOException { + return in.size(); + } +} http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java index d99c9a6..42104d1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java @@ -21,13 +21,12 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; -import org.apache.arrow.vector.schema.FBSerializable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.google.flatbuffers.FlatBufferBuilder; import io.netty.buffer.ArrowBuf; +import org.apache.arrow.vector.schema.FBSerializable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Wrapper around a WritableByteChannel that maintains the position as well adding http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java index 24fdc18..bdb63b9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java @@ -88,10 +88,34 @@ public class JsonFileReader implements AutoCloseable { } } + public void read(VectorSchemaRoot root) throws IOException { + JsonToken t = parser.nextToken(); + if (t == START_OBJECT) { + { + int count = readNextField("count", Integer.class); + root.setRowCount(count); + nextFieldIs("columns"); + readToken(START_ARRAY); + { + for (Field field : schema.getFields()) { + FieldVector vector = root.getVector(field.getName()); + readVector(field, vector); + } + } + readToken(END_ARRAY); + } + readToken(END_OBJECT); + } else if (t == END_ARRAY) { + root.setRowCount(0); + } else { + throw new IllegalArgumentException("Invalid token: " + t); + } + } + public VectorSchemaRoot read() throws IOException { JsonToken t = parser.nextToken(); if (t == START_OBJECT) { - VectorSchemaRoot recordBatch = new VectorSchemaRoot(schema, allocator); + VectorSchemaRoot recordBatch = VectorSchemaRoot.create(schema, allocator); { int count = readNextField("count", Integer.class); recordBatch.setRowCount(count); http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java new file mode 100644 index 0000000..901877b --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.schema; + +import com.google.flatbuffers.FlatBufferBuilder; +import org.apache.arrow.flatbuf.DictionaryBatch; + +public class ArrowDictionaryBatch implements ArrowMessage { + + private final long dictionaryId; + private final ArrowRecordBatch dictionary; + + public ArrowDictionaryBatch(long dictionaryId, ArrowRecordBatch dictionary) { + this.dictionaryId = dictionaryId; + this.dictionary = dictionary; + } + + public long getDictionaryId() { return dictionaryId; } + public ArrowRecordBatch getDictionary() { return dictionary; } + + @Override + public int writeTo(FlatBufferBuilder builder) { + int dataOffset = dictionary.writeTo(builder); + DictionaryBatch.startDictionaryBatch(builder); + DictionaryBatch.addId(builder, dictionaryId); + DictionaryBatch.addData(builder, dataOffset); + return DictionaryBatch.endDictionaryBatch(builder); + } + + @Override + public int computeBodyLength() { return dictionary.computeBodyLength(); } + + @Override + public <T> T accepts(ArrowMessageVisitor<T> visitor) { return visitor.visit(this); } + + @Override + public String toString() { + return "ArrowDictionaryBatch [dictionaryId=" + dictionaryId + ", dictionary=" + dictionary + "]"; + } + + @Override + public void close() { + dictionary.close(); + } +} http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java new file mode 100644 index 0000000..d307428 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.schema; + +public interface ArrowMessage extends FBSerializable, AutoCloseable { + + public int computeBodyLength(); + + public <T> T accepts(ArrowMessageVisitor<T> visitor); + + public static interface ArrowMessageVisitor<T> { + public T visit(ArrowDictionaryBatch message); + public T visit(ArrowRecordBatch message); + } +} http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java index 40c2fbf..6ef514e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java @@ -32,7 +32,8 @@ import com.google.flatbuffers.FlatBufferBuilder; import io.netty.buffer.ArrowBuf; -public class ArrowRecordBatch implements FBSerializable, AutoCloseable { +public class ArrowRecordBatch implements ArrowMessage { + private static final Logger LOGGER = LoggerFactory.getLogger(ArrowRecordBatch.class); /** number of records */ @@ -113,9 +114,13 @@ public class ArrowRecordBatch implements FBSerializable, AutoCloseable { return RecordBatch.endRecordBatch(builder); } + @Override + public <T> T accepts(ArrowMessageVisitor<T> visitor) { return visitor.visit(this); } + /** * releases the buffers */ + @Override public void close() { if (!closed) { closed = true; @@ -134,6 +139,7 @@ public class ArrowRecordBatch implements FBSerializable, AutoCloseable { /** * Computes the size of the serialized body for this recordBatch. */ + @Override public int computeBodyLength() { int size = 0; http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java index f32966c..2deef37 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java @@ -17,79 +17,43 @@ */ package org.apache.arrow.vector.stream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.channels.Channels; -import java.nio.channels.ReadableByteChannel; - import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.file.ArrowReader; import org.apache.arrow.vector.file.ReadChannel; -import org.apache.arrow.vector.schema.ArrowRecordBatch; +import org.apache.arrow.vector.schema.ArrowMessage; import org.apache.arrow.vector.types.pojo.Schema; -import com.google.common.base.Preconditions; +import java.io.IOException; +import java.io.InputStream; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; /** * This classes reads from an input stream and produces ArrowRecordBatches. */ -public class ArrowStreamReader implements AutoCloseable { - private ReadChannel in; - private final BufferAllocator allocator; - private Schema schema; - - /** - * Constructs a streaming read, reading bytes from 'in'. Non-blocking. - */ - public ArrowStreamReader(ReadableByteChannel in, BufferAllocator allocator) { - super(); - this.in = new ReadChannel(in); - this.allocator = allocator; - } - - public ArrowStreamReader(InputStream in, BufferAllocator allocator) { - this(Channels.newChannel(in), allocator); - } - - /** - * Initializes the reader. Must be called before the other APIs. This is blocking. - */ - public void init() throws IOException { - Preconditions.checkState(this.schema == null, "Cannot call init() more than once."); - this.schema = readSchema(); - } +public class ArrowStreamReader extends ArrowReader<ReadChannel> { - /** - * Returns the schema for all records in this stream. - */ - public Schema getSchema () { - Preconditions.checkState(this.schema != null, "Must call init() first."); - return schema; - } - - public long bytesRead() { return in.bytesRead(); } + /** + * Constructs a streaming read, reading bytes from 'in'. Non-blocking. + */ + public ArrowStreamReader(ReadableByteChannel in, BufferAllocator allocator) { + super(new ReadChannel(in), allocator); + } - /** - * Reads and returns the next ArrowRecordBatch. Returns null if this is the end - * of stream. - */ - public ArrowRecordBatch nextRecordBatch() throws IOException { - Preconditions.checkState(this.in != null, "Cannot call after close()"); - Preconditions.checkState(this.schema != null, "Must call init() first."); - return MessageSerializer.deserializeRecordBatch(in, allocator); - } + public ArrowStreamReader(InputStream in, BufferAllocator allocator) { + this(Channels.newChannel(in), allocator); + } - @Override - public void close() throws IOException { - if (this.in != null) { - in.close(); - in = null; + /** + * Reads the schema message from the beginning of the stream. + */ + @Override + protected Schema readSchema(ReadChannel in) throws IOException { + return MessageSerializer.deserializeSchema(in); } - } - /** - * Reads the schema message from the beginning of the stream. - */ - private Schema readSchema() throws IOException { - return MessageSerializer.deserializeSchema(in); - } + @Override + protected ArrowMessage readMessage(ReadChannel in, BufferAllocator allocator) throws IOException { + return MessageSerializer.deserializeMessageBatch(in, allocator); + } } http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java index 60dc586..ea29cd9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java @@ -17,63 +17,40 @@ */ package org.apache.arrow.vector.stream; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.file.ArrowBlock; +import org.apache.arrow.vector.file.ArrowWriter; +import org.apache.arrow.vector.file.WriteChannel; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + import java.io.IOException; import java.io.OutputStream; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; +import java.util.List; -import org.apache.arrow.vector.file.WriteChannel; -import org.apache.arrow.vector.schema.ArrowRecordBatch; -import org.apache.arrow.vector.types.pojo.Schema; - -public class ArrowStreamWriter implements AutoCloseable { - private final WriteChannel out; - private final Schema schema; - private boolean headerSent = false; +public class ArrowStreamWriter extends ArrowWriter { - /** - * Creates the stream writer. non-blocking. - * totalBatches can be set if the writer knows beforehand. Can be -1 if unknown. - */ - public ArrowStreamWriter(WritableByteChannel out, Schema schema) { - this.out = new WriteChannel(out); - this.schema = schema; - } - - public ArrowStreamWriter(OutputStream out, Schema schema) - throws IOException { - this(Channels.newChannel(out), schema); - } - - public long bytesWritten() { return out.getCurrentPosition(); } - - public void writeRecordBatch(ArrowRecordBatch batch) throws IOException { - // Send the header if we have not yet. - checkAndSendHeader(); - MessageSerializer.serialize(out, batch); - } + public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, OutputStream out) { + this(root, provider, Channels.newChannel(out)); + } - /** - * End the stream. This is not required and this object can simply be closed. - */ - public void end() throws IOException { - checkAndSendHeader(); - out.writeIntLittleEndian(0); - } + public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + super(root, provider, out); + } - @Override - public void close() throws IOException { - // The header might not have been sent if this is an empty stream. Send it even in - // this case so readers see a valid empty stream. - checkAndSendHeader(); - out.close(); - } + @Override + protected void startInternal(WriteChannel out) throws IOException {} - private void checkAndSendHeader() throws IOException { - if (!headerSent) { - MessageSerializer.serialize(out, schema); - headerSent = true; + @Override + protected void endInternal(WriteChannel out, + Schema schema, + List<ArrowBlock> dictionaries, + List<ArrowBlock> records) throws IOException { + out.writeIntLittleEndian(0); } - } } - http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java index 92df250..92a6c0c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java @@ -22,7 +22,11 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; +import com.google.flatbuffers.FlatBufferBuilder; + +import io.netty.buffer.ArrowBuf; import org.apache.arrow.flatbuf.Buffer; +import org.apache.arrow.flatbuf.DictionaryBatch; import org.apache.arrow.flatbuf.FieldNode; import org.apache.arrow.flatbuf.Message; import org.apache.arrow.flatbuf.MessageHeader; @@ -33,14 +37,12 @@ import org.apache.arrow.vector.file.ArrowBlock; import org.apache.arrow.vector.file.ReadChannel; import org.apache.arrow.vector.file.WriteChannel; import org.apache.arrow.vector.schema.ArrowBuffer; +import org.apache.arrow.vector.schema.ArrowDictionaryBatch; import org.apache.arrow.vector.schema.ArrowFieldNode; +import org.apache.arrow.vector.schema.ArrowMessage; import org.apache.arrow.vector.schema.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; -import com.google.flatbuffers.FlatBufferBuilder; - -import io.netty.buffer.ArrowBuf; - /** * Utility class for serializing Messages. Messages are all serialized a similar way. * 1. 4 byte little endian message header prefix @@ -81,35 +83,39 @@ public class MessageSerializer { * Deserializes a schema object. Format is from serialize(). */ public static Schema deserializeSchema(ReadChannel in) throws IOException { - Message message = deserializeMessage(in, MessageHeader.Schema); + Message message = deserializeMessage(in); if (message == null) { throw new IOException("Unexpected end of input. Missing schema."); } + if (message.headerType() != MessageHeader.Schema) { + throw new IOException("Expected schema but header was " + message.headerType()); + } return Schema.convertSchema((org.apache.arrow.flatbuf.Schema) message.header(new org.apache.arrow.flatbuf.Schema())); } + /** * Serializes an ArrowRecordBatch. Returns the offset and length of the written batch. */ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) - throws IOException { + throws IOException { + long start = out.getCurrentPosition(); int bodyLength = batch.computeBodyLength(); FlatBufferBuilder builder = new FlatBufferBuilder(); int batchOffset = batch.writeTo(builder); - ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch, - batchOffset, bodyLength); + ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch, batchOffset, bodyLength); int metadataLength = serializedMessage.remaining(); - // Add extra padding bytes so that length prefix + metadata is a multiple - // of 8 after alignment - if ((start + metadataLength + 4) % 8 != 0) { - metadataLength += 8 - (start + metadataLength + 4) % 8; + // calculate alignment bytes so that metadata length points to the correct location after alignment + int padding = (int)((start + metadataLength + 4) % 8); + if (padding != 0) { + metadataLength += (8 - padding); } out.writeIntLittleEndian(metadataLength); @@ -118,6 +124,13 @@ public class MessageSerializer { // Align the output to 8 byte boundary. out.align(); + long bufferLength = writeBatchBuffers(out, batch); + + // Metadata size in the Block account for the size prefix + return new ArrowBlock(start, metadataLength + 4, bufferLength); + } + + private static long writeBatchBuffers(WriteChannel out, ArrowRecordBatch batch) throws IOException { long bufferStart = out.getCurrentPosition(); List<ArrowBuf> buffers = batch.getBuffers(); List<ArrowBuffer> buffersLayout = batch.getBuffersLayout(); @@ -135,22 +148,14 @@ public class MessageSerializer { " != " + startPosition + layout.getSize()); } } - // Metadata size in the Block account for the size prefix - return new ArrowBlock(start, metadataLength + 4, out.getCurrentPosition() - bufferStart); + return out.getCurrentPosition() - bufferStart; } /** * Deserializes a RecordBatch */ - public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, - BufferAllocator alloc) throws IOException { - Message message = deserializeMessage(in, MessageHeader.RecordBatch); - if (message == null) return null; - - if (message.bodyLength() > Integer.MAX_VALUE) { - throw new IOException("Cannot currently deserialize record batches over 2GB"); - } - + private static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, Message message, BufferAllocator alloc) + throws IOException { RecordBatch recordBatchFB = (RecordBatch) message.header(new RecordBatch()); int bodyLength = (int) message.bodyLength(); @@ -191,9 +196,7 @@ public class MessageSerializer { // Now read the body final ArrowBuf body = buffer.slice(block.getMetadataLength(), (int) totalLen - block.getMetadataLength()); - ArrowRecordBatch result = deserializeRecordBatch(recordBatchFB, body); - - return result; + return deserializeRecordBatch(recordBatchFB, body); } // Deserializes a record batch given the Flatbuffer metadata and in-memory body @@ -219,6 +222,106 @@ public class MessageSerializer { } /** + * Serializes a dictionary ArrowRecordBatch. Returns the offset and length of the written batch. + */ + public static ArrowBlock serialize(WriteChannel out, ArrowDictionaryBatch batch) throws IOException { + long start = out.getCurrentPosition(); + int bodyLength = batch.computeBodyLength(); + + FlatBufferBuilder builder = new FlatBufferBuilder(); + int batchOffset = batch.writeTo(builder); + + ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.DictionaryBatch, batchOffset, bodyLength); + + int metadataLength = serializedMessage.remaining(); + + // Add extra padding bytes so that length prefix + metadata is a multiple + // of 8 after alignment + if ((start + metadataLength + 4) % 8 != 0) { + metadataLength += 8 - (start + metadataLength + 4) % 8; + } + + out.writeIntLittleEndian(metadataLength); + out.write(serializedMessage); + + // Align the output to 8 byte boundary. + out.align(); + + // write the embedded record batch + long bufferLength = writeBatchBuffers(out, batch.getDictionary()); + + // Metadata size in the Block account for the size prefix + return new ArrowBlock(start, metadataLength + 4, bufferLength + 8); + } + + /** + * Deserializes a DictionaryBatch + */ + private static ArrowDictionaryBatch deserializeDictionaryBatch(ReadChannel in, + Message message, + BufferAllocator alloc) throws IOException { + DictionaryBatch dictionaryBatchFB = (DictionaryBatch) message.header(new DictionaryBatch()); + + int bodyLength = (int) message.bodyLength(); + + // Now read the record batch body + ArrowBuf body = alloc.buffer(bodyLength); + if (in.readFully(body, bodyLength) != bodyLength) { + throw new IOException("Unexpected end of input trying to read batch."); + } + ArrowRecordBatch recordBatch = deserializeRecordBatch(dictionaryBatchFB.data(), body); + return new ArrowDictionaryBatch(dictionaryBatchFB.id(), recordBatch); + } + + /** + * Deserializes a DictionaryBatch knowing the size of the entire message up front. This + * minimizes the number of reads to the underlying stream. + */ + public static ArrowDictionaryBatch deserializeDictionaryBatch(ReadChannel in, + ArrowBlock block, + BufferAllocator alloc) throws IOException { + // Metadata length contains integer prefix plus byte padding + long totalLen = block.getMetadataLength() + block.getBodyLength(); + + if (totalLen > Integer.MAX_VALUE) { + throw new IOException("Cannot currently deserialize record batches over 2GB"); + } + + ArrowBuf buffer = alloc.buffer((int) totalLen); + if (in.readFully(buffer, (int) totalLen) != totalLen) { + throw new IOException("Unexpected end of input trying to read batch."); + } + + ArrowBuf metadataBuffer = buffer.slice(4, block.getMetadataLength() - 4); + + Message messageFB = + Message.getRootAsMessage(metadataBuffer.nioBuffer().asReadOnlyBuffer()); + + DictionaryBatch dictionaryBatchFB = (DictionaryBatch) messageFB.header(new DictionaryBatch()); + + // Now read the body + final ArrowBuf body = buffer.slice(block.getMetadataLength(), + (int) totalLen - block.getMetadataLength()); + ArrowRecordBatch recordBatch = deserializeRecordBatch(dictionaryBatchFB.data(), body); + return new ArrowDictionaryBatch(dictionaryBatchFB.id(), recordBatch); + } + + public static ArrowMessage deserializeMessageBatch(ReadChannel in, BufferAllocator alloc) throws IOException { + Message message = deserializeMessage(in); + if (message == null) { + return null; + } else if (message.bodyLength() > Integer.MAX_VALUE) { + throw new IOException("Cannot currently deserialize record batches over 2GB"); + } + + switch (message.headerType()) { + case MessageHeader.RecordBatch: return deserializeRecordBatch(in, message, alloc); + case MessageHeader.DictionaryBatch: return deserializeDictionaryBatch(in, message, alloc); + default: throw new IOException("Unexpected message header type " + message.headerType()); + } + } + + /** * Serializes a message header. */ private static ByteBuffer serializeMessage(FlatBufferBuilder builder, byte headerType, @@ -232,7 +335,7 @@ public class MessageSerializer { return builder.dataBuffer(); } - private static Message deserializeMessage(ReadChannel in, byte headerType) throws IOException { + private static Message deserializeMessage(ReadChannel in) throws IOException { // Read the message size. There is an i32 little endian prefix. ByteBuffer buffer = ByteBuffer.allocate(4); if (in.readFully(buffer) != 4) return null; @@ -246,11 +349,6 @@ public class MessageSerializer { } buffer.rewind(); - Message message = Message.getRootAsMessage(buffer); - if (message.headerType() != headerType) { - throw new IOException("Invalid message: expecting " + headerType + - ". Message contained: " + message.headerType()); - } - return message; + return Message.getRootAsMessage(buffer); } } http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java deleted file mode 100644 index fbe1345..0000000 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java +++ /dev/null @@ -1,40 +0,0 @@ -/******************************************************************************* - - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ******************************************************************************/ -package org.apache.arrow.vector.types; - -import org.apache.arrow.vector.ValueVector; - -public class Dictionary { - - private ValueVector dictionary; - private boolean ordered; - - public Dictionary(ValueVector dictionary, boolean ordered) { - this.dictionary = dictionary; - this.ordered = ordered; - } - - public ValueVector getDictionary() { - return dictionary; - } - - public boolean isOrdered() { - return ordered; - } -}
