This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 845f9c2d9 feat(csharp/src/Drivers/Apache): add connect and query
timeout options (#2312)
845f9c2d9 is described below
commit 845f9c2d98416ae6871680c4afd98d7ff2e51333
Author: Bruce Irschick <[email protected]>
AuthorDate: Wed Dec 4 10:27:43 2024 -0800
feat(csharp/src/Drivers/Apache): add connect and query timeout options
(#2312)
Adds options for command and query timeout
| Property | Description | Default |
| :--- | :--- | :--- |
| `adbc.spark.connect_timeout_ms` | Sets the timeout (in milliseconds)
to open a new session. Values can be 0 (infinite) or greater than zero.
| `30000` |
| `adbc.apache.statement.query_timeout_s` | Sets the maximum time (in
seconds) for a query to complete. Values can be 0 (infinite) or greater
than zero. | `60` |
---------
Co-authored-by: Aman Goyal <[email protected]>
Co-authored-by: David Coe <[email protected]>
---
csharp/src/Client/AdbcCommand.cs | 28 +-
...iveServer2Parameters.cs => ApacheParameters.cs} | 20 +-
csharp/src/Drivers/Apache/ApacheUtility.cs | 141 +++++++
.../Drivers/Apache/Hive2/HiveServer2Connection.cs | 106 +++--
.../Drivers/Apache/Hive2/HiveServer2Parameters.cs | 2 -
.../src/Drivers/Apache/Hive2/HiveServer2Reader.cs | 34 +-
.../Drivers/Apache/Hive2/HiveServer2Statement.cs | 134 +++++--
.../src/Drivers/Apache/Impala/ImpalaConnection.cs | 6 +-
.../src/Drivers/Apache/Impala/ImpalaStatement.cs | 2 +-
csharp/src/Drivers/Apache/Spark/README.md | 15 +-
csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 438 +++++++++++----------
.../Apache/Spark/SparkDatabricksConnection.cs | 19 +-
.../Drivers/Apache/Spark/SparkDatabricksReader.cs | 1 -
.../Drivers/Apache/Spark/SparkHttpConnection.cs | 69 ++--
csharp/src/Drivers/Apache/Spark/SparkParameters.cs | 4 +-
.../Apache/Spark/SparkStandardConnection.cs | 10 +-
csharp/src/Drivers/Apache/Spark/SparkStatement.cs | 7 +-
.../test/Drivers/Apache/ApacheTestConfiguration.cs | 9 +-
csharp/test/Drivers/Apache/Common/ClientTests.cs | 22 ++
.../test/Drivers/Apache/Common/StatementTests.cs | 124 +++++-
.../Drivers/Apache/Spark/SparkConnectionTest.cs | 236 ++++++++++-
.../Drivers/Apache/Spark/SparkTestEnvironment.cs | 13 +-
csharp/test/Drivers/Apache/Spark/StatementTests.cs | 2 +
23 files changed, 1077 insertions(+), 365 deletions(-)
diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs
index 8b85be206..c3695feaf 100644
--- a/csharp/src/Client/AdbcCommand.cs
+++ b/csharp/src/Client/AdbcCommand.cs
@@ -21,6 +21,7 @@ using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlTypes;
+using System.Globalization;
using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Types;
@@ -32,10 +33,11 @@ namespace Apache.Arrow.Adbc.Client
/// </summary>
public sealed class AdbcCommand : DbCommand
{
- private AdbcStatement _adbcStatement;
+ private readonly AdbcStatement _adbcStatement;
private AdbcParameterCollection? _dbParameterCollection;
private int _timeout = 30;
private bool _disposed;
+ private string? _commandTimeoutProperty;
/// <summary>
/// Overloaded. Initializes <see cref="AdbcCommand"/>.
@@ -117,10 +119,32 @@ namespace Apache.Arrow.Adbc.Client
}
}
+
+ /// <summary>
+ /// Gets or sets the name of the command timeout property for the
underlying ADBC driver.
+ /// </summary>
+ public string AdbcCommandTimeoutProperty
+ {
+ get
+ {
+ if (string.IsNullOrEmpty(_commandTimeoutProperty))
+ throw new
InvalidOperationException("CommandTimeoutProperty is not set.");
+
+ return _commandTimeoutProperty!;
+ }
+ set => _commandTimeoutProperty = value;
+ }
+
public override int CommandTimeout
{
get => _timeout;
- set => _timeout = value;
+ set
+ {
+ // ensures the property exists before setting the
CommandTimeout value
+ string property = AdbcCommandTimeoutProperty;
+ _adbcStatement.SetOption(property,
value.ToString(CultureInfo.InvariantCulture));
+ _timeout = value;
+ }
}
protected override DbParameterCollection DbParameterCollection
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
b/csharp/src/Drivers/Apache/ApacheParameters.cs
similarity index 66%
copy from csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
copy to csharp/src/Drivers/Apache/ApacheParameters.cs
index 2170cd17b..17c94be32 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
+++ b/csharp/src/Drivers/Apache/ApacheParameters.cs
@@ -15,19 +15,15 @@
* limitations under the License.
*/
-using System;
-
-namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
+namespace Apache.Arrow.Adbc.Drivers.Apache
{
- public static class DataTypeConversionOptions
- {
- public const string None = "none";
- public const string Scalar = "scalar";
- }
-
- public static class TlsOptions
+ /// <summary>
+ /// Options common to all Apache drivers.
+ /// </summary>
+ public class ApacheParameters
{
- public const string AllowSelfSigned = "allow_self_signed";
- public const string AllowHostnameMismatch = "allow_hostname_mismatch";
+ public const string PollTimeMilliseconds =
"adbc.apache.statement.polltime_ms";
+ public const string BatchSize = "adbc.apache.statement.batch_size";
+ public const string QueryTimeoutSeconds =
"adbc.apache.statement.query_timeout_s";
}
}
diff --git a/csharp/src/Drivers/Apache/ApacheUtility.cs
b/csharp/src/Drivers/Apache/ApacheUtility.cs
new file mode 100644
index 000000000..f1cb07e07
--- /dev/null
+++ b/csharp/src/Drivers/Apache/ApacheUtility.cs
@@ -0,0 +1,141 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Threading;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache
+{
+ internal class ApacheUtility
+ {
+ internal const int QueryTimeoutSecondsDefault = 60;
+
+ public enum TimeUnit
+ {
+ Seconds,
+ Milliseconds
+ }
+
+ public static CancellationToken GetCancellationToken(int timeout,
TimeUnit timeUnit)
+ {
+ TimeSpan span;
+
+ if (timeout == 0 || timeout == int.MaxValue)
+ {
+ // the max TimeSpan for CancellationTokenSource is
int.MaxValue in milliseconds (not TimeSpan.MaxValue)
+ // no matter what the unit is
+ span = TimeSpan.FromMilliseconds(int.MaxValue);
+ }
+ else
+ {
+ if (timeUnit == TimeUnit.Seconds)
+ {
+ span = TimeSpan.FromSeconds(timeout);
+ }
+ else
+ {
+ span = TimeSpan.FromMilliseconds(timeout);
+ }
+ }
+
+ return GetCancellationToken(span);
+ }
+
+ private static CancellationToken GetCancellationToken(TimeSpan
timeSpan)
+ {
+ var cts = new CancellationTokenSource(timeSpan);
+ return cts.Token;
+ }
+
+ public static bool QueryTimeoutIsValid(string key, string value, out
int queryTimeoutSeconds)
+ {
+ if (!string.IsNullOrEmpty(value) && int.TryParse(value, out int
queryTimeout) && (queryTimeout >= 0))
+ {
+ queryTimeoutSeconds = queryTimeout;
+ return true;
+ }
+ else
+ {
+ throw new ArgumentOutOfRangeException(key, value, $"The value
'{value}' for option '{key}' is invalid. Must be a numeric value of 0
(infinite) or greater.");
+ }
+ }
+
+ public static bool ContainsException<T>(Exception exception, out T?
containedException) where T : Exception
+ {
+ if (exception is AggregateException aggregateException)
+ {
+ foreach (Exception? ex in aggregateException.InnerExceptions)
+ {
+ if (ex is T ce)
+ {
+ containedException = ce;
+ return true;
+ }
+ }
+ }
+
+ Exception? e = exception;
+ while (e != null)
+ {
+ if (e is T ce)
+ {
+ containedException = ce;
+ return true;
+ }
+ e = e.InnerException;
+ }
+
+ containedException = null;
+ return false;
+ }
+
+ public static bool ContainsException(Exception exception, Type?
exceptionType, out Exception? containedException)
+ {
+ if (exception == null || exceptionType == null)
+ {
+ containedException = null;
+ return false;
+ }
+
+ if (exception is AggregateException aggregateException)
+ {
+ foreach (Exception? ex in aggregateException.InnerExceptions)
+ {
+ if (exceptionType.IsInstanceOfType(ex))
+ {
+ containedException = ex;
+ return true;
+ }
+ }
+ }
+
+ Exception? e = exception;
+ while (e != null)
+ {
+ if (exceptionType.IsInstanceOfType(e))
+ {
+ containedException = e;
+ return true;
+ }
+ e = e.InnerException;
+ }
+
+ containedException = null;
+ return false;
+ }
+ }
+}
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
index c839bbaa7..d420edb2b 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
@@ -30,7 +30,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{
internal const long BatchSizeDefault = 50000;
internal const int PollTimeMillisecondsDefault = 500;
-
+ private const int ConnectTimeoutMillisecondsDefault = 30000;
private TTransport? _transport;
private TCLIService.Client? _client;
private readonly Lazy<string> _vendorVersion;
@@ -45,6 +45,14 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
//
https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects
_vendorVersion = new Lazy<string>(() =>
GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER),
LazyThreadSafetyMode.PublicationOnly);
_vendorName = new Lazy<string>(() =>
GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME),
LazyThreadSafetyMode.PublicationOnly);
+
+ if (properties.TryGetValue(ApacheParameters.QueryTimeoutSeconds,
out string? queryTimeoutSecondsSettingValue))
+ {
+ if
(ApacheUtility.QueryTimeoutIsValid(ApacheParameters.QueryTimeoutSeconds,
queryTimeoutSecondsSettingValue, out int queryTimeoutSeconds))
+ {
+ QueryTimeoutSeconds = queryTimeoutSeconds;
+ }
+ }
}
internal TCLIService.Client Client
@@ -56,30 +64,48 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
internal string VendorName => _vendorName.Value;
+ protected internal int QueryTimeoutSeconds { get; set; } =
ApacheUtility.QueryTimeoutSecondsDefault;
+
internal IReadOnlyDictionary<string, string> Properties { get; }
internal async Task OpenAsync()
{
- TTransport transport = await CreateTransportAsync();
- TProtocol protocol = await CreateProtocolAsync(transport);
- _transport = protocol.Transport;
- _client = new TCLIService.Client(protocol);
- TOpenSessionReq request = CreateSessionRequest();
- TOpenSessionResp? session = await Client.OpenSession(request);
-
- // Some responses don't raise an exception. Explicitly check the
status.
- if (session == null)
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(ConnectTimeoutMilliseconds,
ApacheUtility.TimeUnit.Milliseconds);
+ try
{
- throw new HiveServer2Exception("unable to open session.
unknown error.");
+ TTransport transport = CreateTransport();
+ TProtocol protocol = await CreateProtocolAsync(transport,
cancellationToken);
+ _transport = protocol.Transport;
+ _client = new TCLIService.Client(protocol);
+ TOpenSessionReq request = CreateSessionRequest();
+
+ TOpenSessionResp? session = await Client.OpenSession(request,
cancellationToken);
+
+ // Explicitly check the session status
+ if (session == null)
+ {
+ throw new HiveServer2Exception("Unable to open session.
Unknown error.");
+ }
+ else if (session.Status.StatusCode !=
TStatusCode.SUCCESS_STATUS)
+ {
+ throw new HiveServer2Exception(session.Status.ErrorMessage)
+ .SetNativeError(session.Status.ErrorCode)
+ .SetSqlState(session.Status.SqlState);
+ }
+
+ SessionHandle = session.SessionHandle;
}
- else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS)
+ catch (Exception ex)
+ when (ApacheUtility.ContainsException(ex, out
OperationCanceledException? _) ||
+ (ApacheUtility.ContainsException(ex, out
TTransportException? _) && cancellationToken.IsCancellationRequested))
{
- throw new HiveServer2Exception(session.Status.ErrorMessage)
- .SetNativeError(session.Status.ErrorCode)
- .SetSqlState(session.Status.SqlState);
+ throw new TimeoutException("The operation timed out while
attempting to open a session. Please try increasing connect timeout.", ex);
+ }
+ catch (Exception ex) when (ex is not HiveServer2Exception)
+ {
+ // Handle other exceptions if necessary
+ throw new HiveServer2Exception($"An unexpected error occurred
while opening the session. '{ex.Message}'", ex);
}
-
- SessionHandle = session.SessionHandle;
}
internal TSessionHandle? SessionHandle { get; private set; }
@@ -88,11 +114,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
protected internal HiveServer2TlsOption TlsOptions { get; set; } =
HiveServer2TlsOption.Empty;
- protected internal int HttpRequestTimeout { get; set; } = 30000;
+ protected internal int ConnectTimeoutMilliseconds { get; set; } =
ConnectTimeoutMillisecondsDefault;
- protected abstract Task<TTransport> CreateTransportAsync();
+ protected abstract TTransport CreateTransport();
- protected abstract Task<TProtocol> CreateProtocolAsync(TTransport
transport);
+ protected abstract Task<TProtocol> CreateProtocolAsync(TTransport
transport, CancellationToken cancellationToken = default);
protected abstract TOpenSessionReq CreateSessionRequest();
@@ -110,14 +136,14 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
throw new NotImplementedException();
}
- internal static async Task PollForResponseAsync(TOperationHandle
operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds)
+ internal static async Task PollForResponseAsync(TOperationHandle
operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds,
CancellationToken cancellationToken = default)
{
TGetOperationStatusResp? statusResponse = null;
do
{
- if (statusResponse != null) { await
Task.Delay(pollTimeMilliseconds); }
+ if (statusResponse != null) { await
Task.Delay(pollTimeMilliseconds, cancellationToken); }
TGetOperationStatusReq request = new(operationHandle);
- statusResponse = await client.GetOperationStatus(request);
+ statusResponse = await client.GetOperationStatus(request,
cancellationToken);
} while (statusResponse.OperationState ==
TOperationState.PENDING_STATE || statusResponse.OperationState ==
TOperationState.RUNNING_STATE);
}
@@ -129,24 +155,38 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
InfoType = infoType,
};
- TGetInfoResp getInfoResp = Client.GetInfo(req).Result;
- if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
+ try
{
- throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage)
- .SetNativeError(getInfoResp.Status.ErrorCode)
- .SetSqlState(getInfoResp.Status.SqlState);
+ TGetInfoResp getInfoResp = Client.GetInfo(req,
cancellationToken).Result;
+ if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
+ {
+ throw new
HiveServer2Exception(getInfoResp.Status.ErrorMessage)
+ .SetNativeError(getInfoResp.Status.ErrorCode)
+ .SetSqlState(getInfoResp.Status.SqlState);
+ }
+
+ return getInfoResp.InfoValue.StringValue;
+ }
+ catch (Exception ex)
+ when (ApacheUtility.ContainsException(ex, out
OperationCanceledException? _) ||
+ (ApacheUtility.ContainsException(ex, out
TTransportException? _) && cancellationToken.IsCancellationRequested))
+ {
+ throw new TimeoutException("The metadata query execution timed
out. Consider increasing the query timeout value.", ex);
+ }
+ catch (Exception ex) when (ex is not HiveServer2Exception)
+ {
+ throw new HiveServer2Exception($"An unexpected error occurred
while running metadata query. '{ex.Message}'", ex);
}
-
- return getInfoResp.InfoValue.StringValue;
}
public override void Dispose()
{
if (_client != null)
{
- TCloseSessionReq r6 = new TCloseSessionReq(SessionHandle);
- _client.CloseSession(r6).Wait();
-
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
+ TCloseSessionReq r6 = new(SessionHandle);
+ _client.CloseSession(r6, cancellationToken).Wait();
_transport?.Close();
_client.Dispose();
_transport = null;
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
index 2170cd17b..4f2bc62d2 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
@@ -15,8 +15,6 @@
* limitations under the License.
*/
-using System;
-
namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{
public static class DataTypeConversionOptions
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
index 08b0675d0..34dbf10f2 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
@@ -25,6 +25,7 @@ using System.Threading.Tasks;
using Apache.Arrow.Ipc;
using Apache.Arrow.Types;
using Apache.Hive.Service.Rpc.Thrift;
+using Thrift.Transport;
namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{
@@ -89,19 +90,32 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
return null;
}
- // Await the fetch response
- TFetchResultsResp response = await FetchNext(_statement,
cancellationToken);
+ try
+ {
+ // Await the fetch response
+ TFetchResultsResp response = await FetchNext(_statement,
cancellationToken);
+
+ int columnCount = GetColumnCount(response);
+ int rowCount = GetRowCount(response, columnCount);
+ if ((_statement.BatchSize > 0 && rowCount <
_statement.BatchSize) || rowCount == 0)
+ {
+ // This is the last batch
+ _statement = null;
+ }
- int columnCount = GetColumnCount(response);
- int rowCount = GetRowCount(response, columnCount);
- if ((_statement.BatchSize > 0 && rowCount < _statement.BatchSize)
|| rowCount == 0)
+ // Build the current batch, if any data exists
+ return rowCount > 0 ? CreateBatch(response, columnCount,
rowCount) : null;
+ }
+ catch (Exception ex)
+ when (ApacheUtility.ContainsException(ex, out
OperationCanceledException? _) ||
+ (ApacheUtility.ContainsException(ex, out
TTransportException? _) && cancellationToken.IsCancellationRequested))
{
- // This is the last batch
- _statement = null;
+ throw new TimeoutException("The query execution timed out.
Consider increasing the query timeout value.", ex);
+ }
+ catch (Exception ex) when (ex is not HiveServer2Exception)
+ {
+ throw new HiveServer2Exception($"An unexpected error occurred
while fetching results. '{ex.Message}'", ex);
}
-
- // Build the current batch, if any data exists
- return rowCount > 0 ? CreateBatch(response, columnCount, rowCount)
: null;
}
private RecordBatch CreateBatch(TFetchResultsResp response, int
columnCount, int rowCount)
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
index 824feceb9..06723e324 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
@@ -20,6 +20,7 @@ using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
using Apache.Hive.Service.Rpc.Thrift;
+using Thrift.Transport;
namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{
@@ -32,33 +33,89 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
protected virtual void SetStatementProperties(TExecuteStatementReq
statement)
{
+ statement.QueryTimeout = QueryTimeoutSeconds;
}
- public override QueryResult ExecuteQuery() =>
ExecuteQueryAsync().AsTask().Result;
+ public override QueryResult ExecuteQuery()
+ {
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
+ try
+ {
+ return ExecuteQueryAsyncInternal(cancellationToken).Result;
+ }
+ catch (Exception ex)
+ when (ApacheUtility.ContainsException(ex, out
OperationCanceledException? _) ||
+ (ApacheUtility.ContainsException(ex, out
TTransportException? _) && cancellationToken.IsCancellationRequested))
+ {
+ throw new TimeoutException("The query execution timed out.
Consider increasing the query timeout value.", ex);
+ }
+ catch (Exception ex) when (ex is not HiveServer2Exception)
+ {
+ throw new HiveServer2Exception($"An unexpected error occurred
while fetching results. '{ex.Message}'", ex);
+ }
+ }
- public override UpdateResult ExecuteUpdate() =>
ExecuteUpdateAsync().Result;
+ public override UpdateResult ExecuteUpdate()
+ {
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
+ try
+ {
+ return ExecuteUpdateAsyncInternal(cancellationToken).Result;
+ }
+ catch (Exception ex)
+ when (ApacheUtility.ContainsException(ex, out
OperationCanceledException? _) ||
+ (ApacheUtility.ContainsException(ex, out
TTransportException? _) && cancellationToken.IsCancellationRequested))
+ {
+ throw new TimeoutException("The query execution timed out.
Consider increasing the query timeout value.", ex);
+ }
+ catch (Exception ex) when (ex is not HiveServer2Exception)
+ {
+ throw new HiveServer2Exception($"An unexpected error occurred
while fetching results. '{ex.Message}'", ex);
+ }
+ }
- public override async ValueTask<QueryResult> ExecuteQueryAsync()
+ private async Task<QueryResult>
ExecuteQueryAsyncInternal(CancellationToken cancellationToken = default)
{
- await ExecuteStatementAsync();
- await HiveServer2Connection.PollForResponseAsync(OperationHandle!,
Connection.Client, PollTimeMilliseconds);
- Schema schema = await GetResultSetSchemaAsync(OperationHandle!,
Connection.Client);
+ // this could either:
+ // take QueryTimeoutSeconds * 3
+ // OR
+ // take QueryTimeoutSeconds (but this could be restricting)
+ await ExecuteStatementAsync(cancellationToken); // --> get
QueryTimeout +
+ await HiveServer2Connection.PollForResponseAsync(OperationHandle!,
Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to
QueryTimeout
+ Schema schema = await GetResultSetSchemaAsync(OperationHandle!,
Connection.Client, cancellationToken); // + get the result, up to QueryTimeout
- // TODO: Ensure this is set dynamically based on server
capabilities
return new QueryResult(-1, Connection.NewReader(this, schema));
}
+ public override async ValueTask<QueryResult> ExecuteQueryAsync()
+ {
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
+ try
+ {
+ return await ExecuteQueryAsyncInternal(cancellationToken);
+ }
+ catch (Exception ex)
+ when (ApacheUtility.ContainsException(ex, out
OperationCanceledException? _) ||
+ (ApacheUtility.ContainsException(ex, out
TTransportException? _) && cancellationToken.IsCancellationRequested))
+ {
+ throw new TimeoutException("The query execution timed out.
Consider increasing the query timeout value.", ex);
+ }
+ catch (Exception ex) when (ex is not HiveServer2Exception)
+ {
+ throw new HiveServer2Exception($"An unexpected error occurred
while fetching results. '{ex.Message}'", ex);
+ }
+ }
+
private async Task<Schema> GetResultSetSchemaAsync(TOperationHandle
operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken
= default)
{
TGetResultSetMetadataResp response = await
HiveServer2Connection.GetResultSetMetadataAsync(operationHandle, client,
cancellationToken);
return Connection.SchemaParser.GetArrowSchema(response.Schema,
Connection.DataTypeConversion);
}
- public override async Task<UpdateResult> ExecuteUpdateAsync()
+ public async Task<UpdateResult>
ExecuteUpdateAsyncInternal(CancellationToken cancellationToken = default)
{
const string NumberOfAffectedRowsColumnName = "num_affected_rows";
-
- QueryResult queryResult = await ExecuteQueryAsync();
+ QueryResult queryResult = await
ExecuteQueryAsyncInternal(cancellationToken);
if (queryResult.Stream == null)
{
throw new AdbcException("no data found");
@@ -79,7 +136,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
long? affectedRows = null;
while (true)
{
- using RecordBatch nextBatch = await
stream.ReadNextRecordBatchAsync();
+ using RecordBatch nextBatch = await
stream.ReadNextRecordBatchAsync(cancellationToken);
if (nextBatch == null) { break; }
Int64Array numOfModifiedArray =
(Int64Array)nextBatch.Column(NumberOfAffectedRowsColumnName);
// Note: should only have one item, but iterate for
completeness
@@ -94,26 +151,51 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
return new UpdateResult(affectedRows ?? -1);
}
+ public override async Task<UpdateResult> ExecuteUpdateAsync()
+ {
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
+ try
+ {
+ return await ExecuteUpdateAsyncInternal(cancellationToken);
+ }
+ catch (Exception ex)
+ when (ApacheUtility.ContainsException(ex, out
OperationCanceledException? _) ||
+ (ApacheUtility.ContainsException(ex, out
TTransportException? _) && cancellationToken.IsCancellationRequested))
+ {
+ throw new TimeoutException("The query execution timed out.
Consider increasing the query timeout value.", ex);
+ }
+ catch (Exception ex) when (ex is not HiveServer2Exception)
+ {
+ throw new HiveServer2Exception($"An unexpected error occurred
while fetching results. '{ex.Message}'", ex);
+ }
+ }
+
public override void SetOption(string key, string value)
{
switch (key)
{
- case Options.PollTimeMilliseconds:
+ case ApacheParameters.PollTimeMilliseconds:
UpdatePollTimeIfValid(key, value);
break;
- case Options.BatchSize:
+ case ApacheParameters.BatchSize:
UpdateBatchSizeIfValid(key, value);
break;
+ case ApacheParameters.QueryTimeoutSeconds:
+ if (ApacheUtility.QueryTimeoutIsValid(key, value, out int
queryTimeoutSeconds))
+ {
+ QueryTimeoutSeconds = queryTimeoutSeconds;
+ }
+ break;
default:
throw AdbcException.NotImplemented($"Option '{key}' is not
implemented.");
}
}
- protected async Task ExecuteStatementAsync()
+ protected async Task ExecuteStatementAsync(CancellationToken
cancellationToken = default)
{
TExecuteStatementReq executeRequest = new
TExecuteStatementReq(Connection.SessionHandle, SqlQuery);
SetStatementProperties(executeRequest);
- TExecuteStatementResp executeResponse = await
Connection.Client.ExecuteStatement(executeRequest);
+ TExecuteStatementResp executeResponse = await
Connection.Client.ExecuteStatement(executeRequest, cancellationToken);
if (executeResponse.Status.StatusCode == TStatusCode.ERROR_STATUS)
{
throw new
HiveServer2Exception(executeResponse.Status.ErrorMessage)
@@ -127,23 +209,20 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
protected internal long BatchSize { get; private set; } =
HiveServer2Connection.BatchSizeDefault;
+ protected internal int QueryTimeoutSeconds
+ {
+ // Coordinate updates with the connection
+ get => Connection.QueryTimeoutSeconds;
+ set => Connection.QueryTimeoutSeconds = value;
+ }
+
public HiveServer2Connection Connection { get; private set; }
public TOperationHandle? OperationHandle { get; private set; }
- /// <summary>
- /// Provides the constant string key values to the <see
cref="AdbcStatement.SetOption(string, string)" /> method.
- /// </summary>
- public class Options
- {
- // Options common to all HiveServer2Statement-derived drivers go
here
- public const string PollTimeMilliseconds =
"adbc.statement.polltime_milliseconds";
- public const string BatchSize = "adbc.statement.batch_size";
- }
-
private void UpdatePollTimeIfValid(string key, string value) =>
PollTimeMilliseconds = !string.IsNullOrEmpty(key) && int.TryParse(value,
result: out int pollTimeMilliseconds) && pollTimeMilliseconds >= 0
? pollTimeMilliseconds
- : throw new ArgumentOutOfRangeException(key, value, $"The value
'{value}' for option '{key}' is invalid. Must be a numeric value greater than
or equal to -1.");
+ : throw new ArgumentOutOfRangeException(key, value, $"The value
'{value}' for option '{key}' is invalid. Must be a numeric value greater than
or equal to 0.");
private void UpdateBatchSizeIfValid(string key, string value) =>
BatchSize = !string.IsNullOrEmpty(value) && long.TryParse(value, out long
batchSize) && batchSize > 0
? batchSize
@@ -153,8 +232,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{
if (OperationHandle != null)
{
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
TCloseOperationReq request = new
TCloseOperationReq(OperationHandle);
- Connection.Client.CloseOperation(request).Wait();
+ Connection.Client.CloseOperation(request,
cancellationToken).Wait();
OperationHandle = null;
}
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
index c6c6cc796..0e673c7c4 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
@@ -40,7 +40,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
{
}
- protected override Task<TTransport> CreateTransportAsync()
+ protected override TTransport CreateTransport()
{
string hostName = Properties["HostName"];
string? tmp;
@@ -52,10 +52,10 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
TConfiguration config = new TConfiguration();
TTransport transport = new ThriftSocketTransport(hostName, port,
config);
- return Task.FromResult(transport);
+ return transport;
}
- protected override Task<TProtocol> CreateProtocolAsync(TTransport
transport)
+ protected override Task<TProtocol> CreateProtocolAsync(TTransport
transport, CancellationToken cancellationToken = default)
{
return Task.FromResult<TProtocol>(new TBinaryProtocol(transport));
}
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
index 0bd620ee9..f94ac3970 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
@@ -30,7 +30,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
/// <summary>
/// Provides the constant string key values to the <see
cref="AdbcStatement.SetOption(string, string)" /> method.
/// </summary>
- public new sealed class Options : HiveServer2Statement.Options
+ public sealed class Options : ApacheParameters
{
// options specific to Impala go here
}
diff --git a/csharp/src/Drivers/Apache/Spark/README.md
b/csharp/src/Drivers/Apache/Spark/README.md
index 7d1f8b560..3b5a0e79e 100644
--- a/csharp/src/Drivers/Apache/Spark/README.md
+++ b/csharp/src/Drivers/Apache/Spark/README.md
@@ -37,9 +37,18 @@ but can also be passed in the call to `AdbcDatabase.Connect`.
| `password` | The password for the user name used for basic
authentication. | |
| `adbc.spark.data_type_conv` | Comma-separated list of data conversion
options. Each option indicates the type of conversion to perform on data
returned from the Spark server. <br><br>Allowed values: `none`, `scalar`.
<br><br>Option `none` indicates there is no conversion from Spark type to
native type (i.e., no conversion from String to Timestamp for Apache Spark over
HTTP). Example `adbc.spark.conv_data_type=none`. <br><br>Option `scalar` will
perform conversion (if necessary) from th [...]
| `adbc.spark.tls_options` | Comma-separated list of TLS/SSL options. Each
option indicates the TLS/SSL option when connecting to a Spark server.
<br><br>Allowed values: `allow_self_signed`, `allow_hostname_mismatch`.
<br><br>Option `allow_self_signed` allows certificate errors due to an unknown
certificate authority, typically when using a self-signed certificate. Option
`allow_hostname_mismatch` allow certificate errors due to a mismatch of the
hostname. (e.g., when connecting through [...]
-| `adbc.spark.http_request_timeout_ms` | Sets the timeout (in milliseconds)
when making requests to the Spark server (type: `http`). Set the value higher
than the default if you notice errors due to network timeouts. | `30000` |
-| `adbc.statement.batch_size` | Sets the maximum number of rows to retrieve in
a single batch request. | `50000` |
-| `adbc.statement.polltime_milliseconds` | If polling is necessary to get a
result, this option sets the length of time (in milliseconds) to wait between
polls. | `500` |
+| `adbc.spark.connect_timeout_ms` | Sets the timeout (in milliseconds) to open
a new session. Values can be 0 (infinite) or greater than zero. | `30000` |
+| `adbc.apache.statement.batch_size` | Sets the maximum number of rows to
retrieve in a single batch request. | `50000` |
+| `adbc.apache.statement.polltime_ms` | If polling is necessary to get a
result, this option sets the length of time (in milliseconds) to wait between
polls. | `500` |
+| `adbc.apache.statement.query_timeout_s` | Sets the maximum time (in seconds)
for a query to complete. Values can be 0 (infinite) or greater than zero. |
`60` |
+
+## Timeout Configuration
+
+Timeouts have a hierarchy to their behavior. As specified above, the
`adbc.spark.connect_timeout_ms` is analogous to a ConnectTimeout and used to
initially establish a new session with the server.
+
+The `adbc.apache.statement.query_timeout_s` is analogous to a CommandTimeout
for any subsequent calls to the server for requests, including metadata calls
and executing queries.
+
+The `adbc.apache.statement.polltime_ms` specifies the time between polls to
the service, up to the limit specifed by
`adbc.apache.statement.query_timeout_s`.
## Spark Types
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index f532369e6..b3c0c56ba 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -19,8 +19,6 @@ using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
-using System.Net;
-using System.Net.Http;
using System.Reflection;
using System.Text;
using System.Text.RegularExpressions;
@@ -32,6 +30,7 @@ using Apache.Arrow.Adbc.Extensions;
using Apache.Arrow.Ipc;
using Apache.Arrow.Types;
using Apache.Hive.Service.Rpc.Thrift;
+using Thrift.Transport;
namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
{
@@ -420,26 +419,42 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
SessionHandle = SessionHandle ?? throw new
InvalidOperationException("session not created"),
GetDirectResults = sparkGetDirectResults
};
- TGetTableTypesResp resp = Client.GetTableTypes(req).Result;
- if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS)
+
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
+ try
{
- throw new HiveServer2Exception(resp.Status.ErrorMessage)
- .SetNativeError(resp.Status.ErrorCode)
- .SetSqlState(resp.Status.SqlState);
- }
+ TGetTableTypesResp resp = Client.GetTableTypes(req,
cancellationToken).Result;
- TRowSet rowSet = GetRowSetAsync(resp).Result;
- StringArray tableTypes = rowSet.Columns[0].StringVal.Values;
+ if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS)
+ {
+ throw new HiveServer2Exception(resp.Status.ErrorMessage)
+ .SetNativeError(resp.Status.ErrorCode)
+ .SetSqlState(resp.Status.SqlState);
+ }
- StringArray.Builder tableTypesBuilder = new StringArray.Builder();
- tableTypesBuilder.AppendRange(tableTypes);
+ TRowSet rowSet = GetRowSetAsync(resp,
cancellationToken).Result;
+ StringArray tableTypes = rowSet.Columns[0].StringVal.Values;
- IArrowArray[] dataArrays = new IArrowArray[]
- {
+ StringArray.Builder tableTypesBuilder = new
StringArray.Builder();
+ tableTypesBuilder.AppendRange(tableTypes);
+
+ IArrowArray[] dataArrays = new IArrowArray[]
+ {
tableTypesBuilder.Build()
- };
+ };
- return new SparkInfoArrowStream(StandardSchemas.TableTypesSchema,
dataArrays);
+ return new
SparkInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays);
+ }
+ catch (Exception ex)
+ when (ApacheUtility.ContainsException(ex, out
OperationCanceledException? _) ||
+ (ApacheUtility.ContainsException(ex, out
TTransportException? _) && cancellationToken.IsCancellationRequested))
+ {
+ throw new TimeoutException("The metadata query execution timed
out. Consider increasing the query timeout value.", ex);
+ }
+ catch (Exception ex) when (ex is not HiveServer2Exception)
+ {
+ throw new HiveServer2Exception($"An unexpected error occurred
while running metadata query. '{ex.Message}'", ex);
+ }
}
public override Schema GetTableSchema(string? catalog, string?
dbSchema, string? tableName)
@@ -450,221 +465,248 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
getColumnsReq.TableName = tableName;
getColumnsReq.GetDirectResults = sparkGetDirectResults;
- var columnsResponse = Client.GetColumns(getColumnsReq).Result;
- if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS)
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
+ try
{
- throw new Exception(columnsResponse.Status.ErrorMessage);
- }
+ var columnsResponse = Client.GetColumns(getColumnsReq,
cancellationToken).Result;
+ if (columnsResponse.Status.StatusCode ==
TStatusCode.ERROR_STATUS)
+ {
+ throw new Exception(columnsResponse.Status.ErrorMessage);
+ }
- TRowSet rowSet = GetRowSetAsync(columnsResponse).Result;
- List<TColumn> columns = rowSet.Columns;
- int rowCount = rowSet.Columns[3].StringVal.Values.Length;
+ TRowSet rowSet = GetRowSetAsync(columnsResponse,
cancellationToken).Result;
+ List<TColumn> columns = rowSet.Columns;
+ int rowCount = rowSet.Columns[3].StringVal.Values.Length;
- Field[] fields = new Field[rowCount];
- for (int i = 0; i < rowCount; i++)
+ Field[] fields = new Field[rowCount];
+ for (int i = 0; i < rowCount; i++)
+ {
+ string columnName =
columns[3].StringVal.Values.GetString(i);
+ int? columnType = columns[4].I32Val.Values.GetValue(i);
+ string typeName = columns[5].StringVal.Values.GetString(i);
+ // Note: the following two columns do not seem to be set
correctly for DECIMAL types.
+ //int? columnSize = columns[6].I32Val.Values.GetValue(i);
+ //int? decimalDigits =
columns[8].I32Val.Values.GetValue(i);
+ bool nullable = columns[10].I32Val.Values.GetValue(i) == 1;
+ IArrowType dataType =
SparkConnection.GetArrowType(columnType!.Value, typeName);
+ fields[i] = new Field(columnName, dataType, nullable);
+ }
+ return new Schema(fields, null);
+ }
+ catch (Exception ex)
+ when (ApacheUtility.ContainsException(ex, out
OperationCanceledException? _) ||
+ (ApacheUtility.ContainsException(ex, out
TTransportException? _) && cancellationToken.IsCancellationRequested))
{
- string columnName = columns[3].StringVal.Values.GetString(i);
- int? columnType = columns[4].I32Val.Values.GetValue(i);
- string typeName = columns[5].StringVal.Values.GetString(i);
- // Note: the following two columns do not seem to be set
correctly for DECIMAL types.
- //int? columnSize = columns[6].I32Val.Values.GetValue(i);
- //int? decimalDigits = columns[8].I32Val.Values.GetValue(i);
- bool nullable = columns[10].I32Val.Values.GetValue(i) == 1;
- IArrowType dataType =
SparkConnection.GetArrowType(columnType!.Value, typeName);
- fields[i] = new Field(columnName, dataType, nullable);
+ throw new TimeoutException("The metadata query execution timed
out. Consider increasing the query timeout value.", ex);
+ }
+ catch (Exception ex) when (ex is not HiveServer2Exception)
+ {
+ throw new HiveServer2Exception($"An unexpected error occurred
while running metadata query. '{ex.Message}'", ex);
}
- return new Schema(fields, null);
}
public override IArrowArrayStream GetObjects(GetObjectsDepth depth,
string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern,
IReadOnlyList<string>? tableTypes, string? columnNamePattern)
{
- Trace.TraceError($"getting objects with depth={depth.ToString()},
catalog = {catalogPattern}, dbschema = {dbSchemaPattern}, tablename =
{tableNamePattern}");
-
Dictionary<string, Dictionary<string, Dictionary<string,
TableInfo>>> catalogMap = new Dictionary<string, Dictionary<string,
Dictionary<string, TableInfo>>>();
- if (depth == GetObjectsDepth.All || depth >=
GetObjectsDepth.Catalogs)
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
+ try
{
- TGetCatalogsReq getCatalogsReq = new
TGetCatalogsReq(SessionHandle);
- getCatalogsReq.GetDirectResults = sparkGetDirectResults;
-
- TGetCatalogsResp getCatalogsResp =
Client.GetCatalogs(getCatalogsReq).Result;
- if (getCatalogsResp.Status.StatusCode ==
TStatusCode.ERROR_STATUS)
+ if (depth == GetObjectsDepth.All || depth >=
GetObjectsDepth.Catalogs)
{
- throw new Exception(getCatalogsResp.Status.ErrorMessage);
- }
- var catalogsMetadata =
GetResultSetMetadataAsync(getCatalogsResp).Result;
- IReadOnlyDictionary<string, int> columnMap =
GetColumnIndexMap(catalogsMetadata.Schema.Columns);
+ TGetCatalogsReq getCatalogsReq = new
TGetCatalogsReq(SessionHandle);
+ getCatalogsReq.GetDirectResults = sparkGetDirectResults;
- string catalogRegexp = PatternToRegEx(catalogPattern);
- TRowSet rowSet = GetRowSetAsync(getCatalogsResp).Result;
- IReadOnlyList<string> list =
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
- for (int i = 0; i < list.Count; i++)
- {
- string col = list[i];
- string catalog = col;
+ TGetCatalogsResp getCatalogsResp =
Client.GetCatalogs(getCatalogsReq, cancellationToken).Result;
- if (Regex.IsMatch(catalog, catalogRegexp,
RegexOptions.IgnoreCase))
+ if (getCatalogsResp.Status.StatusCode ==
TStatusCode.ERROR_STATUS)
{
- catalogMap.Add(catalog, new Dictionary<string,
Dictionary<string, TableInfo>>());
+ throw new
Exception(getCatalogsResp.Status.ErrorMessage);
}
- }
- // Handle the case where server does not support 'catalog' in
the namespace.
- if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern))
- {
- catalogMap.Add(string.Empty, []);
- }
- }
+ var catalogsMetadata =
GetResultSetMetadataAsync(getCatalogsResp, cancellationToken).Result;
+ IReadOnlyDictionary<string, int> columnMap =
GetColumnIndexMap(catalogsMetadata.Schema.Columns);
- if (depth == GetObjectsDepth.All || depth >=
GetObjectsDepth.DbSchemas)
- {
- TGetSchemasReq getSchemasReq = new
TGetSchemasReq(SessionHandle);
- getSchemasReq.CatalogName = catalogPattern;
- getSchemasReq.SchemaName = dbSchemaPattern;
- getSchemasReq.GetDirectResults = sparkGetDirectResults;
+ string catalogRegexp = PatternToRegEx(catalogPattern);
+ TRowSet rowSet = GetRowSetAsync(getCatalogsResp,
cancellationToken).Result;
+ IReadOnlyList<string> list =
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
+ for (int i = 0; i < list.Count; i++)
+ {
+ string col = list[i];
+ string catalog = col;
- TGetSchemasResp getSchemasResp =
Client.GetSchemas(getSchemasReq).Result;
- if (getSchemasResp.Status.StatusCode ==
TStatusCode.ERROR_STATUS)
- {
- throw new Exception(getSchemasResp.Status.ErrorMessage);
+ if (Regex.IsMatch(catalog, catalogRegexp,
RegexOptions.IgnoreCase))
+ {
+ catalogMap.Add(catalog, new Dictionary<string,
Dictionary<string, TableInfo>>());
+ }
+ }
+ // Handle the case where server does not support 'catalog'
in the namespace.
+ if (list.Count == 0 &&
string.IsNullOrEmpty(catalogPattern))
+ {
+ catalogMap.Add(string.Empty, []);
+ }
}
- TGetResultSetMetadataResp schemaMetadata =
GetResultSetMetadataAsync(getSchemasResp).Result;
- IReadOnlyDictionary<string, int> columnMap =
GetColumnIndexMap(schemaMetadata.Schema.Columns);
- TRowSet rowSet = GetRowSetAsync(getSchemasResp).Result;
-
- IReadOnlyList<string> catalogList =
rowSet.Columns[columnMap[TableCatalog]].StringVal.Values;
- IReadOnlyList<string> schemaList =
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
-
- for (int i = 0; i < catalogList.Count; i++)
+ if (depth == GetObjectsDepth.All || depth >=
GetObjectsDepth.DbSchemas)
{
- string catalog = catalogList[i];
- string schemaDb = schemaList[i];
- // It seems Spark sometimes returns empty string for
catalog on some schema (temporary tables).
- catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new
Dictionary<string, TableInfo>());
- }
- }
+ TGetSchemasReq getSchemasReq = new
TGetSchemasReq(SessionHandle);
+ getSchemasReq.CatalogName = catalogPattern;
+ getSchemasReq.SchemaName = dbSchemaPattern;
+ getSchemasReq.GetDirectResults = sparkGetDirectResults;
- if (depth == GetObjectsDepth.All || depth >=
GetObjectsDepth.Tables)
- {
- TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle);
- getTablesReq.CatalogName = catalogPattern;
- getTablesReq.SchemaName = dbSchemaPattern;
- getTablesReq.TableName = tableNamePattern;
- getTablesReq.GetDirectResults = sparkGetDirectResults;
-
- TGetTablesResp getTablesResp =
Client.GetTables(getTablesReq).Result;
- if (getTablesResp.Status.StatusCode ==
TStatusCode.ERROR_STATUS)
- {
- throw new Exception(getTablesResp.Status.ErrorMessage);
- }
+ TGetSchemasResp getSchemasResp =
Client.GetSchemas(getSchemasReq, cancellationToken).Result;
+ if (getSchemasResp.Status.StatusCode ==
TStatusCode.ERROR_STATUS)
+ {
+ throw new
Exception(getSchemasResp.Status.ErrorMessage);
+ }
- TGetResultSetMetadataResp tableMetadata =
GetResultSetMetadataAsync(getTablesResp).Result;
- IReadOnlyDictionary<string, int> columnMap =
GetColumnIndexMap(tableMetadata.Schema.Columns);
- TRowSet rowSet = GetRowSetAsync(getTablesResp).Result;
+ TGetResultSetMetadataResp schemaMetadata =
GetResultSetMetadataAsync(getSchemasResp, cancellationToken).Result;
+ IReadOnlyDictionary<string, int> columnMap =
GetColumnIndexMap(schemaMetadata.Schema.Columns);
+ TRowSet rowSet = GetRowSetAsync(getSchemasResp,
cancellationToken).Result;
- IReadOnlyList<string> catalogList =
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
- IReadOnlyList<string> schemaList =
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
- IReadOnlyList<string> tableList =
rowSet.Columns[columnMap[TableName]].StringVal.Values;
- IReadOnlyList<string> tableTypeList =
rowSet.Columns[columnMap[TableType]].StringVal.Values;
+ IReadOnlyList<string> catalogList =
rowSet.Columns[columnMap[TableCatalog]].StringVal.Values;
+ IReadOnlyList<string> schemaList =
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
- for (int i = 0; i < catalogList.Count; i++)
- {
- string catalog = catalogList[i];
- string schemaDb = schemaList[i];
- string tableName = tableList[i];
- string tableType = tableTypeList[i];
- TableInfo tableInfo = new(tableType);
-
catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName,
tableInfo);
+ for (int i = 0; i < catalogList.Count; i++)
+ {
+ string catalog = catalogList[i];
+ string schemaDb = schemaList[i];
+ // It seems Spark sometimes returns empty string for
catalog on some schema (temporary tables).
+ catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb,
new Dictionary<string, TableInfo>());
+ }
}
- }
- if (depth == GetObjectsDepth.All)
- {
- TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle);
- columnsReq.CatalogName = catalogPattern;
- columnsReq.SchemaName = dbSchemaPattern;
- columnsReq.TableName = tableNamePattern;
- columnsReq.GetDirectResults = sparkGetDirectResults;
+ if (depth == GetObjectsDepth.All || depth >=
GetObjectsDepth.Tables)
+ {
+ TGetTablesReq getTablesReq = new
TGetTablesReq(SessionHandle);
+ getTablesReq.CatalogName = catalogPattern;
+ getTablesReq.SchemaName = dbSchemaPattern;
+ getTablesReq.TableName = tableNamePattern;
+ getTablesReq.GetDirectResults = sparkGetDirectResults;
+
+ TGetTablesResp getTablesResp =
Client.GetTables(getTablesReq, cancellationToken).Result;
+ if (getTablesResp.Status.StatusCode ==
TStatusCode.ERROR_STATUS)
+ {
+ throw new Exception(getTablesResp.Status.ErrorMessage);
+ }
- if (!string.IsNullOrEmpty(columnNamePattern))
- columnsReq.ColumnName = columnNamePattern;
+ TGetResultSetMetadataResp tableMetadata =
GetResultSetMetadataAsync(getTablesResp, cancellationToken).Result;
+ IReadOnlyDictionary<string, int> columnMap =
GetColumnIndexMap(tableMetadata.Schema.Columns);
+ TRowSet rowSet = GetRowSetAsync(getTablesResp,
cancellationToken).Result;
- var columnsResponse = Client.GetColumns(columnsReq).Result;
- if (columnsResponse.Status.StatusCode ==
TStatusCode.ERROR_STATUS)
- {
- throw new Exception(columnsResponse.Status.ErrorMessage);
+ IReadOnlyList<string> catalogList =
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
+ IReadOnlyList<string> schemaList =
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
+ IReadOnlyList<string> tableList =
rowSet.Columns[columnMap[TableName]].StringVal.Values;
+ IReadOnlyList<string> tableTypeList =
rowSet.Columns[columnMap[TableType]].StringVal.Values;
+
+ for (int i = 0; i < catalogList.Count; i++)
+ {
+ string catalog = catalogList[i];
+ string schemaDb = schemaList[i];
+ string tableName = tableList[i];
+ string tableType = tableTypeList[i];
+ TableInfo tableInfo = new(tableType);
+
catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName,
tableInfo);
+ }
}
- TGetResultSetMetadataResp columnsMetadata =
GetResultSetMetadataAsync(columnsResponse).Result;
- IReadOnlyDictionary<string, int> columnMap =
GetColumnIndexMap(columnsMetadata.Schema.Columns);
- TRowSet rowSet = GetRowSetAsync(columnsResponse).Result;
-
- IReadOnlyList<string> catalogList =
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
- IReadOnlyList<string> schemaList =
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
- IReadOnlyList<string> tableList =
rowSet.Columns[columnMap[TableName]].StringVal.Values;
- IReadOnlyList<string> columnNameList =
rowSet.Columns[columnMap[ColumnName]].StringVal.Values;
- ReadOnlySpan<int> columnTypeList =
rowSet.Columns[columnMap[DataType]].I32Val.Values.Values;
- IReadOnlyList<string> typeNameList =
rowSet.Columns[columnMap[TypeName]].StringVal.Values;
- ReadOnlySpan<int> nullableList =
rowSet.Columns[columnMap[Nullable]].I32Val.Values.Values;
- IReadOnlyList<string> columnDefaultList =
rowSet.Columns[columnMap[ColumnDef]].StringVal.Values;
- ReadOnlySpan<int> ordinalPosList =
rowSet.Columns[columnMap[OrdinalPosition]].I32Val.Values.Values;
- IReadOnlyList<string> isNullableList =
rowSet.Columns[columnMap[IsNullable]].StringVal.Values;
- IReadOnlyList<string> isAutoIncrementList =
rowSet.Columns[columnMap[IsAutoIncrement]].StringVal.Values;
-
- for (int i = 0; i < catalogList.Count; i++)
+ if (depth == GetObjectsDepth.All)
{
- // For systems that don't support 'catalog' in the
namespace
- string catalog = catalogList[i] ?? string.Empty;
- string schemaDb = schemaList[i];
- string tableName = tableList[i];
- string columnName = columnNameList[i];
- short colType = (short)columnTypeList[i];
- string typeName = typeNameList[i];
- short nullable = (short)nullableList[i];
- string? isAutoIncrementString = isAutoIncrementList[i];
- bool isAutoIncrement =
(!string.IsNullOrEmpty(isAutoIncrementString) &&
(isAutoIncrementString.Equals("YES",
StringComparison.InvariantCultureIgnoreCase) ||
isAutoIncrementString.Equals("TRUE",
StringComparison.InvariantCultureIgnoreCase)));
- string isNullable = isNullableList[i] ?? "YES";
- string columnDefault = columnDefaultList[i] ?? "";
- // Spark/Databricks reports ordinal index zero-indexed,
instead of one-indexed
- int ordinalPos = ordinalPosList[i] + 1;
- TableInfo? tableInfo =
catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName);
- tableInfo?.ColumnName.Add(columnName);
- tableInfo?.ColType.Add(colType);
- tableInfo?.Nullable.Add(nullable);
- tableInfo?.IsAutoIncrement.Add(isAutoIncrement);
- tableInfo?.IsNullable.Add(isNullable);
- tableInfo?.ColumnDefault.Add(columnDefault);
- tableInfo?.OrdinalPosition.Add(ordinalPos);
- SetPrecisionScaleAndTypeName(colType, typeName, tableInfo);
- }
- }
+ TGetColumnsReq columnsReq = new
TGetColumnsReq(SessionHandle);
+ columnsReq.CatalogName = catalogPattern;
+ columnsReq.SchemaName = dbSchemaPattern;
+ columnsReq.TableName = tableNamePattern;
+ columnsReq.GetDirectResults = sparkGetDirectResults;
- StringArray.Builder catalogNameBuilder = new StringArray.Builder();
- List<IArrowArray?> catalogDbSchemasValues = new
List<IArrowArray?>();
+ if (!string.IsNullOrEmpty(columnNamePattern))
+ columnsReq.ColumnName = columnNamePattern;
- foreach (KeyValuePair<string, Dictionary<string,
Dictionary<string, TableInfo>>> catalogEntry in catalogMap)
- {
- catalogNameBuilder.Append(catalogEntry.Key);
+ var columnsResponse = Client.GetColumns(columnsReq,
cancellationToken).Result;
+ if (columnsResponse.Status.StatusCode ==
TStatusCode.ERROR_STATUS)
+ {
+ throw new
Exception(columnsResponse.Status.ErrorMessage);
+ }
- if (depth == GetObjectsDepth.Catalogs)
- {
- catalogDbSchemasValues.Add(null);
+ TGetResultSetMetadataResp columnsMetadata =
GetResultSetMetadataAsync(columnsResponse, cancellationToken).Result;
+ IReadOnlyDictionary<string, int> columnMap =
GetColumnIndexMap(columnsMetadata.Schema.Columns);
+ TRowSet rowSet = GetRowSetAsync(columnsResponse,
cancellationToken).Result;
+
+ IReadOnlyList<string> catalogList =
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
+ IReadOnlyList<string> schemaList =
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
+ IReadOnlyList<string> tableList =
rowSet.Columns[columnMap[TableName]].StringVal.Values;
+ IReadOnlyList<string> columnNameList =
rowSet.Columns[columnMap[ColumnName]].StringVal.Values;
+ ReadOnlySpan<int> columnTypeList =
rowSet.Columns[columnMap[DataType]].I32Val.Values.Values;
+ IReadOnlyList<string> typeNameList =
rowSet.Columns[columnMap[TypeName]].StringVal.Values;
+ ReadOnlySpan<int> nullableList =
rowSet.Columns[columnMap[Nullable]].I32Val.Values.Values;
+ IReadOnlyList<string> columnDefaultList =
rowSet.Columns[columnMap[ColumnDef]].StringVal.Values;
+ ReadOnlySpan<int> ordinalPosList =
rowSet.Columns[columnMap[OrdinalPosition]].I32Val.Values.Values;
+ IReadOnlyList<string> isNullableList =
rowSet.Columns[columnMap[IsNullable]].StringVal.Values;
+ IReadOnlyList<string> isAutoIncrementList =
rowSet.Columns[columnMap[IsAutoIncrement]].StringVal.Values;
+
+ for (int i = 0; i < catalogList.Count; i++)
+ {
+ // For systems that don't support 'catalog' in the
namespace
+ string catalog = catalogList[i] ?? string.Empty;
+ string schemaDb = schemaList[i];
+ string tableName = tableList[i];
+ string columnName = columnNameList[i];
+ short colType = (short)columnTypeList[i];
+ string typeName = typeNameList[i];
+ short nullable = (short)nullableList[i];
+ string? isAutoIncrementString = isAutoIncrementList[i];
+ bool isAutoIncrement =
(!string.IsNullOrEmpty(isAutoIncrementString) &&
(isAutoIncrementString.Equals("YES",
StringComparison.InvariantCultureIgnoreCase) ||
isAutoIncrementString.Equals("TRUE",
StringComparison.InvariantCultureIgnoreCase)));
+ string isNullable = isNullableList[i] ?? "YES";
+ string columnDefault = columnDefaultList[i] ?? "";
+ // Spark/Databricks reports ordinal index
zero-indexed, instead of one-indexed
+ int ordinalPos = ordinalPosList[i] + 1;
+ TableInfo? tableInfo =
catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName);
+ tableInfo?.ColumnName.Add(columnName);
+ tableInfo?.ColType.Add(colType);
+ tableInfo?.Nullable.Add(nullable);
+ tableInfo?.IsAutoIncrement.Add(isAutoIncrement);
+ tableInfo?.IsNullable.Add(isNullable);
+ tableInfo?.ColumnDefault.Add(columnDefault);
+ tableInfo?.OrdinalPosition.Add(ordinalPos);
+ SetPrecisionScaleAndTypeName(colType, typeName,
tableInfo);
+ }
}
- else
+
+ StringArray.Builder catalogNameBuilder = new
StringArray.Builder();
+ List<IArrowArray?> catalogDbSchemasValues = new
List<IArrowArray?>();
+
+ foreach (KeyValuePair<string, Dictionary<string,
Dictionary<string, TableInfo>>> catalogEntry in catalogMap)
{
- catalogDbSchemasValues.Add(GetDbSchemas(
- depth, catalogEntry.Value));
+ catalogNameBuilder.Append(catalogEntry.Key);
+
+ if (depth == GetObjectsDepth.Catalogs)
+ {
+ catalogDbSchemasValues.Add(null);
+ }
+ else
+ {
+ catalogDbSchemasValues.Add(GetDbSchemas(
+ depth, catalogEntry.Value));
+ }
}
- }
- Schema schema = StandardSchemas.GetObjectsSchema;
- IReadOnlyList<IArrowArray> dataArrays = schema.Validate(
- new List<IArrowArray>
- {
+ Schema schema = StandardSchemas.GetObjectsSchema;
+ IReadOnlyList<IArrowArray> dataArrays = schema.Validate(
+ new List<IArrowArray>
+ {
catalogNameBuilder.Build(),
catalogDbSchemasValues.BuildListArrayForType(new
StructType(StandardSchemas.DbSchemaSchema)),
- });
+ });
- return new SparkInfoArrowStream(schema, dataArrays);
+ return new SparkInfoArrowStream(schema, dataArrays);
+ }
+ catch (Exception ex)
+ when (ApacheUtility.ContainsException(ex, out
OperationCanceledException? _) ||
+ (ApacheUtility.ContainsException(ex, out
TTransportException? _) && cancellationToken.IsCancellationRequested))
+ {
+ throw new TimeoutException("The metadata query execution timed
out. Consider increasing the query timeout value.", ex);
+ }
+ catch (Exception ex) when (ex is not HiveServer2Exception)
+ {
+ throw new HiveServer2Exception($"An unexpected error occurred
while running metadata query. '{ex.Message}'", ex);
+ }
}
private static IReadOnlyDictionary<string, int>
GetColumnIndexMap(List<TColumnDesc> columns) => columns
@@ -998,15 +1040,15 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
protected abstract void ValidateAuthentication();
protected abstract void ValidateOptions();
- protected abstract Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response);
- protected abstract Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response);
- protected abstract Task<TRowSet> GetRowSetAsync(TGetTablesResp
response);
- protected abstract Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
getCatalogsResp);
- protected abstract Task<TRowSet> GetRowSetAsync(TGetSchemasResp
getSchemasResp);
- protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response);
- protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response);
- protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response);
- protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response);
+ protected abstract Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response, CancellationToken cancellationToken = default);
+ protected abstract Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response, CancellationToken cancellationToken = default);
+ protected abstract Task<TRowSet> GetRowSetAsync(TGetTablesResp
response, CancellationToken cancellationToken = default);
+ protected abstract Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
getCatalogsResp, CancellationToken cancellationToken = default);
+ protected abstract Task<TRowSet> GetRowSetAsync(TGetSchemasResp
getSchemasResp, CancellationToken cancellationToken = default);
+ protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken
cancellationToken = default);
+ protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken
cancellationToken = default);
+ protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken
cancellationToken = default);
+ protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken
cancellationToken = default);
internal abstract SparkServerType ServerType { get; }
diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
index 7d187fc71..d51ef42b9 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
@@ -16,6 +16,7 @@
*/
using System.Collections.Generic;
+using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
using Apache.Hive.Service.Rpc.Thrift;
@@ -43,24 +44,24 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
return req;
}
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response) =>
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken
cancellationToken = default) =>
Task.FromResult(response.DirectResults.ResultSetMetadata);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response) =>
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken
cancellationToken = default) =>
Task.FromResult(response.DirectResults.ResultSetMetadata);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response) =>
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken
cancellationToken = default) =>
Task.FromResult(response.DirectResults.ResultSetMetadata);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response) =>
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken
cancellationToken = default) =>
Task.FromResult(response.DirectResults.ResultSetMetadata);
- protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response) =>
+ protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response, CancellationToken cancellationToken = default) =>
Task.FromResult(response.DirectResults.ResultSet.Results);
- protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response) =>
+ protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response, CancellationToken cancellationToken = default) =>
Task.FromResult(response.DirectResults.ResultSet.Results);
- protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp
response) =>
+ protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp
response, CancellationToken cancellationToken = default) =>
Task.FromResult(response.DirectResults.ResultSet.Results);
- protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
response) =>
+ protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
response, CancellationToken cancellationToken = default) =>
Task.FromResult(response.DirectResults.ResultSet.Results);
- protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp
response) =>
+ protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp
response, CancellationToken cancellationToken = default) =>
Task.FromResult(response.DirectResults.ResultSet.Results);
}
}
diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
index 77ecdb6a2..059ab1690 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
@@ -15,7 +15,6 @@
* limitations under the License.
*/
-using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
index 9d34ac75c..4c068aaa5 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
@@ -120,24 +120,19 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv);
Properties.TryGetValue(SparkParameters.TLSOptions, out string?
tlsOptions);
TlsOptions = TlsOptionsParser.Parse(tlsOptions);
-
Properties.TryGetValue(SparkParameters.HttpRequestTimeoutMilliseconds, out
string? requestTimeoutMs);
- if (requestTimeoutMs != null)
+ Properties.TryGetValue(SparkParameters.ConnectTimeoutMilliseconds,
out string? connectTimeoutMs);
+ if (connectTimeoutMs != null)
{
- HttpRequestTimeout = int.TryParse(requestTimeoutMs,
NumberStyles.Integer, CultureInfo.InvariantCulture, out int
requestTimeoutMsValue) && requestTimeoutMsValue > 0
- ? requestTimeoutMsValue
- : throw new
ArgumentOutOfRangeException(SparkParameters.HttpRequestTimeoutMilliseconds,
requestTimeoutMs, $"must be a value between 1 .. {int.MaxValue}. default is
30000 milliseconds.");
+ ConnectTimeoutMilliseconds = int.TryParse(connectTimeoutMs,
NumberStyles.Integer, CultureInfo.InvariantCulture, out int
connectTimeoutMsValue) && (connectTimeoutMsValue >= 0)
+ ? connectTimeoutMsValue
+ : throw new
ArgumentOutOfRangeException(SparkParameters.ConnectTimeoutMilliseconds,
connectTimeoutMs, $"must be a value of 0 (infinite) or between 1 ..
{int.MaxValue}. default is 30000 milliseconds.");
}
}
internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema) => new HiveServer2Reader(statement, schema, dataTypeConversion:
statement.Connection.DataTypeConversion);
- protected override Task<TTransport> CreateTransportAsync()
+ protected override TTransport CreateTransport()
{
- foreach (var property in Properties.Keys)
- {
- Trace.TraceError($"key = {property} value =
{Properties[property]}");
- }
-
// Assumption: parameters have already been validated.
Properties.TryGetValue(SparkParameters.HostName, out string?
hostName);
Properties.TryGetValue(SparkParameters.Path, out string? path);
@@ -164,9 +159,12 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
TConfiguration config = new();
ThriftHttpTransport transport = new(httpClient, config)
{
- ConnectTimeout = HttpRequestTimeout,
+ // This value can only be set before the first call/request.
So if a new value for query timeout
+ // is set, we won't be able to update the value. Setting to
~infinite and relying on cancellation token
+ // to ensure cancelled correctly.
+ ConnectTimeout = int.MaxValue,
};
- return Task.FromResult<TTransport>(transport);
+ return transport;
}
private HttpClientHandler NewHttpClientHandler()
@@ -211,11 +209,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
}
}
- protected override async Task<TProtocol>
CreateProtocolAsync(TTransport transport)
+ protected override async Task<TProtocol>
CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken =
default)
{
- Trace.TraceError($"create protocol with {Properties.Count}
properties.");
-
- if (!transport.IsOpen) await
transport.OpenAsync(CancellationToken.None);
+ if (!transport.IsOpen) await
transport.OpenAsync(cancellationToken);
return new TBinaryProtocol(transport);
}
@@ -228,28 +224,29 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
return req;
}
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client);
- protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response) =>
- FetchResultsAsync(response.OperationHandle);
- protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response) =>
- FetchResultsAsync(response.OperationHandle);
- protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp
response) =>
- FetchResultsAsync(response.OperationHandle);
- protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
response) =>
- FetchResultsAsync(response.OperationHandle);
- protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp
response) =>
- FetchResultsAsync(response.OperationHandle);
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken
cancellationToken = default) =>
+ GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken
cancellationToken = default) =>
+ GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken
cancellationToken = default) =>
+ GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken
cancellationToken = default) =>
+ GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
+ protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response, CancellationToken cancellationToken = default) =>
+ FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
+ protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response, CancellationToken cancellationToken = default) =>
+ FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
+ protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp
response, CancellationToken cancellationToken = default) =>
+ FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
+ protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
response, CancellationToken cancellationToken = default) =>
+ FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
+ protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp
response, CancellationToken cancellationToken = default) =>
+ FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
private async Task<TRowSet> FetchResultsAsync(TOperationHandle
operationHandle, long batchSize = BatchSizeDefault, CancellationToken
cancellationToken = default)
{
- await PollForResponseAsync(operationHandle, Client,
PollTimeMillisecondsDefault);
+ await PollForResponseAsync(operationHandle, Client,
PollTimeMillisecondsDefault, cancellationToken);
+
TFetchResultsResp fetchResp = await
FetchNextAsync(operationHandle, Client, batchSize, cancellationToken);
if (fetchResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
{
diff --git a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
index 4722efce5..6cb96dd5f 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
@@ -15,8 +15,6 @@
* limitations under the License.
*/
-using static System.Net.WebRequestMethods;
-
namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
{
/// <summary>
@@ -32,7 +30,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
public const string Type = "adbc.spark.type";
public const string DataTypeConv = "adbc.spark.data_type_conv";
public const string TLSOptions = "adbc.spark.tls_options";
- public const string HttpRequestTimeoutMilliseconds =
"adbc.spark.http_request_timeout_ms";
+ public const string ConnectTimeoutMilliseconds =
"adbc.spark.connect_timeout_ms";
}
public static class SparkAuthTypeConstants
diff --git a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs
index 51813ed6c..c8ab5772c 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs
@@ -18,6 +18,7 @@
using System;
using System.Collections.Generic;
using System.Net;
+using System.Threading;
using System.Threading.Tasks;
using Apache.Hive.Service.Rpc.Thrift;
using Thrift.Protocol;
@@ -85,7 +86,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
}
- protected override Task<TTransport> CreateTransportAsync()
+ protected override TTransport CreateTransport()
{
// Assumption: hostName and port have already been validated.
Properties.TryGetValue(SparkParameters.HostName, out string?
hostName);
@@ -94,14 +95,13 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
// Delay the open connection until later.
bool connectClient = false;
ThriftSocketTransport transport = new(hostName!, int.Parse(port!),
connectClient, config: new());
- return Task.FromResult<TTransport>(transport);
+ return transport;
}
- protected override async Task<TProtocol>
CreateProtocolAsync(TTransport transport)
+ protected override async Task<TProtocol>
CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken =
default)
{
- return await base.CreateProtocolAsync(transport);
+ return await base.CreateProtocolAsync(transport,
cancellationToken);
- //Trace.TraceError($"create protocol with {Properties.Count}
properties.");
//if (!transport.IsOpen) await
transport.OpenAsync(CancellationToken.None);
//return new TBinaryProtocol(transport);
}
diff --git a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
index e4bc3f6cd..25888b1a3 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
@@ -32,6 +32,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
{
case Options.BatchSize:
case Options.PollTimeMilliseconds:
+ case Options.QueryTimeoutSeconds:
{
SetOption(kvp.Key, kvp.Value);
break;
@@ -45,7 +46,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
// TODO: Ensure this is set dynamically depending on server
capabilities.
statement.EnforceResultPersistenceMode = false;
statement.ResultPersistenceMode = 2;
-
+ // This seems like a good idea to have the server timeout so it
doesn't keep processing unnecessarily.
+ // Set in combination with a CancellationToken.
+ statement.QueryTimeout = QueryTimeoutSeconds;
statement.CanReadArrowResult = true;
statement.CanDownloadResult = true;
statement.ConfOverlay = SparkConnection.timestampConfig;
@@ -65,7 +68,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
/// <summary>
/// Provides the constant string key values to the <see
cref="AdbcStatement.SetOption(string, string)" /> method.
/// </summary>
- public new sealed class Options : HiveServer2Statement.Options
+ public sealed class Options : ApacheParameters
{
// options specific to Spark go here
}
diff --git a/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs
b/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs
index fb62ccd9a..ea3d7d16e 100644
--- a/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs
+++ b/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs
@@ -45,11 +45,14 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache
[JsonPropertyName("batch_size"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
public string BatchSize { get; set; } = string.Empty;
- [JsonPropertyName("polltime_milliseconds"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
+ [JsonPropertyName("polltime_ms"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
public string PollTimeMilliseconds { get; set; } = string.Empty;
- [JsonPropertyName("http_request_timeout_ms"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
- public string HttpRequestTimeoutMilliseconds { get; set; } =
string.Empty;
+ [JsonPropertyName("connect_timeout_ms"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
+ public string ConnectTimeoutMilliseconds { get; set; } = string.Empty;
+
+ [JsonPropertyName("query_timeout_s"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
+ public string QueryTimeoutSeconds { get; set; } = string.Empty;
[JsonPropertyName("type"), JsonIgnore(Condition =
JsonIgnoreCondition.WhenWritingDefault)]
public string Type { get; set; } = string.Empty;
diff --git a/csharp/test/Drivers/Apache/Common/ClientTests.cs
b/csharp/test/Drivers/Apache/Common/ClientTests.cs
index e3b0309d0..9148d7281 100644
--- a/csharp/test/Drivers/Apache/Common/ClientTests.cs
+++ b/csharp/test/Drivers/Apache/Common/ClientTests.cs
@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
+using Apache.Arrow.Adbc.Client;
using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Tests.Xunit;
using Xunit;
@@ -203,6 +204,27 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common
}
}
+ [SkippableFact]
+ public void VerifyTimeoutsSet()
+ {
+ using (Adbc.Client.AdbcConnection adbcConnection =
GetAdbcConnection())
+ {
+ int timeout = 99;
+ using AdbcCommand cmd = adbcConnection.CreateCommand();
+
+ // setting the timout before the property value
+ Assert.Throws<InvalidOperationException>(() =>
+ {
+ cmd.CommandTimeout = 1;
+ });
+
+ cmd.AdbcCommandTimeoutProperty =
"adbc.apache.statement.query_timeout_s";
+ cmd.CommandTimeout = timeout;
+
+ Assert.True(cmd.CommandTimeout == timeout, $"ConnectionTimeout
is not set to {timeout}");
+ }
+ }
+
private Adbc.Client.AdbcConnection GetAdbcConnection(bool
includeTableConstraints = true)
{
return new Adbc.Client.AdbcConnection(
diff --git a/csharp/test/Drivers/Apache/Common/StatementTests.cs
b/csharp/test/Drivers/Apache/Common/StatementTests.cs
index 69eec0dd2..b793b7686 100644
--- a/csharp/test/Drivers/Apache/Common/StatementTests.cs
+++ b/csharp/test/Drivers/Apache/Common/StatementTests.cs
@@ -18,6 +18,7 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache;
using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Tests.Xunit;
using Xunit;
@@ -68,11 +69,11 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common
AdbcStatement statement = NewConnection().CreateStatement();
if (throws)
{
- Assert.Throws<ArgumentOutOfRangeException>(() =>
statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.PollTimeMilliseconds,
value));
+ Assert.Throws<ArgumentOutOfRangeException>(() =>
statement.SetOption(ApacheParameters.PollTimeMilliseconds, value));
}
else
{
-
statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.PollTimeMilliseconds,
value);
+ statement.SetOption(ApacheParameters.PollTimeMilliseconds,
value);
}
}
@@ -101,11 +102,74 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common
AdbcStatement statement = NewConnection().CreateStatement();
if (throws)
{
- Assert.Throws<ArgumentOutOfRangeException>(() =>
statement!.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.BatchSize,
value));
+ Assert.Throws<ArgumentOutOfRangeException>(() =>
statement!.SetOption(ApacheParameters.BatchSize, value));
}
else
{
-
statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.BatchSize,
value);
+ statement.SetOption(ApacheParameters.BatchSize, value);
+ }
+ }
+
+ /// <summary>
+ /// Validates if the SetOption handle valid/invalid data correctly for
the QueryTimeout option.
+ /// </summary>
+ [SkippableTheory]
+ [InlineData("zero", true)]
+ [InlineData("-2147483648", true)]
+ [InlineData("2147483648", true)]
+ [InlineData("0", false)]
+ [InlineData("-1", true)]
+ [InlineData("1")]
+ [InlineData("2147483647")]
+ public void CanSetOptionQueryTimeout(string value, bool throws = false)
+ {
+ var testConfiguration = TestConfiguration.Clone() as TConfig;
+ testConfiguration!.QueryTimeoutSeconds = value;
+ if (throws)
+ {
+ Assert.Throws<ArgumentOutOfRangeException>(() =>
NewConnection(testConfiguration).CreateStatement());
+ }
+
+ AdbcStatement statement = NewConnection().CreateStatement();
+ if (throws)
+ {
+ Assert.Throws<ArgumentOutOfRangeException>(() =>
statement.SetOption(ApacheParameters.QueryTimeoutSeconds, value));
+ }
+ else
+ {
+ statement.SetOption(ApacheParameters.QueryTimeoutSeconds,
value);
+ }
+ }
+
+ /// <summary>
+ /// Queries the backend with various timeouts.
+ /// </summary>
+ /// <param name="statementWithExceptions"></param>
+ [SkippableTheory]
+ [ClassData(typeof(StatementTimeoutTestData))]
+ internal void StatementTimeoutTest(StatementWithExceptions
statementWithExceptions)
+ {
+ TConfig testConfiguration = (TConfig)TestConfiguration.Clone();
+
+ if (statementWithExceptions.QueryTimeoutSeconds.HasValue)
+ testConfiguration.QueryTimeoutSeconds =
statementWithExceptions.QueryTimeoutSeconds.Value.ToString();
+
+ if (!string.IsNullOrEmpty(statementWithExceptions.Query))
+ testConfiguration.Query = statementWithExceptions.Query!;
+
+ OutputHelper?.WriteLine($"QueryTimeoutSeconds:
{testConfiguration.QueryTimeoutSeconds}. ShouldSucceed:
{statementWithExceptions.ExceptionType == null}. Query:
[{testConfiguration.Query}]");
+
+ try
+ {
+ AdbcStatement st =
NewConnection(testConfiguration).CreateStatement();
+ st.SqlQuery = testConfiguration.Query;
+ QueryResult qr = st.ExecuteQuery();
+
+ OutputHelper?.WriteLine($"QueryResultRowCount: {qr.RowCount}");
+ }
+ catch (Exception ex) when (ApacheUtility.ContainsException(ex,
statementWithExceptions.ExceptionType, out Exception? containedException))
+ {
+ Assert.IsType(statementWithExceptions.ExceptionType!,
containedException!);
}
}
@@ -116,10 +180,58 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common
public async Task CanInteractUsingSetOptions()
{
const string columnName = "INDEX";
-
Statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.PollTimeMilliseconds,
"100");
-
Statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.BatchSize,
"10");
+ Statement.SetOption(ApacheParameters.PollTimeMilliseconds, "100");
+ Statement.SetOption(ApacheParameters.BatchSize, "10");
using TemporaryTable temporaryTable = await
NewTemporaryTableAsync(Statement, $"{columnName} INT");
await
ValidateInsertSelectDeleteSingleValueAsync(temporaryTable.TableName,
columnName, 1);
}
}
+
+ /// <summary>
+ /// Data type used for metadata timeout tests.
+ /// </summary>
+ internal class StatementWithExceptions
+ {
+ public StatementWithExceptions(int? queryTimeoutSeconds, string?
query, Type? exceptionType)
+ {
+ QueryTimeoutSeconds = queryTimeoutSeconds;
+ Query = query;
+ ExceptionType = exceptionType;
+ }
+
+ /// <summary>
+ /// If null, uses the default timeout.
+ /// </summary>
+ public int? QueryTimeoutSeconds { get; }
+
+ /// <summary>
+ /// If null, expected to succeed.
+ /// </summary>
+ public Type? ExceptionType { get; }
+
+ /// <summary>
+ /// If null, uses the default TestConfiguration.
+ /// </summary>
+ public string? Query { get; }
+ }
+
+ /// <summary>
+ /// Collection of <see cref="StatementWithExceptions"/> for testing
statement timeouts."/>
+ /// </summary>
+ internal class StatementTimeoutTestData :
TheoryData<StatementWithExceptions>
+ {
+ public StatementTimeoutTestData()
+ {
+ string longRunningQuery = "SELECT COUNT(*) AS total_count\nFROM
(\n SELECT t1.id AS id1, t2.id AS id2\n FROM RANGE(1000000) t1\n CROSS JOIN
RANGE(10000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0";
+
+ Add(new(0, null, null));
+ Add(new(null, null, null));
+ Add(new(1, null, typeof(TimeoutException)));
+ Add(new(5, null, null));
+ Add(new(30, null, null));
+ Add(new(5, longRunningQuery, typeof(TimeoutException)));
+ Add(new(null, longRunningQuery, typeof(TimeoutException)));
+ Add(new(0, longRunningQuery, null));
+ }
+ }
}
diff --git a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs
b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs
index c2faa9d12..34e971bd8 100644
--- a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs
+++ b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs
@@ -19,7 +19,10 @@ using System;
using System.Collections.Generic;
using System.Globalization;
using System.Net;
+using Apache.Arrow.Adbc.Drivers.Apache;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Drivers.Apache.Spark;
+using Thrift.Transport;
using Xunit;
using Xunit.Abstractions;
@@ -48,6 +51,231 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
OutputHelper?.WriteLine(exeption.Message);
}
+ /// <summary>
+ /// Tests connection timeout to establish a session with the backend.
+ /// </summary>
+ /// <param name="connectTimeoutMilliseconds">The timeout (in
ms)</param>
+ /// <param name="exceptionType">The exception type to expect (if
any)</param>
+ /// <param name="alternateExceptionType">An alternate exception that
may occur (if any)</param>
+ [SkippableTheory]
+ [InlineData(0, null, null)]
+ [InlineData(1, typeof(TimeoutException), typeof(TTransportException))]
+ [InlineData(10, typeof(TimeoutException), typeof(TTransportException))]
+ [InlineData(30000, null, null)]
+ [InlineData(null, null, null)]
+ public void ConnectionTimeoutTest(int? connectTimeoutMilliseconds,
Type? exceptionType, Type? alternateExceptionType)
+ {
+ SparkTestConfiguration testConfiguration =
(SparkTestConfiguration)TestConfiguration.Clone();
+
+ if (connectTimeoutMilliseconds.HasValue)
+ testConfiguration.ConnectTimeoutMilliseconds =
connectTimeoutMilliseconds.Value.ToString();
+
+ OutputHelper?.WriteLine($"ConnectTimeoutMilliseconds:
{testConfiguration.ConnectTimeoutMilliseconds}. ShouldSucceed: {exceptionType
== null}");
+
+ try
+ {
+ NewConnection(testConfiguration);
+ }
+ catch(AggregateException aex)
+ {
+ if (exceptionType != null)
+ {
+ if (alternateExceptionType != null &&
aex.InnerException?.GetType() != exceptionType)
+ {
+ if (aex.InnerException?.GetType() ==
typeof(HiveServer2Exception))
+ {
+ // a TTransportException is inside a
HiveServer2Exception
+ Assert.IsType(alternateExceptionType,
aex.InnerException!.InnerException);
+ }
+ else
+ {
+ throw;
+ }
+ }
+ else
+ {
+ Assert.IsType(exceptionType, aex.InnerException);
+ }
+ }
+ else
+ {
+ throw;
+ }
+ }
+ }
+
+ /// <summary>
+ /// Tests the various metadata calls on a SparkConnection
+ /// </summary>
+ /// <param name="metadataWithException"></param>
+ [SkippableTheory]
+ [ClassData(typeof(MetadataTimeoutTestData))]
+ internal void MetadataTimeoutTest(MetadataWithExceptions
metadataWithException)
+ {
+ SparkTestConfiguration testConfiguration =
(SparkTestConfiguration)TestConfiguration.Clone();
+
+ if (metadataWithException.QueryTimeoutSeconds.HasValue)
+ testConfiguration.QueryTimeoutSeconds =
metadataWithException.QueryTimeoutSeconds.Value.ToString();
+
+ OutputHelper?.WriteLine($"Action:
{metadataWithException.ActionName}. QueryTimeoutSeconds:
{testConfiguration.QueryTimeoutSeconds}. ShouldSucceed:
{metadataWithException.ExceptionType == null}");
+
+ try
+ {
+ metadataWithException.MetadataAction(testConfiguration);
+ }
+ catch (Exception ex) when (ApacheUtility.ContainsException(ex,
metadataWithException.ExceptionType, out Exception? containedException))
+ {
+ Assert.IsType(metadataWithException.ExceptionType!,
containedException);
+ }
+ catch (Exception ex) when (ApacheUtility.ContainsException(ex,
metadataWithException.AlternateExceptionType, out Exception?
containedException))
+ {
+ Assert.IsType(metadataWithException.AlternateExceptionType!,
containedException);
+ }
+ }
+
+ /// <summary>
+ /// Data type used for metadata timeout tests.
+ /// </summary>
+ internal class MetadataWithExceptions
+ {
+ public MetadataWithExceptions(int? queryTimeoutSeconds, string
actionName, Action<SparkTestConfiguration> action, Type? exceptionType, Type?
alternateExceptionType)
+ {
+ QueryTimeoutSeconds = queryTimeoutSeconds;
+ ActionName = actionName;
+ MetadataAction = action;
+ ExceptionType = exceptionType;
+ AlternateExceptionType = alternateExceptionType;
+ }
+
+ /// <summary>
+ /// If null, uses the default timeout.
+ /// </summary>
+ public int? QueryTimeoutSeconds { get; }
+
+ public string ActionName { get; }
+
+ /// <summary>
+ /// If null, expected to succeed.
+ /// </summary>
+ public Type? ExceptionType { get; }
+
+ /// <summary>
+ /// Sometimes you can expect one but may get another.
+ /// For example, on GetObjectsAll, sometimes a TTransportException
is expected but a TaskCanceledException is received during the test.
+ /// </summary>
+ public Type? AlternateExceptionType { get; }
+
+ /// <summary>
+ /// The metadata action to perform.
+ /// </summary>
+ public Action<SparkTestConfiguration> MetadataAction { get; }
+ }
+
+ /// <summary>
+ /// Used for testing timeouts on metadata calls.
+ /// </summary>
+ internal class MetadataTimeoutTestData :
TheoryData<MetadataWithExceptions>
+ {
+ public MetadataTimeoutTestData()
+ {
+ SparkConnectionTest sparkConnectionTest = new
SparkConnectionTest(null);
+
+ Action<SparkTestConfiguration> getObjectsAll =
(testConfiguration) =>
+ {
+ AdbcConnection cn =
sparkConnectionTest.NewConnection(testConfiguration);
+ cn.GetObjects(AdbcConnection.GetObjectsDepth.All,
testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema,
testConfiguration.Metadata.Table, null, null);
+ };
+
+ Action<SparkTestConfiguration> getObjectsCatalogs =
(testConfiguration) =>
+ {
+ AdbcConnection cn =
sparkConnectionTest.NewConnection(testConfiguration);
+ cn.GetObjects(AdbcConnection.GetObjectsDepth.Catalogs,
testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema,
testConfiguration.Metadata.Schema, null, null);
+ };
+
+ Action<SparkTestConfiguration> getObjectsDbSchemas =
(testConfiguration) =>
+ {
+ AdbcConnection cn =
sparkConnectionTest.NewConnection(testConfiguration);
+ cn.GetObjects(AdbcConnection.GetObjectsDepth.DbSchemas,
testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema,
testConfiguration.Metadata.Schema, null, null);
+ };
+
+ Action<SparkTestConfiguration> getObjectsTables =
(testConfiguration) =>
+ {
+ AdbcConnection cn =
sparkConnectionTest.NewConnection(testConfiguration);
+ cn.GetObjects(AdbcConnection.GetObjectsDepth.Tables,
testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema,
testConfiguration.Metadata.Schema, null, null);
+ };
+
+ AddAction("getObjectsAll", getObjectsAll, new List<Type?>() {
null, typeof(TimeoutException), null, null, null } );
+ AddAction("getObjectsCatalogs", getObjectsCatalogs);
+ AddAction("getObjectsDbSchemas", getObjectsDbSchemas);
+ AddAction("getObjectsTables", getObjectsTables);
+
+ Action<SparkTestConfiguration> getTableTypes =
(testConfiguration) =>
+ {
+ AdbcConnection cn =
sparkConnectionTest.NewConnection(testConfiguration);
+ cn.GetTableTypes();
+ };
+
+ AddAction("getTableTypes", getTableTypes);
+
+ Action<SparkTestConfiguration> getTableSchema =
(testConfiguration) =>
+ {
+ AdbcConnection cn =
sparkConnectionTest.NewConnection(testConfiguration);
+ cn.GetTableSchema(testConfiguration.Metadata.Catalog,
testConfiguration.Metadata.Schema, testConfiguration.Metadata.Table);
+ };
+
+ AddAction("getTableSchema", getTableSchema);
+ }
+
+ /// <summary>
+ /// Adds the action with the default timeouts.
+ /// </summary>
+ /// <param name="name">The friendly name of the action.</param>
+ /// <param name="action">The action to perform.</param>
+ /// <param name="alternateExceptions">Optional list of alternate
exceptions that are possible. Must have 5 items if present.</param>
+ private void AddAction(string name, Action<SparkTestConfiguration>
action, List<Type?>? alternateExceptions = null)
+ {
+ List<Type?> expectedExceptions = new List<Type?>()
+ {
+ null, // QueryTimeout = 0
+ typeof(TTransportException), // QueryTimeout = 1
+ typeof(TimeoutException), // QueryTimeout = 10
+ null, // QueryTimeout = default
+ null // QueryTimeout = 300
+ };
+
+ AddAction(name, action, expectedExceptions,
alternateExceptions);
+ }
+
+ /// <summary>
+ /// Adds the action with the default timeouts.
+ /// </summary>
+ /// <param name="action">The action to perform.</param>
+ /// <param name="expectedExceptions">The expected
exceptions.</param>
+ /// <remarks>
+ /// For List<Type?> the position is based on the behavior when:
+ /// [0] QueryTimeout = 0
+ /// [1] QueryTimeout = 1
+ /// [2] QueryTimeout = 10
+ /// [3] QueryTimeout = default
+ /// [4] QueryTimeout = 300
+ /// </remarks>
+ private void AddAction(string name, Action<SparkTestConfiguration>
action, List<Type?> expectedExceptions, List<Type?>? alternateExceptions)
+ {
+ Assert.True(expectedExceptions.Count == 5);
+
+ if (alternateExceptions != null)
+ {
+ Assert.True(alternateExceptions.Count == 5);
+ }
+
+ Add(new(0, name, action, expectedExceptions[0],
alternateExceptions?[0]));
+ Add(new(1, name, action, expectedExceptions[1],
alternateExceptions?[1]));
+ Add(new(10, name, action, expectedExceptions[2],
alternateExceptions?[2]));
+ Add(new(null, name, action, expectedExceptions[3],
alternateExceptions?[3]));
+ Add(new(300, name, action, expectedExceptions[4],
alternateExceptions?[4]));
+ }
+ }
+
internal class ParametersWithExceptions
{
public ParametersWithExceptions(Dictionary<string, string>
parameters, Type exceptionType)
@@ -85,11 +313,9 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Databricks, [SparkParameters.HostName] =
"valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] =
"httpxxz://hostname.com" }, typeof(ArgumentOutOfRangeException)));
Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Databricks, [SparkParameters.HostName] =
"valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] =
"http-//hostname.com" }, typeof(UriFormatException)));
Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Databricks, [SparkParameters.HostName] =
"valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] =
"httpxxz://hostname.com:1234567890" }, typeof(UriFormatException)));
- Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com",
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword" ,
[SparkParameters.HttpRequestTimeoutMilliseconds] = "0" },
typeof(ArgumentOutOfRangeException)));
- Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com",
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword",
[SparkParameters.HttpRequestTimeoutMilliseconds] = "-1" },
typeof(ArgumentOutOfRangeException)));
- Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com",
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword",
[SparkParameters.HttpRequestTimeoutMilliseconds] = ((long)int.MaxValue +
1).ToString() }, typeof(ArgumentOutOfRangeException)));
- Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com",
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword",
[SparkParameters.HttpRequestTimeoutMilliseconds] = "non-numeric" },
typeof(ArgumentOutOfRangeException)));
- Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com",
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword",
[SparkParameters.HttpRequestTimeoutMilliseconds] = "" },
typeof(ArgumentOutOfRangeException)));
+ Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com",
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword",
[SparkParameters.ConnectTimeoutMilliseconds] = ((long)int.MaxValue +
1).ToString() }, typeof(ArgumentOutOfRangeException)));
+ Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com",
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword",
[SparkParameters.ConnectTimeoutMilliseconds] = "non-numeric" },
typeof(ArgumentOutOfRangeException)));
+ Add(new(new() { [SparkParameters.Type] =
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com",
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword",
[SparkParameters.ConnectTimeoutMilliseconds] = "" },
typeof(ArgumentOutOfRangeException)));
}
}
}
diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
index 54c536853..16a550111 100644
--- a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
+++ b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
@@ -19,6 +19,7 @@ using System;
using System.Collections.Generic;
using System.Data.SqlTypes;
using System.Text;
+using Apache.Arrow.Adbc.Drivers.Apache;
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Drivers.Apache.Spark;
using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2;
@@ -102,15 +103,19 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
}
if (!string.IsNullOrEmpty(testConfiguration.BatchSize))
{
- parameters.Add(HiveServer2Statement.Options.BatchSize,
testConfiguration.BatchSize!);
+ parameters.Add(ApacheParameters.BatchSize,
testConfiguration.BatchSize!);
}
if (!string.IsNullOrEmpty(testConfiguration.PollTimeMilliseconds))
{
-
parameters.Add(HiveServer2Statement.Options.PollTimeMilliseconds,
testConfiguration.PollTimeMilliseconds!);
+ parameters.Add(ApacheParameters.PollTimeMilliseconds,
testConfiguration.PollTimeMilliseconds!);
}
- if
(!string.IsNullOrEmpty(testConfiguration.HttpRequestTimeoutMilliseconds))
+ if
(!string.IsNullOrEmpty(testConfiguration.ConnectTimeoutMilliseconds))
{
- parameters.Add(SparkParameters.HttpRequestTimeoutMilliseconds,
testConfiguration.HttpRequestTimeoutMilliseconds!);
+ parameters.Add(SparkParameters.ConnectTimeoutMilliseconds,
testConfiguration.ConnectTimeoutMilliseconds!);
+ }
+ if (!string.IsNullOrEmpty(testConfiguration.QueryTimeoutSeconds))
+ {
+ parameters.Add(ApacheParameters.QueryTimeoutSeconds,
testConfiguration.QueryTimeoutSeconds!);
}
return parameters;
diff --git a/csharp/test/Drivers/Apache/Spark/StatementTests.cs
b/csharp/test/Drivers/Apache/Spark/StatementTests.cs
index 25d27179a..aaafc31ba 100644
--- a/csharp/test/Drivers/Apache/Spark/StatementTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/StatementTests.cs
@@ -15,6 +15,8 @@
* limitations under the License.
*/
+using System;
+using Xunit;
using Xunit.Abstractions;
namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark