This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new ce2aa022 feat(csharp/Client): Implement support for primary schema 
collections (#1317)
ce2aa022 is described below

commit ce2aa022ff520640cd8cd208b6b6079cb22794df
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Tue Nov 21 14:21:34 2023 -0800

    feat(csharp/Client): Implement support for primary schema collections 
(#1317)
    
    Adds support for MetaDataCollection, Restrictions, Catalogs, Schemas,
    TableTypes, Tables and Columns.
---
 csharp/src/Client/AdbcConnection.cs                | 399 ++++++++++++++++++++-
 csharp/src/Drivers/BigQuery/BigQueryConnection.cs  | 138 ++++++-
 .../Drivers/BigQuery/BigQueryInfoArrowStream.cs    |  14 +-
 csharp/src/arrow                                   |   2 +-
 csharp/test/Drivers/BigQuery/ClientTests.cs        |  42 +++
 csharp/test/Drivers/Snowflake/ClientTests.cs       |  52 ++-
 6 files changed, 609 insertions(+), 38 deletions(-)

diff --git a/csharp/src/Client/AdbcConnection.cs 
b/csharp/src/Client/AdbcConnection.cs
index a1ab743a..0ac735cc 100644
--- a/csharp/src/Client/AdbcConnection.cs
+++ b/csharp/src/Client/AdbcConnection.cs
@@ -19,6 +19,10 @@ using System;
 using System.Collections.Generic;
 using System.Data;
 using System.Data.Common;
+using System.Linq;
+using Apache.Arrow.Ipc;
+using Apache.Arrow.Types;
+using GetObjectsDepth = Apache.Arrow.Adbc.AdbcConnection.GetObjectsDepth;
 
 namespace Apache.Arrow.Adbc.Client
 {
@@ -40,10 +44,10 @@ namespace Apache.Arrow.Adbc.Client
         /// </summary>
         public AdbcConnection()
         {
-           this.AdbcDriver = null;
-           this.DecimalBehavior = DecimalBehavior.UseSqlDecimal;
-           this.adbcConnectionParameters = new Dictionary<string, string>();
-           this.adbcConnectionOptions = new Dictionary<string, string>();
+            this.AdbcDriver = null;
+            this.DecimalBehavior = DecimalBehavior.UseSqlDecimal;
+            this.adbcConnectionParameters = new Dictionary<string, string>();
+            this.adbcConnectionOptions = new Dictionary<string, string>();
         }
 
         /// <summary>
@@ -160,7 +164,7 @@ namespace Apache.Arrow.Adbc.Client
                     throw new InvalidOperationException("No connection values 
are present to connect with");
                 }
 
-                if(this.AdbcDriver == null)
+                if (this.AdbcDriver == null)
                 {
                     throw new InvalidOperationException("The ADBC driver is 
not specified");
                 }
@@ -172,7 +176,7 @@ namespace Apache.Arrow.Adbc.Client
 
         public override void Close()
         {
-           this.Dispose();
+            this.Dispose();
         }
 
         public override ConnectionState State
@@ -183,6 +187,8 @@ namespace Apache.Arrow.Adbc.Client
             }
         }
 
+        private Adbc.AdbcConnection Connection => this.adbcConnectionInternal 
?? throw new InvalidOperationException("Invalid operation. The connection is 
closed.");
+
         /// <summary>
         /// Builds a connection string based on the adbcConnectionParameters.
         /// </summary>
@@ -216,7 +222,7 @@ namespace Apache.Arrow.Adbc.Client
 
             this.adbcConnectionParameters.Clear();
 
-            foreach(string key in builder.Keys)
+            foreach (string key in builder.Keys)
             {
                 this.adbcConnectionParameters.Add(key, 
Convert.ToString(builder[key]));
             }
@@ -224,7 +230,7 @@ namespace Apache.Arrow.Adbc.Client
 
         public override DataTable GetSchema()
         {
-            return GetSchema(null);
+            return GetSchema("metadatacollections", null);
         }
 
         public override DataTable GetSchema(string collectionName)
@@ -234,8 +240,22 @@ namespace Apache.Arrow.Adbc.Client
 
         public override DataTable GetSchema(string collectionName, string[] 
restrictionValues)
         {
-            Schema arrowSchema = 
this.adbcConnectionInternal.GetTableSchema("", "", "");
-            return SchemaConverter.ConvertArrowSchema(arrowSchema, 
this.AdbcStatement, this.DecimalBehavior);
+            SchemaCollection collection;
+            if (!SchemaCollection.TryGetCollection(collectionName, out 
collection))
+            {
+                throw new ArgumentException(
+                    $"The requested collection ('{collectionName}') is not 
defined",
+                    nameof(collectionName));
+            }
+
+            if (restrictionValues != null && restrictionValues.Length > 
collection.Restrictions.Length)
+            {
+                throw new ArgumentException(
+                    $"More restrictions were provided than the requested 
schema ('{collectionName}') supports.",
+                    nameof(restrictionValues));
+            }
+
+            return collection.GetSchema(this.Connection, restrictionValues);
         }
 
         #region NOT_IMPLEMENTED
@@ -257,5 +277,364 @@ namespace Apache.Arrow.Adbc.Client
         }
 
         #endregion
+
+        abstract class SchemaCollection
+        {
+            protected static readonly List<SchemaCollection> collections;
+            private static readonly SortedDictionary<string, SchemaCollection> 
schemaCollections;
+
+            static SchemaCollection()
+            {
+                collections = new List<SchemaCollection>();
+                schemaCollections = new SortedDictionary<string, 
SchemaCollection>(StringComparer.OrdinalIgnoreCase);
+
+                Add(new MetadataCollection());
+                Add(new RestrictionsCollection());
+                Add(new CatalogsCollection());
+                Add(new SchemasCollection());
+                Add(new TableTypesCollection());
+                Add(new TablesCollection());
+                Add(new ColumnsCollection());
+            }
+
+            private static void Add(SchemaCollection collection)
+            {
+                collections.Add(collection);
+                schemaCollections.Add(collection.Name, collection);
+            }
+
+            public static bool TryGetCollection(string name, out 
SchemaCollection collection)
+            {
+                return schemaCollections.TryGetValue(name, out collection);
+            }
+
+            public abstract string Name { get; }
+            public abstract string[] Restrictions { get; }
+
+            public abstract DataTable GetSchema(Adbc.AdbcConnection 
adbcConnection, string[] restrictions);
+        }
+
+        private sealed class MetadataCollection : SchemaCollection
+        {
+            public override string Name => "MetaDataCollections";
+            public override string[] Restrictions => new string[0];
+
+            public override DataTable GetSchema(Adbc.AdbcConnection 
adbcConnection, string[] restrictions)
+            {
+                DataTable result = new DataTable(Name);
+                result.Columns.Add("CollectionName", typeof(string));
+                result.Columns.Add("NumberOfRestrictions", typeof(int));
+
+                foreach (SchemaCollection collection in collections)
+                {
+                    result.Rows.Add(collection.Name, 
collection.Restrictions.Length);
+                }
+
+                return result;
+            }
+        }
+
+        private sealed class RestrictionsCollection : SchemaCollection
+        {
+            public override string Name => "Restrictions";
+            public override string[] Restrictions => new string[0];
+
+            public override DataTable GetSchema(Adbc.AdbcConnection 
adbcConnection, string[] restrictions)
+            {
+                var result = new DataTable(Name);
+                result.Columns.Add("CollectionName", typeof(string));
+                result.Columns.Add("RestrictionName", typeof(string));
+                result.Columns.Add("RestrictionNumber", typeof(int));
+
+                foreach (var collection in collections)
+                {
+                    var collectionRestrictions = collection.Restrictions;
+                    for (int i = 0; i < collectionRestrictions.Length; i++)
+                    {
+                        result.Rows.Add(collection.Name, 
collectionRestrictions[i], i + 1);
+                    }
+                }
+
+                return result;
+            }
+        }
+
+        private abstract class ArrowCollection : SchemaCollection
+        {
+            protected abstract MapItem[] Map { get; }
+
+            protected abstract IArrowArrayStream Invoke(Adbc.AdbcConnection 
connection, string[] restrictions);
+
+            public override DataTable GetSchema(Adbc.AdbcConnection 
adbcConnection, string[] restrictions)
+            {
+                // Flattens the hierarchical ADBC schema into a DataTable
+
+                using (IArrowArrayStream stream = Invoke(adbcConnection, 
restrictions))
+                {
+                    MapItem[] map = this.Map;
+                    DataTable result = new DataTable(Name);
+                    List<int> indices = new List<int>();
+                    List<IRecordType> types = new List<IRecordType>();
+                    List<string> path = new List<string>();
+                    List<Action<State>> loaders = new List<Action<State>>();
+
+                    types.Add(stream.Schema);
+
+                    for (int targetIndex = 0; targetIndex < map.Length; 
targetIndex++)
+                    {
+                        MapItem item = map[targetIndex];
+                        result.Columns.Add(item.AdoName, item.Type);
+
+                        for (int i = 0; i < item.AdbcPath.Length - 1; i++)
+                        {
+                            string part = item.AdbcPath[i];
+                            if (i == path.Count)
+                            {
+                                int index = types[i].GetFieldIndex(part, null);
+                                if (index < 0)
+                                {
+                                    throw new 
InvalidOperationException($"Unable to find '{part}'");
+                                }
+                                ListType listType = 
types[i].GetFieldByIndex(index).DataType as ListType;
+                                if (listType == null || 
listType.ValueDataType.TypeId != ArrowTypeId.Struct)
+                                {
+                                    throw new 
InvalidOperationException($"Field '{part}' has unexpected type.");
+                                }
+
+                                path.Add(part);
+                                indices.Add(index);
+                                types.Add((IRecordType)listType.ValueDataType);
+                            }
+                            else if 
(!StringComparer.OrdinalIgnoreCase.Equals(path[i], part))
+                            {
+                                throw new InvalidOperationException($"expected 
'{path[i]}' found '{part}'");
+                            }
+                        }
+
+                        int srcIndex = types[types.Count - 
1].GetFieldIndex(item.AdbcPath[item.AdbcPath.Length - 1], null);
+                        if (srcIndex < 0)
+                        {
+                            throw new InvalidOperationException($"Unable to 
find '{item.AdbcPath[item.AdbcPath.Length - 1]}'");
+                        }
+                        loaders.Add(State.CreateLoader(item.Type, 
item.AdbcPath.Length - 1, srcIndex, targetIndex));
+                    }
+
+                    State state = new State(result, indices.ToArray(), 
loaders.ToArray());
+                    while (true)
+                    {
+                        using (RecordBatch batch = 
stream.ReadNextRecordBatchAsync().Result)
+                        {
+                            if (batch == null) { return result; }
+
+                            state.AddRecords(batch);
+                        }
+                    }
+                }
+            }
+
+            private class State
+            {
+                private readonly DataTable table;
+                private readonly int[] indices;
+                private readonly Action<State>[] loaders;
+                private readonly object[] buffer;
+                private readonly int[] offsets;
+                private readonly IArrowRecord[] records;
+
+                public State(DataTable table, int[] indices, Action<State>[] 
loaders)
+                {
+                    this.table = table;
+                    this.indices = indices;
+                    this.loaders = loaders;
+                    this.buffer = new object[loaders.Length];
+
+                    this.offsets = new int[indices.Length + 1];
+                    this.records = new IArrowRecord[indices.Length + 1];
+                }
+
+                public void AddRecords(RecordBatch batch)
+                {
+                    ListArray[] lists = new ListArray[this.indices.Length];
+
+                    this.records[0] = batch;
+                    this.offsets[0] = 0;
+                    for (int i = 0; i < indices.Length; i++)
+                    {
+                        lists[i] = 
(ListArray)this.records[i].Column(indices[i]);
+                        this.records[i + 1] = (StructArray)lists[i].Values;
+                        this.offsets[i + 1] = 0;
+                    }
+
+                    Loop(lists, 0, batch.Length);
+                }
+
+                private void Loop(ListArray[] lists, int ptr, int count)
+                {
+                    for (int i = 0; i < count; i++)
+                    {
+                        if (ptr == lists.Length)
+                        {
+                            AddRow();
+                        }
+                        else
+                        {
+                            Loop(lists, ptr + 1, 
lists[ptr].GetValueLength(this.offsets[ptr]));
+                        }
+                        this.offsets[ptr]++;
+                    }
+                }
+
+                private void AddRow()
+                {
+                    foreach (Action<State> loader in this.loaders)
+                    {
+                        loader(this);
+                    }
+                    this.table.Rows.Add(this.buffer);
+                }
+
+                public static Action<State> CreateLoader(Type type, int 
srcLevel, int srcIndex, int targetIndex)
+                {
+                    return Type.GetTypeCode(type) switch
+                    {
+                        TypeCode.Boolean => state =>
+                            state.buffer[targetIndex] = 
((BooleanArray)state.records[srcLevel].Column(srcIndex)).GetValue(state.offsets[srcLevel]),
+                        TypeCode.Int16 => state =>
+                            state.buffer[targetIndex] = 
((Int16Array)state.records[srcLevel].Column(srcIndex)).GetValue(state.offsets[srcLevel]),
+                        TypeCode.Int32 => state =>
+                            state.buffer[targetIndex] = 
((Int32Array)state.records[srcLevel].Column(srcIndex)).GetValue(state.offsets[srcLevel]),
+                        TypeCode.Int64 => state =>
+                            state.buffer[targetIndex] = 
((Int64Array)state.records[srcLevel].Column(srcIndex)).GetValue(state.offsets[srcLevel]),
+                        TypeCode.String => state =>
+                            state.buffer[targetIndex] = 
((StringArray)state.records[srcLevel].Column(srcIndex)).GetString(state.offsets[srcLevel]),
+                        _ => throw new NotSupportedException($"Type 
{type.FullName} is not supported."),
+                    };
+                }
+            }
+
+            protected struct MapItem
+            {
+                public readonly string AdoName;
+                public readonly string[] AdbcPath;
+                public readonly Type Type;
+
+                public MapItem(string adoName, string[] adbcPath, Type type)
+                {
+                    this.AdoName = adoName;
+                    this.AdbcPath = adbcPath;
+                    this.Type = type;
+                }
+            }
+        }
+
+        private sealed class CatalogsCollection : ArrowCollection
+        {
+            public override string Name => "Catalogs";
+            public override string[] Restrictions => new[] { "Catalog" };
+
+            protected override MapItem[] Map => new[]
+            {
+                new MapItem("TABLE_CATALOG", new[] { "catalog_name" }, 
typeof(string)),
+            };
+
+            protected override IArrowArrayStream Invoke(Adbc.AdbcConnection 
connection, string[] restrictions)
+            {
+                string catalog = restrictions?.Length > 0 ? restrictions[0] : 
null;
+                return connection.GetObjects(GetObjectsDepth.Catalogs, 
catalog, null, null, null, null);
+            }
+        }
+
+        private class SchemasCollection : ArrowCollection
+        {
+            public override string Name => "Schemas";
+            public override string[] Restrictions => new[] { "Catalog", 
"Schema" };
+
+            protected override MapItem[] Map => new[]
+            {
+                new MapItem("TABLE_CATALOG", new [] { "catalog_name" }, 
typeof(string)),
+                new MapItem("TABLE_SCHEMA", new [] { "catalog_db_schemas", 
"db_schema_name" }, typeof(string)),
+            };
+
+            protected override IArrowArrayStream Invoke(Adbc.AdbcConnection 
connection, string[] restrictions)
+            {
+                string catalog = restrictions?.Length > 0 ? restrictions[0] : 
null;
+                string schema = restrictions?.Length > 1 ? restrictions[1] : 
null;
+                return connection.GetObjects(GetObjectsDepth.DbSchemas, 
catalog, schema, null, null, null);
+            }
+        }
+
+        private class TableTypesCollection : ArrowCollection
+        {
+            public override string Name => "TableTypes";
+            public override string[] Restrictions => new string[0];
+
+            protected override MapItem[] Map => new[]
+            {
+                new MapItem("TABLE_TYPE", new [] { "table_type" }, 
typeof(string)),
+            };
+
+            protected override IArrowArrayStream Invoke(Adbc.AdbcConnection 
connection, string[] restrictions)
+            {
+                return connection.GetTableTypes();
+            }
+        }
+
+        private class TablesCollection : ArrowCollection
+        {
+            public override string Name => "Tables";
+            public override string[] Restrictions => new[] { "Catalog", 
"Schema", "Table", "TableType" };
+
+            protected override MapItem[] Map => new[]
+            {
+                new MapItem("TABLE_CATALOG", new [] { "catalog_name" }, 
typeof(string)),
+                new MapItem("TABLE_SCHEMA", new [] { "catalog_db_schemas", 
"db_schema_name" }, typeof(string)),
+                new MapItem("TABLE_NAME", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_name" }, typeof(string)),
+                new MapItem("TABLE_TYPE", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_type" }, typeof(string)),
+            };
+
+            protected override IArrowArrayStream Invoke(Adbc.AdbcConnection 
connection, string[] restrictions)
+            {
+                string catalog = restrictions?.Length > 0 ? restrictions[0] : 
null;
+                string schema = restrictions?.Length > 1 ? restrictions[1] : 
null;
+                string table = restrictions?.Length > 2 ? restrictions[2] : 
null;
+                List<string> tableTypes = restrictions?.Length > 3 ? 
restrictions[3].Split(',').ToList() : null;
+                return connection.GetObjects(GetObjectsDepth.Tables, catalog, 
schema, table, tableTypes, null);
+            }
+        }
+
+        private class ColumnsCollection : ArrowCollection
+        {
+            public override string Name => "Columns";
+            public override string[] Restrictions => new[] { "Catalog", 
"Schema", "Table", "Column" };
+
+            protected override MapItem[] Map => new[]
+            {
+                new MapItem("TABLE_CATALOG", new [] { "catalog_name" }, 
typeof(string)),
+                new MapItem("TABLE_SCHEMA", new [] { "catalog_db_schemas", 
"db_schema_name" }, typeof(string)),
+                new MapItem("TABLE_NAME", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_name" }, typeof(string)),
+                new MapItem("COLUMN_NAME", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_columns", "column_name" }, typeof(string)),
+                new MapItem("ORDINAL_POSITION", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_columns", "ordinal_position" }, typeof(int)),
+                new MapItem("REMARKS", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_columns", "remarks" }, typeof(string)),
+                new MapItem("DATA_TYPE", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_columns", "xdbc_type_name" }, typeof(string)),
+                new MapItem("IS_NULLABLE", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_columns", "xdbc_is_nullable" }, typeof(string)),
+                new MapItem("COLUMN_DEFAULT", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_columns", "xdbc_column_def" }, typeof(string)),
+                new MapItem("IS_AUTOINCREMENT", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_columns", "xdbc_is_autoincrement" }, typeof(bool)),
+                new MapItem("IS_GENERATED", new [] { "catalog_db_schemas", 
"db_schema_tables", "table_columns", "xdbc_is_generatedcolumn" }, typeof(bool)),
+                new MapItem("CHARACTER_OCTET_LENGTH", new [] { 
"catalog_db_schemas", "db_schema_tables", "table_columns", 
"xdbc_char_octet_length" }, typeof(int)),
+                new MapItem("CHARACTER_MAXIMUM_LENGTH", new [] { 
"catalog_db_schemas", "db_schema_tables", "table_columns", "xdbc_column_size" 
}, typeof(int)),
+                new MapItem("NUMERIC_PRECISION", new [] { 
"catalog_db_schemas", "db_schema_tables", "table_columns", 
"xdbc_decimal_digits" }, typeof(short)),
+                new MapItem("NUMERIC_PRECISION_RADIX", new [] { 
"catalog_db_schemas", "db_schema_tables", "table_columns", 
"xdbc_num_prec_radix" }, typeof(short)),
+                new MapItem("DATETIME_PRECISION", new [] { 
"catalog_db_schemas", "db_schema_tables", "table_columns", "xdbc_datetime_sub" 
}, typeof(short)),
+            };
+
+            protected override IArrowArrayStream Invoke(Adbc.AdbcConnection 
connection, string[] restrictions)
+            {
+                string catalog = restrictions?.Length > 0 ? restrictions[0] : 
null;
+                string schema = restrictions?.Length > 1 ? restrictions[1] : 
null;
+                string table = restrictions?.Length > 2 ? restrictions[2] : 
null;
+                string column = restrictions?.Length > 3 ? restrictions[3] : 
null;
+                return connection.GetObjects(GetObjectsDepth.All, catalog, 
schema, table, null, column);
+            }
+        }
     }
 }
diff --git a/csharp/src/Drivers/BigQuery/BigQueryConnection.cs 
b/csharp/src/Drivers/BigQuery/BigQueryConnection.cs
index d1a91e70..48a9c953 100644
--- a/csharp/src/Drivers/BigQuery/BigQueryConnection.cs
+++ b/csharp/src/Drivers/BigQuery/BigQueryConnection.cs
@@ -87,7 +87,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
             {
                 
this.properties.TryGetValue(BigQueryParameters.AuthenticationType, out 
authenticationType);
 
-                
if(!authenticationType.Equals(BigQueryConstants.UserAuthenticationType, 
StringComparison.OrdinalIgnoreCase) &&
+                if 
(!authenticationType.Equals(BigQueryConstants.UserAuthenticationType, 
StringComparison.OrdinalIgnoreCase) &&
                     
!authenticationType.Equals(BigQueryConstants.ServiceAccountAuthenticationType, 
StringComparison.OrdinalIgnoreCase))
                 {
                     throw new ArgumentException($"The 
{BigQueryParameters.AuthenticationType} parameter can only be 
`{BigQueryConstants.UserAuthenticationType}` or 
`{BigQueryConstants.ServiceAccountAuthenticationType}`");
@@ -232,7 +232,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
                 infoValue
             };
 
-            return new BigQueryInfoArrowStream(StandardSchemas.GetInfoSchema, 
dataArrays, 4);
+            return new BigQueryInfoArrowStream(StandardSchemas.GetInfoSchema, 
dataArrays);
         }
 
         public override IArrowArrayStream GetObjects(
@@ -246,7 +246,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
             List<IArrowArray> dataArrays = GetCatalogs(depth, catalogPattern, 
dbSchemaPattern,
                 tableNamePattern, tableTypes, columnNamePattern);
 
-            return new 
BigQueryInfoArrowStream(StandardSchemas.GetObjectsSchema, dataArrays, 1);
+            return new 
BigQueryInfoArrowStream(StandardSchemas.GetObjectsSchema, dataArrays);
         }
 
         private List<IArrowArray> GetCatalogs(
@@ -362,7 +362,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
             if (tableNamePattern != null)
             {
                 query = string.Concat(query, string.Format(" WHERE table_name 
LIKE '{0}'", Sanitize(tableNamePattern)));
-                if (tableTypes.Count > 0)
+                if (tableTypes?.Count > 0)
                 {
                     List<string> sanitizedTypes = tableTypes.Select(x => 
Sanitize(x)).ToList();
 
@@ -371,7 +371,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
             }
             else
             {
-                if (tableTypes.Count > 0)
+                if (tableTypes?.Count > 0)
                 {
                     List<string> sanitizedTypes = tableTypes.Select(x => 
Sanitize(x)).ToList();
                     query = string.Concat(query, string.Format(" WHERE 
table_type IN ('{0}')", string.Join("', '", sanitizedTypes).ToUpper()));
@@ -796,7 +796,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
 
         private ParsedDecimalValues ParsePrecisionAndScale(string type)
         {
-            if(string.IsNullOrWhiteSpace(type)) throw new 
ArgumentNullException(nameof(type));
+            if (string.IsNullOrWhiteSpace(type)) throw new 
ArgumentNullException(nameof(type));
 
             string[] values = type.Substring(type.IndexOf("(") + 
1).TrimEnd(')').Split(",".ToCharArray());
 
@@ -817,7 +817,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
                 tableTypesBuilder.Build()
             };
 
-            return new 
BigQueryInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays, 1);
+            return new 
BigQueryInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays);
         }
 
         private ListArray CreateNestedListArray(List<IArrowArray> arrayList, 
IArrowType dataType)
@@ -850,12 +850,14 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
 
             ArrayData data = ArrayDataConcatenator.Concatenate(arrayDataList);
 
-            IArrowArray value = null;
-
             if (data == null)
-                value = new NullArray(0);
-            else
-                value = ArrowArrayFactory.BuildArray(data);
+            {
+                EmptyArrayCreationVisitor visitor = new 
EmptyArrayCreationVisitor(0);
+                dataType.Accept(visitor);
+                data = visitor.Result;
+            }
+
+            IArrowArray value = ArrowArrayFactory.BuildArray(data);
 
             valueOffsetsBufferBuilder.Append(length);
 
@@ -880,7 +882,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
         {
             Dictionary<string, string> options = new Dictionary<string, 
string>();
 
-            foreach (KeyValuePair<string,string> keyValuePair in 
this.properties)
+            foreach (KeyValuePair<string, string> keyValuePair in 
this.properties)
             {
                 if (keyValuePair.Key == BigQueryParameters.AllowLargeResults)
                 {
@@ -977,5 +979,115 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
             XdbcDataType_XDBC_WCHAR = -8,
             XdbcDataType_XDBC_WVARCHAR = -9,
         }
+
+        private class EmptyArrayCreationVisitor :
+            IArrowTypeVisitor<BooleanType>,
+            IArrowTypeVisitor<FixedWidthType>,
+            IArrowTypeVisitor<BinaryType>,
+            IArrowTypeVisitor<StringType>,
+            IArrowTypeVisitor<ListType>,
+            IArrowTypeVisitor<FixedSizeListType>,
+            IArrowTypeVisitor<StructType>,
+            IArrowTypeVisitor<UnionType>,
+            IArrowTypeVisitor<MapType>
+        {
+            public ArrayData Result { get; private set; }
+            private readonly int _length;
+
+            public EmptyArrayCreationVisitor(int length)
+            {
+                _length = length;
+            }
+
+            public void Visit(BooleanType type)
+            {
+                Result = new ArrayData(type, _length, _length, 0, new[] { 
ArrowBuffer.Empty, ArrowBuffer.Empty });
+            }
+
+            public void Visit(FixedWidthType type)
+            {
+                Result = new ArrayData(type, _length, _length, 0, new[] { 
ArrowBuffer.Empty, ArrowBuffer.Empty });
+            }
+
+            public void Visit(BinaryType type)
+            {
+                Result = new ArrayData(type, _length, _length, 0, new[] { 
ArrowBuffer.Empty, ArrowBuffer.Empty, ArrowBuffer.Empty });
+            }
+
+            public void Visit(StringType type)
+            {
+                Result = new ArrayData(type, _length, _length, 0, new[] { 
ArrowBuffer.Empty, ArrowBuffer.Empty, ArrowBuffer.Empty });
+            }
+
+            public void Visit(ListType type)
+            {
+                type.ValueDataType.Accept(this);
+                ArrayData child = Result;
+
+                Result = new ArrayData(type, _length, _length, 0, new[] { 
ArrowBuffer.Empty }, new[] { child });
+            }
+
+            public void Visit(FixedSizeListType type)
+            {
+                type.ValueDataType.Accept(this);
+                ArrayData child = Result;
+
+                Result = new ArrayData(type, _length, _length, 0, new[] { 
ArrowBuffer.Empty }, new[] { child });
+            }
+
+            public void Visit(StructType type)
+            {
+                ArrayData[] children = new ArrayData[type.Fields.Count];
+                for (int i = 0; i < type.Fields.Count; i++)
+                {
+                    type.Fields[i].DataType.Accept(this);
+                    children[i] = Result;
+                }
+
+                Result = new ArrayData(type, _length, _length, 0, new[] { 
ArrowBuffer.Empty }, children);
+            }
+
+            public void Visit(UnionType type)
+            {
+                int bufferCount = type.Mode switch
+                {
+                    UnionMode.Sparse => 1,
+                    UnionMode.Dense => 2,
+                    _ => throw new InvalidOperationException($"Unknown 
UnionMode {type.Mode}"),
+                };
+
+                ArrayData[] children = new ArrayData[type.Fields.Count];
+                for (int i = 0; i < type.Fields.Count; i++)
+                {
+                    type.Fields[i].DataType.Accept(this);
+                    children[i] = Result;
+                }
+
+                ArrowBuffer[] buffers = new ArrowBuffer[bufferCount];
+                buffers[0] = ArrowBuffer.Empty;
+                if (bufferCount > 1)
+                {
+                    buffers[1] = ArrowBuffer.Empty;
+                }
+
+                Result = new ArrayData(type, _length, _length, 0, buffers, 
children);
+            }
+
+            public void Visit(MapType type)
+            {
+                ArrayData[] children = new ArrayData[2];
+                type.KeyField.DataType.Accept(this);
+                children[0] = Result;
+                type.ValueField.DataType.Accept(this);
+                children[1] = Result;
+
+                Result = new ArrayData(type, _length, _length, 0, new[] { 
ArrowBuffer.Empty }, children);
+            }
+
+            public void Visit(IArrowType type)
+            {
+                throw new NotImplementedException($"EmptyArrayCreationVisitor 
for {type.Name} is not supported yet.");
+            }
+        }
     }
 }
diff --git a/csharp/src/Drivers/BigQuery/BigQueryInfoArrowStream.cs 
b/csharp/src/Drivers/BigQuery/BigQueryInfoArrowStream.cs
index 10abdbab..75a67de8 100644
--- a/csharp/src/Drivers/BigQuery/BigQueryInfoArrowStream.cs
+++ b/csharp/src/Drivers/BigQuery/BigQueryInfoArrowStream.cs
@@ -28,25 +28,27 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
     internal class BigQueryInfoArrowStream : IArrowArrayStream
     {
         private Schema schema;
-        private IEnumerable<IArrowArray> data;
-        private int length;
+        private RecordBatch batch;
 
-        public BigQueryInfoArrowStream(Schema schema, IEnumerable<IArrowArray> 
data, int length)
+        public BigQueryInfoArrowStream(Schema schema, List<IArrowArray> data)
         {
             this.schema = schema;
-            this.data = data;
-            this.length = length;
+            this.batch = new RecordBatch(schema, data, data[0].Length);
         }
 
         public Schema Schema { get { return this.schema; } }
 
         public ValueTask<RecordBatch> 
ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
         {
-            return new ValueTask<RecordBatch>(new RecordBatch(schema, data, 
length));
+            RecordBatch batch = this.batch;
+            this.batch = null;
+            return new ValueTask<RecordBatch>(batch);
         }
 
         public void Dispose()
         {
+            this.batch?.Dispose();
+            this.batch = null;
         }
     }
 }
diff --git a/csharp/src/arrow b/csharp/src/arrow
index 79e328b3..c5a1eb01 160000
--- a/csharp/src/arrow
+++ b/csharp/src/arrow
@@ -1 +1 @@
-Subproject commit 79e328b3b7ce23002bc46904c1944654ce4cd0a3
+Subproject commit c5a1eb01631f770dd7b3a6ecf309022d34af4c1c
diff --git a/csharp/test/Drivers/BigQuery/ClientTests.cs 
b/csharp/test/Drivers/BigQuery/ClientTests.cs
index fd234068..d9672d22 100644
--- a/csharp/test/Drivers/BigQuery/ClientTests.cs
+++ b/csharp/test/Drivers/BigQuery/ClientTests.cs
@@ -105,6 +105,48 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.BigQuery
             }
         }
 
+        [SkippableFact]
+        public void VerifySchemaTables()
+        {
+            BigQueryTestConfiguration testConfiguration = 
Utils.LoadTestConfiguration<BigQueryTestConfiguration>(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE);
+
+            using (Adbc.Client.AdbcConnection adbcConnection = 
GetAdbcConnection(testConfiguration))
+            {
+                adbcConnection.Open();
+
+                var collections = 
adbcConnection.GetSchema("MetaDataCollections");
+                Assert.Equal(7, collections.Rows.Count);
+                Assert.Equal(2, collections.Columns.Count);
+
+                var restrictions = adbcConnection.GetSchema("Restrictions");
+                Assert.Equal(11, restrictions.Rows.Count);
+                Assert.Equal(3, restrictions.Columns.Count);
+
+                var catalogs = adbcConnection.GetSchema("Catalogs");
+                Assert.Equal(1, catalogs.Columns.Count);
+                var catalog = (string)catalogs.Rows[0].ItemArray[0];
+
+                catalogs = adbcConnection.GetSchema("Catalogs", new[] { 
catalog });
+                Assert.Equal(1, catalogs.Rows.Count);
+
+                var schemas = adbcConnection.GetSchema("Schemas", new[] { 
catalog });
+                Assert.Equal(2, schemas.Columns.Count);
+                var schema = (string)schemas.Rows[0].ItemArray[1];
+
+                schemas = adbcConnection.GetSchema("Schemas", new[] { catalog, 
schema });
+                Assert.Equal(1, schemas.Rows.Count);
+
+                var tableTypes = adbcConnection.GetSchema("TableTypes");
+                Assert.Equal(1, tableTypes.Columns.Count);
+
+                var tables = adbcConnection.GetSchema("Tables", new[] { 
catalog, schema });
+                Assert.Equal(4, tables.Columns.Count);
+
+                var columns = adbcConnection.GetSchema("Columns", new[] { 
catalog, schema });
+                Assert.Equal(16, columns.Columns.Count);
+            }
+        }
+
         private Adbc.Client.AdbcConnection 
GetAdbcConnection(BigQueryTestConfiguration testConfiguration)
         {
             return new Adbc.Client.AdbcConnection(
diff --git a/csharp/test/Drivers/Snowflake/ClientTests.cs 
b/csharp/test/Drivers/Snowflake/ClientTests.cs
index 187ef7b2..77d12fb4 100644
--- a/csharp/test/Drivers/Snowflake/ClientTests.cs
+++ b/csharp/test/Drivers/Snowflake/ClientTests.cs
@@ -174,28 +174,67 @@ namespace 
Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
             }
         }
 
+        [SkippableFact]
+        public void VerifySchemaTables()
+        {
+            SnowflakeTestConfiguration testConfiguration = 
Utils.LoadTestConfiguration<SnowflakeTestConfiguration>(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE);
+
+            using (Adbc.Client.AdbcConnection adbcConnection = 
GetSnowflakeAdbcConnection(testConfiguration))
+            {
+                adbcConnection.Open();
+
+                var collections = 
adbcConnection.GetSchema("MetaDataCollections");
+                Assert.Equal(7, collections.Rows.Count);
+                Assert.Equal(2, collections.Columns.Count);
+
+                var restrictions = adbcConnection.GetSchema("Restrictions");
+                Assert.Equal(11, restrictions.Rows.Count);
+                Assert.Equal(3, restrictions.Columns.Count);
+
+                var catalogs = adbcConnection.GetSchema("Catalogs");
+                Assert.Equal(1, catalogs.Columns.Count);
+                var catalog = (string)catalogs.Rows[0].ItemArray[0];
+
+                catalogs = adbcConnection.GetSchema("Catalogs", new[] { 
catalog });
+                Assert.Equal(1, catalogs.Rows.Count);
+
+                var schemas = adbcConnection.GetSchema("Schemas", new[] { 
catalog });
+                Assert.Equal(2, schemas.Columns.Count);
+
+                var schema = "INFORMATION_SCHEMA";
+                schemas = adbcConnection.GetSchema("Schemas", new[] { catalog, 
schema });
+                Assert.Equal(1, schemas.Rows.Count);
+
+                var tableTypes = adbcConnection.GetSchema("TableTypes");
+                Assert.Equal(1, tableTypes.Columns.Count);
+
+                var tables = adbcConnection.GetSchema("Tables", new[] { 
catalog, schema });
+                Assert.Equal(4, tables.Columns.Count);
+                Assert.Equal(32, tables.Rows.Count);
+
+                var columns = adbcConnection.GetSchema("Columns", new[] { 
catalog, schema });
+                Assert.Equal(16, columns.Columns.Count);
+                Assert.Equal(441, columns.Rows.Count);
+            }
+        }
+
         private Adbc.Client.AdbcConnection 
GetSnowflakeAdbcConnectionUsingConnectionString(SnowflakeTestConfiguration 
testConfiguration)
         {
             // see https://arrow.apache.org/adbc/0.5.1/driver/snowflake.html
 
             DbConnectionStringBuilder builder = new 
DbConnectionStringBuilder(true);
-
             builder[SnowflakeParameters.ACCOUNT] = testConfiguration.Account;
             builder[SnowflakeParameters.WAREHOUSE] = 
testConfiguration.Warehouse;
             builder[SnowflakeParameters.HOST] = testConfiguration.Host;
             builder[SnowflakeParameters.DATABASE] = testConfiguration.Database;
             builder[SnowflakeParameters.USERNAME] = testConfiguration.User;
-
             if 
(!string.IsNullOrEmpty(testConfiguration.AuthenticationTokenPath))
             {
                 builder[SnowflakeParameters.AUTH_TYPE] = 
testConfiguration.AuthenticationType;
-
                 string privateKey = 
File.ReadAllText(testConfiguration.AuthenticationTokenPath);
-
                 if (testConfiguration.AuthenticationType.Equals("auth_jwt", 
StringComparison.OrdinalIgnoreCase))
                 {
                     builder[SnowflakeParameters.PKCS8_VALUE] = privateKey;
-
                     if(!string.IsNullOrEmpty(testConfiguration.Pkcs8Passcode))
                     {
                         builder[SnowflakeParameters.PKCS8_PASS] = 
testConfiguration.Pkcs8Passcode;
@@ -206,15 +245,12 @@ namespace 
Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
             {
                 builder[SnowflakeParameters.PASSWORD] = 
testConfiguration.Password;
             }
-
             AdbcDriver snowflakeDriver = 
SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration);
-
             return new Adbc.Client.AdbcConnection(builder.ConnectionString)
             {
                 AdbcDriver = snowflakeDriver
             };
         }
-
         private Adbc.Client.AdbcConnection 
GetSnowflakeAdbcConnection(SnowflakeTestConfiguration testConfiguration)
         {
             Dictionary<string, string> parameters = new Dictionary<string, 
string>();


Reply via email to