This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new a4e89f1b1 fix(csharp): Resolve memory leaks described by #1690 (#1695)
a4e89f1b1 is described below

commit a4e89f1b1f9c65543192d4f0fbf0a0cf3011481e
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Mon Apr 1 12:43:20 2024 -0700

    fix(csharp): Resolve memory leaks described by #1690 (#1695)
    
    Fixes memory leaks. Resolves #1690.
---
 csharp/src/Apache.Arrow.Adbc/C/CAdbcDriver.cs      |  2 +-
 .../src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs | 11 +--
 .../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 78 ++++++++++++----------
 3 files changed, 52 insertions(+), 39 deletions(-)

diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriver.cs 
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriver.cs
index 7ca50c877..8f0050b96 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriver.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriver.cs
@@ -148,7 +148,7 @@ namespace Apache.Arrow.Adbc.C
         /// unrecognized codes (the row will be omitted from the result).
         /// </summary>
 #if NET5_0_OR_GREATER
-        internal delegate* unmanaged<CAdbcConnection*, byte*, int, 
CArrowArrayStream*, CAdbcError*, AdbcStatusCode> ConnectionGetInfo;
+        internal delegate* unmanaged<CAdbcConnection*, int*, int, 
CArrowArrayStream*, CAdbcError*, AdbcStatusCode> ConnectionGetInfo;
 #else
         internal IntPtr ConnectionGetInfo;
 #endif
diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs 
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
index fe06d1a1a..d95915ff1 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
@@ -51,7 +51,7 @@ namespace Apache.Arrow.Adbc.C
         private static unsafe delegate* unmanaged<CAdbcConnection*, 
CAdbcError*, AdbcStatusCode> ConnectionRollbackPtr => &RollbackConnection;
         private static unsafe delegate* unmanaged<CAdbcConnection*, 
CAdbcError*, AdbcStatusCode> ConnectionCommitPtr => &CommitConnection;
         private static unsafe delegate* unmanaged<CAdbcConnection*, 
CAdbcError*, AdbcStatusCode> ConnectionReleasePtr => &ReleaseConnection;
-        private static unsafe delegate* unmanaged<CAdbcConnection*, byte*, 
int, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> ConnectionGetInfoPtr => 
&GetConnectionInfo;
+        private static unsafe delegate* unmanaged<CAdbcConnection*, int*, int, 
CArrowArrayStream*, CAdbcError*, AdbcStatusCode> ConnectionGetInfoPtr => 
&GetConnectionInfo;
         private static unsafe delegate* unmanaged<CAdbcConnection*, byte*, 
int, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> 
ConnectionReadPartitionPtr => &ReadConnectionPartition;
         private static unsafe delegate* unmanaged<CAdbcConnection*, byte*, 
byte*, CAdbcError*, AdbcStatusCode> ConnectionSetOptionPtr => 
&SetConnectionOption;
 
@@ -93,7 +93,7 @@ namespace Apache.Arrow.Adbc.C
         private static IntPtr ConnectionCommitPtr => 
s_connectionCommit.Pointer;
         private static unsafe readonly NativeDelegate<ConnectionFn> 
s_connectionRelease = new NativeDelegate<ConnectionFn>(ReleaseConnection);
         private static IntPtr ConnectionReleasePtr => 
s_connectionRelease.Pointer;
-        internal unsafe delegate AdbcStatusCode 
ConnectionGetInfo(CAdbcConnection* connection, byte* info_codes, int 
info_codes_length, CArrowArrayStream* stream, CAdbcError* error);
+        internal unsafe delegate AdbcStatusCode 
ConnectionGetInfo(CAdbcConnection* connection, int* info_codes, int 
info_codes_length, CArrowArrayStream* stream, CAdbcError* error);
         private static unsafe readonly NativeDelegate<ConnectionGetInfo> 
s_connectionGetInfo = new NativeDelegate<ConnectionGetInfo>(GetConnectionInfo);
         private static IntPtr ConnectionGetInfoPtr => 
s_connectionGetInfo.Pointer;
         private unsafe delegate AdbcStatusCode 
ConnectionReadPartition(CAdbcConnection* connection, byte* 
serialized_partition, int serialized_length, CArrowArrayStream* stream, 
CAdbcError* error);
@@ -452,7 +452,7 @@ namespace Apache.Arrow.Adbc.C
 #if NET5_0_OR_GREATER
         [UnmanagedCallersOnly]
 #endif
