This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new c92de88bd9 GH-34636: [C#] Reduce allocations when using ArrayPool 
(#39166)
c92de88bd9 is described below

commit c92de88bd96f9a6d250498e7e2024dddc5a7d6a6
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Sun Dec 10 09:59:15 2023 -0800

    GH-34636: [C#] Reduce allocations when using ArrayPool (#39166)
    
    ### Rationale for this change
    
    GH-34636 is a great suggestion for simplifying the code and making it more 
efficient by changing the delegate-based RentReturn pattern to a "using"-based 
one. As most of the affected call sites were the ones not passing 
CancellationToken properly, it was a good time to fix that as well.
    
    ### Are these changes tested?
    
    This is basically a refactoring which doesn't add new functionality and so 
is covered by existing tests.
    
    Closes #39144
    * Closes: #34636
    
    Authored-by: Curt Hagenlocher <[email protected]>
    Signed-off-by: Curt Hagenlocher <[email protected]>
---
 .../Apache.Arrow/Extensions/ArrayPoolExtensions.cs | 40 +++++++------------
 .../Ipc/ArrowFileReaderImplementation.cs           | 46 +++++++++++-----------
 csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs     |  8 ++--
 .../Ipc/ArrowStreamReaderImplementation.cs         | 40 +++++++++----------
 csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs   |  8 ++--
 5 files changed, 65 insertions(+), 77 deletions(-)

diff --git a/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs 
b/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs
index 95a39439f7..51287674b2 100644
--- a/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs
+++ b/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs
@@ -16,46 +16,36 @@
 using System;
 using System.Buffers;
 using System.Runtime.CompilerServices;
-using System.Threading.Tasks;
 
 namespace Apache.Arrow
 {
     internal static class ArrayPoolExtensions
     {
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        public static void RentReturn(this ArrayPool<byte> pool, int length, 
Action<Memory<byte>> action)
+        public static ArrayLease RentReturn(this ArrayPool<byte> pool, int 
length, out Memory<byte> buffer)
         {
-            byte[] array = null;
-
-            try
-            {
-                array = pool.Rent(length);
-                action(array.AsMemory(0, length));
-            }
-            finally
-            {
-                if (array != null)
-                {
-                    pool.Return(array);
-                }
-            }
+            byte[] array = pool.Rent(length);
+            buffer = array.AsMemory(0, length);
+            return new ArrayLease(pool, array);
         }
 
-        [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        public static async ValueTask RentReturnAsync(this ArrayPool<byte> 
pool, int length, Func<Memory<byte>, ValueTask> action)
+        internal struct ArrayLease : IDisposable
         {
-            byte[] array = null;
+            private readonly ArrayPool<byte> _pool;
+            private byte[] _array;
 
-            try
+            public ArrayLease(ArrayPool<byte> pool, byte[] array)
             {
-                array = pool.Rent(length);
-                await action(array.AsMemory(0, length));
+                _pool = pool;
+                _array = array;
             }
-            finally
+
+            public void Dispose()
             {
-                if (array != null)
+                if (_array != null)
                 {
-                    pool.Return(array);
+                    _pool.Return(_array);
+                    _array = null;
                 }
             }
         }
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs 
b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
index 3ae475885f..02f36b0793 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
@@ -42,47 +42,47 @@ namespace Apache.Arrow.Ipc
         {
         }
 
-        public async ValueTask<int> RecordBatchCountAsync()
+        public async ValueTask<int> RecordBatchCountAsync(CancellationToken 
cancellationToken = default)
         {
             if (!HasReadSchema)
             {
-                await ReadSchemaAsync().ConfigureAwait(false);
+                await ReadSchemaAsync(cancellationToken).ConfigureAwait(false);
             }
 
             return _footer.RecordBatchCount;
         }
 
-        protected override async ValueTask ReadSchemaAsync()
+        protected override async ValueTask ReadSchemaAsync(CancellationToken 
cancellationToken = default)
         {
             if (HasReadSchema)
             {
                 return;
             }
 
-            await ValidateFileAsync().ConfigureAwait(false);
+            await ValidateFileAsync(cancellationToken).ConfigureAwait(false);
 
             int footerLength = 0;
-            await ArrayPool<byte>.Shared.RentReturnAsync(4, async (buffer) =>
+            using (ArrayPool<byte>.Shared.RentReturn(4, out Memory<byte> 
buffer))
             {
                 BaseStream.Position = GetFooterLengthPosition();
 
-                int bytesRead = await 
BaseStream.ReadFullBufferAsync(buffer).ConfigureAwait(false);
+                int bytesRead = await BaseStream.ReadFullBufferAsync(buffer, 
cancellationToken).ConfigureAwait(false);
                 EnsureFullRead(buffer, bytesRead);
 
                 footerLength = ReadFooterLength(buffer);
-            }).ConfigureAwait(false);
+            }
 
-            await ArrayPool<byte>.Shared.RentReturnAsync(footerLength, async 
(buffer) =>
+            using (ArrayPool<byte>.Shared.RentReturn(footerLength, out 
Memory<byte> buffer))
             {
                 long footerStartPosition = GetFooterLengthPosition() - 
footerLength;
 
                 BaseStream.Position = footerStartPosition;
 
-                int bytesRead = await 
BaseStream.ReadFullBufferAsync(buffer).ConfigureAwait(false);
+                int bytesRead = await BaseStream.ReadFullBufferAsync(buffer, 
cancellationToken).ConfigureAwait(false);
                 EnsureFullRead(buffer, bytesRead);
 
                 ReadSchema(buffer);
-            }).ConfigureAwait(false);
+            }
         }
 
         protected override void ReadSchema()
@@ -95,7 +95,7 @@ namespace Apache.Arrow.Ipc
             ValidateFile();
 
             int footerLength = 0;
-            ArrayPool<byte>.Shared.RentReturn(4, (buffer) =>
+            using (ArrayPool<byte>.Shared.RentReturn(4, out Memory<byte> 
buffer))
             {
                 BaseStream.Position = GetFooterLengthPosition();
 
@@ -103,9 +103,9 @@ namespace Apache.Arrow.Ipc
                 EnsureFullRead(buffer, bytesRead);
 
                 footerLength = ReadFooterLength(buffer);
-            });
+            }
 
-            ArrayPool<byte>.Shared.RentReturn(footerLength, (buffer) =>
+            using (ArrayPool<byte>.Shared.RentReturn(footerLength, out 
Memory<byte> buffer))
             {
                 long footerStartPosition = GetFooterLengthPosition() - 
footerLength;
 
@@ -115,7 +115,7 @@ namespace Apache.Arrow.Ipc
                 EnsureFullRead(buffer, bytesRead);
 
                 ReadSchema(buffer);
-            });
+            }
         }
 
         private long GetFooterLengthPosition()
@@ -239,14 +239,14 @@ namespace Apache.Arrow.Ipc
         /// <summary>
         /// Check if file format is valid. If it's valid don't run the 
validation again.
         /// </summary>
-        private async ValueTask ValidateFileAsync()
+        private async ValueTask ValidateFileAsync(CancellationToken 
cancellationToken = default)
         {
             if (IsFileValid)
             {
                 return;
             }
 
-            await ValidateMagicAsync().ConfigureAwait(false);
+            await ValidateMagicAsync(cancellationToken).ConfigureAwait(false);
 
             IsFileValid = true;
         }
@@ -266,20 +266,20 @@ namespace Apache.Arrow.Ipc
             IsFileValid = true;
         }
 
