This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 845f9c2d9 feat(csharp/src/Drivers/Apache): add connect and query 
timeout options (#2312)
845f9c2d9 is described below

commit 845f9c2d98416ae6871680c4afd98d7ff2e51333
Author: Bruce Irschick <[email protected]>
AuthorDate: Wed Dec 4 10:27:43 2024 -0800

    feat(csharp/src/Drivers/Apache): add connect and query timeout options 
(#2312)
    
    Adds options for command and query timeout
    
    | Property               | Description | Default |
    | :---                   | :---        | :---    |
    | `adbc.spark.connect_timeout_ms` | Sets the timeout (in milliseconds)
    to open a new session. Values can be 0 (infinite) or greater than zero.
    | `30000` |
    | `adbc.apache.statement.query_timeout_s` | Sets the maximum time (in
    seconds) for a query to complete. Values can be 0 (infinite) or greater
    than zero. | `60` |
    
    ---------
    
    Co-authored-by: Aman Goyal <[email protected]>
    Co-authored-by: David Coe <[email protected]>
---
 csharp/src/Client/AdbcCommand.cs                   |  28 +-
 ...iveServer2Parameters.cs => ApacheParameters.cs} |  20 +-
 csharp/src/Drivers/Apache/ApacheUtility.cs         | 141 +++++++
 .../Drivers/Apache/Hive2/HiveServer2Connection.cs  | 106 +++--
 .../Drivers/Apache/Hive2/HiveServer2Parameters.cs  |   2 -
 .../src/Drivers/Apache/Hive2/HiveServer2Reader.cs  |  34 +-
 .../Drivers/Apache/Hive2/HiveServer2Statement.cs   | 134 +++++--
 .../src/Drivers/Apache/Impala/ImpalaConnection.cs  |   6 +-
 .../src/Drivers/Apache/Impala/ImpalaStatement.cs   |   2 +-
 csharp/src/Drivers/Apache/Spark/README.md          |  15 +-
 csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 438 +++++++++++----------
 .../Apache/Spark/SparkDatabricksConnection.cs      |  19 +-
 .../Drivers/Apache/Spark/SparkDatabricksReader.cs  |   1 -
 .../Drivers/Apache/Spark/SparkHttpConnection.cs    |  69 ++--
 csharp/src/Drivers/Apache/Spark/SparkParameters.cs |   4 +-
 .../Apache/Spark/SparkStandardConnection.cs        |  10 +-
 csharp/src/Drivers/Apache/Spark/SparkStatement.cs  |   7 +-
 .../test/Drivers/Apache/ApacheTestConfiguration.cs |   9 +-
 csharp/test/Drivers/Apache/Common/ClientTests.cs   |  22 ++
 .../test/Drivers/Apache/Common/StatementTests.cs   | 124 +++++-
 .../Drivers/Apache/Spark/SparkConnectionTest.cs    | 236 ++++++++++-
 .../Drivers/Apache/Spark/SparkTestEnvironment.cs   |  13 +-
 csharp/test/Drivers/Apache/Spark/StatementTests.cs |   2 +
 23 files changed, 1077 insertions(+), 365 deletions(-)

diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs
index 8b85be206..c3695feaf 100644
--- a/csharp/src/Client/AdbcCommand.cs
+++ b/csharp/src/Client/AdbcCommand.cs
@@ -21,6 +21,7 @@ using System.Collections.Generic;
 using System.Data;
 using System.Data.Common;
 using System.Data.SqlTypes;
+using System.Globalization;
 using System.Linq;
 using System.Threading.Tasks;
 using Apache.Arrow.Types;
@@ -32,10 +33,11 @@ namespace Apache.Arrow.Adbc.Client
     /// </summary>
     public sealed class AdbcCommand : DbCommand
     {
-        private AdbcStatement _adbcStatement;
+        private readonly AdbcStatement _adbcStatement;
         private AdbcParameterCollection? _dbParameterCollection;
         private int _timeout = 30;
         private bool _disposed;
+        private string? _commandTimeoutProperty;
 
         /// <summary>
         /// Overloaded. Initializes <see cref="AdbcCommand"/>.
@@ -117,10 +119,32 @@ namespace Apache.Arrow.Adbc.Client
             }
         }
 
+
+        /// <summary>
+        /// Gets or sets the name of the command timeout property for the 
underlying ADBC driver.
+        /// </summary>
+        public string AdbcCommandTimeoutProperty
+        {
+            get
+            {
+                if (string.IsNullOrEmpty(_commandTimeoutProperty))
+                    throw new 
InvalidOperationException("CommandTimeoutProperty is not set.");
+
+                return _commandTimeoutProperty!;
+            }
+            set => _commandTimeoutProperty = value;
+        }
+
         public override int CommandTimeout
         {
             get => _timeout;
-            set => _timeout = value;
+            set
+            {
+                // ensures the property exists before setting the 
CommandTimeout value
+                string property = AdbcCommandTimeoutProperty;
+                _adbcStatement.SetOption(property, 
value.ToString(CultureInfo.InvariantCulture));
+                _timeout = value;
+            }
         }
 
         protected override DbParameterCollection DbParameterCollection
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs 
b/csharp/src/Drivers/Apache/ApacheParameters.cs
similarity index 66%
copy from csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
copy to csharp/src/Drivers/Apache/ApacheParameters.cs
index 2170cd17b..17c94be32 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
+++ b/csharp/src/Drivers/Apache/ApacheParameters.cs
@@ -15,19 +15,15 @@
  * limitations under the License.
  */
 
-using System;
-
-namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
+namespace Apache.Arrow.Adbc.Drivers.Apache
 {
-    public static class DataTypeConversionOptions
-    {
-        public const string None = "none";
-        public const string Scalar = "scalar";
-    }
-
-    public static class TlsOptions
+    /// <summary>
+    /// Options common to all Apache drivers.
+    /// </summary>
+    public class ApacheParameters
     {
-        public const string AllowSelfSigned = "allow_self_signed";
-        public const string AllowHostnameMismatch = "allow_hostname_mismatch";
+        public const string PollTimeMilliseconds = 
"adbc.apache.statement.polltime_ms";
+        public const string BatchSize = "adbc.apache.statement.batch_size";
+        public const string QueryTimeoutSeconds = 
"adbc.apache.statement.query_timeout_s";
     }
 }
diff --git a/csharp/src/Drivers/Apache/ApacheUtility.cs 
b/csharp/src/Drivers/Apache/ApacheUtility.cs
new file mode 100644
index 000000000..f1cb07e07
--- /dev/null
+++ b/csharp/src/Drivers/Apache/ApacheUtility.cs
@@ -0,0 +1,141 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Threading;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache
+{
+    internal class ApacheUtility
+    {
+        internal const int QueryTimeoutSecondsDefault = 60;
+
+        public enum TimeUnit
+        {
+            Seconds,
+            Milliseconds
+        }
+
+        public static CancellationToken GetCancellationToken(int timeout, 
TimeUnit timeUnit)
+        {
+            TimeSpan span;
+
+            if (timeout == 0 || timeout == int.MaxValue)
+            {
+                // the max TimeSpan for CancellationTokenSource is 
int.MaxValue in milliseconds (not TimeSpan.MaxValue)
+                // no matter what the unit is
+                span = TimeSpan.FromMilliseconds(int.MaxValue);
+            }
+            else
+            {
+                if (timeUnit == TimeUnit.Seconds)
+                {
+                    span = TimeSpan.FromSeconds(timeout);
+                }
+                else
+                {
+                    span = TimeSpan.FromMilliseconds(timeout);
+                }
+            }
+
+            return GetCancellationToken(span);
+        }
+
+        private static CancellationToken GetCancellationToken(TimeSpan 
timeSpan)
+        {
+            var cts = new CancellationTokenSource(timeSpan);
+            return cts.Token;
+        }
+
+        public static bool QueryTimeoutIsValid(string key, string value, out 
int queryTimeoutSeconds)
+        {
+            if (!string.IsNullOrEmpty(value) && int.TryParse(value, out int 
queryTimeout) && (queryTimeout >= 0))
+            {
+                queryTimeoutSeconds = queryTimeout;
+                return true;
+            }
+            else
+            {
+                throw new ArgumentOutOfRangeException(key, value, $"The value 
'{value}' for option '{key}' is invalid. Must be a numeric value of 0 
(infinite) or greater.");
+            }
+        }
+
+        public static bool ContainsException<T>(Exception exception, out T? 
containedException) where T : Exception
+        {
+            if (exception is AggregateException aggregateException)
+            {
+                foreach (Exception? ex in aggregateException.InnerExceptions)
+                {
+                    if (ex is T ce)
+                    {
+                        containedException = ce;
+                        return true;
+                    }
+                }
+            }
+
+            Exception? e = exception;
+            while (e != null)
+            {
+                if (e is T ce)
+                {
+                    containedException = ce;
+                    return true;
+                }
+                e = e.InnerException;
+            }
+
+            containedException = null;
+            return false;
+        }
+
+        public static bool ContainsException(Exception exception, Type? 
exceptionType, out Exception? containedException)
+        {
+            if (exception == null || exceptionType == null)
+            {
+                containedException = null;
+                return false;
+            }
+
+            if (exception is AggregateException aggregateException)
+            {
+                foreach (Exception? ex in aggregateException.InnerExceptions)
+                {
+                    if (exceptionType.IsInstanceOfType(ex))
+                    {
+                        containedException = ex;
+                        return true;
+                    }
+                }
+            }
+
+            Exception? e = exception;
+            while (e != null)
+            {
+                if (exceptionType.IsInstanceOfType(e))
+                {
+                    containedException = e;
+                    return true;
+                }
+                e = e.InnerException;
+            }
+
+            containedException = null;
+            return false;
+        }
+    }
+}
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
index c839bbaa7..d420edb2b 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
@@ -30,7 +30,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
     {
         internal const long BatchSizeDefault = 50000;
         internal const int PollTimeMillisecondsDefault = 500;
-
+        private const int ConnectTimeoutMillisecondsDefault = 30000;
         private TTransport? _transport;
         private TCLIService.Client? _client;
         private readonly Lazy<string> _vendorVersion;
@@ -45,6 +45,14 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             // 
https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects
             _vendorVersion = new Lazy<string>(() => 
GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), 
LazyThreadSafetyMode.PublicationOnly);
             _vendorName = new Lazy<string>(() => 
GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), 
LazyThreadSafetyMode.PublicationOnly);
+
+            if (properties.TryGetValue(ApacheParameters.QueryTimeoutSeconds, 
out string? queryTimeoutSecondsSettingValue))
+            {
+                if 
(ApacheUtility.QueryTimeoutIsValid(ApacheParameters.QueryTimeoutSeconds, 
queryTimeoutSecondsSettingValue, out int queryTimeoutSeconds))
+                {
+                    QueryTimeoutSeconds = queryTimeoutSeconds;
+                }
+            }
         }
 
         internal TCLIService.Client Client
@@ -56,30 +64,48 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 
         internal string VendorName => _vendorName.Value;
 
