This is an automated email from the ASF dual-hosted git repository. curth pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push: new e39d71549 feat(csharp): Add support for Prepare to ImportedStatement and to ADO.NET wrapper (#2628) e39d71549 is described below commit e39d71549bde79ebd2662480ae82b94140c62e03 Author: Curt Hagenlocher <c...@hagenlocher.org> AuthorDate: Tue Mar 18 07:31:08 2025 -0700 feat(csharp): Add support for Prepare to ImportedStatement and to ADO.NET wrapper (#2628) Closes #2616. --- .../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 70 ++++++++++++++-------- csharp/src/Client/AdbcCommand.cs | 59 ++++++++++++++---- .../Client/DuckDbClientTests.cs | 32 +++++++++- .../Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs | 32 ++++++++++ 4 files changed, 157 insertions(+), 36 deletions(-) diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs index 73ae28b7d..133fff57a 100644 --- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs +++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs @@ -797,6 +797,15 @@ namespace Apache.Arrow.Adbc.C Dispose(false); } + public override string? SqlQuery + { + set + { + SetSqlQuery(value); + base.SqlQuery = value; + } + } + private unsafe ref CAdbcDriver Driver { get @@ -885,12 +894,6 @@ namespace Apache.Arrow.Adbc.C public unsafe override QueryResult ExecuteQuery() { - if (SqlQuery != null) - { - // TODO: Consider moving this to the setter - SetSqlQuery(SqlQuery); - } - using (CallHelper caller = new CallHelper()) { fixed (CAdbcStatement* statement = &_nativeStatement) @@ -911,12 +914,6 @@ namespace Apache.Arrow.Adbc.C public override unsafe Schema ExecuteSchema() { - if (SqlQuery != null) - { - // TODO: Consider moving this to the setter - SetSqlQuery(SqlQuery); - } - using (CallHelper caller = new CallHelper()) { fixed (CAdbcStatement* statement = &_nativeStatement) @@ -936,12 +933,6 @@ namespace Apache.Arrow.Adbc.C public unsafe override UpdateResult ExecuteUpdate() { - if (SqlQuery != null) - { - // TODO: Consider moving this to the setter - SetSqlQuery(SqlQuery); - } - using (CallHelper caller = new CallHelper()) { fixed (CAdbcStatement* statement = &_nativeStatement) @@ -962,12 +953,6 @@ namespace Apache.Arrow.Adbc.C public unsafe override PartitionedResult ExecutePartitioned() { - if (SqlQuery != null) - { - // TODO: Consider moving this to the setter - SetSqlQuery(SqlQuery); - } - using (CallHelper caller = new CallHelper()) { fixed (CAdbcStatement* statement = &_nativeStatement) @@ -1013,6 +998,41 @@ namespace Apache.Arrow.Adbc.C } } + public unsafe override Schema GetParameterSchema() + { + using (CallHelper caller = new CallHelper()) + { + fixed (CAdbcStatement* statement = &_nativeStatement) + { + caller.TranslateCode( +#if NET5_0_OR_GREATER + Driver.StatementGetParameterSchema +#else + Marshal.GetDelegateForFunctionPointer<StatementGetParameterSchema>(Driver.StatementGetParameterSchema) +#endif + (statement, caller.CreateSchema(), & caller._error)); + } + return caller.ImportSchema(); + } + } + + public unsafe override void Prepare() + { + using (CallHelper caller = new CallHelper()) + { + fixed (CAdbcStatement* statement = &_nativeStatement) + { + caller.TranslateCode( +#if NET5_0_OR_GREATER + Driver.StatementPrepare +#else + Marshal.GetDelegateForFunctionPointer<StatementPrepare>(Driver.StatementPrepare) +#endif + (statement, &caller._error)); + } + } + } + public unsafe override void SetOption(string key, string value) { using (CallHelper caller = new CallHelper()) @@ -1055,7 +1075,7 @@ namespace Apache.Arrow.Adbc.C } } - private unsafe void SetSqlQuery(string sqlQuery) + private unsafe void SetSqlQuery(string? sqlQuery) { fixed (CAdbcStatement* statement = &_nativeStatement) { diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs index a317ca19c..7ad0678af 100644 --- a/csharp/src/Client/AdbcCommand.cs +++ b/csharp/src/Client/AdbcCommand.cs @@ -483,12 +483,61 @@ namespace Apache.Arrow.Adbc.Client } } + public override void Prepare() + { + _adbcStatement.Prepare(); + var schema = _adbcStatement.GetParameterSchema(); + + DbParameterCollection.Clear(); + + foreach (Field field in schema.FieldsList) + { + AdbcParameter parameter = new AdbcParameter + { + ParameterName = field.Name, + IsNullable = field.IsNullable, + DbType = field.DataType.TypeId switch + { + ArrowTypeId.UInt8 => DbType.Byte, + ArrowTypeId.UInt16 => DbType.UInt16, + ArrowTypeId.UInt32 => DbType.UInt32, + ArrowTypeId.UInt64 => DbType.UInt64, + ArrowTypeId.Int8 => DbType.SByte, + ArrowTypeId.Int16 => DbType.Int16, + ArrowTypeId.Int32 => DbType.Int32, + ArrowTypeId.Int64 => DbType.Int64, + ArrowTypeId.Float => DbType.Single, + ArrowTypeId.Double => DbType.Double, + ArrowTypeId.Boolean => DbType.Boolean, + ArrowTypeId.String => DbType.String, + ArrowTypeId.Date32 => DbType.Date, + ArrowTypeId.Date64 => DbType.DateTime, + ArrowTypeId.Time32 => DbType.Time, + ArrowTypeId.Time64 => DbType.Time, + ArrowTypeId.Timestamp => DbType.DateTime, + ArrowTypeId.Decimal32 or + ArrowTypeId.Decimal64 or + ArrowTypeId.Decimal128 or + ArrowTypeId.Decimal256 => DbType.Decimal, + _ => DbType.Object, + }, + }; + DbParameterCollection.Add(parameter); + } + } + + protected override DbParameter CreateDbParameter() + { + return new AdbcParameter(); + } + #if NET5_0_OR_GREATER public override ValueTask DisposeAsync() { return base.DisposeAsync(); } #endif + #region NOT_IMPLEMENTED public override bool DesignTimeVisible { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } @@ -507,16 +556,6 @@ namespace Apache.Arrow.Adbc.Client throw new NotImplementedException(); } - public override void Prepare() - { - throw new NotImplementedException(); - } - - protected override DbParameter CreateDbParameter() - { - return new AdbcParameter(); - } - #endregion private class AdbcParameterCollection : DbParameterCollection diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs index 31e25f662..0a8e89d82 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Client/DuckDbClientTests.cs @@ -15,8 +15,9 @@ * limitations under the License. */ -using System.Collections.Generic; +using System.Data; using Apache.Arrow.Adbc.Client; +using Apache.Arrow.Types; using Xunit; namespace Apache.Arrow.Adbc.Tests.Client @@ -117,6 +118,35 @@ namespace Apache.Arrow.Adbc.Tests.Client }); } + [Fact] + public void BindParameters() + { + using var connection = _duckDb.CreateConnection("bindparameters.db", null); + connection.Open(); + var command = connection.CreateCommand(); + + command.CommandText = "select ?, ?"; + command.Prepare(); + Assert.Equal(2, command.Parameters.Count); + Assert.Equal("0", command.Parameters[0].ParameterName); + Assert.Equal(DbType.Object, command.Parameters[0].DbType); + Assert.Equal("1", command.Parameters[1].ParameterName); + Assert.Equal(DbType.Object, command.Parameters[1].DbType); + + command.Parameters[0].DbType = DbType.Int32; + command.Parameters[0].Value = 1; + command.Parameters[1].DbType = DbType.String; + command.Parameters[1].Value = "foo"; + + using var reader = command.ExecuteReader(); + long count = 0; + while (reader.Read()) + { + count++; + } + Assert.Equal(1, count); + } + private static long GetResultCount(AdbcCommand command, string query) { command.CommandText = "SELECT * from test"; diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs index d5ad20c9c..63bc1022b 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/ImportedDuckDbTests.cs @@ -209,6 +209,33 @@ namespace Apache.Arrow.Adbc.Tests Assert.Equal(6, GetResultCount(statement3, "SELECT * from main.ingested")); } + [Fact] + public void PrepareAndBind() + { + using var database = _duckDb.OpenDatabase("bind.db"); + using var connection = database.Connect(null); + using var statement = connection.CreateStatement(); + + statement.SqlQuery = "select ?, ?"; + statement.Prepare(); + var schema = statement.GetParameterSchema(); + Assert.Equal(2, schema.FieldsList.Count); + Assert.Equal("0", schema.FieldsList[0].Name); + Assert.Equal(ArrowTypeId.Null, schema.FieldsList[0].DataType.TypeId); + Assert.Equal("1", schema.FieldsList[1].Name); + Assert.Equal(ArrowTypeId.Null, schema.FieldsList[1].DataType.TypeId); + + schema = new Schema([new Field("0", Int32Type.Default, false), new Field("1", StringType.Default, false)], null); + RecordBatch recordBatch = new RecordBatch(schema, [ + new Int32Array.Builder().AppendRange([1]).Build(), + new StringArray.Builder().AppendRange(["foo"]).Build() + ], 1); + statement.Bind(recordBatch, schema); + + var results = statement.ExecuteQuery(); + Assert.Equal(1, GetResultCount(results)); + } + [Fact] public async Task GetTableTypes() { @@ -255,6 +282,11 @@ namespace Apache.Arrow.Adbc.Tests { statement.SqlQuery = query; var results = statement.ExecuteQuery(); + return GetResultCount(results); + } + + private static long GetResultCount(QueryResult results) + { long count = 0; using (var stream = results.Stream ?? throw new InvalidOperationException("no results found")) {