This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e1056a51 fix(csharp/src/Apache.Arrow.Adbc/C): GetObjects should 
preserve a null tableTypes parameter value (#1894)
4e1056a51 is described below

commit 4e1056a5128dae5b5fae01dc35571ddb7b62f366
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Thu May 30 09:53:56 2024 -0700

    fix(csharp/src/Apache.Arrow.Adbc/C): GetObjects should preserve a null 
tableTypes parameter value (#1894)
    
    Fix imported "GetObjects" to preserve a null value for the table types
    parameter. This enables GetObjects to work properly with the DuckDB
    driver; is also technically the correct translation.
---
 .../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 27 ++++++++------
 .../Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs | 43 ++++++++++++++++++++++
 2 files changed, 59 insertions(+), 11 deletions(-)

diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs 
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index 0083d43fc..e84250e86 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -579,18 +579,20 @@ namespace Apache.Arrow.Adbc.C
 
             public unsafe override IArrowArrayStream 
GetObjects(GetObjectsDepth depth, string? catalogPattern, string? 
dbSchemaPattern, string? tableNamePattern, IReadOnlyList<string>? tableTypes, 
string? columnNamePattern)
             {
-                tableTypes = tableTypes ?? [];
                 byte** utf8TableTypes = null;
                 try
                 {
-                    // need to terminate with a null entry per 
https://github.com/apache/arrow-adbc/blob/b97e22c4d6524b60bf261e1970155500645be510/adbc.h#L909-L911
-                    utf8TableTypes = (byte**)Marshal.AllocHGlobal(IntPtr.Size 
* (tableTypes.Count + 1));
-                    utf8TableTypes[tableTypes.Count] = null;
-
-                    for (int i = 0; i < tableTypes.Count; i++)
+                    if (tableTypes != null)
                     {
-                        string tableType = tableTypes[i];
-                        utf8TableTypes[i] = 
(byte*)MarshalExtensions.StringToCoTaskMemUTF8(tableType);
+                        // need to terminate with a null entry per 
https://github.com/apache/arrow-adbc/blob/b97e22c4d6524b60bf261e1970155500645be510/adbc.h#L909-L911
+                        utf8TableTypes = 
(byte**)Marshal.AllocHGlobal(IntPtr.Size * (tableTypes.Count + 1));
+                        utf8TableTypes[tableTypes.Count] = null;
+
+                        for (int i = 0; i < tableTypes.Count; i++)
+                        {
+                            string tableType = tableTypes[i];
+                            utf8TableTypes[i] = 
(byte*)MarshalExtensions.StringToCoTaskMemUTF8(tableType);
+                        }
                     }
 
                     using (Utf8Helper utf8Catalog = new 
Utf8Helper(catalogPattern))
@@ -614,11 +616,14 @@ namespace Apache.Arrow.Adbc.C
                 }
                 finally
                 {
-                    for (int i = 0; i < tableTypes.Count; i++)
+                    if (utf8TableTypes != null)
                     {
-                        Marshal.FreeCoTaskMem((IntPtr)utf8TableTypes[i]);
+                        for (int i = 0; i < tableTypes!.Count; i++)
+                        {
+                            Marshal.FreeCoTaskMem((IntPtr)utf8TableTypes[i]);
+                        }
+                        Marshal.FreeHGlobal((IntPtr)utf8TableTypes);
                     }
-                    Marshal.FreeHGlobal((IntPtr)utf8TableTypes);
                 }
             }
 
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs 
b/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
index 8816b828f..49d7af019 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
@@ -16,6 +16,7 @@
  */
 
 using System;
+using System.Collections.Generic;
 using System.Threading.Tasks;
 using Apache.Arrow.Types;
 using Xunit;
@@ -196,6 +197,48 @@ namespace Apache.Arrow.Adbc.Tests
             Assert.Equal(5, GetResultCount(statement2, "SELECT * from 
ingested"));
         }
 
+        [Fact]
+        public async Task GetTableTypes()
+        {
+            using var database = _duckDb.OpenDatabase("tabletypes.db");
+            using var connection = database.Connect(null);
+            using var statement = connection.CreateStatement();
+
+            var types = connection.GetTableTypes();
+            Assert.Single(types.Schema.FieldsList);
+            var data = await types.ReadNextRecordBatchAsync();
+            Assert.Null(data); // Not yet supported in DuckDB
+        }
+
+        [Fact]
+        public async Task GetCatalogs()
+        {
+            using var database = _duckDb.OpenDatabase("tablecatalogs.db");
+            using var connection = database.Connect(null);
+            using var statement = connection.CreateStatement();
+
+            statement.SqlQuery = "CREATE TABLE test(column1 INTEGER);";
+            statement.ExecuteUpdate();
+
+            var catalogs = 
connection.GetObjects(AdbcConnection.GetObjectsDepth.Catalogs, null, null, 
null, null, null);
+            Assert.NotNull(catalogs);
+            Assert.Equal(2, catalogs.Schema.FieldsList.Count);
+
+            var found = new HashSet<string>();
+            RecordBatch? group;
+            do
+            {
+                group = await catalogs.ReadNextRecordBatchAsync();
+                if (group != null && group.Column(0) is StringArray column1)
+                {
+                    found.UnionWith(column1);
+                }
+            } while (group != null);
+            Assert.Equal(3, found.Count);
+            found.ExceptWith(["system", "tablecatalogs", "temp"]);
+            Assert.Empty(found);
+        }
+
         private static long GetResultCount(AdbcStatement statement, string 
query)
         {
             statement.SqlQuery = query;

Reply via email to