+        protected internal int QueryTimeoutSeconds { get; set; } = 
ApacheUtility.QueryTimeoutSecondsDefault;
+
         internal IReadOnlyDictionary<string, string> Properties { get; }
 
         internal async Task OpenAsync()
         {
-            TTransport transport = await CreateTransportAsync();
-            TProtocol protocol = await CreateProtocolAsync(transport);
-            _transport = protocol.Transport;
-            _client = new TCLIService.Client(protocol);
-            TOpenSessionReq request = CreateSessionRequest();
-            TOpenSessionResp? session = await Client.OpenSession(request);
-
-            // Some responses don't raise an exception. Explicitly check the 
status.
-            if (session == null)
+            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(ConnectTimeoutMilliseconds, 
ApacheUtility.TimeUnit.Milliseconds);
+            try
             {
-                throw new HiveServer2Exception("unable to open session. 
unknown error.");
+                TTransport transport = CreateTransport();
+                TProtocol protocol = await CreateProtocolAsync(transport, 
cancellationToken);
+                _transport = protocol.Transport;
+                _client = new TCLIService.Client(protocol);
+                TOpenSessionReq request = CreateSessionRequest();
+
+                TOpenSessionResp? session = await Client.OpenSession(request, 
cancellationToken);
+
+                // Explicitly check the session status
+                if (session == null)
+                {
+                    throw new HiveServer2Exception("Unable to open session. 
Unknown error.");
+                }
+                else if (session.Status.StatusCode != 
TStatusCode.SUCCESS_STATUS)
+                {
+                    throw new HiveServer2Exception(session.Status.ErrorMessage)
+                        .SetNativeError(session.Status.ErrorCode)
+                        .SetSqlState(session.Status.SqlState);
+                }
+
+                SessionHandle = session.SessionHandle;
             }
-            else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS)
+            catch (Exception ex)
+                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
             {
-                throw new HiveServer2Exception(session.Status.ErrorMessage)
-                    .SetNativeError(session.Status.ErrorCode)
-                    .SetSqlState(session.Status.SqlState);
+                throw new TimeoutException("The operation timed out while 
attempting to open a session. Please try increasing connect timeout.", ex);
+            }
+            catch (Exception ex) when (ex is not HiveServer2Exception)
+            {
+                // Handle other exceptions if necessary
+                throw new HiveServer2Exception($"An unexpected error occurred 
while opening the session. '{ex.Message}'", ex);
             }
-
-            SessionHandle = session.SessionHandle;
         }
 
         internal TSessionHandle? SessionHandle { get; private set; }
@@ -88,11 +114,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 
         protected internal HiveServer2TlsOption TlsOptions { get; set; } = 
HiveServer2TlsOption.Empty;
 
-        protected internal int HttpRequestTimeout { get; set; } = 30000;
+        protected internal int ConnectTimeoutMilliseconds { get; set; } = 
ConnectTimeoutMillisecondsDefault;
 
-        protected abstract Task<TTransport> CreateTransportAsync();
+        protected abstract TTransport CreateTransport();
 
-        protected abstract Task<TProtocol> CreateProtocolAsync(TTransport 
transport);
+        protected abstract Task<TProtocol> CreateProtocolAsync(TTransport 
transport, CancellationToken cancellationToken = default);
 
         protected abstract TOpenSessionReq CreateSessionRequest();
 
@@ -110,14 +136,14 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             throw new NotImplementedException();
         }
 
-        internal static async Task PollForResponseAsync(TOperationHandle 
operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds)
+        internal static async Task PollForResponseAsync(TOperationHandle 
operationHandle, TCLIService.IAsync client, int pollTimeMilliseconds, 
CancellationToken cancellationToken = default)
         {
             TGetOperationStatusResp? statusResponse = null;
             do
             {
-                if (statusResponse != null) { await 
Task.Delay(pollTimeMilliseconds); }
+                if (statusResponse != null) { await 
Task.Delay(pollTimeMilliseconds, cancellationToken); }
                 TGetOperationStatusReq request = new(operationHandle);
-                statusResponse = await client.GetOperationStatus(request);
+                statusResponse = await client.GetOperationStatus(request, 
cancellationToken);
             } while (statusResponse.OperationState == 
TOperationState.PENDING_STATE || statusResponse.OperationState == 
TOperationState.RUNNING_STATE);
         }
 
@@ -129,24 +155,38 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                 InfoType = infoType,
             };
 
-            TGetInfoResp getInfoResp = Client.GetInfo(req).Result;
-            if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
+            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            try
             {
-                throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage)
-                    .SetNativeError(getInfoResp.Status.ErrorCode)
-                    .SetSqlState(getInfoResp.Status.SqlState);
+                TGetInfoResp getInfoResp = Client.GetInfo(req, 
cancellationToken).Result;
+                if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
+                {
+                    throw new 
HiveServer2Exception(getInfoResp.Status.ErrorMessage)
+                        .SetNativeError(getInfoResp.Status.ErrorCode)
+                        .SetSqlState(getInfoResp.Status.SqlState);
+                }
+
+                return getInfoResp.InfoValue.StringValue;
+            }
+            catch (Exception ex)
+                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+            {
+                throw new TimeoutException("The metadata query execution timed 
out. Consider increasing the query timeout value.", ex);
+            }
+            catch (Exception ex) when (ex is not HiveServer2Exception)
+            {
+                throw new HiveServer2Exception($"An unexpected error occurred 
while running metadata query. '{ex.Message}'", ex);
             }
-
-            return getInfoResp.InfoValue.StringValue;
         }
 
         public override void Dispose()
         {
             if (_client != null)
             {
-                TCloseSessionReq r6 = new TCloseSessionReq(SessionHandle);
-                _client.CloseSession(r6).Wait();
-
+                CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+                TCloseSessionReq r6 = new(SessionHandle);
+                _client.CloseSession(r6, cancellationToken).Wait();
                 _transport?.Close();
                 _client.Dispose();
                 _transport = null;
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
index 2170cd17b..4f2bc62d2 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs
@@ -15,8 +15,6 @@
  * limitations under the License.
  */
 
-using System;
-
 namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 {
     public static class DataTypeConversionOptions
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
index 08b0675d0..34dbf10f2 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
@@ -25,6 +25,7 @@ using System.Threading.Tasks;
 using Apache.Arrow.Ipc;
 using Apache.Arrow.Types;
 using Apache.Hive.Service.Rpc.Thrift;
+using Thrift.Transport;
 
 namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 {
@@ -89,19 +90,32 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                 return null;
             }
 
-            // Await the fetch response
-            TFetchResultsResp response = await FetchNext(_statement, 
cancellationToken);
+            try
+            {
+                // Await the fetch response
+                TFetchResultsResp response = await FetchNext(_statement, 
cancellationToken);
+
+                int columnCount = GetColumnCount(response);
+                int rowCount = GetRowCount(response, columnCount);
+                if ((_statement.BatchSize > 0 && rowCount < 
_statement.BatchSize) || rowCount == 0)
+                {
+                    // This is the last batch
+                    _statement = null;
+                }
 
-            int columnCount = GetColumnCount(response);
-            int rowCount = GetRowCount(response, columnCount);
-            if ((_statement.BatchSize > 0 && rowCount < _statement.BatchSize) 
|| rowCount == 0)
+                // Build the current batch, if any data exists
+                return rowCount > 0 ? CreateBatch(response, columnCount, 
rowCount) : null;
+            }
+            catch (Exception ex)
+                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
             {
-                // This is the last batch
-                _statement = null;
+                throw new TimeoutException("The query execution timed out. 
Consider increasing the query timeout value.", ex);
+            }
+            catch (Exception ex) when (ex is not HiveServer2Exception)
+            {
+                throw new HiveServer2Exception($"An unexpected error occurred 
while fetching results. '{ex.Message}'", ex);
             }
-
-            // Build the current batch, if any data exists
-            return rowCount > 0 ? CreateBatch(response, columnCount, rowCount) 
: null;
         }
 
         private RecordBatch CreateBatch(TFetchResultsResp response, int 
columnCount, int rowCount)
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
index 824feceb9..06723e324 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
@@ -20,6 +20,7 @@ using System.Threading;
 using System.Threading.Tasks;
 using Apache.Arrow.Ipc;
 using Apache.Hive.Service.Rpc.Thrift;
+using Thrift.Transport;
 
 namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 {
@@ -32,33 +33,89 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 
         protected virtual void SetStatementProperties(TExecuteStatementReq 
statement)
         {
+            statement.QueryTimeout = QueryTimeoutSeconds;
         }
 
-        public override QueryResult ExecuteQuery() => 
ExecuteQueryAsync().AsTask().Result;
+        public override QueryResult ExecuteQuery()
+        {
+            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            try
+            {
+                return ExecuteQueryAsyncInternal(cancellationToken).Result;
+            }
+            catch (Exception ex)
+                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+            {
+                throw new TimeoutException("The query execution timed out. 
Consider increasing the query timeout value.", ex);
+            }
+            catch (Exception ex) when (ex is not HiveServer2Exception)
+            {
+                throw new HiveServer2Exception($"An unexpected error occurred 
while fetching results. '{ex.Message}'", ex);
+            }
+        }
 
-        public override UpdateResult ExecuteUpdate() => 
ExecuteUpdateAsync().Result;
+        public override UpdateResult ExecuteUpdate()
+        {
+            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            try
+            {
+                return ExecuteUpdateAsyncInternal(cancellationToken).Result;
+            }
+            catch (Exception ex)
+                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+            {
+                throw new TimeoutException("The query execution timed out. 
Consider increasing the query timeout value.", ex);
+            }
+            catch (Exception ex) when (ex is not HiveServer2Exception)
+            {
+                throw new HiveServer2Exception($"An unexpected error occurred 
while fetching results. '{ex.Message}'", ex);
+            }
+        }
 
-        public override async ValueTask<QueryResult> ExecuteQueryAsync()
+        private async Task<QueryResult> 
ExecuteQueryAsyncInternal(CancellationToken cancellationToken = default)
         {
-            await ExecuteStatementAsync();
-            await HiveServer2Connection.PollForResponseAsync(OperationHandle!, 
Connection.Client, PollTimeMilliseconds);
-            Schema schema = await GetResultSetSchemaAsync(OperationHandle!, 
Connection.Client);
+            // this could either:
+            // take QueryTimeoutSeconds * 3
+            // OR
+            // take QueryTimeoutSeconds (but this could be restricting)
+            await ExecuteStatementAsync(cancellationToken); // --> get 
QueryTimeout +
+            await HiveServer2Connection.PollForResponseAsync(OperationHandle!, 
Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to 
QueryTimeout
+            Schema schema = await GetResultSetSchemaAsync(OperationHandle!, 
Connection.Client, cancellationToken); // + get the result, up to QueryTimeout
 
-            // TODO: Ensure this is set dynamically based on server 
capabilities
             return new QueryResult(-1, Connection.NewReader(this, schema));
         }
 
+        public override async ValueTask<QueryResult> ExecuteQueryAsync()
+        {
+            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            try
+            {
+                return await ExecuteQueryAsyncInternal(cancellationToken);
+            }
+            catch (Exception ex)
+                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+            {
+                throw new TimeoutException("The query execution timed out. 
Consider increasing the query timeout value.", ex);
+            }
+            catch (Exception ex) when (ex is not HiveServer2Exception)
+            {
+                throw new HiveServer2Exception($"An unexpected error occurred 
while fetching results. '{ex.Message}'", ex);
+            }
+        }
+
         private async Task<Schema> GetResultSetSchemaAsync(TOperationHandle 
operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken 
= default)
         {
             TGetResultSetMetadataResp response = await 
HiveServer2Connection.GetResultSetMetadataAsync(operationHandle, client, 
cancellationToken);
             return Connection.SchemaParser.GetArrowSchema(response.Schema, 
Connection.DataTypeConversion);
         }
 
-        public override async Task<UpdateResult> ExecuteUpdateAsync()
+        public async Task<UpdateResult> 
ExecuteUpdateAsyncInternal(CancellationToken cancellationToken = default)
         {
             const string NumberOfAffectedRowsColumnName = "num_affected_rows";
-
-            QueryResult queryResult = await ExecuteQueryAsync();
+            QueryResult queryResult = await 
ExecuteQueryAsyncInternal(cancellationToken);
             if (queryResult.Stream == null)
             {
                 throw new AdbcException("no data found");
@@ -79,7 +136,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             long? affectedRows = null;
             while (true)
             {
-                using RecordBatch nextBatch = await 
stream.ReadNextRecordBatchAsync();
+                using RecordBatch nextBatch = await 
stream.ReadNextRecordBatchAsync(cancellationToken);
                 if (nextBatch == null) { break; }
                 Int64Array numOfModifiedArray = 
(Int64Array)nextBatch.Column(NumberOfAffectedRowsColumnName);
                 // Note: should only have one item, but iterate for 
completeness
@@ -94,26 +151,51 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             return new UpdateResult(affectedRows ?? -1);
         }
 
+        public override async Task<UpdateResult> ExecuteUpdateAsync()
+        {
+            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            try
+            {
+                return await ExecuteUpdateAsyncInternal(cancellationToken);
+            }
+            catch (Exception ex)
+                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+            {
+                throw new TimeoutException("The query execution timed out. 
Consider increasing the query timeout value.", ex);
+            }
+            catch (Exception ex) when (ex is not HiveServer2Exception)
+            {
+                throw new HiveServer2Exception($"An unexpected error occurred 
while fetching results. '{ex.Message}'", ex);
+            }
+        }
+
         public override void SetOption(string key, string value)
         {
             switch (key)
             {
-                case Options.PollTimeMilliseconds:
+                case ApacheParameters.PollTimeMilliseconds:
                     UpdatePollTimeIfValid(key, value);
                     break;
-                case Options.BatchSize:
+                case ApacheParameters.BatchSize:
                     UpdateBatchSizeIfValid(key, value);
                     break;
+                case ApacheParameters.QueryTimeoutSeconds:
+                    if (ApacheUtility.QueryTimeoutIsValid(key, value, out int 
queryTimeoutSeconds))
+                    {
+                        QueryTimeoutSeconds = queryTimeoutSeconds;
+                    }
+                    break;
                 default:
                     throw AdbcException.NotImplemented($"Option '{key}' is not 
implemented.");
             }
         }
 
-        protected async Task ExecuteStatementAsync()
+        protected async Task ExecuteStatementAsync(CancellationToken 
cancellationToken = default)
         {
             TExecuteStatementReq executeRequest = new 
TExecuteStatementReq(Connection.SessionHandle, SqlQuery);
             SetStatementProperties(executeRequest);
-            TExecuteStatementResp executeResponse = await 
Connection.Client.ExecuteStatement(executeRequest);
+            TExecuteStatementResp executeResponse = await 
Connection.Client.ExecuteStatement(executeRequest, cancellationToken);
             if (executeResponse.Status.StatusCode == TStatusCode.ERROR_STATUS)
             {
                 throw new 
HiveServer2Exception(executeResponse.Status.ErrorMessage)
@@ -127,23 +209,20 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 
         protected internal long BatchSize { get; private set; } = 
HiveServer2Connection.BatchSizeDefault;
 
+        protected internal int QueryTimeoutSeconds
+        {
+            // Coordinate updates with the connection
+            get => Connection.QueryTimeoutSeconds;
+            set => Connection.QueryTimeoutSeconds = value;
+        }
+
         public HiveServer2Connection Connection { get; private set; }
 
         public TOperationHandle? OperationHandle { get; private set; }
 
-        /// <summary>
-        /// Provides the constant string key values to the <see 
cref="AdbcStatement.SetOption(string, string)" /> method.
-        /// </summary>
-        public class Options
-        {
-            // Options common to all HiveServer2Statement-derived drivers go 
here
-            public const string PollTimeMilliseconds = 
"adbc.statement.polltime_milliseconds";
-            public const string BatchSize = "adbc.statement.batch_size";
-        }
-
         private void UpdatePollTimeIfValid(string key, string value) => 
PollTimeMilliseconds = !string.IsNullOrEmpty(key) && int.TryParse(value, 
result: out int pollTimeMilliseconds) && pollTimeMilliseconds >= 0
             ? pollTimeMilliseconds
-            : throw new ArgumentOutOfRangeException(key, value, $"The value 
'{value}' for option '{key}' is invalid. Must be a numeric value greater than 
or equal to -1.");
+            : throw new ArgumentOutOfRangeException(key, value, $"The value 
'{value}' for option '{key}' is invalid. Must be a numeric value greater than 
or equal to 0.");
 
         private void UpdateBatchSizeIfValid(string key, string value) => 
BatchSize = !string.IsNullOrEmpty(value) && long.TryParse(value, out long 
batchSize) && batchSize > 0
             ? batchSize
@@ -153,8 +232,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
         {
             if (OperationHandle != null)
             {
+                CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
                 TCloseOperationReq request = new 
TCloseOperationReq(OperationHandle);
-                Connection.Client.CloseOperation(request).Wait();
+                Connection.Client.CloseOperation(request, 
cancellationToken).Wait();
                 OperationHandle = null;
             }
 
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs 
b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
index c6c6cc796..0e673c7c4 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
@@ -40,7 +40,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
         {
         }
 
-        protected override Task<TTransport> CreateTransportAsync()
+        protected override TTransport CreateTransport()
         {
             string hostName = Properties["HostName"];
             string? tmp;
@@ -52,10 +52,10 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
 
             TConfiguration config = new TConfiguration();
             TTransport transport = new ThriftSocketTransport(hostName, port, 
config);
-            return Task.FromResult(transport);
+            return transport;
         }
 
-        protected override Task<TProtocol> CreateProtocolAsync(TTransport 
transport)
+        protected override Task<TProtocol> CreateProtocolAsync(TTransport 
transport, CancellationToken cancellationToken = default)
         {
             return Task.FromResult<TProtocol>(new TBinaryProtocol(transport));
         }
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs 
b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
index 0bd620ee9..f94ac3970 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs
@@ -30,7 +30,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
         /// <summary>
         /// Provides the constant string key values to the <see 
cref="AdbcStatement.SetOption(string, string)" /> method.
         /// </summary>
-        public new sealed class Options : HiveServer2Statement.Options
+        public sealed class Options : ApacheParameters
         {
             // options specific to Impala go here
         }
diff --git a/csharp/src/Drivers/Apache/Spark/README.md 
b/csharp/src/Drivers/Apache/Spark/README.md
index 7d1f8b560..3b5a0e79e 100644
--- a/csharp/src/Drivers/Apache/Spark/README.md
+++ b/csharp/src/Drivers/Apache/Spark/README.md
@@ -37,9 +37,18 @@ but can also be passed in the call to `AdbcDatabase.Connect`.
 | `password`             | The password for the user name used for basic 
authentication. | |
 | `adbc.spark.data_type_conv` | Comma-separated list of data conversion 
options. Each option indicates the type of conversion to perform on data 
returned from the Spark server. <br><br>Allowed values: `none`, `scalar`. 
<br><br>Option `none` indicates there is no conversion from Spark type to 
native type (i.e., no conversion from String to Timestamp for Apache Spark over 
HTTP). Example `adbc.spark.conv_data_type=none`. <br><br>Option `scalar` will 
perform conversion (if necessary) from th [...]
 | `adbc.spark.tls_options` | Comma-separated list of TLS/SSL options. Each 
option indicates the TLS/SSL option when connecting to a Spark server. 
<br><br>Allowed values: `allow_self_signed`, `allow_hostname_mismatch`. 
<br><br>Option `allow_self_signed` allows certificate errors due to an unknown 
certificate authority, typically when using a self-signed certificate. Option 
`allow_hostname_mismatch` allow certificate errors due to a mismatch of the 
hostname. (e.g., when connecting through  [...]
-| `adbc.spark.http_request_timeout_ms` | Sets the timeout (in milliseconds) 
when making requests to the Spark server (type: `http`). Set the value higher 
than the default if you notice errors due to network timeouts. | `30000` |
-| `adbc.statement.batch_size` | Sets the maximum number of rows to retrieve in 
a single batch request. | `50000` |
-| `adbc.statement.polltime_milliseconds` | If polling is necessary to get a 
result, this option sets the length of time (in milliseconds) to wait between 
polls. | `500` |
+| `adbc.spark.connect_timeout_ms` | Sets the timeout (in milliseconds) to open 
a new session. Values can be 0 (infinite) or greater than zero. | `30000` |
+| `adbc.apache.statement.batch_size` | Sets the maximum number of rows to 
retrieve in a single batch request. | `50000` |
+| `adbc.apache.statement.polltime_ms` | If polling is necessary to get a 
result, this option sets the length of time (in milliseconds) to wait between 
polls. | `500` |
+| `adbc.apache.statement.query_timeout_s` | Sets the maximum time (in seconds) 
for a query to complete. Values can be 0 (infinite) or greater than zero. | 
`60` |
+
+## Timeout Configuration
+
+Timeouts have a hierarchy to their behavior. As specified above, the 
`adbc.spark.connect_timeout_ms` is analogous to a ConnectTimeout and used to 
initially establish a new session with the server.
+
+The `adbc.apache.statement.query_timeout_s` is analogous to a CommandTimeout 
for any subsequent calls to the server for requests, including metadata calls 
and executing queries.
+
+The `adbc.apache.statement.polltime_ms` specifies the time between polls to 
the service, up to the limit specifed by 
`adbc.apache.statement.query_timeout_s`.
 
 ## Spark Types
 
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs 
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index f532369e6..b3c0c56ba 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -19,8 +19,6 @@ using System;
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Linq;
-using System.Net;
-using System.Net.Http;
 using System.Reflection;
 using System.Text;
 using System.Text.RegularExpressions;
@@ -32,6 +30,7 @@ using Apache.Arrow.Adbc.Extensions;
 using Apache.Arrow.Ipc;
 using Apache.Arrow.Types;
 using Apache.Hive.Service.Rpc.Thrift;
+using Thrift.Transport;
 
 namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
 {
@@ -420,26 +419,42 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
                 SessionHandle = SessionHandle ?? throw new 
InvalidOperationException("session not created"),
                 GetDirectResults = sparkGetDirectResults
             };
-            TGetTableTypesResp resp = Client.GetTableTypes(req).Result;
-            if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS)
+
+            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            try
             {
-                throw new HiveServer2Exception(resp.Status.ErrorMessage)
-                    .SetNativeError(resp.Status.ErrorCode)
-                    .SetSqlState(resp.Status.SqlState);
-            }
+                TGetTableTypesResp resp = Client.GetTableTypes(req, 
cancellationToken).Result;
 
-            TRowSet rowSet = GetRowSetAsync(resp).Result;
-            StringArray tableTypes = rowSet.Columns[0].StringVal.Values;
+                if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS)
+                {
+                    throw new HiveServer2Exception(resp.Status.ErrorMessage)
+                        .SetNativeError(resp.Status.ErrorCode)
+                        .SetSqlState(resp.Status.SqlState);
+                }
 
-            StringArray.Builder tableTypesBuilder = new StringArray.Builder();
-            tableTypesBuilder.AppendRange(tableTypes);
+                TRowSet rowSet = GetRowSetAsync(resp, 
cancellationToken).Result;
+                StringArray tableTypes = rowSet.Columns[0].StringVal.Values;
 
-            IArrowArray[] dataArrays = new IArrowArray[]
-            {
+                StringArray.Builder tableTypesBuilder = new 
StringArray.Builder();
+                tableTypesBuilder.AppendRange(tableTypes);
+
+                IArrowArray[] dataArrays = new IArrowArray[]
+                {
                 tableTypesBuilder.Build()
-            };
+                };
 
-            return new SparkInfoArrowStream(StandardSchemas.TableTypesSchema, 
dataArrays);
+                return new 
SparkInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays);
+            }
+            catch (Exception ex)
+                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+            {
+                throw new TimeoutException("The metadata query execution timed 
out. Consider increasing the query timeout value.", ex);
+            }
+            catch (Exception ex) when (ex is not HiveServer2Exception)
+            {
+                throw new HiveServer2Exception($"An unexpected error occurred 
while running metadata query. '{ex.Message}'", ex);
+            }
         }
 
         public override Schema GetTableSchema(string? catalog, string? 
dbSchema, string? tableName)
@@ -450,221 +465,248 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             getColumnsReq.TableName = tableName;
             getColumnsReq.GetDirectResults = sparkGetDirectResults;
 
-            var columnsResponse = Client.GetColumns(getColumnsReq).Result;
-            if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS)
+            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            try
             {
-                throw new Exception(columnsResponse.Status.ErrorMessage);
-            }
+                var columnsResponse = Client.GetColumns(getColumnsReq, 
cancellationToken).Result;
+                if (columnsResponse.Status.StatusCode == 
TStatusCode.ERROR_STATUS)
+                {
+                    throw new Exception(columnsResponse.Status.ErrorMessage);
+                }
 
-            TRowSet rowSet = GetRowSetAsync(columnsResponse).Result;
-            List<TColumn> columns = rowSet.Columns;
-            int rowCount = rowSet.Columns[3].StringVal.Values.Length;
+                TRowSet rowSet = GetRowSetAsync(columnsResponse, 
cancellationToken).Result;
+                List<TColumn> columns = rowSet.Columns;
+                int rowCount = rowSet.Columns[3].StringVal.Values.Length;
 
-            Field[] fields = new Field[rowCount];
-            for (int i = 0; i < rowCount; i++)
+                Field[] fields = new Field[rowCount];
+                for (int i = 0; i < rowCount; i++)
+                {
+                    string columnName = 
columns[3].StringVal.Values.GetString(i);
+                    int? columnType = columns[4].I32Val.Values.GetValue(i);
+                    string typeName = columns[5].StringVal.Values.GetString(i);
+                    // Note: the following two columns do not seem to be set 
correctly for DECIMAL types.
+                    //int? columnSize = columns[6].I32Val.Values.GetValue(i);
+                    //int? decimalDigits = 
columns[8].I32Val.Values.GetValue(i);
+                    bool nullable = columns[10].I32Val.Values.GetValue(i) == 1;
+                    IArrowType dataType = 
SparkConnection.GetArrowType(columnType!.Value, typeName);
+                    fields[i] = new Field(columnName, dataType, nullable);
+                }
+                return new Schema(fields, null);
+            }
+            catch (Exception ex)
+                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
             {
-                string columnName = columns[3].StringVal.Values.GetString(i);
-                int? columnType = columns[4].I32Val.Values.GetValue(i);
-                string typeName = columns[5].StringVal.Values.GetString(i);
-                // Note: the following two columns do not seem to be set 
correctly for DECIMAL types.
-                //int? columnSize = columns[6].I32Val.Values.GetValue(i);
-                //int? decimalDigits = columns[8].I32Val.Values.GetValue(i);
-                bool nullable = columns[10].I32Val.Values.GetValue(i) == 1;
-                IArrowType dataType = 
SparkConnection.GetArrowType(columnType!.Value, typeName);
-                fields[i] = new Field(columnName, dataType, nullable);
+                throw new TimeoutException("The metadata query execution timed 
out. Consider increasing the query timeout value.", ex);
+            }
+            catch (Exception ex) when (ex is not HiveServer2Exception)
+            {
+                throw new HiveServer2Exception($"An unexpected error occurred 
while running metadata query. '{ex.Message}'", ex);
             }
-            return new Schema(fields, null);
         }
 
         public override IArrowArrayStream GetObjects(GetObjectsDepth depth, 
string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, 
IReadOnlyList<string>? tableTypes, string? columnNamePattern)
         {
-            Trace.TraceError($"getting objects with depth={depth.ToString()}, 
catalog = {catalogPattern}, dbschema = {dbSchemaPattern}, tablename = 
{tableNamePattern}");
-
             Dictionary<string, Dictionary<string, Dictionary<string, 
TableInfo>>> catalogMap = new Dictionary<string, Dictionary<string, 
Dictionary<string, TableInfo>>>();
-            if (depth == GetObjectsDepth.All || depth >= 
GetObjectsDepth.Catalogs)
+            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            try
             {
-                TGetCatalogsReq getCatalogsReq = new 
TGetCatalogsReq(SessionHandle);
-                getCatalogsReq.GetDirectResults = sparkGetDirectResults;
-
-                TGetCatalogsResp getCatalogsResp = 
Client.GetCatalogs(getCatalogsReq).Result;
-                if (getCatalogsResp.Status.StatusCode == 
TStatusCode.ERROR_STATUS)
+                if (depth == GetObjectsDepth.All || depth >= 
GetObjectsDepth.Catalogs)
                 {
-                    throw new Exception(getCatalogsResp.Status.ErrorMessage);
-                }
-                var catalogsMetadata = 
GetResultSetMetadataAsync(getCatalogsResp).Result;
-                IReadOnlyDictionary<string, int> columnMap = 
GetColumnIndexMap(catalogsMetadata.Schema.Columns);
+                    TGetCatalogsReq getCatalogsReq = new 
TGetCatalogsReq(SessionHandle);
+                    getCatalogsReq.GetDirectResults = sparkGetDirectResults;
 
-                string catalogRegexp = PatternToRegEx(catalogPattern);
-                TRowSet rowSet = GetRowSetAsync(getCatalogsResp).Result;
-                IReadOnlyList<string> list = 
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
-                for (int i = 0; i < list.Count; i++)
-                {
-                    string col = list[i];
-                    string catalog = col;
+                    TGetCatalogsResp getCatalogsResp = 
Client.GetCatalogs(getCatalogsReq, cancellationToken).Result;
 
-                    if (Regex.IsMatch(catalog, catalogRegexp, 
RegexOptions.IgnoreCase))
+                    if (getCatalogsResp.Status.StatusCode == 
TStatusCode.ERROR_STATUS)
                     {
-                        catalogMap.Add(catalog, new Dictionary<string, 
Dictionary<string, TableInfo>>());
+                        throw new 
Exception(getCatalogsResp.Status.ErrorMessage);
                     }
-                }
-                // Handle the case where server does not support 'catalog' in 
the namespace.
-                if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern))
-                {
-                    catalogMap.Add(string.Empty, []);
-                }
-            }
+                    var catalogsMetadata = 
GetResultSetMetadataAsync(getCatalogsResp, cancellationToken).Result;
+                    IReadOnlyDictionary<string, int> columnMap = 
GetColumnIndexMap(catalogsMetadata.Schema.Columns);
 
-            if (depth == GetObjectsDepth.All || depth >= 
GetObjectsDepth.DbSchemas)
-            {
-                TGetSchemasReq getSchemasReq = new 
TGetSchemasReq(SessionHandle);
-                getSchemasReq.CatalogName = catalogPattern;
-                getSchemasReq.SchemaName = dbSchemaPattern;
-                getSchemasReq.GetDirectResults = sparkGetDirectResults;
+                    string catalogRegexp = PatternToRegEx(catalogPattern);
+                    TRowSet rowSet = GetRowSetAsync(getCatalogsResp, 
cancellationToken).Result;
+                    IReadOnlyList<string> list = 
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
+                    for (int i = 0; i < list.Count; i++)
+                    {
+                        string col = list[i];
+                        string catalog = col;
 
-                TGetSchemasResp getSchemasResp = 
Client.GetSchemas(getSchemasReq).Result;
-                if (getSchemasResp.Status.StatusCode == 
TStatusCode.ERROR_STATUS)
-                {
-                    throw new Exception(getSchemasResp.Status.ErrorMessage);
+                        if (Regex.IsMatch(catalog, catalogRegexp, 
RegexOptions.IgnoreCase))
+                        {
+                            catalogMap.Add(catalog, new Dictionary<string, 
Dictionary<string, TableInfo>>());
+                        }
+                    }
+                    // Handle the case where server does not support 'catalog' 
in the namespace.
+                    if (list.Count == 0 && 
string.IsNullOrEmpty(catalogPattern))
+                    {
+                        catalogMap.Add(string.Empty, []);
+                    }
                 }
 
-                TGetResultSetMetadataResp schemaMetadata = 
GetResultSetMetadataAsync(getSchemasResp).Result;
-                IReadOnlyDictionary<string, int> columnMap = 
GetColumnIndexMap(schemaMetadata.Schema.Columns);
-                TRowSet rowSet = GetRowSetAsync(getSchemasResp).Result;
-
-                IReadOnlyList<string> catalogList = 
rowSet.Columns[columnMap[TableCatalog]].StringVal.Values;
-                IReadOnlyList<string> schemaList = 
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
-
-                for (int i = 0; i < catalogList.Count; i++)
+                if (depth == GetObjectsDepth.All || depth >= 
GetObjectsDepth.DbSchemas)
                 {
-                    string catalog = catalogList[i];
-                    string schemaDb = schemaList[i];
-                    // It seems Spark sometimes returns empty string for 
catalog on some schema (temporary tables).
-                    catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new 
Dictionary<string, TableInfo>());
-                }
-            }
+                    TGetSchemasReq getSchemasReq = new 
TGetSchemasReq(SessionHandle);
+                    getSchemasReq.CatalogName = catalogPattern;
+                    getSchemasReq.SchemaName = dbSchemaPattern;
+                    getSchemasReq.GetDirectResults = sparkGetDirectResults;
 
-            if (depth == GetObjectsDepth.All || depth >= 
GetObjectsDepth.Tables)
-            {
-                TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle);
-                getTablesReq.CatalogName = catalogPattern;
-                getTablesReq.SchemaName = dbSchemaPattern;
-                getTablesReq.TableName = tableNamePattern;
-                getTablesReq.GetDirectResults = sparkGetDirectResults;
-
-                TGetTablesResp getTablesResp = 
Client.GetTables(getTablesReq).Result;
-                if (getTablesResp.Status.StatusCode == 
TStatusCode.ERROR_STATUS)
-                {
-                    throw new Exception(getTablesResp.Status.ErrorMessage);
-                }
+                    TGetSchemasResp getSchemasResp = 
Client.GetSchemas(getSchemasReq, cancellationToken).Result;
+                    if (getSchemasResp.Status.StatusCode == 
TStatusCode.ERROR_STATUS)
+                    {
+                        throw new 
Exception(getSchemasResp.Status.ErrorMessage);
+                    }
 
-                TGetResultSetMetadataResp tableMetadata = 
GetResultSetMetadataAsync(getTablesResp).Result;
-                IReadOnlyDictionary<string, int> columnMap = 
GetColumnIndexMap(tableMetadata.Schema.Columns);
-                TRowSet rowSet = GetRowSetAsync(getTablesResp).Result;
+                    TGetResultSetMetadataResp schemaMetadata = 
GetResultSetMetadataAsync(getSchemasResp, cancellationToken).Result;
+                    IReadOnlyDictionary<string, int> columnMap = 
GetColumnIndexMap(schemaMetadata.Schema.Columns);
+                    TRowSet rowSet = GetRowSetAsync(getSchemasResp, 
cancellationToken).Result;
 
-                IReadOnlyList<string> catalogList = 
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
-                IReadOnlyList<string> schemaList = 
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
-                IReadOnlyList<string> tableList = 
rowSet.Columns[columnMap[TableName]].StringVal.Values;
-                IReadOnlyList<string> tableTypeList = 
rowSet.Columns[columnMap[TableType]].StringVal.Values;
+                    IReadOnlyList<string> catalogList = 
rowSet.Columns[columnMap[TableCatalog]].StringVal.Values;
+                    IReadOnlyList<string> schemaList = 
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
 
-                for (int i = 0; i < catalogList.Count; i++)
-                {
-                    string catalog = catalogList[i];
-                    string schemaDb = schemaList[i];
-                    string tableName = tableList[i];
-                    string tableType = tableTypeList[i];
-                    TableInfo tableInfo = new(tableType);
-                    
catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName,
 tableInfo);
+                    for (int i = 0; i < catalogList.Count; i++)
+                    {
+                        string catalog = catalogList[i];
+                        string schemaDb = schemaList[i];
+                        // It seems Spark sometimes returns empty string for 
catalog on some schema (temporary tables).
+                        catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, 
new Dictionary<string, TableInfo>());
+                    }
                 }
-            }
 
-            if (depth == GetObjectsDepth.All)
-            {
-                TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle);
-                columnsReq.CatalogName = catalogPattern;
-                columnsReq.SchemaName = dbSchemaPattern;
-                columnsReq.TableName = tableNamePattern;
-                columnsReq.GetDirectResults = sparkGetDirectResults;
+                if (depth == GetObjectsDepth.All || depth >= 
GetObjectsDepth.Tables)
+                {
+                    TGetTablesReq getTablesReq = new 
TGetTablesReq(SessionHandle);
+                    getTablesReq.CatalogName = catalogPattern;
+                    getTablesReq.SchemaName = dbSchemaPattern;
+                    getTablesReq.TableName = tableNamePattern;
+                    getTablesReq.GetDirectResults = sparkGetDirectResults;
+
+                    TGetTablesResp getTablesResp = 
Client.GetTables(getTablesReq, cancellationToken).Result;
+                    if (getTablesResp.Status.StatusCode == 
TStatusCode.ERROR_STATUS)
+                    {
+                        throw new Exception(getTablesResp.Status.ErrorMessage);
+                    }
 
-                if (!string.IsNullOrEmpty(columnNamePattern))
-                    columnsReq.ColumnName = columnNamePattern;
+                    TGetResultSetMetadataResp tableMetadata = 
GetResultSetMetadataAsync(getTablesResp, cancellationToken).Result;
+                    IReadOnlyDictionary<string, int> columnMap = 
GetColumnIndexMap(tableMetadata.Schema.Columns);
+                    TRowSet rowSet = GetRowSetAsync(getTablesResp, 
cancellationToken).Result;
 
-                var columnsResponse = Client.GetColumns(columnsReq).Result;
-                if (columnsResponse.Status.StatusCode == 
TStatusCode.ERROR_STATUS)
-                {
-                    throw new Exception(columnsResponse.Status.ErrorMessage);
+                    IReadOnlyList<string> catalogList = 
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
+                    IReadOnlyList<string> schemaList = 
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
+                    IReadOnlyList<string> tableList = 
rowSet.Columns[columnMap[TableName]].StringVal.Values;
+                    IReadOnlyList<string> tableTypeList = 
rowSet.Columns[columnMap[TableType]].StringVal.Values;
+
+                    for (int i = 0; i < catalogList.Count; i++)
+                    {
+                        string catalog = catalogList[i];
+                        string schemaDb = schemaList[i];
+                        string tableName = tableList[i];
+                        string tableType = tableTypeList[i];
+                        TableInfo tableInfo = new(tableType);
+                        
catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName,
 tableInfo);
+                    }
                 }
 
-                TGetResultSetMetadataResp columnsMetadata = 
GetResultSetMetadataAsync(columnsResponse).Result;
-                IReadOnlyDictionary<string, int> columnMap = 
GetColumnIndexMap(columnsMetadata.Schema.Columns);
-                TRowSet rowSet = GetRowSetAsync(columnsResponse).Result;
-
-                IReadOnlyList<string> catalogList = 
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
-                IReadOnlyList<string> schemaList = 
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
-                IReadOnlyList<string> tableList = 
rowSet.Columns[columnMap[TableName]].StringVal.Values;
-                IReadOnlyList<string> columnNameList = 
rowSet.Columns[columnMap[ColumnName]].StringVal.Values;
-                ReadOnlySpan<int> columnTypeList = 
rowSet.Columns[columnMap[DataType]].I32Val.Values.Values;
-                IReadOnlyList<string> typeNameList = 
rowSet.Columns[columnMap[TypeName]].StringVal.Values;
-                ReadOnlySpan<int> nullableList = 
rowSet.Columns[columnMap[Nullable]].I32Val.Values.Values;
-                IReadOnlyList<string> columnDefaultList = 
rowSet.Columns[columnMap[ColumnDef]].StringVal.Values;
-                ReadOnlySpan<int> ordinalPosList = 
rowSet.Columns[columnMap[OrdinalPosition]].I32Val.Values.Values;
-                IReadOnlyList<string> isNullableList = 
rowSet.Columns[columnMap[IsNullable]].StringVal.Values;
-                IReadOnlyList<string> isAutoIncrementList = 
rowSet.Columns[columnMap[IsAutoIncrement]].StringVal.Values;
-
-                for (int i = 0; i < catalogList.Count; i++)
+                if (depth == GetObjectsDepth.All)
                 {
-                    // For systems that don't support 'catalog' in the 
namespace
-                    string catalog = catalogList[i] ?? string.Empty;
-                    string schemaDb = schemaList[i];
-                    string tableName = tableList[i];
-                    string columnName = columnNameList[i];
-                    short colType = (short)columnTypeList[i];
-                    string typeName = typeNameList[i];
-                    short nullable = (short)nullableList[i];
-                    string? isAutoIncrementString = isAutoIncrementList[i];
-                    bool isAutoIncrement = 
(!string.IsNullOrEmpty(isAutoIncrementString) && 
(isAutoIncrementString.Equals("YES", 
StringComparison.InvariantCultureIgnoreCase) || 
isAutoIncrementString.Equals("TRUE", 
StringComparison.InvariantCultureIgnoreCase)));
-                    string isNullable = isNullableList[i] ?? "YES";
-                    string columnDefault = columnDefaultList[i] ?? "";
-                    // Spark/Databricks reports ordinal index zero-indexed, 
instead of one-indexed
-                    int ordinalPos = ordinalPosList[i] + 1;
-                    TableInfo? tableInfo = 
catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName);
-                    tableInfo?.ColumnName.Add(columnName);
-                    tableInfo?.ColType.Add(colType);
-                    tableInfo?.Nullable.Add(nullable);
-                    tableInfo?.IsAutoIncrement.Add(isAutoIncrement);
-                    tableInfo?.IsNullable.Add(isNullable);
-                    tableInfo?.ColumnDefault.Add(columnDefault);
-                    tableInfo?.OrdinalPosition.Add(ordinalPos);
-                    SetPrecisionScaleAndTypeName(colType, typeName, tableInfo);
-                }
-            }
+                    TGetColumnsReq columnsReq = new 
TGetColumnsReq(SessionHandle);
+                    columnsReq.CatalogName = catalogPattern;
+                    columnsReq.SchemaName = dbSchemaPattern;
+                    columnsReq.TableName = tableNamePattern;
+                    columnsReq.GetDirectResults = sparkGetDirectResults;
 
-            StringArray.Builder catalogNameBuilder = new StringArray.Builder();
-            List<IArrowArray?> catalogDbSchemasValues = new 
List<IArrowArray?>();
+                    if (!string.IsNullOrEmpty(columnNamePattern))
+                        columnsReq.ColumnName = columnNamePattern;
 
-            foreach (KeyValuePair<string, Dictionary<string, 
Dictionary<string, TableInfo>>> catalogEntry in catalogMap)
-            {
-                catalogNameBuilder.Append(catalogEntry.Key);
+                    var columnsResponse = Client.GetColumns(columnsReq, 
cancellationToken).Result;
+                    if (columnsResponse.Status.StatusCode == 
TStatusCode.ERROR_STATUS)
+                    {
+                        throw new 
Exception(columnsResponse.Status.ErrorMessage);
+                    }
 
-                if (depth == GetObjectsDepth.Catalogs)
-                {
-                    catalogDbSchemasValues.Add(null);
+                    TGetResultSetMetadataResp columnsMetadata = 
GetResultSetMetadataAsync(columnsResponse, cancellationToken).Result;
+                    IReadOnlyDictionary<string, int> columnMap = 
GetColumnIndexMap(columnsMetadata.Schema.Columns);
+                    TRowSet rowSet = GetRowSetAsync(columnsResponse, 
cancellationToken).Result;
+
+                    IReadOnlyList<string> catalogList = 
rowSet.Columns[columnMap[TableCat]].StringVal.Values;
+                    IReadOnlyList<string> schemaList = 
rowSet.Columns[columnMap[TableSchem]].StringVal.Values;
+                    IReadOnlyList<string> tableList = 
rowSet.Columns[columnMap[TableName]].StringVal.Values;
+                    IReadOnlyList<string> columnNameList = 
rowSet.Columns[columnMap[ColumnName]].StringVal.Values;
+                    ReadOnlySpan<int> columnTypeList = 
rowSet.Columns[columnMap[DataType]].I32Val.Values.Values;
+                    IReadOnlyList<string> typeNameList = 
rowSet.Columns[columnMap[TypeName]].StringVal.Values;
+                    ReadOnlySpan<int> nullableList = 
rowSet.Columns[columnMap[Nullable]].I32Val.Values.Values;
+                    IReadOnlyList<string> columnDefaultList = 
rowSet.Columns[columnMap[ColumnDef]].StringVal.Values;
+                    ReadOnlySpan<int> ordinalPosList = 
rowSet.Columns[columnMap[OrdinalPosition]].I32Val.Values.Values;
+                    IReadOnlyList<string> isNullableList = 
rowSet.Columns[columnMap[IsNullable]].StringVal.Values;
+                    IReadOnlyList<string> isAutoIncrementList = 
rowSet.Columns[columnMap[IsAutoIncrement]].StringVal.Values;
+
+                    for (int i = 0; i < catalogList.Count; i++)
+                    {
+                        // For systems that don't support 'catalog' in the 
namespace
+                        string catalog = catalogList[i] ?? string.Empty;
+                        string schemaDb = schemaList[i];
+                        string tableName = tableList[i];
+                        string columnName = columnNameList[i];
+                        short colType = (short)columnTypeList[i];
+                        string typeName = typeNameList[i];
+                        short nullable = (short)nullableList[i];
+                        string? isAutoIncrementString = isAutoIncrementList[i];
+                        bool isAutoIncrement = 
(!string.IsNullOrEmpty(isAutoIncrementString) && 
(isAutoIncrementString.Equals("YES", 
StringComparison.InvariantCultureIgnoreCase) || 
isAutoIncrementString.Equals("TRUE", 
StringComparison.InvariantCultureIgnoreCase)));
+                        string isNullable = isNullableList[i] ?? "YES";
+                        string columnDefault = columnDefaultList[i] ?? "";
+                        // Spark/Databricks reports ordinal index 
zero-indexed, instead of one-indexed
+                        int ordinalPos = ordinalPosList[i] + 1;
+                        TableInfo? tableInfo = 
catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName);
+                        tableInfo?.ColumnName.Add(columnName);
+                        tableInfo?.ColType.Add(colType);
+                        tableInfo?.Nullable.Add(nullable);
+                        tableInfo?.IsAutoIncrement.Add(isAutoIncrement);
+                        tableInfo?.IsNullable.Add(isNullable);
+                        tableInfo?.ColumnDefault.Add(columnDefault);
+                        tableInfo?.OrdinalPosition.Add(ordinalPos);
+                        SetPrecisionScaleAndTypeName(colType, typeName, 
tableInfo);
+                    }
                 }
-                else
+
+                StringArray.Builder catalogNameBuilder = new 
StringArray.Builder();
+                List<IArrowArray?> catalogDbSchemasValues = new 
List<IArrowArray?>();
+
+                foreach (KeyValuePair<string, Dictionary<string, 
Dictionary<string, TableInfo>>> catalogEntry in catalogMap)
                 {
-                    catalogDbSchemasValues.Add(GetDbSchemas(
-                                depth, catalogEntry.Value));
+                    catalogNameBuilder.Append(catalogEntry.Key);
+
+                    if (depth == GetObjectsDepth.Catalogs)
+                    {
+                        catalogDbSchemasValues.Add(null);
+                    }
+                    else
+                    {
+                        catalogDbSchemasValues.Add(GetDbSchemas(
+                                    depth, catalogEntry.Value));
+                    }
                 }
-            }
 
-            Schema schema = StandardSchemas.GetObjectsSchema;
-            IReadOnlyList<IArrowArray> dataArrays = schema.Validate(
-                new List<IArrowArray>
-                {
+                Schema schema = StandardSchemas.GetObjectsSchema;
+                IReadOnlyList<IArrowArray> dataArrays = schema.Validate(
+                    new List<IArrowArray>
+                    {
                     catalogNameBuilder.Build(),
                     catalogDbSchemasValues.BuildListArrayForType(new 
StructType(StandardSchemas.DbSchemaSchema)),
-                });
+                    });
 
-            return new SparkInfoArrowStream(schema, dataArrays);
+                return new SparkInfoArrowStream(schema, dataArrays);
+            }
+            catch (Exception ex)
+                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+            {
+                throw new TimeoutException("The metadata query execution timed 
out. Consider increasing the query timeout value.", ex);
+            }
+            catch (Exception ex) when (ex is not HiveServer2Exception)
+            {
+                throw new HiveServer2Exception($"An unexpected error occurred 
while running metadata query. '{ex.Message}'", ex);
+            }
         }
 
         private static IReadOnlyDictionary<string, int> 
GetColumnIndexMap(List<TColumnDesc> columns) => columns
@@ -998,15 +1040,15 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
         protected abstract void ValidateAuthentication();
         protected abstract void ValidateOptions();
 
-        protected abstract Task<TRowSet> GetRowSetAsync(TGetTableTypesResp 
response);
-        protected abstract Task<TRowSet> GetRowSetAsync(TGetColumnsResp 
response);
-        protected abstract Task<TRowSet> GetRowSetAsync(TGetTablesResp 
response);
-        protected abstract Task<TRowSet> GetRowSetAsync(TGetCatalogsResp 
getCatalogsResp);
-        protected abstract Task<TRowSet> GetRowSetAsync(TGetSchemasResp 
getSchemasResp);
-        protected abstract Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetSchemasResp response);
-        protected abstract Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetCatalogsResp response);
-        protected abstract Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetColumnsResp response);
-        protected abstract Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetTablesResp response);
+        protected abstract Task<TRowSet> GetRowSetAsync(TGetTableTypesResp 
response, CancellationToken cancellationToken = default);
+        protected abstract Task<TRowSet> GetRowSetAsync(TGetColumnsResp 
response, CancellationToken cancellationToken = default);
+        protected abstract Task<TRowSet> GetRowSetAsync(TGetTablesResp 
response, CancellationToken cancellationToken = default);
+        protected abstract Task<TRowSet> GetRowSetAsync(TGetCatalogsResp 
getCatalogsResp, CancellationToken cancellationToken = default);
+        protected abstract Task<TRowSet> GetRowSetAsync(TGetSchemasResp 
getSchemasResp, CancellationToken cancellationToken = default);
+        protected abstract Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken 
cancellationToken = default);
+        protected abstract Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken 
cancellationToken = default);
+        protected abstract Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken 
cancellationToken = default);
+        protected abstract Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken 
cancellationToken = default);
 
         internal abstract SparkServerType ServerType { get; }
 
diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs 
b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
index 7d187fc71..d51ef42b9 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
@@ -16,6 +16,7 @@
 */
 
 using System.Collections.Generic;
+using System.Threading;
 using System.Threading.Tasks;
 using Apache.Arrow.Ipc;
 using Apache.Hive.Service.Rpc.Thrift;
@@ -43,24 +44,24 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             return req;
         }
 
-        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetSchemasResp response) =>
+        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken 
cancellationToken = default) =>
             Task.FromResult(response.DirectResults.ResultSetMetadata);
-        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetCatalogsResp response) =>
+        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken 
cancellationToken = default) =>
             Task.FromResult(response.DirectResults.ResultSetMetadata);
