This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new a4e89f1b1 fix(csharp): Resolve memory leaks described by #1690 (#1695)
a4e89f1b1 is described below
commit a4e89f1b1f9c65543192d4f0fbf0a0cf3011481e
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Mon Apr 1 12:43:20 2024 -0700
fix(csharp): Resolve memory leaks described by #1690 (#1695)
Fixes memory leaks. Resolves #1690.
---
csharp/src/Apache.Arrow.Adbc/C/CAdbcDriver.cs | 2 +-
.../src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs | 11 +--
.../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 78 ++++++++++++----------
3 files changed, 52 insertions(+), 39 deletions(-)
diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriver.cs
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriver.cs
index 7ca50c877..8f0050b96 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriver.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriver.cs
@@ -148,7 +148,7 @@ namespace Apache.Arrow.Adbc.C
/// unrecognized codes (the row will be omitted from the result).
/// </summary>
#if NET5_0_OR_GREATER
- internal delegate* unmanaged<CAdbcConnection*, byte*, int,
CArrowArrayStream*, CAdbcError*, AdbcStatusCode> ConnectionGetInfo;
+ internal delegate* unmanaged<CAdbcConnection*, int*, int,
CArrowArrayStream*, CAdbcError*, AdbcStatusCode> ConnectionGetInfo;
#else
internal IntPtr ConnectionGetInfo;
#endif
diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
index fe06d1a1a..d95915ff1 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
@@ -51,7 +51,7 @@ namespace Apache.Arrow.Adbc.C
private static unsafe delegate* unmanaged<CAdbcConnection*,
CAdbcError*, AdbcStatusCode> ConnectionRollbackPtr => &RollbackConnection;
private static unsafe delegate* unmanaged<CAdbcConnection*,
CAdbcError*, AdbcStatusCode> ConnectionCommitPtr => &CommitConnection;
private static unsafe delegate* unmanaged<CAdbcConnection*,
CAdbcError*, AdbcStatusCode> ConnectionReleasePtr => &ReleaseConnection;
- private static unsafe delegate* unmanaged<CAdbcConnection*, byte*,
int, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> ConnectionGetInfoPtr =>
&GetConnectionInfo;
+ private static unsafe delegate* unmanaged<CAdbcConnection*, int*, int,
CArrowArrayStream*, CAdbcError*, AdbcStatusCode> ConnectionGetInfoPtr =>
&GetConnectionInfo;
private static unsafe delegate* unmanaged<CAdbcConnection*, byte*,
int, CArrowArrayStream*, CAdbcError*, AdbcStatusCode>
ConnectionReadPartitionPtr => &ReadConnectionPartition;
private static unsafe delegate* unmanaged<CAdbcConnection*, byte*,
byte*, CAdbcError*, AdbcStatusCode> ConnectionSetOptionPtr =>
&SetConnectionOption;
@@ -93,7 +93,7 @@ namespace Apache.Arrow.Adbc.C
private static IntPtr ConnectionCommitPtr =>
s_connectionCommit.Pointer;
private static unsafe readonly NativeDelegate<ConnectionFn>
s_connectionRelease = new NativeDelegate<ConnectionFn>(ReleaseConnection);
private static IntPtr ConnectionReleasePtr =>
s_connectionRelease.Pointer;
- internal unsafe delegate AdbcStatusCode
ConnectionGetInfo(CAdbcConnection* connection, byte* info_codes, int
info_codes_length, CArrowArrayStream* stream, CAdbcError* error);
+ internal unsafe delegate AdbcStatusCode
ConnectionGetInfo(CAdbcConnection* connection, int* info_codes, int
info_codes_length, CArrowArrayStream* stream, CAdbcError* error);
private static unsafe readonly NativeDelegate<ConnectionGetInfo>
s_connectionGetInfo = new NativeDelegate<ConnectionGetInfo>(GetConnectionInfo);
private static IntPtr ConnectionGetInfoPtr =>
s_connectionGetInfo.Pointer;
private unsafe delegate AdbcStatusCode
ConnectionReadPartition(CAdbcConnection* connection, byte*
serialized_partition, int serialized_length, CArrowArrayStream* stream,
CAdbcError* error);
@@ -452,7 +452,7 @@ namespace Apache.Arrow.Adbc.C
#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
- private unsafe static AdbcStatusCode
GetConnectionInfo(CAdbcConnection* nativeConnection, byte* info_codes, int
info_codes_length, CArrowArrayStream* stream, CAdbcError* error)
+ private unsafe static AdbcStatusCode
GetConnectionInfo(CAdbcConnection* nativeConnection, int* info_codes, int
info_codes_length, CArrowArrayStream* stream, CAdbcError* error)
{
GCHandle gch =
GCHandle.FromIntPtr((IntPtr)nativeConnection->private_data);
ConnectionStub stub = (ConnectionStub)gch.Target;
@@ -732,6 +732,7 @@ namespace Apache.Arrow.Adbc.C
columnNamePattern =
Marshal.PtrToStringUTF8((IntPtr)column_name);
#endif
+ // TODO (GH-1694): Marshaling is incorrect
GCHandle gch = GCHandle.FromIntPtr((IntPtr)table_type);
List<string> tableTypes = (List<string>)gch.Target;
@@ -791,6 +792,7 @@ namespace Apache.Arrow.Adbc.C
return AdbcStatusCode.UnknownError;
}
+ // TODO (GH-1694): Marshaling is incorrect
GCHandle gch =
GCHandle.FromIntPtr((IntPtr)serializedPartition);
PartitionDescriptor descriptor =
(PartitionDescriptor)gch.Target;
@@ -799,13 +801,14 @@ namespace Apache.Arrow.Adbc.C
return AdbcStatusCode.Success;
}
- public unsafe AdbcStatusCode GetInfo(ref CAdbcConnection
nativeConnection, byte* info_codes, int info_codes_length, CArrowArrayStream*
stream, ref CAdbcError error)
+ public unsafe AdbcStatusCode GetInfo(ref CAdbcConnection
nativeConnection, int* info_codes, int info_codes_length, CArrowArrayStream*
stream, ref CAdbcError error)
{
if (nativeConnection.private_data == null)
{
return AdbcStatusCode.UnknownError;
}
+ // TODO (GH-1694): Marshaling is incorrect
GCHandle gch = GCHandle.FromIntPtr((IntPtr)info_codes);
List<int> codes = (List<int>)gch.Target;
diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index c621eb6ec..82b2a8c15 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -17,7 +17,6 @@
using System;
using System.Collections.Generic;
-using System.Data;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
@@ -697,31 +696,32 @@ namespace Apache.Arrow.Adbc.C
}
#if NET5_0_OR_GREATER
- public unsafe void Call(delegate* unmanaged<CAdbcConnection*,
byte*, int, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> fn, ref
CAdbcConnection connection, List<int> infoCodes, CArrowArrayStream* stream)
+ public unsafe void Call(delegate* unmanaged<CAdbcConnection*,
int*, int, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> fn, ref
CAdbcConnection connection, List<int> infoCodes, CArrowArrayStream* stream)
+ {
+ fixed (CAdbcConnection* cn = &connection)
+ fixed (CAdbcError* e = &_error)
+ {
+ Span<int> span = CollectionsMarshal.AsSpan(infoCodes);
+ fixed (int* spanPtr = span)
+ {
+ TranslateCode(fn(cn, spanPtr, infoCodes.Count, stream,
e));
+ }
+ }
+ }
#else
public unsafe void Call(IntPtr ptr, ref CAdbcConnection
connection, List<int> infoCodes, CArrowArrayStream* stream)
-#endif
{
- int numInts = infoCodes.Count;
-
- // Calculate the total number of bytes needed
- int totalBytes = numInts * sizeof(int);
-
- IntPtr bytePtr = Marshal.AllocHGlobal(totalBytes);
-
- int[] intArray = infoCodes.ToArray();
- Marshal.Copy(intArray, 0, bytePtr, numInts);
-
fixed (CAdbcConnection* cn = &connection)
fixed (CAdbcError* e = &_error)
{
-#if NET5_0_OR_GREATER
- TranslateCode(fn(cn, (byte*)bytePtr, infoCodes.Count,
stream, e));
-#else
-
TranslateCode(Marshal.GetDelegateForFunctionPointer<CAdbcDriverExporter.ConnectionGetInfo>(ptr)(cn,
(byte*)bytePtr, infoCodes.Count, stream, e));
-#endif
+ Span<int> span = infoCodes.ToArray().AsSpan();
+ fixed (int* spanPtr = span)
+ {
+
TranslateCode(Marshal.GetDelegateForFunctionPointer<CAdbcDriverExporter.ConnectionGetInfo>(ptr)(cn,
spanPtr, infoCodes.Count, stream, e));
+ }
}
}
+#endif
#if NET5_0_OR_GREATER
public unsafe void Call(delegate* unmanaged<CAdbcConnection*, int,
byte*, byte*, byte*, byte**, byte*, CArrowArrayStream*, CAdbcError*,
AdbcStatusCode> fn, ref CAdbcConnection connection, int depth, string catalog,
string db_schema, string table_name, List<string> table_types, string
column_name, CArrowArrayStream* stream)
@@ -737,9 +737,8 @@ namespace Apache.Arrow.Adbc.C
}
// need to terminate with a null entry per
https://github.com/apache/arrow-adbc/blob/b97e22c4d6524b60bf261e1970155500645be510/adbc.h#L909-L911
- table_types.Add(null);
-
- byte** bTable_type = (byte**)Marshal.AllocHGlobal(IntPtr.Size
* table_types.Count);
+ byte** bTable_type = (byte**)Marshal.AllocHGlobal(IntPtr.Size
* (table_types.Count + 1));
+ bTable_type[table_types.Count] = null;
for (int i = 0; i < table_types.Count; i++)
{
@@ -751,25 +750,36 @@ namespace Apache.Arrow.Adbc.C
#endif
}
- using (Utf8Helper catalogHelper = new Utf8Helper(catalog))
- using (Utf8Helper schemaHelper = new Utf8Helper(db_schema))
- using (Utf8Helper tableNameHelper = new Utf8Helper(table_name))
- using (Utf8Helper columnNameHelper = new
Utf8Helper(column_name))
+ try
{
- bcatalog = (byte*)(IntPtr)(catalogHelper);
- bDb_schema = (byte*)(IntPtr)(schemaHelper);
- bTable_name = (byte*)(IntPtr)(tableNameHelper);
- bColumn_Name = (byte*)(IntPtr)(columnNameHelper);
-
- fixed (CAdbcConnection* cn = &connection)
- fixed (CAdbcError* e = &_error)
+ using (Utf8Helper catalogHelper = new Utf8Helper(catalog))
+ using (Utf8Helper schemaHelper = new Utf8Helper(db_schema))
+ using (Utf8Helper tableNameHelper = new
Utf8Helper(table_name))
+ using (Utf8Helper columnNameHelper = new
Utf8Helper(column_name))
{
+ bcatalog = (byte*)(IntPtr)(catalogHelper);
+ bDb_schema = (byte*)(IntPtr)(schemaHelper);
+ bTable_name = (byte*)(IntPtr)(tableNameHelper);
+ bColumn_Name = (byte*)(IntPtr)(columnNameHelper);
+
+ fixed (CAdbcConnection* cn = &connection)
+ fixed (CAdbcError* e = &_error)
+ {
#if NET5_0_OR_GREATER
- TranslateCode(fn(cn, depth, bcatalog, bDb_schema,
bTable_name, bTable_type, bColumn_Name, stream, e));
+ TranslateCode(fn(cn, depth, bcatalog, bDb_schema,
bTable_name, bTable_type, bColumn_Name, stream, e));
#else
-
TranslateCode(Marshal.GetDelegateForFunctionPointer<CAdbcDriverExporter.ConnectionGetObjects>(fn)(cn,
depth, bcatalog, bDb_schema, bTable_name, bTable_type, bColumn_Name, stream,
e));
+
TranslateCode(Marshal.GetDelegateForFunctionPointer<CAdbcDriverExporter.ConnectionGetObjects>(fn)(cn,
depth, bcatalog, bDb_schema, bTable_name, bTable_type, bColumn_Name, stream,
e));
#endif
+ }
+ }
+ }
+ finally
+ {
+ for (int i = 0; i < table_types.Count; i++)
+ {
+ Marshal.FreeCoTaskMem((IntPtr)bTable_type[i]);
}
+ Marshal.FreeHGlobal((IntPtr)bTable_type);
}
}