This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new a830bf26c fix(csharp): an assortment of small fixes not worth 
individual pull requests (#1807)
a830bf26c is described below

commit a830bf26ce67fb31c139049afcdca9c60168b4ab
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Fri May 3 05:09:08 2024 -0700

    fix(csharp): an assortment of small fixes not worth individual pull 
requests (#1807)
    
    Closes #1806
---
 .../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 99 ++++++++++++++++------
 .../Extensions/CollectionExtensions.cs             |  2 +-
 csharp/src/Client/AdbcCommand.cs                   | 49 ++++++-----
 csharp/src/Client/AdbcConnection.cs                | 22 ++---
 csharp/src/Client/AdbcDataReader.cs                |  2 +-
 csharp/test/Drivers/Interop/Snowflake/CastTests.cs |  2 +-
 .../test/Drivers/Interop/Snowflake/ClientTests.cs  |  2 +-
 .../test/Drivers/Interop/Snowflake/DriverTests.cs  |  2 +-
 8 files changed, 111 insertions(+), 69 deletions(-)

diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs 
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index c8a61cf44..a191775ec 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -151,20 +151,35 @@ namespace Apache.Arrow.Adbc.C
                 if (parameters == null) throw new 
ArgumentNullException(nameof(parameters));
 
                 CAdbcDatabase nativeDatabase = new CAdbcDatabase();
-
-                using (CallHelper caller = new CallHelper())
+                ImportedAdbcDatabase? result = null;
+                try
                 {
-                    caller.Call(Driver.DatabaseNew, ref nativeDatabase);
-
-                    foreach (KeyValuePair<string, string> pair in parameters)
+                    using (CallHelper caller = new CallHelper())
                     {
-                        caller.Call(Driver.DatabaseSetOption, ref 
nativeDatabase, pair.Key, pair.Value);
+                        caller.Call(Driver.DatabaseNew, ref nativeDatabase);
+
+                        foreach (KeyValuePair<string, string> pair in 
parameters)
+                        {
+                            caller.Call(Driver.DatabaseSetOption, ref 
nativeDatabase, pair.Key, pair.Value);
+                        }
+
+                        caller.Call(Driver.DatabaseInit, ref nativeDatabase);
                     }
 
-                    caller.Call(Driver.DatabaseInit, ref nativeDatabase);
+                    result = new ImportedAdbcDatabase(this, nativeDatabase);
+                }
+                finally
+                {
+                    if (result == null && nativeDatabase.private_data != null)
+                    {
+                        using (CallHelper caller = new CallHelper())
+                        {
+                            caller.Call(Driver.DatabaseRelease, ref 
nativeDatabase);
+                        }
+                    }
                 }
 
-                return new ImportedAdbcDatabase(this, nativeDatabase);
+                return result;
             }
 
             public unsafe override void Dispose()
@@ -219,23 +234,38 @@ namespace Apache.Arrow.Adbc.C
             public unsafe override AdbcConnection 
Connect(IReadOnlyDictionary<string, string>? options)
             {
                 CAdbcConnection nativeConnection = new CAdbcConnection();
-
-                using (CallHelper caller = new CallHelper())
+                ImportedAdbcConnection? result = null;
+                try
                 {
-                    caller.Call(Driver.ConnectionNew, ref nativeConnection);
+                    using (CallHelper caller = new CallHelper())
+                    {
+                        caller.Call(Driver.ConnectionNew, ref 
nativeConnection);
+
+                        if (options != null)
+                        {
+                            foreach (KeyValuePair<string, string> pair in 
options)
+                            {
+                                caller.Call(Driver.ConnectionSetOption, ref 
nativeConnection, pair.Key, pair.Value);
+                            }
+                        }
+
+                        caller.Call(Driver.ConnectionInit, ref 
nativeConnection, ref _nativeDatabase);
 
-                    if (options != null)
+                        result = new ImportedAdbcConnection(_driver, 
nativeConnection);
+                    }
+                }
+                finally
+                {
+                    if (result == null && nativeConnection.private_data != 
null)
                     {
-                        foreach (KeyValuePair<string, string> pair in options)
+                        using (CallHelper caller = new CallHelper())
                         {
-                            caller.Call(Driver.ConnectionSetOption, ref 
nativeConnection, pair.Key, pair.Value);
+                            caller.Call(Driver.ConnectionRelease, ref 
nativeConnection);
                         }
                     }
-
-                    caller.Call(Driver.ConnectionInit, ref nativeConnection, 
ref _nativeDatabase);
                 }
 
-                return new ImportedAdbcConnection(_driver, nativeConnection);
+                return result;
             }
 
             public override void Dispose()
@@ -338,22 +368,37 @@ namespace Apache.Arrow.Adbc.C
             public unsafe override AdbcStatement CreateStatement()
             {
                 CAdbcStatement nativeStatement = new CAdbcStatement();
-
-                using (CallHelper caller = new CallHelper())
+                ImportedAdbcStatement? result = null;
+                try
                 {
-                    fixed (CAdbcConnection* connection = &_nativeConnection)
+                    using (CallHelper caller = new CallHelper())
                     {
-                        caller.TranslateCode(
+                        fixed (CAdbcConnection* connection = 
&_nativeConnection)
+                        {
+                            caller.TranslateCode(
 #if NET5_0_OR_GREATER
-                            Driver.StatementNew
+                                Driver.StatementNew
 #else
-                            
Marshal.GetDelegateForFunctionPointer<StatementNew>(Driver.StatementNew)
+                                
Marshal.GetDelegateForFunctionPointer<StatementNew>(Driver.StatementNew)
 #endif
-                            (connection, &nativeStatement, &caller._error));
+                                (connection, &nativeStatement, 
&caller._error));
+                        }
+
+                        result = new ImportedAdbcStatement(_driver, 
nativeStatement);
+                    }
+                }
+                finally
+                {
+                    if (result == null && nativeStatement.private_data != null)
+                    {
+                        using (CallHelper caller = new CallHelper())
+                        {
+                            caller.Call(Driver.StatementRelease, ref 
nativeStatement);
+                        }
                     }
                 }
 
-                return new ImportedAdbcStatement(_driver, nativeStatement);
+                return result;
             }
 
             public unsafe override IArrowArrayStream 
GetInfo(IReadOnlyList<AdbcInfoCode> codes)
@@ -831,7 +876,11 @@ namespace Apache.Arrow.Adbc.C
             {
                 Debug.Assert(_schema != null);
                 Schema schema = CArrowSchemaImporter.ImportSchema(_schema);
+
+                // ImportSchema makes a copy so we need to free the original
+                CArrowSchema.Free(_schema);
                 _schema = null;
+
                 return schema;
             }
 
diff --git a/csharp/src/Apache.Arrow.Adbc/Extensions/CollectionExtensions.cs 
b/csharp/src/Apache.Arrow.Adbc/Extensions/CollectionExtensions.cs
index f153b11f2..20cd09420 100644
--- a/csharp/src/Apache.Arrow.Adbc/Extensions/CollectionExtensions.cs
+++ b/csharp/src/Apache.Arrow.Adbc/Extensions/CollectionExtensions.cs
@@ -30,7 +30,7 @@ namespace Apache.Arrow.Adbc.Extensions
         public static Span<T> AsSpan<T>(this IReadOnlyList<T> list)
         {
             T[]? array = list as T[];
-            if (array != null) { return array.AsSpan(); }
+            if (array != null) { return (Span<T>)array; }
 
 #if NET5_0_OR_GREATER
             List<T>? concreteList = list as List<T>;
diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs
index ffc05ae98..5b3ed7c24 100644
--- a/csharp/src/Client/AdbcCommand.cs
+++ b/csharp/src/Client/AdbcCommand.cs
@@ -29,28 +29,23 @@ namespace Apache.Arrow.Adbc.Client
     {
         private AdbcStatement adbcStatement;
         private int _timeout = 30;
+        private bool _disposed;
 
         /// <summary>
         /// Overloaded. Initializes <see cref="AdbcCommand"/>.
         /// </summary>
-        /// <param name="adbcStatement">
-        /// The <see cref="AdbcStatement"/> to use.
-        /// </param>
         /// <param name="adbcConnection">
         /// The <see cref="AdbcConnection"/> to use.
         /// </param>
         /// <exception cref="ArgumentNullException"></exception>
-        public AdbcCommand(AdbcStatement adbcStatement, AdbcConnection 
adbcConnection) : base()
+        public AdbcCommand(AdbcConnection adbcConnection) : base()
         {
-            if (adbcStatement == null)
-                throw new ArgumentNullException(nameof(adbcStatement));
-
             if (adbcConnection == null)
                 throw new ArgumentNullException(nameof(adbcConnection));
 
-            this.adbcStatement = adbcStatement;
             this.DbConnection = adbcConnection;
             this.DecimalBehavior = adbcConnection.DecimalBehavior;
+            this.adbcStatement = adbcConnection.CreateStatement();
         }
 
         /// <summary>
@@ -61,31 +56,39 @@ namespace Apache.Arrow.Adbc.Client
         public AdbcCommand(string query, AdbcConnection adbcConnection) : 
base()
         {
             if (string.IsNullOrEmpty(query))
-                throw new ArgumentNullException(nameof(adbcStatement));
+                throw new ArgumentNullException(nameof(query));
 
             if (adbcConnection == null)
                 throw new ArgumentNullException(nameof(adbcConnection));
 
-            this.adbcStatement = adbcConnection.AdbcStatement;
+            this.adbcStatement = adbcConnection.CreateStatement();
             this.CommandText = query;
 
             this.DbConnection = adbcConnection;
             this.DecimalBehavior = adbcConnection.DecimalBehavior;
         }
 
+        // For testing
+        internal AdbcCommand(AdbcStatement adbcStatement, AdbcConnection 
adbcConnection)
+        {
+            this.adbcStatement = adbcStatement;
+            this.DbConnection = adbcConnection;
+            this.DecimalBehavior = adbcConnection.DecimalBehavior;
+        }
+
         /// <summary>
         /// Gets the <see cref="AdbcStatement"/> associated with
         /// this <see cref="AdbcCommand"/>.
         /// </summary>
-        public AdbcStatement AdbcStatement => this.adbcStatement;
+        public AdbcStatement AdbcStatement => _disposed ? throw new 
ObjectDisposedException(nameof(AdbcCommand)) : this.adbcStatement;
 
         public DecimalBehavior DecimalBehavior { get; set; }
 
         public override string CommandText
         {
-            get => this.adbcStatement.SqlQuery ?? string.Empty;
+            get => AdbcStatement.SqlQuery ?? string.Empty;
 #nullable disable
-            set => this.adbcStatement.SqlQuery = string.IsNullOrEmpty(value) ? 
null : value;
+            set => AdbcStatement.SqlQuery = string.IsNullOrEmpty(value) ? null 
: value;
 #nullable restore
         }
 
@@ -116,15 +119,15 @@ namespace Apache.Arrow.Adbc.Client
         /// </summary>
         public byte[]? SubstraitPlan
         {
-            get => this.adbcStatement.SubstraitPlan;
-            set => this.adbcStatement.SubstraitPlan = value;
+            get => AdbcStatement.SubstraitPlan;
+            set => AdbcStatement.SubstraitPlan = value;
         }
 
         protected override DbConnection? DbConnection { get; set; }
 
         public override int ExecuteNonQuery()
         {
-            return 
Convert.ToInt32(this.adbcStatement.ExecuteUpdate().AffectedRows);
+            return Convert.ToInt32(AdbcStatement.ExecuteUpdate().AffectedRows);
         }
 
         /// <summary>
@@ -134,7 +137,7 @@ namespace Apache.Arrow.Adbc.Client
         /// <returns></returns>
         public long ExecuteUpdate()
         {
-            return this.adbcStatement.ExecuteUpdate().AffectedRows;
+            return AdbcStatement.ExecuteUpdate().AffectedRows;
         }
 
         /// <summary>
@@ -143,7 +146,7 @@ namespace Apache.Arrow.Adbc.Client
         /// <returns><see cref="Result"></returns>
         public QueryResult ExecuteQuery()
         {
-            QueryResult executed = this.adbcStatement.ExecuteQuery();
+            QueryResult executed = AdbcStatement.ExecuteQuery();
 
             return executed;
         }
@@ -171,6 +174,9 @@ namespace Apache.Arrow.Adbc.Client
         /// <returns><see cref="AdbcDataReader"/></returns>
         public new AdbcDataReader ExecuteReader(CommandBehavior behavior)
         {
+            if (_disposed)
+                throw new ObjectDisposedException(nameof(AdbcCommand));
+
             bool closeConnection = (behavior & 
CommandBehavior.CloseConnection) != 0;
             switch (behavior & ~CommandBehavior.CloseConnection)
             {
@@ -186,15 +192,14 @@ namespace Apache.Arrow.Adbc.Client
 
         protected override void Dispose(bool disposing)
         {
-            if (disposing)
+            if (disposing && !_disposed)
             {
                 // TODO: ensure not in the middle of pulling
-                this.adbcStatement?.Dispose();
+                this.adbcStatement.Dispose();
+                _disposed = true;
             }
 
             base.Dispose(disposing);
-
-            GC.SuppressFinalize(this);
         }
 
 #if NET5_0_OR_GREATER
diff --git a/csharp/src/Client/AdbcConnection.cs 
b/csharp/src/Client/AdbcConnection.cs
index 667b1b89d..5347ba779 100644
--- a/csharp/src/Client/AdbcConnection.cs
+++ b/csharp/src/Client/AdbcConnection.cs
@@ -38,7 +38,6 @@ namespace Apache.Arrow.Adbc.Client
         private readonly Dictionary<string, string> adbcConnectionParameters;
         private readonly Dictionary<string, string> adbcConnectionOptions;
 
-        private AdbcStatement? adbcStatement;
         private AdbcTransaction? currentTransaction;
 
         /// <summary>
@@ -115,22 +114,12 @@ namespace Apache.Arrow.Adbc.Client
         public AdbcDriver? AdbcDriver { get; set; }
 
         /// <summary>
-        /// Gets the <see cref="AdbcStatement"/> associated with the
-        /// connection.
+        /// Creates an <see cref="AdbcStatement"/> for the connection.
         /// </summary>
-        internal AdbcStatement AdbcStatement
+        internal AdbcStatement CreateStatement()
         {
-            get
-            {
-                if (this.adbcStatement == null)
-                {
-                    // need to have a connection in order to have a statement
-                    EnsureConnectionOpen();
-                    this.adbcStatement = 
this.adbcConnectionInternal!.CreateStatement();
-                }
-
-                return this.adbcStatement;
-            }
+            EnsureConnectionOpen();
+            return this.adbcConnectionInternal!.CreateStatement();
         }
 
 #if NET5_0_OR_GREATER
@@ -147,7 +136,7 @@ namespace Apache.Arrow.Adbc.Client
         {
             EnsureConnectionOpen();
 
-            return new AdbcCommand(this.AdbcStatement, this);
+            return new AdbcCommand(this);
         }
 
         /// <summary>
@@ -163,7 +152,6 @@ namespace Apache.Arrow.Adbc.Client
         {
             this.adbcConnectionInternal?.Dispose();
             this.adbcConnectionInternal = null;
-            this.adbcStatement = null;
 
             base.Dispose(disposing);
         }
diff --git a/csharp/src/Client/AdbcDataReader.cs 
b/csharp/src/Client/AdbcDataReader.cs
index 5504ce83c..124399865 100644
--- a/csharp/src/Client/AdbcDataReader.cs
+++ b/csharp/src/Client/AdbcDataReader.cs
@@ -91,7 +91,7 @@ namespace Apache.Arrow.Adbc.Client
         /// <summary>
         /// The total number of record batches in the result.
         /// </summary>
-        public int TotalBatches { get; set; }
+        public int TotalBatches { get; private set; }
 
         private RecordBatch RecordBatch
         {
diff --git a/csharp/test/Drivers/Interop/Snowflake/CastTests.cs 
b/csharp/test/Drivers/Interop/Snowflake/CastTests.cs
index a46734768..dd5ab2e7d 100644
--- a/csharp/test/Drivers/Interop/Snowflake/CastTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/CastTests.cs
@@ -408,7 +408,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
         {
             statement.SqlQuery = string.Format("ALTER SESSION SET TIMEZONE = 
'{0}'", timezone);
             UpdateResult result = statement.ExecuteUpdate();
-            Assert.Equal(-1, result.AffectedRows);
+            Assert.Equal(1, result.AffectedRows);
         }
 
         public void Dispose()
diff --git a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs 
b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
index 04cf0d1eb..f90e278f1 100644
--- a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
@@ -54,7 +54,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
             {
                 string[] queries = 
SnowflakeTestingUtils.GetQueries(testConfiguration);
 
-                List<int> expectedResults = new List<int>() { -1, 1, 1 };
+                List<int> expectedResults = new List<int>() { 1, 1, 1 };
 
                 Tests.ClientTests.CanClientExecuteUpdate(adbcConnection, 
testConfiguration, queries, expectedResults);
             }
diff --git a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs 
b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
index 0dcff9aa9..da27ef1b5 100644
--- a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
@@ -100,7 +100,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
         {
             string[] queries = 
SnowflakeTestingUtils.GetQueries(_testConfiguration);
 
-            List<int> expectedResults = new List<int>() { -1, 1, 1 };
+            List<int> expectedResults = new List<int>() { 1, 1, 1 };
 
             for (int i = 0; i < queries.Length; i++)
             {

Reply via email to