This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 4d08122c6 feat(csharp/src/Drivers/Apache/Spark): implement async
overrides for Spark driver (#1830)
4d08122c6 is described below
commit 4d08122c6c6687c9e5839eaa7bda65b14ea1eac4
Author: Bruce Irschick <[email protected]>
AuthorDate: Wed May 8 19:30:44 2024 -0700
feat(csharp/src/Drivers/Apache/Spark): implement async overrides for Spark
driver (#1830)
Implements asynchronous overrides for the Spark driver.
* Implements ExecuteUpdateAsync
* Implements ExecuteQueryAsync
* Refactor methods and names to use ...Async, where appropriate.
* Refactor tests base to use async methods
---
.../Drivers/Apache/Hive2/HiveServer2Connection.cs | 9 +--
.../Drivers/Apache/Hive2/HiveServer2Statement.cs | 83 ++++++++++++++++++----
.../src/Drivers/Apache/Impala/ImpalaConnection.cs | 5 +-
csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs | 2 +-
.../src/Drivers/Apache/Impala/ImpalaStatement.cs | 38 ++++------
csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 47 ++++--------
csharp/src/Drivers/Apache/Spark/SparkDatabase.cs | 2 +-
csharp/src/Drivers/Apache/Spark/SparkStatement.cs | 58 ++-------------
csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs | 33 +++++++++
csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs | 59 +++++++--------
.../Apache/Spark/BinaryBooleanValueTests.cs | 10 +--
.../Drivers/Apache/Spark/ComplexTypesValueTests.cs | 6 +-
.../Drivers/Apache/Spark/DateTimeValueTests.cs | 14 ++--
csharp/test/Drivers/Apache/Spark/DriverTests.cs | 46 ++++++++++--
.../test/Drivers/Apache/Spark/NumericValueTests.cs | 68 +++++++++---------
csharp/test/Drivers/Apache/Spark/SparkTestBase.cs | 5 +-
.../test/Drivers/Apache/Spark/StringValueTests.cs | 20 +++---
17 files changed, 276 insertions(+), 229 deletions(-)
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
index d923a2afa..9e11cec10 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
@@ -46,17 +46,18 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
get { return this.client ?? throw new
InvalidOperationException("connection not open"); }
}
- public void Open()
+ internal async Task OpenAsync()
{
- TProtocol protocol = CreateProtocol();
+ TProtocol protocol = await CreateProtocolAsync();
this.transport = protocol.Transport;
this.client = new TCLIService.Client(protocol);
- var s0 = this.client.OpenSession(CreateSessionRequest()).Result;
+ var s0 = await this.client.OpenSession(CreateSessionRequest());
this.sessionHandle = s0.SessionHandle;
}
- protected abstract TProtocol CreateProtocol();
+ protected abstract ValueTask<TProtocol> CreateProtocolAsync();
+
protected abstract TOpenSessionReq CreateSessionRequest();
public override IArrowArrayStream GetObjects(GetObjectsDepth depth,
string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern,
IReadOnlyList<string>? tableTypes, string? columnNamePattern)
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
index b8ec90a0a..cc4abb964 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
@@ -15,15 +15,19 @@
* limitations under the License.
*/
-using System.Threading;
+using System;
+using System.Threading.Tasks;
+using Apache.Arrow.Ipc;
using Apache.Hive.Service.Rpc.Thrift;
namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{
public abstract class HiveServer2Statement : AdbcStatement
{
- protected HiveServer2Connection connection;
- protected TOperationHandle? operationHandle;
+ private const int PollTimeMillisecondsDefault = 500;
+ private const int BatchSizeDefault = 50000;
+ protected internal HiveServer2Connection connection;
+ protected internal TOperationHandle? operationHandle;
protected HiveServer2Statement(HiveServer2Connection connection)
{
@@ -34,11 +38,64 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{
}
- protected void ExecuteStatement()
+ protected abstract IArrowArrayStream NewReader<T>(T statement, Schema
schema) where T : HiveServer2Statement;
+
+ public override QueryResult ExecuteQuery() =>
ExecuteQueryAsync().AsTask().Result;
+
+ public override UpdateResult ExecuteUpdate() =>
ExecuteUpdateAsync().Result;
+
+ public override async ValueTask<QueryResult> ExecuteQueryAsync()
+ {
+ await ExecuteStatementAsync();
+ await PollForResponseAsync();
+ Schema schema = await GetSchemaAsync();
+
+ // TODO: Ensure this is set dynamically based on server
capabilities
+ return new QueryResult(-1, NewReader(this, schema));
+ }
+
+ public override async Task<UpdateResult> ExecuteUpdateAsync()
+ {
+ const string NumberOfAffectedRowsColumnName = "num_affected_rows";
+
+ QueryResult queryResult = await ExecuteQueryAsync();
+ if (queryResult.Stream == null)
+ {
+ throw new AdbcException("no data found");
+ }
+
+ using IArrowArrayStream stream = queryResult.Stream;
+
+ // Check if the affected rows columns are returned in the result.
+ Field affectedRowsField =
stream.Schema.GetFieldByName(NumberOfAffectedRowsColumnName);
+ if (affectedRowsField != null && affectedRowsField.DataType.TypeId
!= Types.ArrowTypeId.Int64)
+ {
+ throw new AdbcException($"Unexpected data type for column:
'{NumberOfAffectedRowsColumnName}'", new
ArgumentException(NumberOfAffectedRowsColumnName));
+ }
+
+ // If no altered rows, i.e. DDC statements, then -1 is the default.
+ long? affectedRows = null;
+ while (true)
+ {
+ using RecordBatch nextBatch = await
stream.ReadNextRecordBatchAsync();
+ if (nextBatch == null) { break; }
+ Int64Array numOfModifiedArray =
(Int64Array)nextBatch.Column(NumberOfAffectedRowsColumnName);
+ // Note: should only have one item, but iterate for
completeness
+ for (int i = 0; i < numOfModifiedArray.Length; i++)
+ {
+ // Note: handle the case where the affected rows are zero
(0).
+ affectedRows = (affectedRows ?? 0) +
numOfModifiedArray.GetValue(i).GetValueOrDefault(0);
+ }
+ }
+
+ return new UpdateResult(affectedRows ?? -1);
+ }
+
+ protected async Task ExecuteStatementAsync()
{
TExecuteStatementReq executeRequest = new
TExecuteStatementReq(this.connection.sessionHandle, this.SqlQuery);
SetStatementProperties(executeRequest);
- var executeResponse =
this.connection.Client.ExecuteStatement(executeRequest).Result;
+ TExecuteStatementResp executeResponse = await
this.connection.Client.ExecuteStatement(executeRequest);
if (executeResponse.Status.StatusCode == TStatusCode.ERROR_STATUS)
{
throw new
HiveServer2Exception(executeResponse.Status.ErrorMessage)
@@ -48,24 +105,28 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
this.operationHandle = executeResponse.OperationHandle;
}
- protected void PollForResponse()
+ protected async Task PollForResponseAsync()
{
TGetOperationStatusResp? statusResponse = null;
do
{
- if (statusResponse != null) { Thread.Sleep(500); }
+ if (statusResponse != null) { await
Task.Delay(PollTimeMilliseconds); }
TGetOperationStatusReq request = new
TGetOperationStatusReq(this.operationHandle);
- statusResponse =
this.connection.Client.GetOperationStatus(request).Result;
+ statusResponse = await
this.connection.Client.GetOperationStatus(request);
} while (statusResponse.OperationState ==
TOperationState.PENDING_STATE || statusResponse.OperationState ==
TOperationState.RUNNING_STATE);
}
- protected Schema GetSchema()
+ protected async ValueTask<Schema> GetSchemaAsync()
{
TGetResultSetMetadataReq request = new
TGetResultSetMetadataReq(this.operationHandle);
- TGetResultSetMetadataResp response =
this.connection.Client.GetResultSetMetadata(request).Result;
+ TGetResultSetMetadataResp response = await
this.connection.Client.GetResultSetMetadata(request);
return SchemaParser.GetArrowSchema(response.Schema);
}
+ protected internal int PollTimeMilliseconds { get; } =
PollTimeMillisecondsDefault;
+
+ protected internal int BatchSize { get; } = BatchSizeDefault;
+
public override void Dispose()
{
if (this.operationHandle != null)
@@ -77,7 +138,5 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
base.Dispose();
}
-
-
}
}
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
index 457647676..e9e0018d1 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
@@ -16,6 +16,7 @@
*/
using System.Collections.Generic;
+using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Ipc;
using Apache.Hive.Service.Rpc.Thrift;
@@ -32,7 +33,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
{
}
- protected override TProtocol CreateProtocol()
+ protected override ValueTask<TProtocol> CreateProtocolAsync()
{
string hostName = properties["HostName"];
string? tmp;
@@ -44,7 +45,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
TConfiguration config = new TConfiguration();
TTransport transport = new ThriftSocketTransport(hostName, port,
config);
- return new TBinaryProtocol(transport);
+ return new ValueTask<TProtocol>(new TBinaryProtocol(transport));
}
protected override TOpenSessionReq CreateSessionRequest()
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs
b/csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs
index 70cb35b66..931197f45 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs
@@ -31,7 +31,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
public override AdbcConnection Connect(IReadOnlyDictionary<string,
string>? properties)
{
ImpalaConnection connection = new
ImpalaConnection(this.properties);
- connection.Open();
+ connection.OpenAsync().Wait();
return connection;
}
}
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
index 855f42794..5fe408c38 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
@@ -35,62 +35,48 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
{
}
- public override QueryResult ExecuteQuery()
- {
- ExecuteStatement();
- PollForResponse();
-
- Schema schema = GetSchema();
-
- return new QueryResult(-1, new HiveServer2Reader(this, schema));
- }
-
- public override UpdateResult ExecuteUpdate()
- {
- throw new NotImplementedException();
- }
-
public override object GetValue(IArrowArray arrowArray, int index)
{
throw new NotSupportedException();
}
+ protected override IArrowArrayStream NewReader<T>(T statement, Schema
schema) => new HiveServer2Reader(statement, schema);
+
class HiveServer2Reader : IArrowArrayStream
{
- ImpalaStatement? statement;
- Schema schema;
+ HiveServer2Statement? statement;
int counter;
- public HiveServer2Reader(ImpalaStatement statement, Schema schema)
+ public HiveServer2Reader(HiveServer2Statement statement, Schema
schema)
{
this.statement = statement;
- this.schema = schema;
+ this.Schema = schema;
}
- public Schema Schema { get { return schema; } }
+ public Schema Schema { get; }
- public ValueTask<RecordBatch?>
ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
+ public async ValueTask<RecordBatch?>
ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
{
if (this.statement == null)
{
- return new ValueTask<RecordBatch?>((RecordBatch?)null);
+ return null;
}
TFetchResultsReq request = new
TFetchResultsReq(this.statement.operationHandle, TFetchOrientation.FETCH_NEXT,
50000);
- TFetchResultsResp response =
this.statement.connection.Client.FetchResults(request).Result;
+ TFetchResultsResp response = await
this.statement.connection.Client.FetchResults(request, cancellationToken);
var buffer = new System.IO.MemoryStream();
- response.WriteAsync(new TBinaryProtocol(new
TStreamTransport(null, buffer, new TConfiguration())),
cancellationToken).Wait();
+ await response.WriteAsync(new TBinaryProtocol(new
TStreamTransport(null, buffer, new TConfiguration())), cancellationToken);
System.IO.File.WriteAllBytes(string.Format("d:/src/buffer{0}.bin",
this.counter++), buffer.ToArray());
- RecordBatch result = new RecordBatch(this.schema,
response.Results.Columns.Select(GetArray),
GetArray(response.Results.Columns[0]).Length);
+ RecordBatch result = new RecordBatch(this.Schema,
response.Results.Columns.Select(GetArray),
GetArray(response.Results.Columns[0]).Length);
if (!response.HasMoreRows)
{
this.statement = null;
}
- return new ValueTask<RecordBatch?>(result);
+ return result;
}
public void Dispose()
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index 0c6dd6a57..798fc9c47 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -37,19 +37,19 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
{
public class SparkConnection : HiveServer2Connection
{
- const string userAgent = "MicrosoftSparkODBCDriver/2.7.6.1014";
+ const string UserAgent = "MicrosoftSparkODBCDriver/2.7.6.1014";
- readonly AdbcInfoCode[] infoSupportedCodes = new [] {
+ readonly AdbcInfoCode[] infoSupportedCodes = new[] {
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.DriverArrowVersion,
AdbcInfoCode.VendorName
};
- const string infoDriverName = "ADBC Spark Driver";
- const string infoDriverVersion = "1.0.0";
- const string infoVendorName = "Spark";
- const string infoDriverArrowVersion = "1.0.0";
+ const string InfoDriverName = "ADBC Spark Driver";
+ const string InfoDriverVersion = "1.0.0";
+ const string InfoVendorName = "Spark";
+ const string InfoDriverArrowVersion = "1.0.0";
internal static TSparkGetDirectResults sparkGetDirectResults = new
TSparkGetDirectResults(1000);
@@ -83,7 +83,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
{
}
- protected override TProtocol CreateProtocol()
+ protected override async ValueTask<TProtocol> CreateProtocolAsync()
{
Trace.TraceError($"create protocol with {properties.Count}
properties.");
@@ -101,12 +101,10 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
else
token = properties["password"];
- string uri = "https://" + hostName + "/" + path;
-
HttpClient httpClient = new HttpClient();
- httpClient.BaseAddress = new Uri(uri);
+ httpClient.BaseAddress = new UriBuilder(Uri.UriSchemeHttps,
hostName, -1, path).Uri;
httpClient.DefaultRequestHeaders.Authorization = new
AuthenticationHeaderValue("Bearer", token);
- httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(userAgent);
+ httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(UserAgent);
httpClient.DefaultRequestHeaders.AcceptEncoding.Clear();
httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new
StringWithQualityHeaderValue("identity"));
httpClient.DefaultRequestHeaders.ExpectContinue = false;
@@ -116,7 +114,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
ThriftHttpTransport transport = new
ThriftHttpTransport(httpClient, config);
// can switch to the one below if want to use the experimental one
with IPeekableTransport
// ThriftHttpTransport transport = new
ThriftHttpTransport(httpClient, config);
- transport.OpenAsync(CancellationToken.None).Wait();
+ await transport.OpenAsync(CancellationToken.None);
return new TBinaryProtocol(transport);
}
@@ -134,23 +132,6 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
return new SparkStatement(this);
}
- public override void Dispose()
- {
- /*
- if (this.client != null)
- {
- TCloseSessionReq r6 = new TCloseSessionReq(this.sessionHandle);
- this.client.CloseSession(r6).Wait();
-
- this.transport.Close();
- this.client.Dispose();
-
- this.transport = null;
- this.client = null;
- }
- */
- }
-
public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode>
codes)
{
const int strValTypeID = 0;
@@ -206,25 +187,25 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
- stringInfoBuilder.Append(infoDriverName);
+ stringInfoBuilder.Append(InfoDriverName);
break;
case AdbcInfoCode.DriverVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
- stringInfoBuilder.Append(infoDriverVersion);
+ stringInfoBuilder.Append(InfoDriverVersion);
break;
case AdbcInfoCode.DriverArrowVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
- stringInfoBuilder.Append(infoDriverArrowVersion);
+ stringInfoBuilder.Append(InfoDriverArrowVersion);
break;
case AdbcInfoCode.VendorName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
- stringInfoBuilder.Append(infoVendorName);
+ stringInfoBuilder.Append(InfoVendorName);
break;
default:
infoNameBuilder.Append((UInt32)code);
diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs
b/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs
index 726053d99..92687e40d 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs
@@ -31,7 +31,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
public override AdbcConnection Connect(IReadOnlyDictionary<string,
string>? properties)
{
SparkConnection connection = new SparkConnection(this.properties);
- connection.Open();
+ connection.OpenAsync().Wait();
return connection;
}
}
diff --git a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
index 0ec2dc136..d36c355c8 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
@@ -55,67 +55,17 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
};
}
- public override QueryResult ExecuteQuery()
- {
- ExecuteStatement();
- PollForResponse();
- Schema schema = GetSchema();
-
- // TODO: Ensure this is set dynamically based on server
capabilities
- return new QueryResult(-1, new SparkReader(this, schema));
- }
-
- public override UpdateResult ExecuteUpdate()
- {
- const string NumberOfAffectedRowsColumnName = "num_affected_rows";
-
- QueryResult queryResult = ExecuteQuery();
- if (queryResult.Stream == null)
- {
- throw new AdbcException("no data found");
- }
-
- using IArrowArrayStream stream = queryResult.Stream;
-
- // Check if the affected rows columns are returned in the result.
- Field affectedRowsField =
stream.Schema.GetFieldByName(NumberOfAffectedRowsColumnName);
- if (affectedRowsField != null && affectedRowsField.DataType.TypeId
!= Types.ArrowTypeId.Int64)
- {
- throw new AdbcException($"Unexpected data type for column:
'{NumberOfAffectedRowsColumnName}'", new
ArgumentException(NumberOfAffectedRowsColumnName));
- }
-
- // If no altered rows, i.e. DDC statements, then -1 is the default.
- long? affectedRows = null;
- while (true)
- {
- using RecordBatch nextBatch =
stream.ReadNextRecordBatchAsync().Result;
- if (nextBatch == null) { break; }
- Int64Array numOfModifiedArray =
(Int64Array)nextBatch.Column(NumberOfAffectedRowsColumnName);
- // Note: should only have one item, but iterate for
completeness
- for (int i = 0; i < numOfModifiedArray.Length; i++)
- {
- // Note: handle the case where the affected rows are zero
(0).
- affectedRows = (affectedRows ?? 0) +
numOfModifiedArray.GetValue(i).GetValueOrDefault(0);
- }
- }
-
- return new UpdateResult(affectedRows ?? -1);
- }
-
- public override object? GetValue(IArrowArray arrowArray, int index)
- {
- return base.GetValue(arrowArray, index);
- }
+ protected override IArrowArrayStream NewReader<T>(T statement, Schema
schema) => new SparkReader(statement, schema);
sealed class SparkReader : IArrowArrayStream
{
- SparkStatement? statement;
+ HiveServer2Statement? statement;
Schema schema;
List<TSparkArrowBatch>? batches;
int index;
IArrowReader? reader;
- public SparkReader(SparkStatement statement, Schema schema)
+ public SparkReader(HiveServer2Statement statement, Schema schema)
{
this.statement = statement;
this.schema = schema;
@@ -151,7 +101,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
return null;
}
- TFetchResultsReq request = new
TFetchResultsReq(this.statement.operationHandle, TFetchOrientation.FETCH_NEXT,
50000);
+ TFetchResultsReq request = new
TFetchResultsReq(this.statement.operationHandle, TFetchOrientation.FETCH_NEXT,
this.statement.BatchSize);
TFetchResultsResp response = await
this.statement.connection.client!.FetchResults(request, cancellationToken);
this.batches = response.Results.ArrowBatches;
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs
b/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs
index 37f870471..b6205aea1 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs
@@ -15,6 +15,7 @@
* limitations under the License.
*/
+using System.Threading.Tasks;
using Xunit;
namespace Apache.Arrow.Adbc.Tests
@@ -55,5 +56,37 @@ namespace Apache.Arrow.Adbc.Tests
Assert.True(queryResult.RowCount == count, "The RowCount value
does not match the counted records");
}
}
+
+ /// <summary>
+ /// Validates that a <see cref="QueryResult"/> contains a number
+ /// of records.
+ /// </summary>
+ /// <param name="queryResult">
+ /// The query result.
+ /// </param>
+ /// <param name="expectedNumberOfResults">
+ /// The number of records.
+ /// </param>
+ public static async Task CanExecuteQueryAsync(QueryResult queryResult,
long expectedNumberOfResults)
+ {
+ long count = 0;
+
+ while (queryResult.Stream != null)
+ {
+ RecordBatch nextBatch = await
queryResult.Stream.ReadNextRecordBatchAsync();
+ if (nextBatch == null) { break; }
+ count += nextBatch.Length;
+ }
+
+ Assert.True(expectedNumberOfResults == count, $"The parsed records
({count}) differ from the expected amount ({expectedNumberOfResults})");
+
+ // if the values were set, make sure they are correct
+ if (queryResult.RowCount != -1)
+ {
+ Assert.True(queryResult.RowCount == expectedNumberOfResults,
"The RowCount value does not match the expected results");
+
+ Assert.True(queryResult.RowCount == count, "The RowCount value
does not match the counted records");
+ }
+ }
}
}
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs
b/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs
index 557142b90..92271678f 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/TestBase.cs
@@ -65,11 +65,11 @@ namespace Apache.Arrow.Adbc.Tests
/// <param name="statement">The ADBC statement to apply the
update.</param>
/// <param name="columns">The columns definition in the native SQL
dialect.</param>
/// <returns>A disposable temporary table object that will drop the
table when disposed.</returns>
- protected virtual TemporaryTable NewTemporaryTable(AdbcStatement
statement, string columns)
+ protected virtual async ValueTask<TemporaryTable>
NewTemporaryTableAsync(AdbcStatement statement, string columns)
{
string tableName = NewTableName();
string sqlUpdate = string.Format("CREATE TEMPORARY IF NOT EXISTS
TABLE {0} ({1})", tableName, columns);
- return TemporaryTable.NewTemporaryTable(statement, tableName,
sqlUpdate);
+ return await TemporaryTable.NewTemporaryTableAsync(statement,
tableName, sqlUpdate);
}
/// <summary>
@@ -195,12 +195,12 @@ namespace Apache.Arrow.Adbc.Tests
/// <param name="value">The value to insert, select and delete.</param>
/// <param name="formattedValue">The formated value to insert, select
and delete.</param>
/// <returns></returns>
- protected async Task ValidateInsertSelectDeleteSingleValue(string
selectStatement, string tableName, string columnName, object value, string?
formattedValue = null)
+ protected async Task ValidateInsertSelectDeleteSingleValueAsync(string
selectStatement, string tableName, string columnName, object value, string?
formattedValue = null)
{
- InsertSingleValue(tableName, columnName, formattedValue ??
value?.ToString());
- await SelectAndValidateValues(selectStatement, value, 1);
+ await InsertSingleValueAsync(tableName, columnName, formattedValue
?? value?.ToString());
+ await SelectAndValidateValuesAsync(selectStatement, value, 1);
string whereClause = GetWhereClause(columnName, formattedValue ??
value);
- DeleteFromTable(tableName, whereClause, 1);
+ await DeleteFromTableAsync(tableName, whereClause, 1);
}
/// <summary>
@@ -211,12 +211,12 @@ namespace Apache.Arrow.Adbc.Tests
/// <param name="value">The value to insert, select and delete.</param>
/// <param name="formattedValue">The formated value to insert, select
and delete.</param>
/// <returns></returns>
- protected async Task ValidateInsertSelectDeleteSingleValue(string
tableName, string columnName, object? value, string? formattedValue = null)
+ protected async Task ValidateInsertSelectDeleteSingleValueAsync(string
tableName, string columnName, object? value, string? formattedValue = null)
{
- InsertSingleValue(tableName, columnName, formattedValue ??
value?.ToString());
- await SelectAndValidateValues(tableName, columnName, value, 1,
formattedValue);
+ await InsertSingleValueAsync(tableName, columnName, formattedValue
?? value?.ToString());
+ await SelectAndValidateValuesAsync(tableName, columnName, value,
1, formattedValue);
string whereClause = GetWhereClause(columnName, formattedValue ??
value);
- DeleteFromTable(tableName, whereClause, 1);
+ await DeleteFromTableAsync(tableName, whereClause, 1);
}
/// <summary>
@@ -225,12 +225,12 @@ namespace Apache.Arrow.Adbc.Tests
/// <param name="tableName">The name of the table to use.</param>
/// <param name="columnName">The name of the column.</param>
/// <param name="value">The value to insert.</param>
- protected virtual void InsertSingleValue(string tableName, string
columnName, string? value)
+ protected virtual async Task InsertSingleValueAsync(string tableName,
string columnName, string? value)
{
string insertNumberStatement = GetInsertValueStatement(tableName,
columnName, value);
OutputHelper?.WriteLine(insertNumberStatement);
Statement.SqlQuery = insertNumberStatement;
- UpdateResult updateResult = Statement.ExecuteUpdate();
+ UpdateResult updateResult = await Statement.ExecuteUpdateAsync();
Assert.Equal(1, updateResult.AffectedRows);
}
@@ -250,12 +250,12 @@ namespace Apache.Arrow.Adbc.Tests
/// <param name="tableName">The name of the table to use.</param>
/// <param name="whereClause">The WHERE clause string.</param>
/// <param name="expectedRowsAffected">The expected number of affected
rows.</param>
- protected virtual void DeleteFromTable(string tableName, string
whereClause, int expectedRowsAffected)
+ protected virtual async Task DeleteFromTableAsync(string tableName,
string whereClause, int expectedRowsAffected)
{
string deleteNumberStatement = GetDeleteValueStatement(tableName,
whereClause);
OutputHelper?.WriteLine(deleteNumberStatement);
Statement.SqlQuery = deleteNumberStatement;
- UpdateResult updateResult = Statement.ExecuteUpdate();
+ UpdateResult updateResult = await Statement.ExecuteUpdateAsync();
Assert.Equal(expectedRowsAffected, updateResult.AffectedRows);
}
@@ -276,10 +276,10 @@ namespace Apache.Arrow.Adbc.Tests
/// <param name="value">The value to select and validate.</param>
/// <param name="expectedLength">The number of expected results
(rows).</param>
/// <returns></returns>
- protected virtual async Task SelectAndValidateValues(string table,
string columnName, object? value, int expectedLength, string? formattedValue =
null)
+ protected virtual async Task SelectAndValidateValuesAsync(string
table, string columnName, object? value, int expectedLength, string?
formattedValue = null)
{
string selectNumberStatement =
GetSelectSingleValueStatement(table, columnName, formattedValue ?? value);
- await SelectAndValidateValues(selectNumberStatement, value,
expectedLength);
+ await SelectAndValidateValuesAsync(selectNumberStatement, value,
expectedLength);
}
/// <summary>
@@ -289,11 +289,11 @@ namespace Apache.Arrow.Adbc.Tests
/// <param name="value">The value to select and validate.</param>
/// <param name="expectedLength">The number of expected results
(rows).</param>
/// <returns></returns>
- protected virtual async Task SelectAndValidateValues(string
selectStatement, object? value, int expectedLength)
+ protected virtual async Task SelectAndValidateValuesAsync(string
selectStatement, object? value, int expectedLength)
{
Statement.SqlQuery = selectStatement;
OutputHelper?.WriteLine(selectStatement);
- QueryResult queryResult = Statement.ExecuteQuery();
+ QueryResult queryResult = await Statement.ExecuteQueryAsync();
int actualLength = 0;
using (IArrowArrayStream stream = queryResult.Stream ?? throw new
InvalidOperationException("stream is null"))
{
@@ -571,12 +571,10 @@ namespace Apache.Arrow.Adbc.Tests
/// </summary>
public string TableName { get; }
- private TemporaryTable(AdbcStatement statement, string tableName,
string sqlQuery)
+ private TemporaryTable(AdbcStatement statement, string tableName)
{
TableName = tableName;
_statement = statement;
- statement.SqlQuery = sqlQuery;
- statement.ExecuteUpdate();
}
/// <summary>
@@ -586,18 +584,20 @@ namespace Apache.Arrow.Adbc.Tests
/// <param name="tableName">The name of temporary table to
create.</param>
/// <param name="sqlUpdate">The SQL query to create the table in
the native SQL dialect.</param>
/// <returns></returns>
- public static TemporaryTable NewTemporaryTable(AdbcStatement
statement, string tableName, string sqlUpdate)
+ public static async ValueTask<TemporaryTable>
NewTemporaryTableAsync(AdbcStatement statement, string tableName, string
sqlUpdate)
{
- return new TemporaryTable(statement, tableName, sqlUpdate);
+ statement.SqlQuery = sqlUpdate;
+ await statement.ExecuteUpdateAsync();
+ return new TemporaryTable(statement, tableName);
}
/// <summary>
/// Drops the tables.
/// </summary>
- protected virtual void Drop()
+ protected virtual async Task DropAsync()
{
_statement.SqlQuery = $"DROP TABLE {TableName}";
- _statement.ExecuteUpdate();
+ await _statement.ExecuteUpdateAsync();
}
protected virtual void Dispose(bool disposing)
@@ -606,7 +606,7 @@ namespace Apache.Arrow.Adbc.Tests
{
if (disposing)
{
- Drop();
+ DropAsync().Wait();
}
_disposedValue = true;
@@ -624,6 +624,7 @@ namespace Apache.Arrow.Adbc.Tests
protected class TemporarySchema : IDisposable
{
private bool _disposedValue;
+ private readonly AdbcStatement _statement;
private TemporarySchema(string catalogName, AdbcStatement
statement)
{
@@ -632,17 +633,17 @@ namespace Apache.Arrow.Adbc.Tests
_statement = statement;
}
- public static TemporarySchema NewTemporarySchema(string
catalogName, AdbcStatement statement)
+ public static async ValueTask<TemporarySchema>
NewTemporarySchemaAsync(string catalogName, AdbcStatement statement)
{
TemporarySchema schema = new TemporarySchema(catalogName,
statement);
statement.SqlQuery = $"CREATE SCHEMA IF NOT EXISTS
{schema.CatalogName}.{schema.SchemaName}";
- statement.ExecuteUpdate();
+ await statement.ExecuteUpdateAsync();
return schema;
}
public string CatalogName { get; }
+
public string SchemaName { get; }
- private AdbcStatement _statement;
protected virtual void Dispose(bool disposing)
{
diff --git a/csharp/test/Drivers/Apache/Spark/BinaryBooleanValueTests.cs
b/csharp/test/Drivers/Apache/Spark/BinaryBooleanValueTests.cs
index a4f26e40c..fa19558f6 100644
--- a/csharp/test/Drivers/Apache/Spark/BinaryBooleanValueTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/BinaryBooleanValueTests.cs
@@ -52,9 +52,9 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestBinaryData(byte[] value)
{
string columnName = "BINARYTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} {1}", columnName, "BINARY"));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName,
"BINARY"));
string formattedValue =
$"X'{BitConverter.ToString(value).Replace("-", "")}'";
- await ValidateInsertSelectDeleteSingleValue(
+ await ValidateInsertSelectDeleteSingleValueAsync(
table.TableName,
columnName,
value,
@@ -71,9 +71,9 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestBooleanData(bool? value)
{
string columnName = "BOOLEANTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} {1}", columnName, "BOOLEAN"));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName,
"BOOLEAN"));
string? formattedValue = value == null ? null :
QuoteValue($"{value?.ToString(CultureInfo.InvariantCulture)}");
- await ValidateInsertSelectDeleteSingleValue(
+ await ValidateInsertSelectDeleteSingleValueAsync(
table.TableName,
columnName,
value,
@@ -102,7 +102,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
{
string selectStatement = $"SELECT {projectionClause};";
// Note: by default, this returns as String type, not NULL type.
- await SelectAndValidateValues(selectStatement, null, 1);
+ await SelectAndValidateValuesAsync(selectStatement, null, 1);
}
}
}
diff --git a/csharp/test/Drivers/Apache/Spark/ComplexTypesValueTests.cs
b/csharp/test/Drivers/Apache/Spark/ComplexTypesValueTests.cs
index 87f30ce89..2968c68f5 100644
--- a/csharp/test/Drivers/Apache/Spark/ComplexTypesValueTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/ComplexTypesValueTests.cs
@@ -49,7 +49,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestArrayData(string projection, string value)
{
string selectStatement = $"SELECT {projection};";
- await SelectAndValidateValues(selectStatement, value, 1);
+ await SelectAndValidateValuesAsync(selectStatement, value, 1);
}
/// <summary>
@@ -61,7 +61,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestMapData(string projection, string value)
{
string selectStatement = $"SELECT {projection};";
- await SelectAndValidateValues(selectStatement, value, 1);
+ await SelectAndValidateValuesAsync(selectStatement, value, 1);
}
/// <summary>
@@ -73,7 +73,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestStructData(string projection, string value)
{
string selectStatement = $"SELECT {projection};";
- await SelectAndValidateValues(selectStatement, value, 1);
+ await SelectAndValidateValuesAsync(selectStatement, value, 1);
}
}
}
diff --git a/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs
b/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs
index b418e39bd..c2278f388 100644
--- a/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/DateTimeValueTests.cs
@@ -58,12 +58,12 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestTimestampData(DateTimeOffset value, string
columnType)
{
string columnName = "TIMESTAMPTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} {1}", columnName, columnType));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName,
columnType));
string formattedValue = $"{value.ToString(DateTimeZoneFormat,
CultureInfo.InvariantCulture)}";
DateTimeOffset truncatedValue =
DateTimeOffset.ParseExact(formattedValue, DateTimeZoneFormat,
CultureInfo.InvariantCulture);
- await ValidateInsertSelectDeleteSingleValue(
+ await ValidateInsertSelectDeleteSingleValueAsync(
table.TableName,
columnName,
truncatedValue,
@@ -81,12 +81,12 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
Skip.If(value == DateTimeOffset.MinValue);
string columnName = "TIMESTAMPTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} {1}", columnName, columnType));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName,
columnType));
string formattedValue = $"{value.ToString(DateFormat,
CultureInfo.InvariantCulture)}";
DateTimeOffset truncatedValue =
DateTimeOffset.ParseExact(formattedValue, DateFormat,
CultureInfo.InvariantCulture);
- await ValidateInsertSelectDeleteSingleValue(
+ await ValidateInsertSelectDeleteSingleValueAsync(
table.TableName,
columnName,
// Remove timezone offset
@@ -102,12 +102,12 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestDateData(DateTimeOffset value, string columnType)
{
string columnName = "DATETYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} {1}", columnName, columnType));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName,
columnType));
string formattedValue = $"{value.ToString(DateFormat,
CultureInfo.InvariantCulture)}";
DateTimeOffset truncatedValue =
DateTimeOffset.ParseExact(formattedValue, DateFormat,
CultureInfo.InvariantCulture);
- await ValidateInsertSelectDeleteSingleValue(
+ await ValidateInsertSelectDeleteSingleValueAsync(
table.TableName,
columnName,
// Remove timezone offset
@@ -148,7 +148,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestIntervalData(string intervalClause, string value)
{
string selectStatement = $"SELECT {intervalClause} AS
INTERVAL_VALUE;";
- await SelectAndValidateValues(selectStatement, value, 1);
+ await SelectAndValidateValuesAsync(selectStatement, value, 1);
}
public static IEnumerable<object[]> TimestampData(string columnType)
diff --git a/csharp/test/Drivers/Apache/Spark/DriverTests.cs
b/csharp/test/Drivers/Apache/Spark/DriverTests.cs
index ff0afdd50..a4f3a4607 100644
--- a/csharp/test/Drivers/Apache/Spark/DriverTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/DriverTests.cs
@@ -18,6 +18,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Threading.Tasks;
using Apache.Arrow.Adbc.Tests.Metadata;
using Apache.Arrow.Adbc.Tests.Xunit;
using Apache.Arrow.Ipc;
@@ -79,13 +80,13 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
/// Validates if the driver can call GetInfo.
/// </summary>
[SkippableFact, Order(2)]
- public void CanGetInfo()
+ public async Task CanGetInfo()
{
AdbcConnection adbcConnection = NewConnection();
using IArrowArrayStream stream = adbcConnection.GetInfo(new
List<AdbcInfoCode>() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion,
AdbcInfoCode.VendorName });
- RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;
+ RecordBatch recordBatch = await stream.ReadNextRecordBatchAsync();
UInt32Array infoNameArray =
(UInt32Array)recordBatch.Column("info_name");
List<string> expectedValues = new List<string>() { "DriverName",
"DriverVersion", "VendorName" };
@@ -254,10 +255,10 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
{
string catalogName = TestConfiguration.Metadata.Catalog;
string schemaPrefix = Guid.NewGuid().ToString().Replace("-", "");
- using TemporarySchema schema =
TemporarySchema.NewTemporarySchema(catalogName, Statement);
+ using TemporarySchema schema =
TemporarySchema.NewTemporarySchemaAsync(catalogName, Statement).Result;
string schemaName = schema.SchemaName;
string fullTableName =
$"{DelimitIdentifier(catalogName)}.{DelimitIdentifier(schemaName)}.{DelimitIdentifier(tableName)}";
- using TemporaryTable temporaryTable =
TemporaryTable.NewTemporaryTable(Statement, fullTableName, $"CREATE TABLE IF
NOT EXISTS {fullTableName} (INDEX INT)");
+ using TemporaryTable temporaryTable =
TemporaryTable.NewTemporaryTableAsync(Statement, fullTableName, $"CREATE TABLE
IF NOT EXISTS {fullTableName} (INDEX INT)").Result;
using IArrowArrayStream stream = Connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.Tables,
@@ -308,13 +309,13 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
/// Validates if the driver can call GetTableTypes.
/// </summary>
[SkippableFact, Order(9)]
- public void CanGetTableTypes()
+ public async Task CanGetTableTypes()
{
AdbcConnection adbcConnection = NewConnection();
using IArrowArrayStream arrowArrayStream =
adbcConnection.GetTableTypes();
- RecordBatch recordBatch =
arrowArrayStream.ReadNextRecordBatchAsync().Result;
+ RecordBatch recordBatch = await
arrowArrayStream.ReadNextRecordBatchAsync();
StringArray stringArray =
(StringArray)recordBatch.Column("table_type");
@@ -355,6 +356,39 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
Tests.DriverTests.CanExecuteQuery(queryResult,
TestConfiguration.ExpectedResultsCount);
}
+ /// <summary>
+ /// Validates if the driver can connect to a live server and
+ /// parse the results using the asynchronous methods.
+ /// </summary>
+ [SkippableFact, Order(11)]
+ public async Task CanExecuteQueryAsync()
+ {
+ using AdbcConnection adbcConnection = NewConnection();
+ using AdbcStatement statement = adbcConnection.CreateStatement();
+
+ statement.SqlQuery = TestConfiguration.Query;
+ QueryResult queryResult = await statement.ExecuteQueryAsync();
+
+ await Tests.DriverTests.CanExecuteQueryAsync(queryResult,
TestConfiguration.ExpectedResultsCount);
+ }
+
+ /// <summary>
+ /// Validates if the driver can connect to a live server and
+ /// perform and update asynchronously.
+ /// </summary>
+ [SkippableFact, Order(12)]
+ public async Task CanExecuteUpdateAsync()
+ {
+ using AdbcConnection adbcConnection = NewConnection();
+ using AdbcStatement statement = adbcConnection.CreateStatement();
+ using TemporaryTable temporaryTable = await
NewTemporaryTableAsync(statement, "INDEX INT");
+
+ statement.SqlQuery =
GetInsertValueStatement(temporaryTable.TableName, "INDEX", "1");
+ UpdateResult updateResult = await statement.ExecuteUpdateAsync();
+
+ Assert.Equal(1, updateResult.AffectedRows);
+ }
+
public static IEnumerable<object[]> CatalogNamePatternData()
{
string? catalogName = new
DriverTests(null).TestConfiguration?.Metadata?.Catalog;
diff --git a/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
b/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
index 718fa8f54..f84371e1d 100644
--- a/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
@@ -43,8 +43,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestIntegerSanity(int value)
{
string columnName = "INTTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} INT", columnName));
- await ValidateInsertSelectDeleteSingleValue(table.TableName,
columnName, value);
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} INT", columnName));
+ await ValidateInsertSelectDeleteSingleValueAsync(table.TableName,
columnName, value);
}
/// <summary>
@@ -56,8 +56,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestIntegerMinMax(int value)
{
string columnName = "INTTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} INT", columnName));
- await ValidateInsertSelectDeleteSingleValue(table.TableName,
columnName, value);
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} INT", columnName));
+ await ValidateInsertSelectDeleteSingleValueAsync(table.TableName,
columnName, value);
}
/// <summary>
@@ -69,8 +69,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestLongMinMax(long value)
{
string columnName = "BIGINTTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} BIGINT", columnName));
- await ValidateInsertSelectDeleteSingleValue(table.TableName,
columnName, value);
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} BIGINT", columnName));
+ await ValidateInsertSelectDeleteSingleValueAsync(table.TableName,
columnName, value);
}
/// <summary>
@@ -82,8 +82,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestSmallIntMinMax(short value)
{
string columnName = "SMALLINTTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} SMALLINT", columnName));
- await ValidateInsertSelectDeleteSingleValue(table.TableName,
columnName, value);
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} SMALLINT", columnName));
+ await ValidateInsertSelectDeleteSingleValueAsync(table.TableName,
columnName, value);
}
/// <summary>
@@ -95,8 +95,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestTinyIntMinMax(sbyte value)
{
string columnName = "TINYINTTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} TINYINT", columnName));
- await ValidateInsertSelectDeleteSingleValue(table.TableName,
columnName, value);
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} TINYINT", columnName));
+ await ValidateInsertSelectDeleteSingleValueAsync(table.TableName,
columnName, value);
}
/// <summary>
@@ -111,8 +111,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestSmallNumberRange(string value)
{
string columnName = "SMALLNUMBER";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} DECIMAL(2,0)", columnName));
- await ValidateInsertSelectDeleteSingleValue(table.TableName,
columnName, SqlDecimal.Parse(value));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(2,0)",
columnName));
+ await ValidateInsertSelectDeleteSingleValueAsync(table.TableName,
columnName, SqlDecimal.Parse(value));
}
/// <summary>
@@ -126,8 +126,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestSmallNumberRangeOverlimit(int value)
{
string columnName = "SMALLNUMBER";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} DECIMAL(2,0)", columnName));
- await Assert.ThrowsAsync<HiveServer2Exception>(async () => await
ValidateInsertSelectDeleteSingleValue(table.TableName, columnName, new
SqlDecimal(value)));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(2,0)",
columnName));
+ await Assert.ThrowsAsync<HiveServer2Exception>(async () => await
ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName, new
SqlDecimal(value)));
}
/// <summary>
@@ -142,8 +142,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestLargeScaleNumberRange(string value)
{
string columnName = "LARGESCALENUMBER";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} DECIMAL(38,37)", columnName));
- await ValidateInsertSelectDeleteSingleValue(table.TableName,
columnName, SqlDecimal.Parse(value));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,37)",
columnName));
+ await ValidateInsertSelectDeleteSingleValueAsync(table.TableName,
columnName, SqlDecimal.Parse(value));
}
/// <summary>
@@ -157,8 +157,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestLargeScaleNumberOverlimit(string value)
{
string columnName = "LARGESCALENUMBER";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} DECIMAL(38,37)", columnName));
- await Assert.ThrowsAsync<HiveServer2Exception>(async () => await
ValidateInsertSelectDeleteSingleValue(table.TableName, columnName,
SqlDecimal.Parse(value)));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,37)",
columnName));
+ await Assert.ThrowsAsync<HiveServer2Exception>(async () => await
ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName,
SqlDecimal.Parse(value)));
}
/// <summary>
@@ -172,8 +172,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestSmallScaleNumberRange(string value)
{
string columnName = "SMALLSCALENUMBER";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} DECIMAL(38,2)", columnName));
- await ValidateInsertSelectDeleteSingleValue(table.TableName,
columnName, SqlDecimal.Parse(value));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,2)",
columnName));
+ await ValidateInsertSelectDeleteSingleValueAsync(table.TableName,
columnName, SqlDecimal.Parse(value));
}
/// <summary>
@@ -185,8 +185,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestSmallScaleNumberOverlimit(string value)
{
string columnName = "SMALLSCALENUMBER";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} DECIMAL(38,2)", columnName));
- await Assert.ThrowsAsync<HiveServer2Exception>(async () => await
ValidateInsertSelectDeleteSingleValue(table.TableName, columnName,
SqlDecimal.Parse(value)));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,2)",
columnName));
+ await Assert.ThrowsAsync<HiveServer2Exception>(async () => await
ValidateInsertSelectDeleteSingleValueAsync(table.TableName, columnName,
SqlDecimal.Parse(value)));
}
@@ -200,13 +200,13 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestRoundingNumbers(decimal input, decimal output)
{
string columnName = "SMALLSCALENUMBER";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} DECIMAL(38,2)", columnName));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} DECIMAL(38,2)",
columnName));
SqlDecimal value = new SqlDecimal(input);
SqlDecimal returned = new SqlDecimal(output);
- InsertSingleValue(table.TableName, columnName, value.ToString());
- await SelectAndValidateValues(table.TableName, columnName,
returned, 1);
+ await InsertSingleValueAsync(table.TableName, columnName,
value.ToString());
+ await SelectAndValidateValuesAsync(table.TableName, columnName,
returned, 1);
string whereClause = GetWhereClause(columnName, returned);
- DeleteFromTable(table.TableName, whereClause, 1);
+ await DeleteFromTableAsync(table.TableName, whereClause, 1);
}
/// <summary>
@@ -225,12 +225,12 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestDoubleValuesInsertSelectDelete(double value)
{
string columnName = "DOUBLETYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} DOUBLE", columnName));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} DOUBLE", columnName));
string valueString = ConvertDoubleToString(value);
- InsertSingleValue(table.TableName, columnName, valueString);
- await SelectAndValidateValues(table.TableName, columnName, value,
1);
+ await InsertSingleValueAsync(table.TableName, columnName,
valueString);
+ await SelectAndValidateValuesAsync(table.TableName, columnName,
value, 1);
string whereClause = GetWhereClause(columnName, value);
- DeleteFromTable(table.TableName, whereClause, 1);
+ await DeleteFromTableAsync(table.TableName, whereClause, 1);
}
/// <summary>
@@ -252,12 +252,12 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestFloatValuesInsertSelectDelete(float value)
{
string columnName = "FLOATTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} FLOAT", columnName));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} FLOAT", columnName));
string valueString = ConvertFloatToString(value);
- InsertSingleValue(table.TableName, columnName, valueString);
- await SelectAndValidateValues(table.TableName, columnName, value,
1);
+ await InsertSingleValueAsync(table.TableName, columnName,
valueString);
+ await SelectAndValidateValuesAsync(table.TableName, columnName,
value, 1);
string whereClause = GetWhereClause(columnName, value);
- DeleteFromTable(table.TableName, whereClause, 1);
+ await DeleteFromTableAsync(table.TableName, whereClause, 1);
}
}
}
diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestBase.cs
b/csharp/test/Drivers/Apache/Spark/SparkTestBase.cs
index 303ffed32..4a1482e47 100644
--- a/csharp/test/Drivers/Apache/Spark/SparkTestBase.cs
+++ b/csharp/test/Drivers/Apache/Spark/SparkTestBase.cs
@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
+using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache.Spark;
using Xunit.Abstractions;
@@ -32,12 +33,12 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
protected override AdbcDriver NewDriver => new SparkDriver();
- protected override TemporaryTable NewTemporaryTable(AdbcStatement
statement, string columns) {
+ protected override async ValueTask<TemporaryTable>
NewTemporaryTableAsync(AdbcStatement statement, string columns) {
string tableName = NewTableName();
// Note: Databricks/Spark doesn't support TEMPORARY table.
string sqlUpdate = string.Format("CREATE TABLE {0} ({1})",
tableName, columns);
OutputHelper?.WriteLine(sqlUpdate);
- return TemporaryTable.NewTemporaryTable(statement, tableName,
sqlUpdate);
+ return await TemporaryTable.NewTemporaryTableAsync(statement,
tableName, sqlUpdate);
}
protected override string Delimiter => "`";
diff --git a/csharp/test/Drivers/Apache/Spark/StringValueTests.cs
b/csharp/test/Drivers/Apache/Spark/StringValueTests.cs
index 2d2e1d292..9d2f708d0 100644
--- a/csharp/test/Drivers/Apache/Spark/StringValueTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/StringValueTests.cs
@@ -55,8 +55,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestStringData(string? value)
{
string columnName = "STRINGTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} {1}", columnName, "STRING"));
- await ValidateInsertSelectDeleteSingleValue(
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName,
"STRING"));
+ await ValidateInsertSelectDeleteSingleValueAsync(
table.TableName,
columnName,
value,
@@ -75,8 +75,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestVarcharData(string? value)
{
string columnName = "VARCHARTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} {1}", columnName, "VARCHAR(100)"));
- await ValidateInsertSelectDeleteSingleValue(
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName,
"VARCHAR(100)"));
+ await ValidateInsertSelectDeleteSingleValueAsync(
table.TableName,
columnName,
value,
@@ -96,15 +96,15 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
{
string columnName = "CHARTYPE";
int fieldLength = 100;
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} {1}", columnName, $"CHAR({fieldLength})"));
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName,
$"CHAR({fieldLength})"));
string? formattedValue = value != null ?
QuoteValue(value.PadRight(fieldLength)) : value;
string? paddedValue = value != null ? value.PadRight(fieldLength)
: value;
- InsertSingleValue(table.TableName, columnName, formattedValue);
- await SelectAndValidateValues(table.TableName, columnName,
paddedValue, 1, formattedValue);
+ await InsertSingleValueAsync(table.TableName, columnName,
formattedValue);
+ await SelectAndValidateValuesAsync(table.TableName, columnName,
paddedValue, 1, formattedValue);
string whereClause = GetWhereClause(columnName, formattedValue ??
paddedValue);
- DeleteFromTable(table.TableName, whereClause, 1);
+ await DeleteFromTableAsync(table.TableName, whereClause, 1);
}
/// <summary>
@@ -115,8 +115,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public async Task TestVarcharExceptionData(string value)
{
string columnName = "VARCHARTYPE";
- using TemporaryTable table = NewTemporaryTable(Statement,
string.Format("{0} {1}", columnName, "VARCHAR(10)"));
- AdbcException exception = await
Assert.ThrowsAsync<HiveServer2Exception>(async () => await
ValidateInsertSelectDeleteSingleValue(
+ using TemporaryTable table = await
NewTemporaryTableAsync(Statement, string.Format("{0} {1}", columnName,
"VARCHAR(10)"));
+ AdbcException exception = await
Assert.ThrowsAsync<HiveServer2Exception>(async () => await
ValidateInsertSelectDeleteSingleValueAsync(
table.TableName,
columnName,
value,