This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 9287e9a1b fix(csharp/src/Apache.Arrow.Adbc): Fix marshaling in three 
functions where it was broken (#1758)
9287e9a1b is described below

commit 9287e9a1be8dcdaec95fbedee9e6294f278261c5
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Wed Apr 24 15:26:14 2024 -0700

    fix(csharp/src/Apache.Arrow.Adbc): Fix marshaling in three functions where 
it was broken (#1758)
    
    Closes #1694.
---
 .../src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs | 53 +++++++++++++++++-----
 1 file changed, 41 insertions(+), 12 deletions(-)

diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs 
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
index 4455a25bc..a8b3089b3 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
@@ -17,10 +17,12 @@
 
 using System;
 using System.Collections.Generic;
+using System.Linq;
 using System.Runtime.InteropServices;
 using Apache.Arrow.C;
 using Apache.Arrow.Ipc;
 
+
 #if NETSTANDARD
 using Apache.Arrow.Adbc.Extensions;
 #endif
@@ -663,7 +665,7 @@ namespace Apache.Arrow.Adbc.C
                     nativeConnection->private_data = 
(void*)GCHandle.ToIntPtr(handle);
                     return AdbcStatusCode.Success;
                 }
-                catch(Exception e)
+                catch (Exception e)
                 {
                     return SetError(error, e);
                 }
@@ -773,9 +775,31 @@ namespace Apache.Arrow.Adbc.C
                 columnNamePattern = 
Marshal.PtrToStringUTF8((IntPtr)column_name);
 #endif
 
-                // TODO (GH-1694): Marshaling is incorrect
-                GCHandle gch = GCHandle.FromIntPtr((IntPtr)table_type);
-                List<string> tableTypes = (List<string>)gch.Target;
+                List<string> tableTypes = null;
+                const int maxTableTypeCount = 100;
+                if (table_type != null)
+                {
+                    int count = 0;
+                    while (table_type[count] != null && count <= 
maxTableTypeCount)
+                    {
+                        count++;
+                    }
+
+                    if (count > maxTableTypeCount)
+                    {
+                        throw new InvalidOperationException($"We do not expect 
to get more than {maxTableTypeCount} table types");
+                    }
+
+                    tableTypes = new List<string>(count);
+                    for (int i = 0; i < count; i++)
+                    {
+#if NETSTANDARD
+                        
tableTypes.Add(MarshalExtensions.PtrToStringUTF8((IntPtr)table_type[i]));
+#else
+                        
tableTypes.Add(Marshal.PtrToStringUTF8((IntPtr)table_type[i]));
+#endif
+                    }
+                }
 
                 AdbcConnection.GetObjectsDepth goDepth = 
(AdbcConnection.GetObjectsDepth)depth;
 
@@ -812,20 +836,25 @@ namespace Apache.Arrow.Adbc.C
 
             public unsafe void ReadPartition(byte* serializedPartition, int 
serialized_length, CArrowArrayStream* stream)
             {
-                // TODO (GH-1694): Marshaling is incorrect
-                GCHandle gch = 
GCHandle.FromIntPtr((IntPtr)serializedPartition);
-                PartitionDescriptor descriptor = 
(PartitionDescriptor)gch.Target;
+                byte[] partition = new byte[serialized_length];
+                fixed (byte* partitionPtr = partition)
+                {
+                    Buffer.MemoryCopy(serializedPartition, partitionPtr, 
serialized_length, serialized_length);
+                }
 
-                
CArrowArrayStreamExporter.ExportArrayStream(connection.ReadPartition(descriptor),
 stream);
+                
CArrowArrayStreamExporter.ExportArrayStream(connection.ReadPartition(new 
PartitionDescriptor(partition)), stream);
             }
 
             public unsafe void GetInfo(int* info_codes, int info_codes_length, 
CArrowArrayStream* stream)
             {
-                // TODO (GH-1694): Marshaling is incorrect
-                GCHandle gch = GCHandle.FromIntPtr((IntPtr)info_codes);
-                List<int> codes = (List<int>)gch.Target;
+                int[] infoCodes = new int[info_codes_length];
+                fixed (int* infoCodesPtr = infoCodes)
+                {
+                    long length = (long)info_codes_length * sizeof(int);
+                    Buffer.MemoryCopy(info_codes, infoCodesPtr, length, 
length);
+                }
 
-                
CArrowArrayStreamExporter.ExportArrayStream(connection.GetInfo(codes), stream);
+                
CArrowArrayStreamExporter.ExportArrayStream(connection.GetInfo(infoCodes.ToList()),
 stream);
             }
 
             public unsafe void InitConnection(ref CAdbcDatabase nativeDatabase)

Reply via email to