-        private async ValueTask ValidateMagicAsync()
+        private async ValueTask ValidateMagicAsync(CancellationToken 
cancellationToken = default)
         {
             long startingPosition = BaseStream.Position;
             int magicLength = ArrowFileConstants.Magic.Length;
 
             try
             {
-                await ArrayPool<byte>.Shared.RentReturnAsync(magicLength, 
async (buffer) =>
+                using (ArrayPool<byte>.Shared.RentReturn(magicLength, out 
Memory<byte> buffer))
                 {
                     // Seek to the beginning of the stream
                     BaseStream.Position = 0;
 
                     // Read beginning of stream
-                    await BaseStream.ReadAsync(buffer).ConfigureAwait(false);
+                    await BaseStream.ReadAsync(buffer, 
cancellationToken).ConfigureAwait(false);
 
                     VerifyMagic(buffer);
 
@@ -287,10 +287,10 @@ namespace Apache.Arrow.Ipc
                     BaseStream.Position = BaseStream.Length - magicLength;
 
                     // Read the end of the stream
-                    await BaseStream.ReadAsync(buffer).ConfigureAwait(false);
+                    await BaseStream.ReadAsync(buffer, 
cancellationToken).ConfigureAwait(false);
 
                     VerifyMagic(buffer);
-                }).ConfigureAwait(false);
+                }
             }
             finally
             {
@@ -305,7 +305,7 @@ namespace Apache.Arrow.Ipc
 
             try
             {
-                ArrayPool<byte>.Shared.RentReturn(magicLength, buffer =>
+                using (ArrayPool<byte>.Shared.RentReturn(magicLength, out 
Memory<byte> buffer))
                 {
                     // Seek to the beginning of the stream
                     BaseStream.Position = 0;
@@ -322,7 +322,7 @@ namespace Apache.Arrow.Ipc
                     BaseStream.Read(buffer);
 
                     VerifyMagic(buffer);
-                });
+                }
             }
             finally
             {
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs 
b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
index 95b9f60fff..547fa800ec 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
@@ -215,7 +215,7 @@ namespace Apache.Arrow.Ipc
 
             // Write footer length
 
-            Buffers.RentReturn(4, (buffer) =>
+            using (Buffers.RentReturn(4, out Memory<byte> buffer))
             {
                 int footerLength;
                 checked
@@ -226,7 +226,7 @@ namespace Apache.Arrow.Ipc
                 BinaryPrimitives.WriteInt32LittleEndian(buffer.Span, 
footerLength);
 
                 BaseStream.Write(buffer);
-            });
+            }
 
             // Write magic
 