-        private unsafe static AdbcStatusCode 
GetConnectionInfo(CAdbcConnection* nativeConnection, byte* info_codes, int 
info_codes_length, CArrowArrayStream* stream, CAdbcError* error)
+        private unsafe static AdbcStatusCode 
GetConnectionInfo(CAdbcConnection* nativeConnection, int* info_codes, int 
info_codes_length, CArrowArrayStream* stream, CAdbcError* error)
         {
             GCHandle gch = 
GCHandle.FromIntPtr((IntPtr)nativeConnection->private_data);
             ConnectionStub stub = (ConnectionStub)gch.Target;
@@ -732,6 +732,7 @@ namespace Apache.Arrow.Adbc.C
                 columnNamePattern = 
Marshal.PtrToStringUTF8((IntPtr)column_name);
 #endif
 
+                // TODO (GH-1694): Marshaling is incorrect
                 GCHandle gch = GCHandle.FromIntPtr((IntPtr)table_type);
                 List<string> tableTypes = (List<string>)gch.Target;
 
@@ -791,6 +792,7 @@ namespace Apache.Arrow.Adbc.C
                     return AdbcStatusCode.UnknownError;
                 }
 
+                // TODO (GH-1694): Marshaling is incorrect
                 GCHandle gch = 
GCHandle.FromIntPtr((IntPtr)serializedPartition);
                 PartitionDescriptor descriptor = 
(PartitionDescriptor)gch.Target;
 
@@ -799,13 +801,14 @@ namespace Apache.Arrow.Adbc.C
                 return AdbcStatusCode.Success;
             }
 
-            public unsafe AdbcStatusCode GetInfo(ref CAdbcConnection 
nativeConnection, byte* info_codes, int info_codes_length, CArrowArrayStream* 
stream, ref CAdbcError error)
+            public unsafe AdbcStatusCode GetInfo(ref CAdbcConnection 
nativeConnection, int* info_codes, int info_codes_length, CArrowArrayStream* 
stream, ref CAdbcError error)
             {
                 if (nativeConnection.private_data == null)
                 {
                     return AdbcStatusCode.UnknownError;
                 }
 
+                // TODO (GH-1694): Marshaling is incorrect
                 GCHandle gch = GCHandle.FromIntPtr((IntPtr)info_codes);
                 List<int> codes = (List<int>)gch.Target;
 
diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs 
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index c621eb6ec..82b2a8c15 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -17,7 +17,6 @@
 
 using System;
 using System.Collections.Generic;
-using System.Data;
 using System.IO;
 using System.Linq;
 using System.Runtime.InteropServices;
@@ -697,31 +696,32 @@ namespace Apache.Arrow.Adbc.C
             }
 
 #if NET5_0_OR_GREATER
-            public unsafe void Call(delegate* unmanaged<CAdbcConnection*, 
byte*, int, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> fn, ref 
CAdbcConnection connection, List<int> infoCodes, CArrowArrayStream* stream)
+            public unsafe void Call(delegate* unmanaged<CAdbcConnection*, 
int*, int, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> fn, ref 
CAdbcConnection connection, List<int> infoCodes, CArrowArrayStream* stream)
+            {
+                fixed (CAdbcConnection* cn = &connection)
+                fixed (CAdbcError* e = &_error)
+                {
+                    Span<int> span = CollectionsMarshal.AsSpan(infoCodes);
+                    fixed (int* spanPtr = span)
+                    {
+                        TranslateCode(fn(cn, spanPtr, infoCodes.Count, stream, 
e));
+                    }
+                }
+            }
 #else
             public unsafe void Call(IntPtr ptr, ref CAdbcConnection 
connection, List<int> infoCodes, CArrowArrayStream* stream)
-#endif
             {
-                int numInts = infoCodes.Count;
-
-                // Calculate the total number of bytes needed
-                int totalBytes = numInts * sizeof(int);
-
-                IntPtr bytePtr = Marshal.AllocHGlobal(totalBytes);
-
-                int[] intArray = infoCodes.ToArray();
-                Marshal.Copy(intArray, 0, bytePtr, numInts);
-
                 fixed (CAdbcConnection* cn = &connection)
                 fixed (CAdbcError* e = &_error)
                 {
-#if NET5_0_OR_GREATER
-                    TranslateCode(fn(cn, (byte*)bytePtr, infoCodes.Count, 
stream, e));
-#else
-                    
TranslateCode(Marshal.GetDelegateForFunctionPointer<CAdbcDriverExporter.ConnectionGetInfo>(ptr)(cn,
 (byte*)bytePtr, infoCodes.Count, stream, e));
-#endif
+                    Span<int> span = infoCodes.ToArray().AsSpan();
+                    fixed (int* spanPtr = span)
+                    {
+                        
TranslateCode(Marshal.GetDelegateForFunctionPointer<CAdbcDriverExporter.ConnectionGetInfo>(ptr)(cn,
 spanPtr, infoCodes.Count, stream, e));
+                    }
                 }
             }
+#endif
 
 #if NET5_0_OR_GREATER
             public unsafe void Call(delegate* unmanaged<CAdbcConnection*, int, 
byte*, byte*, byte*, byte**, byte*, CArrowArrayStream*, CAdbcError*, 
AdbcStatusCode> fn, ref CAdbcConnection connection, int depth, string catalog, 
string db_schema, string table_name, List<string> table_types, string 
column_name, CArrowArrayStream* stream)
@@ -737,9 +737,8 @@ namespace Apache.Arrow.Adbc.C
                 }
 
                 // need to terminate with a null entry per 
https://github.com/apache/arrow-adbc/blob/b97e22c4d6524b60bf261e1970155500645be510/adbc.h#L909-L911
-                table_types.Add(null);
-
-                byte** bTable_type = (byte**)Marshal.AllocHGlobal(IntPtr.Size 
* table_types.Count);
+                byte** bTable_type = (byte**)Marshal.AllocHGlobal(IntPtr.Size 
* (table_types.Count + 1));
+                bTable_type[table_types.Count] = null;
 
                 for (int i = 0; i < table_types.Count; i++)
                 {
@@ -751,25 +750,36 @@ namespace Apache.Arrow.Adbc.C
 #endif
                 }
 
-                using (Utf8Helper catalogHelper = new Utf8Helper(catalog))
-                using (Utf8Helper schemaHelper = new Utf8Helper(db_schema))
-                using (Utf8Helper tableNameHelper = new Utf8Helper(table_name))
-                using (Utf8Helper columnNameHelper = new 
Utf8Helper(column_name))
+                try
                 {
-                    bcatalog = (byte*)(IntPtr)(catalogHelper);
-                    bDb_schema = (byte*)(IntPtr)(schemaHelper);
-                    bTable_name = (byte*)(IntPtr)(tableNameHelper);
-                    bColumn_Name = (byte*)(IntPtr)(columnNameHelper);
-
-                    fixed (CAdbcConnection* cn = &connection)
-                    fixed (CAdbcError* e = &_error)
+                    using (Utf8Helper catalogHelper = new Utf8Helper(catalog))
+                    using (Utf8Helper schemaHelper = new Utf8Helper(db_schema))
+                    using (Utf8Helper tableNameHelper = new 
Utf8Helper(table_name))
+                    using (Utf8Helper columnNameHelper = new 
Utf8Helper(column_name))
                     {
+                        bcatalog = (byte*)(IntPtr)(catalogHelper);
+                        bDb_schema = (byte*)(IntPtr)(schemaHelper);
+                        bTable_name = (byte*)(IntPtr)(tableNameHelper);
+                        bColumn_Name = (byte*)(IntPtr)(columnNameHelper);
+
+                        fixed (CAdbcConnection* cn = &connection)
+                        fixed (CAdbcError* e = &_error)
+                        {
 #if NET5_0_OR_GREATER
-                        TranslateCode(fn(cn, depth, bcatalog, bDb_schema, 
bTable_name, bTable_type, bColumn_Name, stream, e));
+                            TranslateCode(fn(cn, depth, bcatalog, bDb_schema, 
bTable_name, bTable_type, bColumn_Name, stream, e));
 #else
-                        
TranslateCode(Marshal.GetDelegateForFunctionPointer<CAdbcDriverExporter.ConnectionGetObjects>(fn)(cn,
 depth, bcatalog, bDb_schema, bTable_name, bTable_type, bColumn_Name, stream, 
e));
+                            
TranslateCode(Marshal.GetDelegateForFunctionPointer<CAdbcDriverExporter.ConnectionGetObjects>(fn)(cn,
 depth, bcatalog, bDb_schema, bTable_name, bTable_type, bColumn_Name, stream, 
e));
 #endif
+                        }
+                    }
+                }
+                finally
+                {
+                    for (int i = 0; i < table_types.Count; i++)
+                    {
+                        Marshal.FreeCoTaskMem((IntPtr)bTable_type[i]);
                     }
+                    Marshal.FreeHGlobal((IntPtr)bTable_type);
                 }
             }
 

Reply via email to