-        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetColumnsResp response) =>
+        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken 
cancellationToken = default) =>
             Task.FromResult(response.DirectResults.ResultSetMetadata);
-        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetTablesResp response) =>
+        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken 
cancellationToken = default) =>
             Task.FromResult(response.DirectResults.ResultSetMetadata);
 
-        protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp 
response) =>
+        protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp 
response, CancellationToken cancellationToken = default) =>
             Task.FromResult(response.DirectResults.ResultSet.Results);
-        protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp 
response) =>
+        protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp 
response, CancellationToken cancellationToken = default) =>
             Task.FromResult(response.DirectResults.ResultSet.Results);
-        protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp 
response) =>
+        protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp 
response, CancellationToken cancellationToken = default) =>
             Task.FromResult(response.DirectResults.ResultSet.Results);
-        protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp 
response) =>
+        protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp 
response, CancellationToken cancellationToken = default) =>
             Task.FromResult(response.DirectResults.ResultSet.Results);
-        protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp 
response) =>
+        protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp 
response, CancellationToken cancellationToken = default) =>
             Task.FromResult(response.DirectResults.ResultSet.Results);
     }
 }
diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs 
b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
index 77ecdb6a2..059ab1690 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
@@ -15,7 +15,6 @@
 * limitations under the License.
 */
 
