This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new a830bf26c fix(csharp): an assortment of small fixes not worth
individual pull requests (#1807)
a830bf26c is described below
commit a830bf26ce67fb31c139049afcdca9c60168b4ab
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Fri May 3 05:09:08 2024 -0700
fix(csharp): an assortment of small fixes not worth individual pull
requests (#1807)
Closes #1806
---
.../src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs | 99 ++++++++++++++++------
.../Extensions/CollectionExtensions.cs | 2 +-
csharp/src/Client/AdbcCommand.cs | 49 ++++++-----
csharp/src/Client/AdbcConnection.cs | 22 ++---
csharp/src/Client/AdbcDataReader.cs | 2 +-
csharp/test/Drivers/Interop/Snowflake/CastTests.cs | 2 +-
.../test/Drivers/Interop/Snowflake/ClientTests.cs | 2 +-
.../test/Drivers/Interop/Snowflake/DriverTests.cs | 2 +-
8 files changed, 111 insertions(+), 69 deletions(-)
diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
index c8a61cf44..a191775ec 100644
--- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
+++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
@@ -151,20 +151,35 @@ namespace Apache.Arrow.Adbc.C
if (parameters == null) throw new
ArgumentNullException(nameof(parameters));
CAdbcDatabase nativeDatabase = new CAdbcDatabase();
-
- using (CallHelper caller = new CallHelper())
+ ImportedAdbcDatabase? result = null;
+ try
{
- caller.Call(Driver.DatabaseNew, ref nativeDatabase);
-
- foreach (KeyValuePair<string, string> pair in parameters)
+ using (CallHelper caller = new CallHelper())
{
- caller.Call(Driver.DatabaseSetOption, ref
nativeDatabase, pair.Key, pair.Value);
+ caller.Call(Driver.DatabaseNew, ref nativeDatabase);
+
+ foreach (KeyValuePair<string, string> pair in
parameters)
+ {
+ caller.Call(Driver.DatabaseSetOption, ref
nativeDatabase, pair.Key, pair.Value);
+ }
+
+ caller.Call(Driver.DatabaseInit, ref nativeDatabase);
}
- caller.Call(Driver.DatabaseInit, ref nativeDatabase);
+ result = new ImportedAdbcDatabase(this, nativeDatabase);
+ }
+ finally
+ {
+ if (result == null && nativeDatabase.private_data != null)
+ {
+ using (CallHelper caller = new CallHelper())
+ {
+ caller.Call(Driver.DatabaseRelease, ref
nativeDatabase);
+ }
+ }
}
- return new ImportedAdbcDatabase(this, nativeDatabase);
+ return result;
}
public unsafe override void Dispose()
@@ -219,23 +234,38 @@ namespace Apache.Arrow.Adbc.C
public unsafe override AdbcConnection
Connect(IReadOnlyDictionary<string, string>? options)
{
CAdbcConnection nativeConnection = new CAdbcConnection();
-
- using (CallHelper caller = new CallHelper())
+ ImportedAdbcConnection? result = null;
+ try
{
- caller.Call(Driver.ConnectionNew, ref nativeConnection);
+ using (CallHelper caller = new CallHelper())
+ {
+ caller.Call(Driver.ConnectionNew, ref
nativeConnection);
+
+ if (options != null)
+ {
+ foreach (KeyValuePair<string, string> pair in
options)
+ {
+ caller.Call(Driver.ConnectionSetOption, ref
nativeConnection, pair.Key, pair.Value);
+ }
+ }
+
+ caller.Call(Driver.ConnectionInit, ref
nativeConnection, ref _nativeDatabase);
- if (options != null)
+ result = new ImportedAdbcConnection(_driver,
nativeConnection);
+ }
+ }
+ finally
+ {
+ if (result == null && nativeConnection.private_data !=
null)
{
- foreach (KeyValuePair<string, string> pair in options)
+ using (CallHelper caller = new CallHelper())
{
- caller.Call(Driver.ConnectionSetOption, ref
nativeConnection, pair.Key, pair.Value);
+ caller.Call(Driver.ConnectionRelease, ref
nativeConnection);
}
}
-
- caller.Call(Driver.ConnectionInit, ref nativeConnection,
ref _nativeDatabase);
}
- return new ImportedAdbcConnection(_driver, nativeConnection);
+ return result;
}
public override void Dispose()
@@ -338,22 +368,37 @@ namespace Apache.Arrow.Adbc.C
public unsafe override AdbcStatement CreateStatement()
{
CAdbcStatement nativeStatement = new CAdbcStatement();
-
- using (CallHelper caller = new CallHelper())
+ ImportedAdbcStatement? result = null;
+ try
{
- fixed (CAdbcConnection* connection = &_nativeConnection)
+ using (CallHelper caller = new CallHelper())
{
- caller.TranslateCode(
+ fixed (CAdbcConnection* connection =
&_nativeConnection)
+ {
+ caller.TranslateCode(
#if NET5_0_OR_GREATER
- Driver.StatementNew
+ Driver.StatementNew
#else
-
Marshal.GetDelegateForFunctionPointer<StatementNew>(Driver.StatementNew)
+
Marshal.GetDelegateForFunctionPointer<StatementNew>(Driver.StatementNew)
#endif
- (connection, &nativeStatement, &caller._error));
+ (connection, &nativeStatement,
&caller._error));
+ }
+
+ result = new ImportedAdbcStatement(_driver,
nativeStatement);
+ }
+ }
+ finally
+ {
+ if (result == null && nativeStatement.private_data != null)
+ {
+ using (CallHelper caller = new CallHelper())
+ {
+ caller.Call(Driver.StatementRelease, ref
nativeStatement);
+ }
}
}
- return new ImportedAdbcStatement(_driver, nativeStatement);
+ return result;
}
public unsafe override IArrowArrayStream
GetInfo(IReadOnlyList<AdbcInfoCode> codes)
@@ -831,7 +876,11 @@ namespace Apache.Arrow.Adbc.C
{
Debug.Assert(_schema != null);
Schema schema = CArrowSchemaImporter.ImportSchema(_schema);
+
+ // ImportSchema makes a copy so we need to free the original
+ CArrowSchema.Free(_schema);
_schema = null;
+
return schema;
}
diff --git a/csharp/src/Apache.Arrow.Adbc/Extensions/CollectionExtensions.cs
b/csharp/src/Apache.Arrow.Adbc/Extensions/CollectionExtensions.cs
index f153b11f2..20cd09420 100644
--- a/csharp/src/Apache.Arrow.Adbc/Extensions/CollectionExtensions.cs
+++ b/csharp/src/Apache.Arrow.Adbc/Extensions/CollectionExtensions.cs
@@ -30,7 +30,7 @@ namespace Apache.Arrow.Adbc.Extensions
public static Span<T> AsSpan<T>(this IReadOnlyList<T> list)
{
T[]? array = list as T[];
- if (array != null) { return array.AsSpan(); }
+ if (array != null) { return (Span<T>)array; }
#if NET5_0_OR_GREATER
List<T>? concreteList = list as List<T>;
diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs
index ffc05ae98..5b3ed7c24 100644
--- a/csharp/src/Client/AdbcCommand.cs
+++ b/csharp/src/Client/AdbcCommand.cs
@@ -29,28 +29,23 @@ namespace Apache.Arrow.Adbc.Client
{
private AdbcStatement adbcStatement;
private int _timeout = 30;
+ private bool _disposed;
/// <summary>
/// Overloaded. Initializes <see cref="AdbcCommand"/>.
/// </summary>
- /// <param name="adbcStatement">
- /// The <see cref="AdbcStatement"/> to use.
- /// </param>
/// <param name="adbcConnection">
/// The <see cref="AdbcConnection"/> to use.
/// </param>
/// <exception cref="ArgumentNullException"></exception>
- public AdbcCommand(AdbcStatement adbcStatement, AdbcConnection
adbcConnection) : base()
+ public AdbcCommand(AdbcConnection adbcConnection) : base()
{
- if (adbcStatement == null)
- throw new ArgumentNullException(nameof(adbcStatement));
-
if (adbcConnection == null)
throw new ArgumentNullException(nameof(adbcConnection));
- this.adbcStatement = adbcStatement;
this.DbConnection = adbcConnection;
this.DecimalBehavior = adbcConnection.DecimalBehavior;
+ this.adbcStatement = adbcConnection.CreateStatement();
}
/// <summary>
@@ -61,31 +56,39 @@ namespace Apache.Arrow.Adbc.Client
public AdbcCommand(string query, AdbcConnection adbcConnection) :
base()
{
if (string.IsNullOrEmpty(query))
- throw new ArgumentNullException(nameof(adbcStatement));
+ throw new ArgumentNullException(nameof(query));
if (adbcConnection == null)
throw new ArgumentNullException(nameof(adbcConnection));
- this.adbcStatement = adbcConnection.AdbcStatement;
+ this.adbcStatement = adbcConnection.CreateStatement();
this.CommandText = query;
this.DbConnection = adbcConnection;
this.DecimalBehavior = adbcConnection.DecimalBehavior;
}
+ // For testing
+ internal AdbcCommand(AdbcStatement adbcStatement, AdbcConnection
adbcConnection)
+ {
+ this.adbcStatement = adbcStatement;
+ this.DbConnection = adbcConnection;
+ this.DecimalBehavior = adbcConnection.DecimalBehavior;
+ }
+
/// <summary>
/// Gets the <see cref="AdbcStatement"/> associated with
/// this <see cref="AdbcCommand"/>.
/// </summary>
- public AdbcStatement AdbcStatement => this.adbcStatement;
+ public AdbcStatement AdbcStatement => _disposed ? throw new
ObjectDisposedException(nameof(AdbcCommand)) : this.adbcStatement;
public DecimalBehavior DecimalBehavior { get; set; }
public override string CommandText
{
- get => this.adbcStatement.SqlQuery ?? string.Empty;
+ get => AdbcStatement.SqlQuery ?? string.Empty;
#nullable disable
- set => this.adbcStatement.SqlQuery = string.IsNullOrEmpty(value) ?
null : value;
+ set => AdbcStatement.SqlQuery = string.IsNullOrEmpty(value) ? null
: value;
#nullable restore
}
@@ -116,15 +119,15 @@ namespace Apache.Arrow.Adbc.Client
/// </summary>
public byte[]? SubstraitPlan
{
- get => this.adbcStatement.SubstraitPlan;
- set => this.adbcStatement.SubstraitPlan = value;
+ get => AdbcStatement.SubstraitPlan;
+ set => AdbcStatement.SubstraitPlan = value;
}
protected override DbConnection? DbConnection { get; set; }
public override int ExecuteNonQuery()
{
- return
Convert.ToInt32(this.adbcStatement.ExecuteUpdate().AffectedRows);
+ return Convert.ToInt32(AdbcStatement.ExecuteUpdate().AffectedRows);
}
/// <summary>
@@ -134,7 +137,7 @@ namespace Apache.Arrow.Adbc.Client
/// <returns></returns>
public long ExecuteUpdate()
{
- return this.adbcStatement.ExecuteUpdate().AffectedRows;
+ return AdbcStatement.ExecuteUpdate().AffectedRows;
}
/// <summary>
@@ -143,7 +146,7 @@ namespace Apache.Arrow.Adbc.Client
/// <returns><see cref="Result"></returns>
public QueryResult ExecuteQuery()
{
- QueryResult executed = this.adbcStatement.ExecuteQuery();
+ QueryResult executed = AdbcStatement.ExecuteQuery();
return executed;
}
@@ -171,6 +174,9 @@ namespace Apache.Arrow.Adbc.Client
/// <returns><see cref="AdbcDataReader"/></returns>
public new AdbcDataReader ExecuteReader(CommandBehavior behavior)
{
+ if (_disposed)
+ throw new ObjectDisposedException(nameof(AdbcCommand));
+
bool closeConnection = (behavior &
CommandBehavior.CloseConnection) != 0;
switch (behavior & ~CommandBehavior.CloseConnection)
{
@@ -186,15 +192,14 @@ namespace Apache.Arrow.Adbc.Client
protected override void Dispose(bool disposing)
{
- if (disposing)
+ if (disposing && !_disposed)
{
// TODO: ensure not in the middle of pulling
- this.adbcStatement?.Dispose();
+ this.adbcStatement.Dispose();
+ _disposed = true;
}
base.Dispose(disposing);
-
- GC.SuppressFinalize(this);
}
#if NET5_0_OR_GREATER
diff --git a/csharp/src/Client/AdbcConnection.cs
b/csharp/src/Client/AdbcConnection.cs
index 667b1b89d..5347ba779 100644
--- a/csharp/src/Client/AdbcConnection.cs
+++ b/csharp/src/Client/AdbcConnection.cs
@@ -38,7 +38,6 @@ namespace Apache.Arrow.Adbc.Client
private readonly Dictionary<string, string> adbcConnectionParameters;
private readonly Dictionary<string, string> adbcConnectionOptions;
- private AdbcStatement? adbcStatement;
private AdbcTransaction? currentTransaction;
/// <summary>
@@ -115,22 +114,12 @@ namespace Apache.Arrow.Adbc.Client
public AdbcDriver? AdbcDriver { get; set; }
/// <summary>
- /// Gets the <see cref="AdbcStatement"/> associated with the
- /// connection.
+ /// Creates an <see cref="AdbcStatement"/> for the connection.
/// </summary>
- internal AdbcStatement AdbcStatement
+ internal AdbcStatement CreateStatement()
{
- get
- {
- if (this.adbcStatement == null)
- {
- // need to have a connection in order to have a statement
- EnsureConnectionOpen();
- this.adbcStatement =
this.adbcConnectionInternal!.CreateStatement();
- }
-
- return this.adbcStatement;
- }
+ EnsureConnectionOpen();
+ return this.adbcConnectionInternal!.CreateStatement();
}
#if NET5_0_OR_GREATER
@@ -147,7 +136,7 @@ namespace Apache.Arrow.Adbc.Client
{
EnsureConnectionOpen();
- return new AdbcCommand(this.AdbcStatement, this);
+ return new AdbcCommand(this);
}
/// <summary>
@@ -163,7 +152,6 @@ namespace Apache.Arrow.Adbc.Client
{
this.adbcConnectionInternal?.Dispose();
this.adbcConnectionInternal = null;
- this.adbcStatement = null;
base.Dispose(disposing);
}
diff --git a/csharp/src/Client/AdbcDataReader.cs
b/csharp/src/Client/AdbcDataReader.cs
index 5504ce83c..124399865 100644
--- a/csharp/src/Client/AdbcDataReader.cs
+++ b/csharp/src/Client/AdbcDataReader.cs
@@ -91,7 +91,7 @@ namespace Apache.Arrow.Adbc.Client
/// <summary>
/// The total number of record batches in the result.
/// </summary>
- public int TotalBatches { get; set; }
+ public int TotalBatches { get; private set; }
private RecordBatch RecordBatch
{
diff --git a/csharp/test/Drivers/Interop/Snowflake/CastTests.cs
b/csharp/test/Drivers/Interop/Snowflake/CastTests.cs
index a46734768..dd5ab2e7d 100644
--- a/csharp/test/Drivers/Interop/Snowflake/CastTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/CastTests.cs
@@ -408,7 +408,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
{
statement.SqlQuery = string.Format("ALTER SESSION SET TIMEZONE =
'{0}'", timezone);
UpdateResult result = statement.ExecuteUpdate();
- Assert.Equal(-1, result.AffectedRows);
+ Assert.Equal(1, result.AffectedRows);
}
public void Dispose()
diff --git a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
index 04cf0d1eb..f90e278f1 100644
--- a/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/ClientTests.cs
@@ -54,7 +54,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
{
string[] queries =
SnowflakeTestingUtils.GetQueries(testConfiguration);
- List<int> expectedResults = new List<int>() { -1, 1, 1 };
+ List<int> expectedResults = new List<int>() { 1, 1, 1 };
Tests.ClientTests.CanClientExecuteUpdate(adbcConnection,
testConfiguration, queries, expectedResults);
}
diff --git a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
index 0dcff9aa9..da27ef1b5 100644
--- a/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
+++ b/csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
@@ -100,7 +100,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake
{
string[] queries =
SnowflakeTestingUtils.GetQueries(_testConfiguration);
- List<int> expectedResults = new List<int>() { -1, 1, 1 };
+ List<int> expectedResults = new List<int>() { 1, 1, 1 };
for (int i = 0; i < queries.Length; i++)
{