This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new e39d71549 feat(csharp): Add support for Prepare to ImportedStatement 
and to ADO.NET wrapper (#2628)
e39d71549 is described below

commit e39d71549bde79ebd2662480ae82b94140c62e03
Author: Curt Hagenlocher <c...@hagenlocher.org>
AuthorDate: Tue Mar 18 07:31:08 2025 -0700

    feat(csharp): Add support for Prepare to ImportedStatement and to ADO.NET 
wrapper (#2628)
    
    Closes #2616.
---
 .../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 70 ++++++++++++++--------
 csharp/src/Client/AdbcCommand.cs                   | 59 ++++++++++++++----
 .../Client/DuckDbClientTests.cs                    | 32 +++++++++-
 .../Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs | 32 ++++++++++
 4 files changed, 157 insertions(+), 36 deletions(-)

diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs 
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index 73ae28b7d..133fff57a 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -797,6 +797,15 @@ namespace Apache.Arrow.Adbc.C
                 Dispose(false);
             }
 
+            public override string? SqlQuery
+            {
+                set
+                {
+                    SetSqlQuery(value);
+                    base.SqlQuery = value;
+                }
+            }
+
             private unsafe ref CAdbcDriver Driver
             {
                 get
@@ -885,12 +894,6 @@ namespace Apache.Arrow.Adbc.C
 
             public unsafe override QueryResult ExecuteQuery()
             {
-                if (SqlQuery != null)
-                {
-                    // TODO: Consider moving this to the setter
-                    SetSqlQuery(SqlQuery);
-                }
-
                 using (CallHelper caller = new CallHelper())
                 {
                     fixed (CAdbcStatement* statement = &_nativeStatement)
@@ -911,12 +914,6 @@ namespace Apache.Arrow.Adbc.C
 
             public override unsafe Schema ExecuteSchema()
             {
-                if (SqlQuery != null)
-                {
-                    // TODO: Consider moving this to the setter
-                    SetSqlQuery(SqlQuery);
-                }
-
                 using (CallHelper caller = new CallHelper())
                 {
                     fixed (CAdbcStatement* statement = &_nativeStatement)
@@ -936,12 +933,6 @@ namespace Apache.Arrow.Adbc.C
 
             public unsafe override UpdateResult ExecuteUpdate()
             {
-                if (SqlQuery != null)
-                {
-                    // TODO: Consider moving this to the setter
-                    SetSqlQuery(SqlQuery);
-                }
-
                 using (CallHelper caller = new CallHelper())
                 {
                     fixed (CAdbcStatement* statement = &_nativeStatement)
@@ -962,12 +953,6 @@ namespace Apache.Arrow.Adbc.C
 
             public unsafe override PartitionedResult ExecutePartitioned()
             {
-                if (SqlQuery != null)
-                {
-                    // TODO: Consider moving this to the setter
-                    SetSqlQuery(SqlQuery);
-                }
-
                 using (CallHelper caller = new CallHelper())
                 {
                     fixed (CAdbcStatement* statement = &_nativeStatement)
@@ -1013,6 +998,41 @@ namespace Apache.Arrow.Adbc.C
                 }
             }
 
+            public unsafe override Schema GetParameterSchema()
+            {
+                using (CallHelper caller = new CallHelper())
+                {
+                    fixed (CAdbcStatement* statement = &_nativeStatement)
+                    {
+                        caller.TranslateCode(
+#if NET5_0_OR_GREATER
+                            Driver.StatementGetParameterSchema
+#else
+                            
Marshal.GetDelegateForFunctionPointer<StatementGetParameterSchema>(Driver.StatementGetParameterSchema)
+#endif
+                            (statement, caller.CreateSchema(), & 
caller._error));
+                    }
+                    return caller.ImportSchema();
+                }
+            }
+
+            public unsafe override void Prepare()
+            {
+                using (CallHelper caller = new CallHelper())
+                {
+                    fixed (CAdbcStatement* statement = &_nativeStatement)
+                    {
+                        caller.TranslateCode(
+#if NET5_0_OR_GREATER
+                            Driver.StatementPrepare
+#else
+                            
Marshal.GetDelegateForFunctionPointer<StatementPrepare>(Driver.StatementPrepare)
+#endif
+                            (statement, &caller._error));
+                    }
+                }
+            }
+
             public unsafe override void SetOption(string key, string value)
             {
                 using (CallHelper caller = new CallHelper())
@@ -1055,7 +1075,7 @@ namespace Apache.Arrow.Adbc.C
                 }
             }
 
-            private unsafe void SetSqlQuery(string sqlQuery)
+            private unsafe void SetSqlQuery(string? sqlQuery)
             {
                 fixed (CAdbcStatement* statement = &_nativeStatement)
                 {
diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs
index a317ca19c..7ad0678af 100644
--- a/csharp/src/Client/AdbcCommand.cs
+++ b/csharp/src/Client/AdbcCommand.cs
@@ -483,12 +483,61 @@ namespace Apache.Arrow.Adbc.Client
             }
         }
 
+        public override void Prepare()
+        {
+            _adbcStatement.Prepare();
+            var schema = _adbcStatement.GetParameterSchema();
+
+            DbParameterCollection.Clear();
+
+            foreach (Field field in schema.FieldsList)
+            {
+                AdbcParameter parameter = new AdbcParameter
+                {
+                    ParameterName = field.Name,
+                    IsNullable = field.IsNullable,
+                    DbType = field.DataType.TypeId switch
+                    {
+                        ArrowTypeId.UInt8 => DbType.Byte,
+                        ArrowTypeId.UInt16 => DbType.UInt16,
+                        ArrowTypeId.UInt32 => DbType.UInt32,
+                        ArrowTypeId.UInt64 => DbType.UInt64,
+                        ArrowTypeId.Int8 => DbType.SByte,
+                        ArrowTypeId.Int16 => DbType.Int16,
+                        ArrowTypeId.Int32 => DbType.Int32,
+                        ArrowTypeId.Int64 => DbType.Int64,
+                        ArrowTypeId.Float => DbType.Single,
+                        ArrowTypeId.Double => DbType.Double,
+                        ArrowTypeId.Boolean => DbType.Boolean,
+                        ArrowTypeId.String => DbType.String,
+                        ArrowTypeId.Date32 => DbType.Date,
+                        ArrowTypeId.Date64 => DbType.DateTime,
+                        ArrowTypeId.Time32 => DbType.Time,
+                        ArrowTypeId.Time64 => DbType.Time,
+                        ArrowTypeId.Timestamp => DbType.DateTime,
+                        ArrowTypeId.Decimal32 or
+                        ArrowTypeId.Decimal64 or
+                        ArrowTypeId.Decimal128 or
+                        ArrowTypeId.Decimal256 => DbType.Decimal,
+                        _ => DbType.Object,
+                    },
+                };
+                DbParameterCollection.Add(parameter);
+            }
+        }
+
+        protected override DbParameter CreateDbParameter()
+        {
+            return new AdbcParameter();
+        }
+
 #if NET5_0_OR_GREATER
         public override ValueTask DisposeAsync()
         {
             return base.DisposeAsync();
         }
 #endif
+
         #region NOT_IMPLEMENTED
 
         public override bool DesignTimeVisible { get => throw new 
NotImplementedException(); set => throw new NotImplementedException(); }
@@ -507,16 +556,6 @@ namespace Apache.Arrow.Adbc.Client
             throw new NotImplementedException();
         }
 
-        public override void Prepare()
-        {
-            throw new NotImplementedException();
-        }
-
-        protected override DbParameter CreateDbParameter()
-        {
-            return new AdbcParameter();
-        }
-
         #endregion
 
         private class AdbcParameterCollection : DbParameterCollection
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs 
b/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs
index 31e25f662..0a8e89d82 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs
@@ -15,8 +15,9 @@
 * limitations under the License.
 */
 
-using System.Collections.Generic;
+using System.Data;
 using Apache.Arrow.Adbc.Client;
+using Apache.Arrow.Types;
 using Xunit;
 
 namespace Apache.Arrow.Adbc.Tests.Client
@@ -117,6 +118,35 @@ namespace Apache.Arrow.Adbc.Tests.Client
             });
         }
 
+        [Fact]
+        public void BindParameters()
+        {
+            using var connection = 
_duckDb.CreateConnection("bindparameters.db", null);
+            connection.Open();
+            var command = connection.CreateCommand();
+
+            command.CommandText = "select ?, ?";
+            command.Prepare();
+            Assert.Equal(2, command.Parameters.Count);
+            Assert.Equal("0", command.Parameters[0].ParameterName);
+            Assert.Equal(DbType.Object, command.Parameters[0].DbType);
+            Assert.Equal("1", command.Parameters[1].ParameterName);
+            Assert.Equal(DbType.Object, command.Parameters[1].DbType);
+
+            command.Parameters[0].DbType = DbType.Int32;
+            command.Parameters[0].Value = 1;
+            command.Parameters[1].DbType = DbType.String;
+            command.Parameters[1].Value = "foo";
+
+            using var reader = command.ExecuteReader();
+            long count = 0;
+            while (reader.Read())
+            {
+                count++;
+            }
+            Assert.Equal(1, count);
+        }
+
         private static long GetResultCount(AdbcCommand command, string query)
         {
             command.CommandText = "SELECT * from test";
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs 
b/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
index d5ad20c9c..63bc1022b 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs
@@ -209,6 +209,33 @@ namespace Apache.Arrow.Adbc.Tests
             Assert.Equal(6, GetResultCount(statement3, "SELECT * from 
main.ingested"));
         }
 
+        [Fact]
+        public void PrepareAndBind()
+        {
+            using var database = _duckDb.OpenDatabase("bind.db");
+            using var connection = database.Connect(null);
+            using var statement = connection.CreateStatement();
+
+            statement.SqlQuery = "select ?, ?";
+            statement.Prepare();
+            var schema = statement.GetParameterSchema();
+            Assert.Equal(2, schema.FieldsList.Count);
+            Assert.Equal("0", schema.FieldsList[0].Name);
+            Assert.Equal(ArrowTypeId.Null, 
schema.FieldsList[0].DataType.TypeId);
+            Assert.Equal("1", schema.FieldsList[1].Name);
+            Assert.Equal(ArrowTypeId.Null, 
schema.FieldsList[1].DataType.TypeId);
+
+            schema = new Schema([new Field("0", Int32Type.Default, false), new 
Field("1", StringType.Default, false)], null);
+            RecordBatch recordBatch = new RecordBatch(schema, [
+                new Int32Array.Builder().AppendRange([1]).Build(),
+                new StringArray.Builder().AppendRange(["foo"]).Build()
+                ], 1);
+            statement.Bind(recordBatch, schema);
+
+            var results = statement.ExecuteQuery();
+            Assert.Equal(1, GetResultCount(results));
+        }
+
         [Fact]
         public async Task GetTableTypes()
         {
@@ -255,6 +282,11 @@ namespace Apache.Arrow.Adbc.Tests
         {
             statement.SqlQuery = query;
             var results = statement.ExecuteQuery();
+            return GetResultCount(results);
+        }
+
+        private static long GetResultCount(QueryResult results)
+        {
             long count = 0;
             using (var stream = results.Stream ?? throw new 
InvalidOperationException("no results found"))
             {

Reply via email to