-using System;
 using System.Collections.Generic;
 using System.Threading;
 using System.Threading.Tasks;
diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs 
b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
index 9d34ac75c..4c068aaa5 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
@@ -120,24 +120,19 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv);
             Properties.TryGetValue(SparkParameters.TLSOptions, out string? 
tlsOptions);
             TlsOptions = TlsOptionsParser.Parse(tlsOptions);
-            
Properties.TryGetValue(SparkParameters.HttpRequestTimeoutMilliseconds, out 
string? requestTimeoutMs);
-            if (requestTimeoutMs != null)
+            Properties.TryGetValue(SparkParameters.ConnectTimeoutMilliseconds, 
out string? connectTimeoutMs);
+            if (connectTimeoutMs != null)
             {
-                HttpRequestTimeout = int.TryParse(requestTimeoutMs, 
NumberStyles.Integer, CultureInfo.InvariantCulture, out int 
requestTimeoutMsValue) && requestTimeoutMsValue > 0
-                    ? requestTimeoutMsValue
-                    : throw new 
ArgumentOutOfRangeException(SparkParameters.HttpRequestTimeoutMilliseconds, 
requestTimeoutMs, $"must be a value between 1 .. {int.MaxValue}. default is 
30000 milliseconds.");
+                ConnectTimeoutMilliseconds = int.TryParse(connectTimeoutMs, 
NumberStyles.Integer, CultureInfo.InvariantCulture, out int 
connectTimeoutMsValue) && (connectTimeoutMsValue >= 0)
+                    ? connectTimeoutMsValue
+                    : throw new 
ArgumentOutOfRangeException(SparkParameters.ConnectTimeoutMilliseconds, 
connectTimeoutMs, $"must be a value of 0 (infinite) or between 1 .. 
{int.MaxValue}. default is 30000 milliseconds.");
             }
         }
 
         internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: 
statement.Connection.DataTypeConversion);
 