@@ -286,7 +286,7 @@ namespace Apache.Arrow.Ipc
 
             cancellationToken.ThrowIfCancellationRequested();
 
-            await Buffers.RentReturnAsync(4, async (buffer) =>
+            using (Buffers.RentReturn(4, out Memory<byte> buffer))
             {
                 int footerLength;
                 checked
@@ -297,7 +297,7 @@ namespace Apache.Arrow.Ipc
                 BinaryPrimitives.WriteInt32LittleEndian(buffer.Span, 
footerLength);
 
                 await BaseStream.WriteAsync(buffer, 
cancellationToken).ConfigureAwait(false);
-            }).ConfigureAwait(false);
+            }
 
             // Write magic
 
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs 
b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
index 184e0348e5..5428c88c27 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
@@ -78,7 +78,7 @@ namespace Apache.Arrow.Ipc
             }
 
             RecordBatch result = null;
-            await ArrayPool<byte>.Shared.RentReturnAsync(messageLength, async 
(messageBuff) =>
+            using (ArrayPool<byte>.Shared.RentReturn(messageLength, out 
Memory<byte> messageBuff))
             {
                 int bytesRead = await 
BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken)
                     .ConfigureAwait(false);
@@ -96,7 +96,7 @@ namespace Apache.Arrow.Ipc
 
                 Google.FlatBuffers.ByteBuffer bodybb = 
CreateByteBuffer(bodyBuff);
                 result = CreateArrowObjectFromMessage(message, bodybb, 
bodyBuffOwner);
-            }).ConfigureAwait(false);
+            }
 
             return new ReadResult(messageLength, result);
         }
@@ -125,7 +125,7 @@ namespace Apache.Arrow.Ipc
             }
 
             RecordBatch result = null;
-            ArrayPool<byte>.Shared.RentReturn(messageLength, messageBuff =>
+            using (ArrayPool<byte>.Shared.RentReturn(messageLength, out 
Memory<byte> messageBuff))
             {
                 int bytesRead = BaseStream.ReadFullBuffer(messageBuff);
                 EnsureFullRead(messageBuff, bytesRead);
@@ -141,12 +141,12 @@ namespace Apache.Arrow.Ipc
 
                 Google.FlatBuffers.ByteBuffer bodybb = 
CreateByteBuffer(bodyBuff);
                 result = CreateArrowObjectFromMessage(message, bodybb, 
bodyBuffOwner);
-            });
+            }
 
             return new ReadResult(messageLength, result);
         }
 
-        protected virtual async ValueTask ReadSchemaAsync()
+        protected virtual async ValueTask ReadSchemaAsync(CancellationToken 
cancellationToken = default)
         {
             if (HasReadSchema)
             {
@@ -154,18 +154,18 @@ namespace Apache.Arrow.Ipc
             }
 
             // Figure out length of schema
-            int schemaMessageLength = await 
ReadMessageLengthAsync(throwOnFullRead: true)
+            int schemaMessageLength = await 
ReadMessageLengthAsync(throwOnFullRead: true, cancellationToken)
                 .ConfigureAwait(false);
 
-            await ArrayPool<byte>.Shared.RentReturnAsync(schemaMessageLength, 
async (buff) =>
+            using (ArrayPool<byte>.Shared.RentReturn(schemaMessageLength, out 
Memory<byte> buff))
             {
                 // Read in schema
-                int bytesRead = await 
BaseStream.ReadFullBufferAsync(buff).ConfigureAwait(false);
+                int bytesRead = await BaseStream.ReadFullBufferAsync(buff, 
cancellationToken).ConfigureAwait(false);
                 EnsureFullRead(buff, bytesRead);
 
                 Google.FlatBuffers.ByteBuffer schemabb = 
CreateByteBuffer(buff);
                 Schema = 
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb), ref 
_dictionaryMemo);
-            }).ConfigureAwait(false);
+            }
         }
 
         protected virtual void ReadSchema()
