This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new a5a366d9ca GH-32662: [C#] Make dictionaries in file and memory
implementations work correctly and support integration tests (#39146)
a5a366d9ca is described below
commit a5a366d9ca08334d5c01fdc013ae7954a4912c4c
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Sun Dec 10 06:51:07 2023 -0800
GH-32662: [C#] Make dictionaries in file and memory implementations work
correctly and support integration tests (#39146)
### Rationale for this change
While dictionary support was implemented for C# in #6870 for streams,
support did not extend to files or memory buffers. This change rectifies that.
### What changes are included in this PR?
Changes to the memory and file implementations to support reading and
writing of dictionaries, including nested dictionaries.
Changes to the integration tests so that they work with dictionaries.
Enabling the dictionary tests in CI.
### Are these changes tested?
Yes, both directly and indirectly via the integration tests.
### Are there any user-facing changes?
No.
* Closes: #32662
Authored-by: Curt Hagenlocher <[email protected]>
Signed-off-by: Curt Hagenlocher <[email protected]>
---
.../Ipc/ArrowFileReaderImplementation.cs | 34 ++++++
csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs | 54 ++++++++-
csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs | 4 +-
.../Ipc/ArrowMemoryReaderImplementation.cs | 60 +++++-----
.../Ipc/ArrowStreamReaderImplementation.cs | 128 +++++++++++++--------
csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 78 ++++++-------
csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs | 10 +-
csharp/src/Apache.Arrow/Types/DictionaryType.cs | 3 +-
.../ArrowReaderBenchmark.cs | 2 +-
.../IntegrationCommand.cs | 23 ++--
.../test/Apache.Arrow.IntegrationTest/JsonFile.cs | 106 ++++++++++++++---
.../Apache.Arrow.Tests/ArrowFileReaderTests.cs | 6 +-
csharp/test/Apache.Arrow.Tests/TestData.cs | 2 +-
dev/archery/archery/integration/datagen.py | 5 +-
dev/archery/archery/integration/tester_csharp.py | 4 +-
docs/source/status.rst | 2 +-
16 files changed, 352 insertions(+), 169 deletions(-)
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
index d88665e496..3ae475885f 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
@@ -35,6 +35,8 @@ namespace Apache.Arrow.Ipc
private ArrowFooter _footer;
+ private bool HasReadDictionaries => HasReadSchema &&
DictionaryMemo.LoadedDictionaryCount >= _footer.DictionaryCount;
+
public ArrowFileReaderImplementation(Stream stream, MemoryAllocator
allocator, ICompressionCodecFactory compressionCodecFactory, bool leaveOpen)
: base(stream, allocator, compressionCodecFactory, leaveOpen)
{
@@ -143,6 +145,7 @@ namespace Apache.Arrow.Ipc
public async ValueTask<RecordBatch> ReadRecordBatchAsync(int index,
CancellationToken cancellationToken)
{
await ReadSchemaAsync().ConfigureAwait(false);
+ await
ReadDictionariesAsync(cancellationToken).ConfigureAwait(false);
if (index >= _footer.RecordBatchCount)
{
@@ -159,6 +162,7 @@ namespace Apache.Arrow.Ipc
public RecordBatch ReadRecordBatch(int index)
{
ReadSchema();
+ ReadDictionaries();
if (index >= _footer.RecordBatchCount)
{
@@ -175,6 +179,7 @@ namespace Apache.Arrow.Ipc
public override async ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
await ReadSchemaAsync().ConfigureAwait(false);
+ await
ReadDictionariesAsync(cancellationToken).ConfigureAwait(false);
if (_recordBatchIndex >= _footer.RecordBatchCount)
{
@@ -190,6 +195,7 @@ namespace Apache.Arrow.Ipc
public override RecordBatch ReadNextRecordBatch()
{
ReadSchema();
+ ReadDictionaries();
if (_recordBatchIndex >= _footer.RecordBatchCount)
{
@@ -202,6 +208,34 @@ namespace Apache.Arrow.Ipc
return result;
}
+ private async ValueTask ReadDictionariesAsync(CancellationToken
cancellationToken = default)
+ {
+ if (HasReadDictionaries)
+ {
+ return;
+ }
+
+ foreach (Block block in _footer.Dictionaries)
+ {
+ BaseStream.Position = block.Offset;
+ await ReadMessageAsync(cancellationToken);
+ }
+ }
+
+ private void ReadDictionaries()
+ {
+ if (HasReadDictionaries)
+ {
+ return;
+ }
+
+ foreach (Block block in _footer.Dictionaries)
+ {
+ BaseStream.Position = block.Offset;
+ ReadMessage();
+ }
+ }
+
/// <summary>
/// Check if file format is valid. If it's valid don't run the
validation again.
/// </summary>
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
index 4fefb121cb..95b9f60fff 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
@@ -23,10 +23,12 @@ using System.Threading.Tasks;
namespace Apache.Arrow.Ipc
{
- public class ArrowFileWriter: ArrowStreamWriter
+ public class ArrowFileWriter : ArrowStreamWriter
{
private long _currentRecordBatchOffset = -1;
+ private long _currentDictionaryOffset = -1;
+ private List<Block> DictionaryBlocks { get; set; }
private List<Block> RecordBatchBlocks { get; }
public ArrowFileWriter(Stream stream, Schema schema)
@@ -105,6 +107,34 @@ namespace Apache.Arrow.Ipc
_currentRecordBatchOffset = -1;
}
+ private protected override void StartingWritingDictionary()
+ {
+ if (DictionaryBlocks == null) { DictionaryBlocks = new
List<Block>(); }
+ _currentDictionaryOffset = BaseStream.Position;
+ }
+
+ private protected override void FinishedWritingDictionary(long
bodyLength, long metadataLength)
+ {
+ // Dictionaries only appear after a Schema is written, so the
dictionary offsets must
+ // always be greater than 0.
+ Debug.Assert(_currentDictionaryOffset > 0,
"_currentDictionaryOffset must be positive.");
+
+ int metadataLengthInt = checked((int)metadataLength);
+
+ Debug.Assert(BitUtility.IsMultipleOf8(_currentDictionaryOffset));
+ Debug.Assert(BitUtility.IsMultipleOf8(metadataLengthInt));
+ Debug.Assert(BitUtility.IsMultipleOf8(bodyLength));
+
+ var block = new Block(
+ offset: _currentDictionaryOffset,
+ length: bodyLength,
+ metadataLength: metadataLengthInt);
+
+ DictionaryBlocks.Add(block);
+
+ _currentDictionaryOffset = -1;
+ }
+
private protected override void WriteEndInternal()
{
base.WriteEndInternal();
@@ -161,9 +191,16 @@ namespace Apache.Arrow.Ipc
Google.FlatBuffers.VectorOffset recordBatchesVectorOffset =
Builder.EndVector();
// Serialize all dictionaries
- // NOTE: Currently unsupported.
- Flatbuf.Footer.StartDictionariesVector(Builder, 0);
+ int dictionaryCount = DictionaryBlocks?.Count ?? 0;
+ Flatbuf.Footer.StartDictionariesVector(Builder, dictionaryCount);
+
+ for (int i = dictionaryCount - 1; i >= 0; i--)
+ {
+ Block dictionary = DictionaryBlocks[i];
+ Flatbuf.Block.CreateBlock(
+ Builder, dictionary.Offset, dictionary.MetadataLength,
dictionary.BodyLength);
+ }
Google.FlatBuffers.VectorOffset dictionaryBatchesOffset =
Builder.EndVector();
@@ -221,9 +258,16 @@ namespace Apache.Arrow.Ipc
Google.FlatBuffers.VectorOffset recordBatchesVectorOffset =
Builder.EndVector();
// Serialize all dictionaries
- // NOTE: Currently unsupported.
- Flatbuf.Footer.StartDictionariesVector(Builder, 0);
+ int dictionaryCount = DictionaryBlocks?.Count ?? 0;
+ Flatbuf.Footer.StartDictionariesVector(Builder, dictionaryCount);
+
+ for (int i = dictionaryCount - 1; i >= 0; i--)
+ {
+ Block dictionary = DictionaryBlocks[i];
+ Flatbuf.Block.CreateBlock(
+ Builder, dictionary.Offset, dictionary.MetadataLength,
dictionary.BodyLength);
+ }
Google.FlatBuffers.VectorOffset dictionaryBatchesOffset =
Builder.EndVector();
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs
index db269ae019..600624ef9e 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowFooter.cs
@@ -25,8 +25,8 @@ namespace Apache.Arrow.Ipc
private readonly List<Block> _dictionaries;
private readonly List<Block> _recordBatches;
- public IEnumerable<Block> Dictionaries => _dictionaries;
- public IEnumerable<Block> RecordBatches => _recordBatches;
+ public IReadOnlyList<Block> Dictionaries => _dictionaries;
+ public IReadOnlyList<Block> RecordBatches => _recordBatches;
public Block GetRecordBatchBlock(int i) => _recordBatches[i];
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
index af4f963ee5..6e2336a591 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
@@ -43,30 +43,17 @@ namespace Apache.Arrow.Ipc
{
ReadSchema();
- if (_buffer.Length <= _bufferPosition + sizeof(int))
+ RecordBatch batch = null;
+ while (batch == null)
{
- // reached the end
- return null;
- }
-
- // Get Length of record batch for message header.
- int messageLength =
BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
- _bufferPosition += sizeof(int);
-
- if (messageLength == 0)
- {
- //reached the end
- return null;
- }
- else if (messageLength == MessageSerializer.IpcContinuationToken)
- {
- // ARROW-6313, if the first 4 bytes are continuation message,
read the next 4 for the length
if (_buffer.Length <= _bufferPosition + sizeof(int))
{
- throw new InvalidDataException("Corrupted IPC message.
Received a continuation token at the end of the message.");
+ // reached the end
+ return null;
}
- messageLength =
BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
+ // Get Length of record batch for message header.
+ int messageLength =
BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
_bufferPosition += sizeof(int);
if (messageLength == 0)
@@ -74,17 +61,36 @@ namespace Apache.Arrow.Ipc
//reached the end
return null;
}
- }
+ else if (messageLength ==
MessageSerializer.IpcContinuationToken)
+ {
+ // ARROW-6313, if the first 4 bytes are continuation
message, read the next 4 for the length
+ if (_buffer.Length <= _bufferPosition + sizeof(int))
+ {
+ throw new InvalidDataException("Corrupted IPC message.
Received a continuation token at the end of the message.");
+ }
+
+ messageLength =
BinaryPrimitives.ReadInt32LittleEndian(_buffer.Span.Slice(_bufferPosition));
+ _bufferPosition += sizeof(int);
+
+ if (messageLength == 0)
+ {
+ //reached the end
+ return null;
+ }
+ }
+
+ Message message = Message.GetRootAsMessage(
+ CreateByteBuffer(_buffer.Slice(_bufferPosition,
messageLength)));
+ _bufferPosition += messageLength;
- Message message = Message.GetRootAsMessage(
- CreateByteBuffer(_buffer.Slice(_bufferPosition,
messageLength)));
- _bufferPosition += messageLength;
+ int bodyLength = (int)message.BodyLength;
+ ByteBuffer bodybb =
CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
+ _bufferPosition += bodyLength;
- int bodyLength = (int)message.BodyLength;
- ByteBuffer bodybb =
CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
- _bufferPosition += bodyLength;
+ batch = CreateArrowObjectFromMessage(message, bodybb,
memoryOwner: null);
+ }
- return CreateArrowObjectFromMessage(message, bodybb, memoryOwner:
null);
+ return batch;
}
private void ReadSchema()
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
index df80ffe1e0..184e0348e5 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
@@ -57,79 +57,93 @@ namespace Apache.Arrow.Ipc
{
await ReadSchemaAsync().ConfigureAwait(false);
- RecordBatch result = null;
-
- while (result == null)
+ ReadResult result = default;
+ do
{
- int messageLength = await
ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken)
- .ConfigureAwait(false);
+ result = await
ReadMessageAsync(cancellationToken).ConfigureAwait(false);
+ } while (result.Batch == null && result.MessageLength > 0);
- if (messageLength == 0)
- {
- // reached end
- return null;
- }
+ return result.Batch;
+ }
- await ArrayPool<byte>.Shared.RentReturnAsync(messageLength,
async (messageBuff) =>
- {
- int bytesRead = await
BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken)
- .ConfigureAwait(false);
- EnsureFullRead(messageBuff, bytesRead);
+ protected async ValueTask<ReadResult>
ReadMessageAsync(CancellationToken cancellationToken)
+ {
+ int messageLength = await ReadMessageLengthAsync(throwOnFullRead:
false, cancellationToken)
+ .ConfigureAwait(false);
- Flatbuf.Message message =
Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff));
+ if (messageLength == 0)
+ {
+ // reached end
+ return default;
+ }
- int bodyLength = checked((int)message.BodyLength);
+ RecordBatch result = null;
+ await ArrayPool<byte>.Shared.RentReturnAsync(messageLength, async
(messageBuff) =>
+ {
+ int bytesRead = await
BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken)
+ .ConfigureAwait(false);
+ EnsureFullRead(messageBuff, bytesRead);
- IMemoryOwner<byte> bodyBuffOwner =
_allocator.Allocate(bodyLength);
- Memory<byte> bodyBuff = bodyBuffOwner.Memory.Slice(0,
bodyLength);
- bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff,
cancellationToken)
- .ConfigureAwait(false);
- EnsureFullRead(bodyBuff, bytesRead);
+ Flatbuf.Message message =
Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff));
- Google.FlatBuffers.ByteBuffer bodybb =
CreateByteBuffer(bodyBuff);
- result = CreateArrowObjectFromMessage(message, bodybb,
bodyBuffOwner);
- }).ConfigureAwait(false);
- }
+ int bodyLength = checked((int)message.BodyLength);
+
+ IMemoryOwner<byte> bodyBuffOwner =
_allocator.Allocate(bodyLength);
+ Memory<byte> bodyBuff = bodyBuffOwner.Memory.Slice(0,
bodyLength);
+ bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff,
cancellationToken)
+ .ConfigureAwait(false);
+ EnsureFullRead(bodyBuff, bytesRead);
+
+ Google.FlatBuffers.ByteBuffer bodybb =
CreateByteBuffer(bodyBuff);
+ result = CreateArrowObjectFromMessage(message, bodybb,
bodyBuffOwner);
+ }).ConfigureAwait(false);
- return result;
+ return new ReadResult(messageLength, result);
}
protected RecordBatch ReadRecordBatch()
{
ReadSchema();
- RecordBatch result = null;
-
- while (result == null)
+ ReadResult result = default;
+ do
{
- int messageLength = ReadMessageLength(throwOnFullRead: false);
+ result = ReadMessage();
+ } while (result.Batch == null && result.MessageLength > 0);
- if (messageLength == 0)
- {
- // reached end
- return null;
- }
+ return result.Batch;
+ }
- ArrayPool<byte>.Shared.RentReturn(messageLength, messageBuff =>
- {
- int bytesRead = BaseStream.ReadFullBuffer(messageBuff);
- EnsureFullRead(messageBuff, bytesRead);
+ protected ReadResult ReadMessage()
+ {
+ int messageLength = ReadMessageLength(throwOnFullRead: false);
- Flatbuf.Message message =
Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff));
+ if (messageLength == 0)
+ {
+ // reached end
+ return default;
+ }
- int bodyLength = checked((int)message.BodyLength);
+ RecordBatch result = null;
+ ArrayPool<byte>.Shared.RentReturn(messageLength, messageBuff =>
+ {
+ int bytesRead = BaseStream.ReadFullBuffer(messageBuff);
+ EnsureFullRead(messageBuff, bytesRead);
- IMemoryOwner<byte> bodyBuffOwner =
_allocator.Allocate(bodyLength);
- Memory<byte> bodyBuff = bodyBuffOwner.Memory.Slice(0,
bodyLength);
- bytesRead = BaseStream.ReadFullBuffer(bodyBuff);
- EnsureFullRead(bodyBuff, bytesRead);
+ Flatbuf.Message message =
Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff));
- Google.FlatBuffers.ByteBuffer bodybb =
CreateByteBuffer(bodyBuff);
- result = CreateArrowObjectFromMessage(message, bodybb,
bodyBuffOwner);
- });
- }
+ int bodyLength = checked((int)message.BodyLength);
+
+ IMemoryOwner<byte> bodyBuffOwner =
_allocator.Allocate(bodyLength);
+ Memory<byte> bodyBuff = bodyBuffOwner.Memory.Slice(0,
bodyLength);
+ bytesRead = BaseStream.ReadFullBuffer(bodyBuff);
+ EnsureFullRead(bodyBuff, bytesRead);
- return result;
+ Google.FlatBuffers.ByteBuffer bodybb =
CreateByteBuffer(bodyBuff);
+ result = CreateArrowObjectFromMessage(message, bodybb,
bodyBuffOwner);
+ });
+
+ return new ReadResult(messageLength, result);
}
protected virtual async ValueTask ReadSchemaAsync()
@@ -264,5 +278,17 @@ namespace Apache.Arrow.Ipc
throw new InvalidOperationException("Unexpectedly reached the
end of the stream before a full buffer was read.");
}
}
+
+ internal struct ReadResult
+ {
+ public readonly int MessageLength;
+ public readonly RecordBatch Batch;
+
+ public ReadResult(int messageLength, RecordBatch batch)
+ {
+ MessageLength = messageLength;
+ Batch = batch;
+ }
+ }
}
}
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
index dcb8852bc1..d4e8bb48df 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
@@ -270,7 +270,6 @@ namespace Apache.Arrow.Ipc
_options = options ?? IpcOptions.Default;
}
-
private void CreateSelfAndChildrenFieldNodes(ArrayData data)
{
if (data.DataType is NestedType)
@@ -319,7 +318,7 @@ namespace Apache.Arrow.Ipc
if (!HasWrittenDictionaryBatch)
{
DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo);
- WriteDictionaries(recordBatch);
+ WriteDictionaries(_dictionaryMemo);
HasWrittenDictionaryBatch = true;
}
@@ -358,7 +357,7 @@ namespace Apache.Arrow.Ipc
if (!HasWrittenDictionaryBatch)
{
DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo);
- await WriteDictionariesAsync(recordBatch,
cancellationToken).ConfigureAwait(false);
+ await WriteDictionariesAsync(_dictionaryMemo,
cancellationToken).ConfigureAwait(false);
HasWrittenDictionaryBatch = true;
}
@@ -492,74 +491,65 @@ namespace Apache.Arrow.Ipc
return Tuple.Create(recordBatchBuilder, fieldNodesVectorOffset);
}
+ private protected virtual void StartingWritingDictionary()
+ {
+ }
- private protected void WriteDictionaries(RecordBatch recordBatch)
+ private protected virtual void FinishedWritingDictionary(long
bodyLength, long metadataLength)
{
- foreach (Field field in recordBatch.Schema.FieldsList)
- {
- WriteDictionary(field);
- }
}
- private protected void WriteDictionary(Field field)
+ private protected void WriteDictionaries(DictionaryMemo dictionaryMemo)
{
- if (field.DataType.TypeId != ArrowTypeId.Dictionary)
+ int fieldCount = dictionaryMemo?.DictionaryCount ?? 0;
+ for (int i = 0; i < fieldCount; i++)
{
- if (field.DataType is NestedType nestedType)
- {
- foreach (Field child in nestedType.Fields)
- {
- WriteDictionary(child);
- }
- }
- return;
+ WriteDictionary(i, dictionaryMemo.GetDictionaryType(i),
dictionaryMemo.GetDictionary(i));
}
+ }
+
+ private protected void WriteDictionary(long id, IArrowType valueType,
IArrowArray dictionary)
+ {
+ StartingWritingDictionary();
(ArrowRecordBatchFlatBufferBuilder recordBatchBuilder,
Offset<Flatbuf.DictionaryBatch> dictionaryBatchOffset) =
- CreateDictionaryBatchOffset(field);
+ CreateDictionaryBatchOffset(id, valueType, dictionary);
- WriteMessage(Flatbuf.MessageHeader.DictionaryBatch,
+ long metadataLength =
WriteMessage(Flatbuf.MessageHeader.DictionaryBatch,
dictionaryBatchOffset, recordBatchBuilder.TotalLength);
- WriteBufferData(recordBatchBuilder.Buffers);
+ long bufferLength = WriteBufferData(recordBatchBuilder.Buffers);
+
+ FinishedWritingDictionary(bufferLength, metadataLength);
}
- private protected async Task WriteDictionariesAsync(RecordBatch
recordBatch, CancellationToken cancellationToken)
+ private protected async Task WriteDictionariesAsync(DictionaryMemo
dictionaryMemo, CancellationToken cancellationToken)
{
- foreach (Field field in recordBatch.Schema.FieldsList)
+ int fieldCount = dictionaryMemo?.DictionaryCount ?? 0;
+ for (int i = 0; i < fieldCount; i++)
{
- await WriteDictionaryAsync(field,
cancellationToken).ConfigureAwait(false);
+ await WriteDictionaryAsync(i,
dictionaryMemo.GetDictionaryType(i), dictionaryMemo.GetDictionary(i),
cancellationToken).ConfigureAwait(false);
}
}
- private protected async Task WriteDictionaryAsync(Field field,
CancellationToken cancellationToken)
+ private protected async Task WriteDictionaryAsync(long id, IArrowType
valueType, IArrowArray dictionary, CancellationToken cancellationToken)
{
- if (field.DataType.TypeId != ArrowTypeId.Dictionary)
- {
- if (field.DataType is NestedType nestedType)
- {
- foreach (Field child in nestedType.Fields)
- {
- await WriteDictionaryAsync(child,
cancellationToken).ConfigureAwait(false);
- }
- }
- return;
- }
+ StartingWritingDictionary();
(ArrowRecordBatchFlatBufferBuilder recordBatchBuilder,
Offset<Flatbuf.DictionaryBatch> dictionaryBatchOffset) =
- CreateDictionaryBatchOffset(field);
+ CreateDictionaryBatchOffset(id, valueType, dictionary);
- await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch,
+ long metadataLength = await
WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch,
dictionaryBatchOffset, recordBatchBuilder.TotalLength,
cancellationToken).ConfigureAwait(false);
- await WriteBufferDataAsync(recordBatchBuilder.Buffers,
cancellationToken).ConfigureAwait(false);
+ long bufferLength = await
WriteBufferDataAsync(recordBatchBuilder.Buffers,
cancellationToken).ConfigureAwait(false);
+
+ FinishedWritingDictionary(bufferLength, metadataLength);
}
- private Tuple<ArrowRecordBatchFlatBufferBuilder,
Offset<Flatbuf.DictionaryBatch>> CreateDictionaryBatchOffset(Field field)
+ private Tuple<ArrowRecordBatchFlatBufferBuilder,
Offset<Flatbuf.DictionaryBatch>> CreateDictionaryBatchOffset(long id,
IArrowType valueType, IArrowArray dictionary)
{
- Field dictionaryField = new Field("dummy",
((DictionaryType)field.DataType).ValueType, false);
- long id = DictionaryMemo.GetId(field);
- IArrowArray dictionary = DictionaryMemo.GetDictionary(id);
+ Field dictionaryField = new Field("dummy", valueType, false);
var fields = new Field[] { dictionaryField };
@@ -987,12 +977,12 @@ namespace Apache.Arrow.Ipc
arrayData.Dictionary.EnsureDataType(dictionaryType.ValueType.TypeId);
IArrowArray dictionary =
ArrowArrayFactory.BuildArray(arrayData.Dictionary);
+ WalkChildren(dictionary.Data, ref dictionaryMemo);
dictionaryMemo ??= new DictionaryMemo();
long id = dictionaryMemo.GetOrAssignId(field);
dictionaryMemo.AddOrReplaceDictionary(id, dictionary);
- WalkChildren(dictionary.Data, ref dictionaryMemo);
}
else
{
diff --git a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs
b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs
index 24f25a1429..b107cc65bf 100644
--- a/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs
+++ b/csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs
@@ -33,6 +33,9 @@ namespace Apache.Arrow.Ipc
_fieldToId = new Dictionary<Field, long>();
}
+ public int DictionaryCount => _fieldToId.Count;
+ public int LoadedDictionaryCount => _idToDictionary.Count;
+
public IArrowType GetDictionaryType(long id)
{
if (!_idToValueType.TryGetValue(id, out IArrowType type))
@@ -72,9 +75,12 @@ namespace Apache.Arrow.Ipc
throw new ArgumentException($"Field type
{field.DataType.Name} does not match the existing type {valueTypeInDic})");
}
}
+ else
+ {
+ _idToValueType.Add(id, valueType);
+ }
_fieldToId.Add(field, id);
- _idToValueType.Add(id, valueType);
}
public long GetId(Field field)
@@ -90,7 +96,7 @@ namespace Apache.Arrow.Ipc
{
if (!_fieldToId.TryGetValue(field, out long id))
{
- id = _fieldToId.Count + 1;
+ id = _fieldToId.Count;
AddField(id, field);
}
return id;
diff --git a/csharp/src/Apache.Arrow/Types/DictionaryType.cs
b/csharp/src/Apache.Arrow/Types/DictionaryType.cs
index 5c1dd4095e..6316578aa6 100644
--- a/csharp/src/Apache.Arrow/Types/DictionaryType.cs
+++ b/csharp/src/Apache.Arrow/Types/DictionaryType.cs
@@ -20,6 +20,7 @@ namespace Apache.Arrow.Types
{
public sealed class DictionaryType : FixedWidthType
{
+ [Obsolete]
public static readonly DictionaryType Default = new
DictionaryType(Int64Type.Default, Int64Type.Default, false);
public DictionaryType(IArrowType indexType, IArrowType valueType, bool
ordered)
@@ -36,7 +37,7 @@ namespace Apache.Arrow.Types
public override ArrowTypeId TypeId => ArrowTypeId.Dictionary;
public override string Name => "dictionary";
- public override int BitWidth => 64;
+ public override int BitWidth => ((IntegerType)IndexType).BitWidth;
public override void Accept(IArrowTypeVisitor visitor) => Accept(this,
visitor);
public IArrowType IndexType { get; private set; }
diff --git a/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
b/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
index 4e491a2a6b..cd8198d434 100644
--- a/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
+++ b/csharp/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
@@ -38,7 +38,7 @@ namespace Apache.Arrow.Benchmarks
[GlobalSetup]
public async Task GlobalSetup()
{
- RecordBatch batch = TestData.CreateSampleRecordBatch(length:
Count);
+ RecordBatch batch = TestData.CreateSampleRecordBatch(length:
Count, createDictionaryArray: false);
_memoryStream = new MemoryStream();
ArrowStreamWriter writer = new ArrowStreamWriter(_memoryStream,
batch.Schema);
diff --git a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs
b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs
index d19d19f1ce..6a1e912409 100644
--- a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs
+++ b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs
@@ -14,14 +14,8 @@
// limitations under the License.
using System;
-using System.Collections.Generic;
-using System.Globalization;
using System.IO;
-using System.Numerics;
-using System.Text;
-using System.Text.Json;
using System.Threading.Tasks;
-using Apache.Arrow.Arrays;
using Apache.Arrow.Ipc;
using Apache.Arrow.Tests;
using Apache.Arrow.Types;
@@ -49,6 +43,7 @@ namespace Apache.Arrow.IntegrationTest
"json-to-arrow" => JsonToArrow,
"stream-to-file" => StreamToFile,
"file-to-stream" => FileToStream,
+ "round-trip-json-arrow" => RoundTripJsonArrow,
_ => () =>
{
Console.WriteLine($"Mode '{Mode}' is not supported.");
@@ -58,6 +53,14 @@ namespace Apache.Arrow.IntegrationTest
return await commandDelegate();
}
+ private async Task<int> RoundTripJsonArrow()
+ {
+ int status = await JsonToArrow();
+ if (status != 0) { return status; }
+
+ return await Validate();
+ }
+
private async Task<int> Validate()
{
JsonFile jsonFile = await ParseJsonFile();
@@ -72,7 +75,7 @@ namespace Apache.Arrow.IntegrationTest
return -1;
}
- Schema jsonFileSchema = jsonFile.Schema.ToArrow();
+ Schema jsonFileSchema = jsonFile.GetSchemaAndDictionaries(out
Func<DictionaryType, IArrowArray> dictionaries);
Schema arrowFileSchema = reader.Schema;
SchemaComparer.Compare(jsonFileSchema, arrowFileSchema);
@@ -80,7 +83,7 @@ namespace Apache.Arrow.IntegrationTest
for (int i = 0; i < batchCount; i++)
{
RecordBatch arrowFileRecordBatch =
reader.ReadNextRecordBatch();
- RecordBatch jsonFileRecordBatch =
jsonFile.Batches[i].ToArrow(jsonFileSchema);
+ RecordBatch jsonFileRecordBatch =
jsonFile.Batches[i].ToArrow(jsonFileSchema, dictionaries);
ArrowReaderVerifier.CompareBatches(jsonFileRecordBatch,
arrowFileRecordBatch, strictCompare: false);
}
@@ -98,7 +101,7 @@ namespace Apache.Arrow.IntegrationTest
private async Task<int> JsonToArrow()
{
JsonFile jsonFile = await ParseJsonFile();
- Schema schema = jsonFile.Schema.ToArrow();
+ Schema schema = jsonFile.GetSchemaAndDictionaries(out
Func<DictionaryType, IArrowArray> dictionaries);
using (FileStream fs = ArrowFileInfo.Create())
{
@@ -107,7 +110,7 @@ namespace Apache.Arrow.IntegrationTest
foreach (var jsonRecordBatch in jsonFile.Batches)
{
- RecordBatch batch = jsonRecordBatch.ToArrow(schema);
+ RecordBatch batch = jsonRecordBatch.ToArrow(schema,
dictionaries);
await writer.WriteRecordBatchAsync(batch);
}
await writer.WriteEndAsync();
diff --git a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs
b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs
index 987a236a10..bdb9e2682b 100644
--- a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs
+++ b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs
@@ -15,9 +15,9 @@
using System;
using System.Collections.Generic;
-using System.Diagnostics;
using System.Globalization;
using System.IO;
+using System.Linq;
using System.Numerics;
using System.Text;
using System.Text.Json;
@@ -31,8 +31,10 @@ namespace Apache.Arrow.IntegrationTest
public class JsonFile
{
public JsonSchema Schema { get; set; }
+
+ public List<JsonDictionary> Dictionaries { get; set; }
+
public List<JsonRecordBatch> Batches { get; set; }
- //public List<DictionaryBatch> Dictionaries {get;set;}
public static async ValueTask<JsonFile> ParseAsync(FileInfo fileInfo)
{
@@ -48,6 +50,33 @@ namespace Apache.Arrow.IntegrationTest
return JsonSerializer.Deserialize<JsonFile>(fileStream, options);
}
+ public Schema GetSchemaAndDictionaries(out Func<DictionaryType,
IArrowArray> dictionaries)
+ {
+ Schema schema = Schema.ToArrow(out Dictionary<DictionaryType, int>
dictionaryIndexes);
+
+ Func<DictionaryType, IArrowArray> lookup = null;
+ lookup = type => Dictionaries.Single(d => d.Id ==
dictionaryIndexes[type]).Data.ToArrow(type.ValueType, lookup);
+ dictionaries = lookup;
+
+ return schema;
+ }
+
+ /// <summary>
+ /// Return both the schema and a specific batch number.
+ /// This method is used by C Data Interface integration testing.
+ /// </summary>
+ public Schema ToArrow(int batchNumber, out RecordBatch batch)
+ {
+ Schema schema = Schema.ToArrow(out Dictionary<DictionaryType, int>
dictionaryIndexes);
+
+ Func<DictionaryType, IArrowArray> lookup = null;
+ lookup = type => Dictionaries.Single(d => d.Id ==
dictionaryIndexes[type]).Data.ToArrow(type.ValueType, lookup);
+
+ batch = Batches[batchNumber].ToArrow(schema, lookup);
+
+ return schema;
+ }
+
private static JsonSerializerOptions GetJsonOptions()
{
JsonSerializerOptions options = new JsonSerializerOptions()
@@ -67,22 +96,33 @@ namespace Apache.Arrow.IntegrationTest
/// <summary>
/// Decode this JSON schema as a Schema instance.
/// </summary>
+ public Schema ToArrow(out Dictionary<DictionaryType, int>
dictionaryIndexes)
+ {
+ dictionaryIndexes = new Dictionary<DictionaryType, int>();
+ return CreateSchema(this, dictionaryIndexes);
+ }
+
+ /// <summary>
+ /// Decode this JSON schema as a Schema instance without computing
dictionaries.
+ /// This method is used by C Data Interface integration testing.
+ /// </summary>
public Schema ToArrow()
{
- return CreateSchema(this);
+ Dictionary<DictionaryType, int> dictionaryIndexes = new
Dictionary<DictionaryType, int>();
+ return CreateSchema(this, dictionaryIndexes);
}
- private static Schema CreateSchema(JsonSchema jsonSchema)
+ private static Schema CreateSchema(JsonSchema jsonSchema,
Dictionary<DictionaryType, int> dictionaryIndexes)
{
Schema.Builder builder = new Schema.Builder();
for (int i = 0; i < jsonSchema.Fields.Count; i++)
{
- builder.Field(f => CreateField(f, jsonSchema.Fields[i]));
+ builder.Field(f => CreateField(f, jsonSchema.Fields[i],
dictionaryIndexes));
}
return builder.Build();
}
- private static void CreateField(Field.Builder builder, JsonField
jsonField)
+ private static void CreateField(Field.Builder builder, JsonField
jsonField, Dictionary<DictionaryType, int> dictionaryIndexes)
{
Field[] children = null;
if (jsonField.Children?.Count > 0)
@@ -91,13 +131,26 @@ namespace Apache.Arrow.IntegrationTest
for (int i = 0; i < jsonField.Children.Count; i++)
{
Field.Builder field = new Field.Builder();
- CreateField(field, jsonField.Children[i]);
+ CreateField(field, jsonField.Children[i],
dictionaryIndexes);
children[i] = field.Build();
}
}
+ IArrowType type = ToArrowType(jsonField.Type, children);
+
+ if (jsonField.Dictionary != null)
+ {
+ DictionaryType dictType = new DictionaryType(
+ ToArrowType(jsonField.Dictionary.IndexType, new Field[0]),
+ type,
+ jsonField.Dictionary.IsOrdered);
+
+ dictionaryIndexes[dictType] = jsonField.Dictionary.Id;
+ type = dictType;
+ }
+
builder.Name(jsonField.Name)
- .DataType(ToArrowType(jsonField.Type, children))
+ .DataType(type)
.Nullable(jsonField.Nullable);
if (jsonField.Metadata != null)
@@ -300,10 +353,18 @@ namespace Apache.Arrow.IntegrationTest
public class JsonDictionaryIndex
{
public int Id { get; set; }
- public JsonArrowType Type { get; set; }
+ public JsonArrowType IndexType { get; set; }
public bool IsOrdered { get; set; }
}
+ public class JsonDictionary
+ {
+ public int Id { get; set; }
+
+ [JsonPropertyName("data")]
+ public JsonRecordBatch Data { get; set; }
+ }
+
public class JsonMetadata : List<KeyValuePair<string, string>>
{
}
@@ -316,12 +377,19 @@ namespace Apache.Arrow.IntegrationTest
/// <summary>
/// Decode this JSON record batch as a RecordBatch instance.
/// </summary>
- public RecordBatch ToArrow(Schema schema)
+ public RecordBatch ToArrow(Schema schema, Func<DictionaryType,
IArrowArray> dictionaries)
+ {
+ return CreateRecordBatch(schema, dictionaries, this);
+ }
+
+ public IArrowArray ToArrow(IArrowType arrowType, Func<DictionaryType,
IArrowArray> dictionaries)
{
- return CreateRecordBatch(schema, this);
+ ArrayCreator creator = new ArrayCreator(this.Columns[0],
dictionaries);
+ arrowType.Accept(creator);
+ return creator.Array;
}
- private RecordBatch CreateRecordBatch(Schema schema, JsonRecordBatch
jsonRecordBatch)
+ private RecordBatch CreateRecordBatch(Schema schema,
Func<DictionaryType, IArrowArray> dictionaries, JsonRecordBatch jsonRecordBatch)
{
if (schema.FieldsList.Count != jsonRecordBatch.Columns.Count)
{
@@ -333,7 +401,7 @@ namespace Apache.Arrow.IntegrationTest
{
JsonFieldData data = jsonRecordBatch.Columns[i];
Field field = schema.FieldsList[i];
- ArrayCreator creator = new ArrayCreator(data);
+ ArrayCreator creator = new ArrayCreator(data, dictionaries);
field.DataType.Accept(creator);
arrays.Add(creator.Array);
}
@@ -369,14 +437,18 @@ namespace Apache.Arrow.IntegrationTest
IArrowTypeVisitor<StructType>,
IArrowTypeVisitor<UnionType>,
IArrowTypeVisitor<MapType>,
+ IArrowTypeVisitor<DictionaryType>,
IArrowTypeVisitor<NullType>
{
private JsonFieldData JsonFieldData { get; set; }
public IArrowArray Array { get; private set; }
- public ArrayCreator(JsonFieldData jsonFieldData)
+ private readonly Func<DictionaryType, IArrowArray> dictionaries;
+
+ public ArrayCreator(JsonFieldData jsonFieldData,
Func<DictionaryType, IArrowArray> dictionaries)
{
JsonFieldData = jsonFieldData;
+ this.dictionaries = dictionaries;
}
public void Visit(BooleanType type)
@@ -656,6 +728,12 @@ namespace Apache.Arrow.IntegrationTest
Array = new MapArray(arrayData);
}
+ public void Visit(DictionaryType type)
+ {
+ type.IndexType.Accept(this);
+ Array = new DictionaryArray(type, Array,
this.dictionaries(type));
+ }
+
private ArrayData[] GetChildren(NestedType type)
{
ArrayData[] children = new ArrayData[type.Fields.Count];
diff --git a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs
b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs
index 2f2229ded4..585b1acc27 100644
--- a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs
+++ b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs
@@ -66,7 +66,7 @@ namespace Apache.Arrow.Tests
ArrowFileReader reader = new ArrowFileReader(stream,
memoryPool, leaveOpen: shouldLeaveOpen);
reader.ReadNextRecordBatch();
- Assert.Equal(1, memoryPool.Statistics.Allocations);
+ Assert.Equal(2, memoryPool.Statistics.Allocations);
Assert.True(memoryPool.Statistics.BytesAllocated > 0);
reader.Dispose();
@@ -132,8 +132,8 @@ namespace Apache.Arrow.Tests
[Fact]
public async Task TestReadMultipleRecordBatchAsync()
{
- RecordBatch originalBatch1 =
TestData.CreateSampleRecordBatch(length: 100);
- RecordBatch originalBatch2 =
TestData.CreateSampleRecordBatch(length: 50);
+ RecordBatch originalBatch1 =
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+ RecordBatch originalBatch2 =
TestData.CreateSampleRecordBatch(length: 50, createDictionaryArray: false);
using (MemoryStream stream = new MemoryStream())
{
diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs
b/csharp/test/Apache.Arrow.Tests/TestData.cs
index 3af6efb97b..79e886f0de 100644
--- a/csharp/test/Apache.Arrow.Tests/TestData.cs
+++ b/csharp/test/Apache.Arrow.Tests/TestData.cs
@@ -23,7 +23,7 @@ namespace Apache.Arrow.Tests
{
public static class TestData
{
- public static RecordBatch CreateSampleRecordBatch(int length, bool
createDictionaryArray = false)
+ public static RecordBatch CreateSampleRecordBatch(int length, bool
createDictionaryArray = true)
{
return CreateSampleRecordBatch(length, columnSetCount: 1,
createDictionaryArray);
}
diff --git a/dev/archery/archery/integration/datagen.py
b/dev/archery/archery/integration/datagen.py
index 80cc1c1e76..341b48117a 100644
--- a/dev/archery/archery/integration/datagen.py
+++ b/dev/archery/archery/integration/datagen.py
@@ -1836,15 +1836,12 @@ def get_generated_json_files(tempdir=None):
.skip_tester('C#')
.skip_tester('JS'),
- generate_dictionary_case()
- .skip_tester('C#'),
+ generate_dictionary_case(),
generate_dictionary_unsigned_case()
- .skip_tester('C#')
.skip_tester('Java'), # TODO(ARROW-9377)
generate_nested_dictionary_case()
- .skip_tester('C#')
.skip_tester('Java'), # TODO(ARROW-7779)
generate_run_end_encoded_case()
diff --git a/dev/archery/archery/integration/tester_csharp.py
b/dev/archery/archery/integration/tester_csharp.py
index 4f77656411..9aab5b0b28 100644
--- a/dev/archery/archery/integration/tester_csharp.py
+++ b/dev/archery/archery/integration/tester_csharp.py
@@ -78,9 +78,7 @@ class _CDataBase:
def _read_batch_from_json(self, json_path, num_batch):
from Apache.Arrow.IntegrationTest import CDataInterface
- jf = CDataInterface.ParseJsonFile(json_path)
- schema = jf.Schema.ToArrow()
- return schema, jf.Batches[num_batch].ToArrow(schema)
+ return CDataInterface.ParseJsonFile(json_path).ToArrow(num_batch)
def _run_gc(self):
from Apache.Arrow.IntegrationTest import CDataInterface
diff --git a/docs/source/status.rst b/docs/source/status.rst
index 6167d3037b..140e15f44c 100644
--- a/docs/source/status.rst
+++ b/docs/source/status.rst
@@ -100,7 +100,7 @@ Data Types
| Data type | C++ | Java | Go | JavaScript | C# | Rust |
Julia | Swift |
| (special) | | | | | | |
| |
+===================+=======+=======+=======+============+=======+=======+=======+=======+
-| Dictionary | ✓ | ✓ (3) | ✓ | ✓ | ✓ (3) | ✓ (3) | ✓
| |
+| Dictionary | ✓ | ✓ (3) | ✓ | ✓ | ✓ | ✓ (3) | ✓
| |
+-------------------+-------+-------+-------+------------+-------+-------+-------+-------+
| Extension | ✓ | ✓ | ✓ | | | ✓ | ✓
| |
+-------------------+-------+-------+-------+------------+-------+-------+-------+-------+