-        protected override Task<TTransport> CreateTransportAsync()
+        protected override TTransport CreateTransport()
         {
-            foreach (var property in Properties.Keys)
-            {
-                Trace.TraceError($"key = {property} value = 
{Properties[property]}");
-            }
-
             // Assumption: parameters have already been validated.
             Properties.TryGetValue(SparkParameters.HostName, out string? 
hostName);
             Properties.TryGetValue(SparkParameters.Path, out string? path);
@@ -164,9 +159,12 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             TConfiguration config = new();
             ThriftHttpTransport transport = new(httpClient, config)
             {
-                ConnectTimeout = HttpRequestTimeout,
+                // This value can only be set before the first call/request. 
So if a new value for query timeout
+                // is set, we won't be able to update the value. Setting to 
~infinite and relying on cancellation token
+                // to ensure cancelled correctly.
+                ConnectTimeout = int.MaxValue,
             };
-            return Task.FromResult<TTransport>(transport);
+            return transport;
         }
 
         private HttpClientHandler NewHttpClientHandler()
@@ -211,11 +209,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             }
         }
 
-        protected override async Task<TProtocol> 
CreateProtocolAsync(TTransport transport)
+        protected override async Task<TProtocol> 
CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = 
default)
         {
-            Trace.TraceError($"create protocol with {Properties.Count} 
properties.");
-
-            if (!transport.IsOpen) await 
transport.OpenAsync(CancellationToken.None);
+            if (!transport.IsOpen) await 
transport.OpenAsync(cancellationToken);
             return new TBinaryProtocol(transport);
         }
 
@@ -228,28 +224,29 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             return req;
         }
 