@@ -178,20 +178,20 @@ namespace Apache.Arrow.Ipc
             // Figure out length of schema
             int schemaMessageLength = ReadMessageLength(throwOnFullRead: true);
 
-            ArrayPool<byte>.Shared.RentReturn(schemaMessageLength, buff =>
+            using (ArrayPool<byte>.Shared.RentReturn(schemaMessageLength, out 
Memory<byte> buff))
             {
                 int bytesRead = BaseStream.ReadFullBuffer(buff);
                 EnsureFullRead(buff, bytesRead);
 
                 Google.FlatBuffers.ByteBuffer schemabb = 
CreateByteBuffer(buff);
                 Schema = 
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb), ref 
_dictionaryMemo);
-            });
+            }
         }
 
         private async ValueTask<int> ReadMessageLengthAsync(bool 
throwOnFullRead, CancellationToken cancellationToken = default)
         {
             int messageLength = 0;
-            await ArrayPool<byte>.Shared.RentReturnAsync(4, async 
(lengthBuffer) =>
+            using (ArrayPool<byte>.Shared.RentReturn(4, out Memory<byte> 
lengthBuffer))
             {
                 int bytesRead = await 
BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken)
                     .ConfigureAwait(false);
@@ -201,7 +201,7 @@ namespace Apache.Arrow.Ipc
                 }
                 else if (bytesRead != 4)
                 {
-                    return;
+                    return 0;
                 }
 
                 messageLength = BitUtility.ReadInt32(lengthBuffer);
@@ -217,13 +217,12 @@ namespace Apache.Arrow.Ipc
                     }
                     else if (bytesRead != 4)
                     {
-                        messageLength = 0;
-                        return;
+                        return 0;
                     }
 
                     messageLength = BitUtility.ReadInt32(lengthBuffer);
                 }
-            }).ConfigureAwait(false);
+            };
 
             return messageLength;
         }
@@ -231,7 +230,7 @@ namespace Apache.Arrow.Ipc
         private int ReadMessageLength(bool throwOnFullRead)
         {
             int messageLength = 0;
-            ArrayPool<byte>.Shared.RentReturn(4, lengthBuffer =>
+            using (ArrayPool<byte>.Shared.RentReturn(4, out Memory<byte> 
lengthBuffer))
             {
                 int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer);
                 if (throwOnFullRead)
@@ -240,7 +239,7 @@ namespace Apache.Arrow.Ipc
                 }
                 else if (bytesRead != 4)
                 {
-                    return;
+                    return 0;
                 }
 
                 messageLength = BitUtility.ReadInt32(lengthBuffer);
@@ -255,13 +254,12 @@ namespace Apache.Arrow.Ipc
                     }
                     else if (bytesRead != 4)
                     {
-                        messageLength = 0;
-                        return;
+                        return 0;
                     }
 
                     messageLength = BitUtility.ReadInt32(lengthBuffer);
                 }
-            });
+            }
 
             return messageLength;
         }
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs 
b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
index 483dcea898..5f490019b2 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
@@ -890,7 +890,7 @@ namespace Apache.Arrow.Ipc
 
         private void WriteIpcMessageLength(int length)
         {
-            Buffers.RentReturn(_options.SizeOfIpcLength, (buffer) =>
+            using (Buffers.RentReturn(_options.SizeOfIpcLength, out 
Memory<byte> buffer))
             {
                 Memory<byte> currentBufferPosition = buffer;
                 if (!_options.WriteLegacyIpcFormat)
@@ -902,12 +902,12 @@ namespace Apache.Arrow.Ipc
 
                 
BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length);
                 BaseStream.Write(buffer);
-            });
+            }
         }
 
         private async ValueTask WriteIpcMessageLengthAsync(int length, 
CancellationToken cancellationToken)
         {
-            await Buffers.RentReturnAsync(_options.SizeOfIpcLength, async 
(buffer) =>
+            using (Buffers.RentReturn(_options.SizeOfIpcLength, out 
Memory<byte> buffer))
             {
                 Memory<byte> currentBufferPosition = buffer;
                 if (!_options.WriteLegacyIpcFormat)
@@ -919,7 +919,7 @@ namespace Apache.Arrow.Ipc
 
                 
BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length);
                 await BaseStream.WriteAsync(buffer, 
cancellationToken).ConfigureAwait(false);
-            }).ConfigureAwait(false);
+            }
         }
 
         protected int CalculatePadding(long offset, int alignment = 8)

Reply via email to