-        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetSchemasResp response) =>
-            GetResultSetMetadataAsync(response.OperationHandle, Client);
-        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetCatalogsResp response) =>
-            GetResultSetMetadataAsync(response.OperationHandle, Client);
-        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetColumnsResp response) =>
-            GetResultSetMetadataAsync(response.OperationHandle, Client);
-        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetTablesResp response) =>
-            GetResultSetMetadataAsync(response.OperationHandle, Client);
-        protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp 
response) =>
-            FetchResultsAsync(response.OperationHandle);
-        protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp 
response) =>
-            FetchResultsAsync(response.OperationHandle);
-        protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp 
response) =>
-            FetchResultsAsync(response.OperationHandle);
-        protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp 
response) =>
-            FetchResultsAsync(response.OperationHandle);
-        protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp 
response) =>
-            FetchResultsAsync(response.OperationHandle);
+        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken 
cancellationToken = default) =>
+            GetResultSetMetadataAsync(response.OperationHandle, Client, 
cancellationToken);
+        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken 
cancellationToken = default) =>
+            GetResultSetMetadataAsync(response.OperationHandle, Client, 
cancellationToken);
+        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken 
cancellationToken = default) =>
+            GetResultSetMetadataAsync(response.OperationHandle, Client, 
cancellationToken);
+        protected override Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken 
cancellationToken = default) =>
+            GetResultSetMetadataAsync(response.OperationHandle, Client, 
cancellationToken);
+        protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp 
response, CancellationToken cancellationToken = default) =>
+            FetchResultsAsync(response.OperationHandle, cancellationToken: 
cancellationToken);
+        protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp 
response, CancellationToken cancellationToken = default) =>
+            FetchResultsAsync(response.OperationHandle, cancellationToken: 
cancellationToken);
+        protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp 
response, CancellationToken cancellationToken = default) =>
+            FetchResultsAsync(response.OperationHandle, cancellationToken: 
cancellationToken);
+        protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp 
response, CancellationToken cancellationToken = default) =>
+            FetchResultsAsync(response.OperationHandle, cancellationToken: 
cancellationToken);
+        protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp 
response, CancellationToken cancellationToken = default) =>
+            FetchResultsAsync(response.OperationHandle, cancellationToken: 
cancellationToken);
 
         private async Task<TRowSet> FetchResultsAsync(TOperationHandle 
operationHandle, long batchSize = BatchSizeDefault, CancellationToken 
cancellationToken = default)
         {
-            await PollForResponseAsync(operationHandle, Client, 
PollTimeMillisecondsDefault);
+            await PollForResponseAsync(operationHandle, Client, 
PollTimeMillisecondsDefault, cancellationToken);
+
             TFetchResultsResp fetchResp = await 
FetchNextAsync(operationHandle, Client, batchSize, cancellationToken);
             if (fetchResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
             {
diff --git a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs 
b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
index 4722efce5..6cb96dd5f 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
@@ -15,8 +15,6 @@
  * limitations under the License.
  */
 
-using static System.Net.WebRequestMethods;
-
 namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
 {
     /// <summary>
@@ -32,7 +30,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
         public const string Type = "adbc.spark.type";
         public const string DataTypeConv = "adbc.spark.data_type_conv";
         public const string TLSOptions = "adbc.spark.tls_options";
-        public const string HttpRequestTimeoutMilliseconds = 
"adbc.spark.http_request_timeout_ms";
+        public const string ConnectTimeoutMilliseconds = 
"adbc.spark.connect_timeout_ms";
     }
 
     public static class SparkAuthTypeConstants
diff --git a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs 
b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs
index 51813ed6c..c8ab5772c 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs
@@ -18,6 +18,7 @@
 using System;
 using System.Collections.Generic;
 using System.Net;
+using System.Threading;
 using System.Threading.Tasks;
 using Apache.Hive.Service.Rpc.Thrift;
 using Thrift.Protocol;
@@ -85,7 +86,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
 
         }
 
-        protected override Task<TTransport> CreateTransportAsync()
+        protected override TTransport CreateTransport()
         {
             // Assumption: hostName and port have already been validated.
             Properties.TryGetValue(SparkParameters.HostName, out string? 
hostName);
@@ -94,14 +95,13 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             // Delay the open connection until later.
             bool connectClient = false;
             ThriftSocketTransport transport = new(hostName!, int.Parse(port!), 
connectClient, config: new());
-            return Task.FromResult<TTransport>(transport);
+            return transport;
         }
 
-        protected override async Task<TProtocol> 
CreateProtocolAsync(TTransport transport)
+        protected override async Task<TProtocol> 
CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = 
default)
         {
-            return await base.CreateProtocolAsync(transport);
+            return await base.CreateProtocolAsync(transport, 
cancellationToken);
 
-            //Trace.TraceError($"create protocol with {Properties.Count} 
properties.");
             //if (!transport.IsOpen) await 
transport.OpenAsync(CancellationToken.None);
             //return new TBinaryProtocol(transport);
         }
diff --git a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs 
b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
index e4bc3f6cd..25888b1a3 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
@@ -32,6 +32,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
                 {
                     case Options.BatchSize:
                     case Options.PollTimeMilliseconds:
+                    case Options.QueryTimeoutSeconds:
                         {
                             SetOption(kvp.Key, kvp.Value);
                             break;
@@ -45,7 +46,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             // TODO: Ensure this is set dynamically depending on server 
capabilities.
             statement.EnforceResultPersistenceMode = false;
             statement.ResultPersistenceMode = 2;
-
+            // This seems like a good idea to have the server timeout so it 
doesn't keep processing unnecessarily.
+            // Set in combination with a CancellationToken.
+            statement.QueryTimeout = QueryTimeoutSeconds;
             statement.CanReadArrowResult = true;
             statement.CanDownloadResult = true;
             statement.ConfOverlay = SparkConnection.timestampConfig;
@@ -65,7 +68,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
         /// <summary>
         /// Provides the constant string key values to the <see 
cref="AdbcStatement.SetOption(string, string)" /> method.
         /// </summary>
-        public new sealed class Options : HiveServer2Statement.Options
+        public sealed class Options : ApacheParameters
         {
             // options specific to Spark go here
         }
diff --git a/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs 
b/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs
index fb62ccd9a..ea3d7d16e 100644
--- a/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs
+++ b/csharp/test/Drivers/Apache/ApacheTestConfiguration.cs
@@ -45,11 +45,14 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache
         [JsonPropertyName("batch_size"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
         public string BatchSize { get; set; } = string.Empty;
 
-        [JsonPropertyName("polltime_milliseconds"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
+        [JsonPropertyName("polltime_ms"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
         public string PollTimeMilliseconds { get; set; } = string.Empty;
 
-        [JsonPropertyName("http_request_timeout_ms"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
-        public string HttpRequestTimeoutMilliseconds { get; set; } = 
string.Empty;
+        [JsonPropertyName("connect_timeout_ms"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
+        public string ConnectTimeoutMilliseconds { get; set; } = string.Empty;
+
+        [JsonPropertyName("query_timeout_s"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
+        public string QueryTimeoutSeconds { get; set; } = string.Empty;
 
         [JsonPropertyName("type"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
         public string Type { get; set; } = string.Empty;
diff --git a/csharp/test/Drivers/Apache/Common/ClientTests.cs 
b/csharp/test/Drivers/Apache/Common/ClientTests.cs
index e3b0309d0..9148d7281 100644
--- a/csharp/test/Drivers/Apache/Common/ClientTests.cs
+++ b/csharp/test/Drivers/Apache/Common/ClientTests.cs
@@ -17,6 +17,7 @@
 
 using System;
 using System.Collections.Generic;
+using Apache.Arrow.Adbc.Client;
 using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2;
 using Apache.Arrow.Adbc.Tests.Xunit;
 using Xunit;
@@ -203,6 +204,27 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common
             }
         }
 
+        [SkippableFact]
+        public void VerifyTimeoutsSet()
+        {
+            using (Adbc.Client.AdbcConnection adbcConnection = 
GetAdbcConnection())
+            {
+                int timeout = 99;
+                using AdbcCommand cmd = adbcConnection.CreateCommand();
+
+                // setting the timout before the property value
+                Assert.Throws<InvalidOperationException>(() =>
+                {
+                    cmd.CommandTimeout = 1;
+                });
+
+                cmd.AdbcCommandTimeoutProperty = 
"adbc.apache.statement.query_timeout_s";
+                cmd.CommandTimeout = timeout;
+
+                Assert.True(cmd.CommandTimeout == timeout, $"ConnectionTimeout 
is not set to {timeout}");
+            }
+        }
+
         private Adbc.Client.AdbcConnection GetAdbcConnection(bool 
includeTableConstraints = true)
         {
             return new Adbc.Client.AdbcConnection(
diff --git a/csharp/test/Drivers/Apache/Common/StatementTests.cs 
b/csharp/test/Drivers/Apache/Common/StatementTests.cs
index 69eec0dd2..b793b7686 100644
--- a/csharp/test/Drivers/Apache/Common/StatementTests.cs
+++ b/csharp/test/Drivers/Apache/Common/StatementTests.cs
@@ -18,6 +18,7 @@
 using System;
 using System.Collections.Generic;
 using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache;
 using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2;
 using Apache.Arrow.Adbc.Tests.Xunit;
 using Xunit;
@@ -68,11 +69,11 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common
             AdbcStatement statement = NewConnection().CreateStatement();
             if (throws)
             {
-                Assert.Throws<ArgumentOutOfRangeException>(() => 
statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.PollTimeMilliseconds,
 value));
+                Assert.Throws<ArgumentOutOfRangeException>(() => 
statement.SetOption(ApacheParameters.PollTimeMilliseconds, value));
             }
             else
             {
-                
statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.PollTimeMilliseconds,
 value);
+                statement.SetOption(ApacheParameters.PollTimeMilliseconds, 
value);
             }
         }
 
@@ -101,11 +102,74 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common
             AdbcStatement statement = NewConnection().CreateStatement();
             if (throws)
             {
-                Assert.Throws<ArgumentOutOfRangeException>(() => 
statement!.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.BatchSize,
 value));
+                Assert.Throws<ArgumentOutOfRangeException>(() => 
statement!.SetOption(ApacheParameters.BatchSize, value));
             }
             else
             {
-                
statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.BatchSize,
 value);
+                statement.SetOption(ApacheParameters.BatchSize, value);
+            }
+        }
+
+        /// <summary>
+        /// Validates if the SetOption handle valid/invalid data correctly for 
the QueryTimeout option.
+        /// </summary>
+        [SkippableTheory]
+        [InlineData("zero", true)]
+        [InlineData("-2147483648", true)]
+        [InlineData("2147483648", true)]
+        [InlineData("0", false)]
+        [InlineData("-1", true)]
+        [InlineData("1")]
+        [InlineData("2147483647")]
+        public void CanSetOptionQueryTimeout(string value, bool throws = false)
+        {
+            var testConfiguration = TestConfiguration.Clone() as TConfig;
+            testConfiguration!.QueryTimeoutSeconds = value;
+            if (throws)
+            {
+                Assert.Throws<ArgumentOutOfRangeException>(() => 
NewConnection(testConfiguration).CreateStatement());
+            }
+
+            AdbcStatement statement = NewConnection().CreateStatement();
+            if (throws)
+            {
+                Assert.Throws<ArgumentOutOfRangeException>(() => 
statement.SetOption(ApacheParameters.QueryTimeoutSeconds, value));
+            }
+            else
+            {
+                statement.SetOption(ApacheParameters.QueryTimeoutSeconds, 
value);
+            }
+        }
+
+        /// <summary>
+        /// Queries the backend with various timeouts.
+        /// </summary>
+        /// <param name="statementWithExceptions"></param>
+        [SkippableTheory]
+        [ClassData(typeof(StatementTimeoutTestData))]
+        internal void StatementTimeoutTest(StatementWithExceptions 
statementWithExceptions)
+        {
+            TConfig testConfiguration = (TConfig)TestConfiguration.Clone();
+
+            if (statementWithExceptions.QueryTimeoutSeconds.HasValue)
+                testConfiguration.QueryTimeoutSeconds = 
statementWithExceptions.QueryTimeoutSeconds.Value.ToString();
+
+            if (!string.IsNullOrEmpty(statementWithExceptions.Query))
+                testConfiguration.Query = statementWithExceptions.Query!;
+
+            OutputHelper?.WriteLine($"QueryTimeoutSeconds: 
{testConfiguration.QueryTimeoutSeconds}. ShouldSucceed: 
{statementWithExceptions.ExceptionType == null}. Query: 
[{testConfiguration.Query}]");
+
+            try
+            {
+                AdbcStatement st = 
NewConnection(testConfiguration).CreateStatement();
+                st.SqlQuery = testConfiguration.Query;
+                QueryResult qr = st.ExecuteQuery();
+
+                OutputHelper?.WriteLine($"QueryResultRowCount: {qr.RowCount}");
+            }
+            catch (Exception ex) when (ApacheUtility.ContainsException(ex, 
statementWithExceptions.ExceptionType, out Exception? containedException))
+            {
+                Assert.IsType(statementWithExceptions.ExceptionType!, 
containedException!);
             }
         }
 
@@ -116,10 +180,58 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common
         public async Task CanInteractUsingSetOptions()
         {
             const string columnName = "INDEX";
-            
Statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.PollTimeMilliseconds,
 "100");
-            
Statement.SetOption(Adbc.Drivers.Apache.Hive2.HiveServer2Statement.Options.BatchSize,
 "10");
+            Statement.SetOption(ApacheParameters.PollTimeMilliseconds, "100");
+            Statement.SetOption(ApacheParameters.BatchSize, "10");
             using TemporaryTable temporaryTable = await 
NewTemporaryTableAsync(Statement, $"{columnName} INT");
             await 
ValidateInsertSelectDeleteSingleValueAsync(temporaryTable.TableName, 
columnName, 1);
         }
     }
+
+    /// <summary>
+    /// Data type used for metadata timeout tests.
+    /// </summary>
+    internal class StatementWithExceptions
+    {
+        public StatementWithExceptions(int? queryTimeoutSeconds, string? 
query, Type? exceptionType)
+        {
+            QueryTimeoutSeconds = queryTimeoutSeconds;
+            Query = query;
+            ExceptionType = exceptionType;
+        }
+
+        /// <summary>
+        /// If null, uses the default timeout.
+        /// </summary>
+        public int? QueryTimeoutSeconds { get; }
+
+        /// <summary>
+        /// If null, expected to succeed.
+        /// </summary>
+        public Type? ExceptionType { get; }
+
+        /// <summary>
+        /// If null, uses the default TestConfiguration.
+        /// </summary>
+        public string? Query { get; }
+    }
+
+    /// <summary>
+    /// Collection of <see cref="StatementWithExceptions"/> for testing 
statement timeouts."/>
+    /// </summary>
+    internal class StatementTimeoutTestData : 
TheoryData<StatementWithExceptions>
+    {
+        public StatementTimeoutTestData()
+        {
+            string longRunningQuery = "SELECT COUNT(*) AS total_count\nFROM 
(\n  SELECT t1.id AS id1, t2.id AS id2\n  FROM RANGE(1000000) t1\n  CROSS JOIN 
RANGE(10000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0";
+
+            Add(new(0, null, null));
+            Add(new(null, null, null));
+            Add(new(1, null, typeof(TimeoutException)));
+            Add(new(5, null, null));
+            Add(new(30, null, null));
+            Add(new(5, longRunningQuery, typeof(TimeoutException)));
+            Add(new(null, longRunningQuery, typeof(TimeoutException)));
+            Add(new(0, longRunningQuery, null));
+        }
+    }
 }
diff --git a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs 
b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs
index c2faa9d12..34e971bd8 100644
--- a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs
+++ b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs
@@ -19,7 +19,10 @@ using System;
 using System.Collections.Generic;
 using System.Globalization;
 using System.Net;
+using Apache.Arrow.Adbc.Drivers.Apache;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
 using Apache.Arrow.Adbc.Drivers.Apache.Spark;
+using Thrift.Transport;
 using Xunit;
 using Xunit.Abstractions;
 
@@ -48,6 +51,231 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
             OutputHelper?.WriteLine(exeption.Message);
         }
 
+        /// <summary>
+        /// Tests connection timeout to establish a session with the backend.
+        /// </summary>
+        /// <param name="connectTimeoutMilliseconds">The timeout (in 
ms)</param>
+        /// <param name="exceptionType">The exception type to expect (if 
any)</param>
+        /// <param name="alternateExceptionType">An alternate exception that 
may occur (if any)</param>
+        [SkippableTheory]
+        [InlineData(0, null, null)]
+        [InlineData(1, typeof(TimeoutException), typeof(TTransportException))]
+        [InlineData(10, typeof(TimeoutException), typeof(TTransportException))]
+        [InlineData(30000, null, null)]
+        [InlineData(null, null, null)]
+        public void ConnectionTimeoutTest(int? connectTimeoutMilliseconds, 
Type? exceptionType, Type? alternateExceptionType)
+        {
+            SparkTestConfiguration testConfiguration = 
(SparkTestConfiguration)TestConfiguration.Clone();
+
+            if (connectTimeoutMilliseconds.HasValue)
+                testConfiguration.ConnectTimeoutMilliseconds = 
connectTimeoutMilliseconds.Value.ToString();
+
+            OutputHelper?.WriteLine($"ConnectTimeoutMilliseconds: 
{testConfiguration.ConnectTimeoutMilliseconds}. ShouldSucceed: {exceptionType 
== null}");
+
+            try
+            {
+                NewConnection(testConfiguration);
+            }
+            catch(AggregateException aex)
+            {
+                if (exceptionType != null)
+                {
+                    if (alternateExceptionType != null && 
aex.InnerException?.GetType() != exceptionType)
+                    {
+                        if (aex.InnerException?.GetType() == 
typeof(HiveServer2Exception))
+                        {
+                            // a TTransportException is inside a 
HiveServer2Exception
+                            Assert.IsType(alternateExceptionType, 
aex.InnerException!.InnerException);
+                        }
+                        else
+                        {
+                            throw;
+                        }
+                    }
+                    else
+                    {
+                        Assert.IsType(exceptionType, aex.InnerException);
+                    }
+                }
+                else
+                {
+                    throw;
+                }
+            }
+        }
+
+        /// <summary>
+        /// Tests the various metadata calls on a SparkConnection
+        /// </summary>
+        /// <param name="metadataWithException"></param>
+        [SkippableTheory]
+        [ClassData(typeof(MetadataTimeoutTestData))]
+        internal void MetadataTimeoutTest(MetadataWithExceptions 
metadataWithException)
+        {
+            SparkTestConfiguration testConfiguration = 
(SparkTestConfiguration)TestConfiguration.Clone();
+
+            if (metadataWithException.QueryTimeoutSeconds.HasValue)
+                testConfiguration.QueryTimeoutSeconds = 
metadataWithException.QueryTimeoutSeconds.Value.ToString();
+
+            OutputHelper?.WriteLine($"Action: 
{metadataWithException.ActionName}. QueryTimeoutSeconds: 
{testConfiguration.QueryTimeoutSeconds}. ShouldSucceed: 
{metadataWithException.ExceptionType == null}");
+
+            try
+            {
+                metadataWithException.MetadataAction(testConfiguration);
+            }
+            catch (Exception ex) when (ApacheUtility.ContainsException(ex, 
metadataWithException.ExceptionType, out Exception? containedException))
+            {
+                Assert.IsType(metadataWithException.ExceptionType!, 
containedException);
+            }
+            catch (Exception ex) when (ApacheUtility.ContainsException(ex, 
metadataWithException.AlternateExceptionType, out Exception? 
containedException))
+            {
+                Assert.IsType(metadataWithException.AlternateExceptionType!, 
containedException);
+            }
+        }
+
+        /// <summary>
+        /// Data type used for metadata timeout tests.
+        /// </summary>
+        internal class MetadataWithExceptions
+        {
+            public MetadataWithExceptions(int? queryTimeoutSeconds, string 
actionName, Action<SparkTestConfiguration> action, Type? exceptionType, Type? 
alternateExceptionType)
+            {
+                QueryTimeoutSeconds = queryTimeoutSeconds;
+                ActionName = actionName;
+                MetadataAction = action;
+                ExceptionType = exceptionType;
+                AlternateExceptionType = alternateExceptionType;
+            }
+
+            /// <summary>
+            /// If null, uses the default timeout.
+            /// </summary>
+            public int? QueryTimeoutSeconds { get; }
+
+            public string ActionName { get; }
+
+            /// <summary>
+            /// If null, expected to succeed.
+            /// </summary>
+            public Type? ExceptionType { get; }
+
+            /// <summary>
+            /// Sometimes you can expect one but may get another.
+            /// For example, on GetObjectsAll, sometimes a TTransportException 
is expected but a TaskCanceledException is received during the test.
+            /// </summary>
+            public Type? AlternateExceptionType { get; }
+
+            /// <summary>
+            /// The metadata action to perform.
+            /// </summary>
+            public Action<SparkTestConfiguration> MetadataAction { get; }
+        }
+
+        /// <summary>
+        /// Used for testing timeouts on metadata calls.
+        /// </summary>
+        internal class MetadataTimeoutTestData : 
TheoryData<MetadataWithExceptions>
+        {
+            public MetadataTimeoutTestData()
+            {
+                SparkConnectionTest sparkConnectionTest = new 
SparkConnectionTest(null);
+
+                Action<SparkTestConfiguration> getObjectsAll = 
(testConfiguration) =>
+                {
+                    AdbcConnection cn = 
sparkConnectionTest.NewConnection(testConfiguration);
+                    cn.GetObjects(AdbcConnection.GetObjectsDepth.All, 
testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, 
testConfiguration.Metadata.Table, null, null);
+                };
+
+                Action<SparkTestConfiguration> getObjectsCatalogs = 
(testConfiguration) =>
+                {
+                    AdbcConnection cn = 
sparkConnectionTest.NewConnection(testConfiguration);
+                    cn.GetObjects(AdbcConnection.GetObjectsDepth.Catalogs, 
testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, 
testConfiguration.Metadata.Schema, null, null);
+                };
+
+                Action<SparkTestConfiguration> getObjectsDbSchemas = 
(testConfiguration) =>
+                {
+                    AdbcConnection cn = 
sparkConnectionTest.NewConnection(testConfiguration);
+                    cn.GetObjects(AdbcConnection.GetObjectsDepth.DbSchemas, 
testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, 
testConfiguration.Metadata.Schema, null, null);
+                };
+
+                Action<SparkTestConfiguration> getObjectsTables = 
(testConfiguration) =>
+                {
+                    AdbcConnection cn = 
sparkConnectionTest.NewConnection(testConfiguration);
+                    cn.GetObjects(AdbcConnection.GetObjectsDepth.Tables, 
testConfiguration.Metadata.Catalog, testConfiguration.Metadata.Schema, 
testConfiguration.Metadata.Schema, null, null);
+                };
+
+                AddAction("getObjectsAll", getObjectsAll, new List<Type?>() { 
null, typeof(TimeoutException), null, null, null } );
+                AddAction("getObjectsCatalogs", getObjectsCatalogs);
+                AddAction("getObjectsDbSchemas", getObjectsDbSchemas);
+                AddAction("getObjectsTables", getObjectsTables);
+
+                Action<SparkTestConfiguration> getTableTypes = 
(testConfiguration) =>
+                {
+                    AdbcConnection cn = 
sparkConnectionTest.NewConnection(testConfiguration);
+                    cn.GetTableTypes();
+                };
+
+                AddAction("getTableTypes", getTableTypes);
+
+                Action<SparkTestConfiguration> getTableSchema = 
(testConfiguration) =>
+                {
+                    AdbcConnection cn = 
sparkConnectionTest.NewConnection(testConfiguration);
+                    cn.GetTableSchema(testConfiguration.Metadata.Catalog, 
testConfiguration.Metadata.Schema, testConfiguration.Metadata.Table);
+                };
+
+                AddAction("getTableSchema", getTableSchema);
+            }
+
+            /// <summary>
+            /// Adds the action with the default timeouts.
+            /// </summary>
+            /// <param name="name">The friendly name of the action.</param>
+            /// <param name="action">The action to perform.</param>
+            /// <param name="alternateExceptions">Optional list of alternate 
exceptions that are possible. Must have 5 items if present.</param>
+            private void AddAction(string name, Action<SparkTestConfiguration> 
action, List<Type?>? alternateExceptions = null)
+            {
+                List<Type?> expectedExceptions = new List<Type?>()
+                {
+                    null, // QueryTimeout = 0
+                    typeof(TTransportException), // QueryTimeout = 1
+                    typeof(TimeoutException), // QueryTimeout = 10
+                    null, // QueryTimeout = default
+                    null // QueryTimeout = 300
+                };
+
+                AddAction(name, action, expectedExceptions, 
alternateExceptions);
+            }
+
+            /// <summary>
+            /// Adds the action with the default timeouts.
+            /// </summary>
+            /// <param name="action">The action to perform.</param>
+            /// <param name="expectedExceptions">The expected 
exceptions.</param>
+            /// <remarks>
+            /// For List<Type?> the position is based on the behavior when:
+            ///    [0] QueryTimeout = 0
+            ///    [1] QueryTimeout = 1
+            ///    [2] QueryTimeout = 10
+            ///    [3] QueryTimeout = default
+            ///    [4] QueryTimeout = 300
+            /// </remarks>
+            private void AddAction(string name, Action<SparkTestConfiguration> 
action, List<Type?> expectedExceptions, List<Type?>? alternateExceptions)
+            {
+                Assert.True(expectedExceptions.Count == 5);
+
+                if (alternateExceptions != null)
+                {
+                    Assert.True(alternateExceptions.Count == 5);
+                }
+
+                Add(new(0, name, action, expectedExceptions[0], 
alternateExceptions?[0]));
+                Add(new(1, name, action, expectedExceptions[1], 
alternateExceptions?[1]));
+                Add(new(10, name, action, expectedExceptions[2], 
alternateExceptions?[2]));
+                Add(new(null, name, action, expectedExceptions[3], 
alternateExceptions?[3]));
+                Add(new(300, name, action, expectedExceptions[4], 
alternateExceptions?[4]));
+            }
+        }
+
         internal class ParametersWithExceptions
         {
             public ParametersWithExceptions(Dictionary<string, string> 
parameters, Type exceptionType)
@@ -85,11 +313,9 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
                 Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = 
"valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = 
"httpxxz://hostname.com" }, typeof(ArgumentOutOfRangeException)));
                 Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = 
"valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = 
"http-//hostname.com" }, typeof(UriFormatException)));
                 Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Databricks, [SparkParameters.HostName] = 
"valid.server.com", [SparkParameters.Token] = "abcdef", [AdbcOptions.Uri] = 
"httpxxz://hostname.com:1234567890" }, typeof(UriFormatException)));
-                Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", 
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword" , 
[SparkParameters.HttpRequestTimeoutMilliseconds] = "0" }, 
typeof(ArgumentOutOfRangeException)));
-                Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", 
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", 
[SparkParameters.HttpRequestTimeoutMilliseconds] = "-1" }, 
typeof(ArgumentOutOfRangeException)));
-                Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", 
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", 
[SparkParameters.HttpRequestTimeoutMilliseconds] = ((long)int.MaxValue + 
1).ToString() }, typeof(ArgumentOutOfRangeException)));
-                Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", 
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", 
[SparkParameters.HttpRequestTimeoutMilliseconds] = "non-numeric" }, 
typeof(ArgumentOutOfRangeException)));
-                Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", 
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", 
[SparkParameters.HttpRequestTimeoutMilliseconds] = "" }, 
typeof(ArgumentOutOfRangeException)));
+                Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", 
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", 
[SparkParameters.ConnectTimeoutMilliseconds] = ((long)int.MaxValue + 
1).ToString() }, typeof(ArgumentOutOfRangeException)));
+                Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", 
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", 
[SparkParameters.ConnectTimeoutMilliseconds] = "non-numeric" }, 
typeof(ArgumentOutOfRangeException)));
+                Add(new(new() { [SparkParameters.Type] = 
SparkServerTypeConstants.Http, [SparkParameters.HostName] = "valid.server.com", 
[AdbcOptions.Username] = "user", [AdbcOptions.Password] = "myPassword", 
[SparkParameters.ConnectTimeoutMilliseconds] = "" }, 
typeof(ArgumentOutOfRangeException)));
             }
         }
     }
diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs 
b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
index 54c536853..16a550111 100644
--- a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
+++ b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
@@ -19,6 +19,7 @@ using System;
 using System.Collections.Generic;
 using System.Data.SqlTypes;
 using System.Text;
+using Apache.Arrow.Adbc.Drivers.Apache;
 using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
 using Apache.Arrow.Adbc.Drivers.Apache.Spark;
 using Apache.Arrow.Adbc.Tests.Drivers.Apache.Hive2;
@@ -102,15 +103,19 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
             }
             if (!string.IsNullOrEmpty(testConfiguration.BatchSize))
             {
-                parameters.Add(HiveServer2Statement.Options.BatchSize, 
testConfiguration.BatchSize!);
+                parameters.Add(ApacheParameters.BatchSize, 
testConfiguration.BatchSize!);
             }
             if (!string.IsNullOrEmpty(testConfiguration.PollTimeMilliseconds))
             {
-                
parameters.Add(HiveServer2Statement.Options.PollTimeMilliseconds, 
testConfiguration.PollTimeMilliseconds!);
+                parameters.Add(ApacheParameters.PollTimeMilliseconds, 
testConfiguration.PollTimeMilliseconds!);
             }
-            if 
(!string.IsNullOrEmpty(testConfiguration.HttpRequestTimeoutMilliseconds))
+            if 
(!string.IsNullOrEmpty(testConfiguration.ConnectTimeoutMilliseconds))
             {
-                parameters.Add(SparkParameters.HttpRequestTimeoutMilliseconds, 
testConfiguration.HttpRequestTimeoutMilliseconds!);
+                parameters.Add(SparkParameters.ConnectTimeoutMilliseconds, 
testConfiguration.ConnectTimeoutMilliseconds!);
+            }
+            if (!string.IsNullOrEmpty(testConfiguration.QueryTimeoutSeconds))
+            {
+                parameters.Add(ApacheParameters.QueryTimeoutSeconds, 
testConfiguration.QueryTimeoutSeconds!);
             }
 
             return parameters;
diff --git a/csharp/test/Drivers/Apache/Spark/StatementTests.cs 
b/csharp/test/Drivers/Apache/Spark/StatementTests.cs
index 25d27179a..aaafc31ba 100644
--- a/csharp/test/Drivers/Apache/Spark/StatementTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/StatementTests.cs
@@ -15,6 +15,8 @@
 * limitations under the License.
 */
 
+using System;
+using Xunit;
 using Xunit.Abstractions;
 
 namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark

Reply via email to