This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c7bbd5b3 feat(csharp/src/Drivers/Apache): add support for 
Statement.Cancel (#3302)
3c7bbd5b3 is described below

commit 3c7bbd5b3a83e6346b75dccbc593c63969164c0d
Author: Bruce Irschick <bruce.irsch...@improving.com>
AuthorDate: Thu Aug 21 13:06:31 2025 -0700

    feat(csharp/src/Drivers/Apache): add support for Statement.Cancel (#3302)
    
    Add support for `AdbcStatement.Cancel`.
    
    - If a `CancellationTokenSource` exists, it will be set to Cancel.
    - If an operation is in progress, a `CancelOperation` will be sent using
    the current operation handle.
    - If no operation handle or cancellation source is available, then
    `Cancel` is a no-op
    
    Once the query result is created and returned, it is up to consumer of
    the QueryResult to close the Stream.
    
    Note to reviewers: Use the "Hide Whitespace" option.
    
    closes #3287
---
 csharp/src/Drivers/Apache/ApacheUtility.cs         |  21 ++-
 .../Drivers/Apache/Hive2/HiveServer2Statement.cs   | 147 +++++++++++++++++----
 .../src/Drivers/Databricks/DatabricksConnection.cs |  10 ++
 .../test/Drivers/Apache/Common/StatementTests.cs   |  32 +++++
 .../test/Drivers/Apache/Impala/StatementTests.cs   |   9 ++
 csharp/test/Drivers/Apache/Spark/StatementTests.cs |  16 ++-
 .../Databricks/E2E/DatabricksTestConfiguration.cs  |   6 +
 .../Databricks/E2E/DatabricksTestEnvironment.cs    |   8 ++
 .../test/Drivers/Databricks/E2E/StatementTests.cs  |  34 ++++-
 9 files changed, 239 insertions(+), 44 deletions(-)

diff --git a/csharp/src/Drivers/Apache/ApacheUtility.cs 
b/csharp/src/Drivers/Apache/ApacheUtility.cs
index 104d498fd..0ee596a71 100644
--- a/csharp/src/Drivers/Apache/ApacheUtility.cs
+++ b/csharp/src/Drivers/Apache/ApacheUtility.cs
@@ -33,7 +33,20 @@ namespace Apache.Arrow.Adbc.Drivers.Apache
             Milliseconds
         }
 
+        public static CancellationTokenSource GetCancellationTokenSource(int 
timeout, TimeUnit timeUnit)
+        {
+            TimeSpan timeSpan = CalculateTimeSpan(timeout, timeUnit);
+            return new CancellationTokenSource(timeSpan);
+        }
+
         public static CancellationToken GetCancellationToken(int timeout, 
TimeUnit timeUnit)
+        {
+            TimeSpan timeSpan = CalculateTimeSpan(timeout, timeUnit);
+            var cts = new CancellationTokenSource(timeSpan);
+            return cts.Token;
+        }
+
+        private static TimeSpan CalculateTimeSpan(int timeout, TimeUnit 
timeUnit)
         {
             TimeSpan span;
 
@@ -55,13 +68,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache
                 }
             }
 
-            return GetCancellationToken(span);
-        }
-
-        private static CancellationToken GetCancellationToken(TimeSpan 
timeSpan)
-        {
-            var cts = new CancellationTokenSource(timeSpan);
-            return cts.Token;
+            return span;
         }
 
         public static bool QueryTimeoutIsValid(string key, string value, out 
int queryTimeoutSeconds)
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
index 0ed673ea3..10b8ec2d6 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
@@ -17,6 +17,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.Linq;
 using System.Threading;
 using System.Threading.Tasks;
@@ -52,6 +53,10 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
         protected const string PrimaryKeyPrefix = "PK_";
         protected const string ForeignKeyPrefix = "FK_";
 
+        // Lock to ensure consistent access to TokenSource
+        private readonly object _tokenSourceLock = new();
+        private CancellationTokenSource? _executeTokenSource;
+
         internal HiveServer2Statement(HiveServer2Connection connection)
             : base(connection)
         {
@@ -77,45 +82,49 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 
         public override QueryResult ExecuteQuery()
         {
-            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            CancellationTokenSource ts = SetTokenSource();
             try
             {
-                return ExecuteQueryAsyncInternal(cancellationToken).Result;
+                return ExecuteQueryAsyncInternal(ts.Token).Result;
             }
-            catch (Exception ex)
-                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
-                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+            catch (Exception ex) when (IsCancellation(ex, ts.Token))
             {
-                throw new TimeoutException("The query execution timed out. 
Consider increasing the query timeout value.", ex);
+                throw new TimeoutException("The query execution timed out or 
was cancelled. Consider increasing the query timeout value.", ex);
             }
             catch (Exception ex) when (ex is not HiveServer2Exception)
             {
                 throw new HiveServer2Exception($"An unexpected error occurred 
while fetching results. '{ApacheUtility.FormatExceptionMessage(ex)}'", ex);
             }
+            finally
+            {
+                DisposeTokenSource();
+            }
         }
 
         public override UpdateResult ExecuteUpdate()
         {
-            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            CancellationTokenSource ts = SetTokenSource();
             try
             {
-                return ExecuteUpdateAsyncInternal(cancellationToken).Result;
+                return ExecuteUpdateAsyncInternal(ts.Token).Result;
             }
-            catch (Exception ex)
-                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
-                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+            catch (Exception ex) when (IsCancellation(ex, ts.Token))
             {
-                throw new TimeoutException("The query execution timed out. 
Consider increasing the query timeout value.", ex);
+                throw new TimeoutException("The query execution timed out or 
was cancelled. Consider increasing the query timeout value.", ex);
             }
             catch (Exception ex) when (ex is not HiveServer2Exception)
             {
                 throw new HiveServer2Exception($"An unexpected error occurred 
while fetching results. '{ApacheUtility.FormatExceptionMessage(ex)}'", ex);
             }
+            finally
+            {
+                DisposeTokenSource();
+            }
         }
 
         private async Task<QueryResult> 
ExecuteQueryAsyncInternal(CancellationToken cancellationToken = default)
         {
-            return await this.TraceActivityAsync(async _ =>
+            return await this.TraceActivityAsync(async activity =>
             {
                 if (IsMetadataCommand)
                 {
@@ -136,8 +145,17 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                 }
                 else
                 {
-                    await 
HiveServer2Connection.PollForResponseAsync(response.OperationHandle!, 
Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to 
QueryTimeout
-                    metadata = await 
HiveServer2Connection.GetResultSetMetadataAsync(response.OperationHandle!, 
Connection.Client, cancellationToken);
+                    try
+                    {
+                        await 
HiveServer2Connection.PollForResponseAsync(response.OperationHandle!, 
Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to 
QueryTimeout
+                        metadata = await 
HiveServer2Connection.GetResultSetMetadataAsync(response.OperationHandle!, 
Connection.Client, cancellationToken);
+                    }
+                    catch (Exception ex) when (IsCancellation(ex, 
cancellationToken))
+                    {
+                        // If the operation was cancelled, we need to cancel 
the operation on the server
+                        await CancelOperationAsync(activity, 
response.OperationHandle);
+                        throw;
+                    }
                 }
                 Schema schema = GetSchemaFromMetadata(metadata);
                 return new QueryResult(-1, Connection.NewReader(this, schema, 
response, metadata));
@@ -146,21 +164,23 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 
         public override async ValueTask<QueryResult> ExecuteQueryAsync()
         {
-            CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            CancellationTokenSource ts = SetTokenSource();
             try
             {
-                return await ExecuteQueryAsyncInternal(cancellationToken);
+                return await ExecuteQueryAsyncInternal(ts.Token);
             }
-            catch (Exception ex)
-                when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
-                     (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+            catch (Exception ex) when (IsCancellation(ex, ts.Token))
             {
-                throw new TimeoutException("The query execution timed out. 
Consider increasing the query timeout value.", ex);
+                throw new TimeoutException("The query execution timed out or 
was cancelled. Consider increasing the query timeout value.", ex);
             }
             catch (Exception ex) when (ex is not HiveServer2Exception)
             {
                 throw new HiveServer2Exception($"An unexpected error occurred 
while fetching results. '{ApacheUtility.FormatExceptionMessage(ex)}'", ex);
             }
+            finally
+            {
+                DisposeTokenSource();
+            }
         }
 
         private async Task<UpdateResult> 
ExecuteUpdateAsyncInternal(CancellationToken cancellationToken = default)
@@ -221,21 +241,23 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
         {
             return await this.TraceActivityAsync(async _ =>
             {
-                CancellationToken cancellationToken = 
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+                CancellationTokenSource ts = SetTokenSource();
                 try
                 {
-                    return await ExecuteUpdateAsyncInternal(cancellationToken);
+                    return await ExecuteUpdateAsyncInternal(ts.Token);
                 }
-                catch (Exception ex)
-                    when (ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
-                         (ApacheUtility.ContainsException(ex, out 
TTransportException? _) && cancellationToken.IsCancellationRequested))
+                catch (Exception ex) when (IsCancellation(ex, ts.Token))
                 {
-                    throw new TimeoutException("The query execution timed out. 
Consider increasing the query timeout value.", ex);
+                    throw new TimeoutException("The query execution timed out 
or was cancelled. Consider increasing the query timeout value.", ex);
                 }
                 catch (Exception ex) when (ex is not HiveServer2Exception)
                 {
                     throw new HiveServer2Exception($"An unexpected error 
occurred while fetching results. '{ApacheUtility.FormatExceptionMessage(ex)}'", 
ex);
                 }
+                finally
+                {
+                    DisposeTokenSource();
+                }
             });
         }
 
@@ -1006,5 +1028,76 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             directResults = null;
             return false;
         }
+
+        /// <inheritdoc/>
+        public override void Cancel()
+        {
+            this.TraceActivity(_ =>
+            {
+                // This will cancel any operation using the current token 
source
+                CancelTokenSource();
+            });
+        }
+
+        private async Task CancelOperationAsync(Activity? activity, 
TOperationHandle? operationHandle)
+        {
+            if (operationHandle == null)
+            {
+                return;
+            }
+            using CancellationTokenSource cancellationTokenSource = 
ApacheUtility.GetCancellationTokenSource(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+            try
+            {
+                activity?.AddEvent(
+                    "db.operation.cancel_operation.starting",
+                    [new(SemanticConventions.Db.Operation.OperationId, new 
Guid(operationHandle.OperationId.Guid).ToString("N"))]);
+                TCancelOperationReq req = new(operationHandle);
+                TCancelOperationResp resp = await Client.CancelOperation(req, 
cancellationTokenSource.Token);
+                HiveServer2Connection.HandleThriftResponse(resp.Status, 
activity);
+                activity?.AddEvent(
+                    "db.operation.cancel_operation.completed",
+                    [new(SemanticConventions.Db.Response.StatusCode, 
resp.Status.StatusCode.ToString())]);
+            }
+            catch (Exception ex)
+            {
+                activity?.AddException(ex);
+            }
+        }
+
+        private CancellationTokenSource SetTokenSource()
+        {
+            lock (_tokenSourceLock)
+            {
+                if (_executeTokenSource != null)
+                {
+                    throw new InvalidOperationException("Simultaneous query or 
update execution is not allowed. Ensure to complete the query or update before 
starting a new one.");
+                }
+
+                _executeTokenSource = 
ApacheUtility.GetCancellationTokenSource(QueryTimeoutSeconds, 
ApacheUtility.TimeUnit.Seconds);
+                return _executeTokenSource;
+            }
+        }
+
+        private void CancelTokenSource()
+        {
+            lock (_tokenSourceLock)
+            {
+                // Cancel any running execution
+                _executeTokenSource?.Cancel();
+            }
+        }
+
+        private void DisposeTokenSource()
+        {
+            lock (_tokenSourceLock)
+            {
+                _executeTokenSource?.Dispose();
+                _executeTokenSource = null;
+            }
+        }
+
+        private static bool IsCancellation(Exception ex, CancellationToken 
cancellationToken) =>
+            ApacheUtility.ContainsException(ex, out 
OperationCanceledException? _) ||
+            (ApacheUtility.ContainsException(ex, out TTransportException? _) 
&& cancellationToken.IsCancellationRequested);
     }
 }
diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs 
b/csharp/src/Drivers/Databricks/DatabricksConnection.cs
index 2739197e8..9ffe7b657 100644
--- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs
@@ -294,6 +294,16 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
         /// </summary>
         internal bool EnableDirectResults => _enableDirectResults;
 
+        /// <inheritdoc/>
+        protected internal override bool TrySetGetDirectResults(IRequest 
request)
+        {
+            if (EnableDirectResults)
+            {
+                return base.TrySetGetDirectResults(request);
+            }
+            return false;
+        }
+
         /// <summary>
         /// Gets whether CloudFetch is enabled.
         /// </summary>
diff --git a/csharp/test/Drivers/Apache/Common/StatementTests.cs 
b/csharp/test/Drivers/Apache/Common/StatementTests.cs
index 24fb017f5..dd9b2d074 100644
--- a/csharp/test/Drivers/Apache/Common/StatementTests.cs
+++ b/csharp/test/Drivers/Apache/Common/StatementTests.cs
@@ -171,6 +171,38 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common
             }
         }
 
+        internal virtual async Task CanCancelStatementTest(string query)
+        {
+            const int millisecondsDelay = 500;
+
+            TConfig testConfiguration = (TConfig)TestConfiguration.Clone();
+            testConfiguration.QueryTimeoutSeconds = "0"; // no timeout
+            testConfiguration.Query = query;
+
+            AdbcStatement st = 
NewConnection(testConfiguration).CreateStatement();
+            // Note: for this test to be valid, the query needs to run for 
more time than the delay value!
+            st.SqlQuery = testConfiguration.Query;
+            for (int i = 0; i < 10; i++)
+            {
+                // Reuse the statement to check for issue that might arise 
from using the Statement multiple times.
+                try
+                {
+                    Task<QueryResult> queryTask = Task.Run(st.ExecuteQuery);
+
+                    await Task.Delay(millisecondsDelay);
+                    st.Cancel();
+
+                    QueryResult queryResult = await queryTask;
+                    OutputHelper?.WriteLine($"QueryResultRowCount: 
{queryResult.RowCount}");
+                    Assert.Fail("Expecting query to timeout, but it did not.");
+                }
+                catch (Exception ex) when (ApacheUtility.ContainsException(ex, 
typeof(TimeoutException), out Exception? containedException))
+                {
+                    Assert.IsType<TimeoutException>(containedException!);
+                }
+            }
+        }
+
         /// <summary>
         /// Validates if the driver can execute update statements.
         /// </summary>
diff --git a/csharp/test/Drivers/Apache/Impala/StatementTests.cs 
b/csharp/test/Drivers/Apache/Impala/StatementTests.cs
index 0b943fc29..6f7612fc7 100644
--- a/csharp/test/Drivers/Apache/Impala/StatementTests.cs
+++ b/csharp/test/Drivers/Apache/Impala/StatementTests.cs
@@ -24,6 +24,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala
 {
     public class StatementTests : 
Common.StatementTests<ApacheTestConfiguration, ImpalaTestEnvironment>
     {
+        private const string LongRunningQuery = "SELECT COUNT(*) AS 
total_count\nFROM (\n  SELECT t1.id AS id1, t2.id AS id2\n  FROM RANGE(1000000) 
t1\n  CROSS JOIN RANGE(100000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0";
+
         public StatementTests(ITestOutputHelper? outputHelper)
             : base(outputHelper, new ImpalaTestEnvironment.Factory())
         {
@@ -47,6 +49,13 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala
             await 
base.CanGetCrossReferenceFromChildTable(TestConfiguration.Metadata.Catalog, 
TestConfiguration.Metadata.Schema);
         }
 
+        [SkippableTheory(Skip = "Untested")]
+        [InlineData(LongRunningQuery)]
+        internal override async Task CanCancelStatementTest(string query)
+        {
+            await base.CanCancelStatementTest(query);
+        }
+
         protected override void PrepareCreateTableWithForeignKeys(string 
fullTableNameParent, out string sqlUpdate, out string tableNameChild, out 
string fullTableNameChild, out IReadOnlyList<string> foreignKeys)
         {
             CreateNewTableName(out tableNameChild, out fullTableNameChild);
diff --git a/csharp/test/Drivers/Apache/Spark/StatementTests.cs 
b/csharp/test/Drivers/Apache/Spark/StatementTests.cs
index e9c28c76d..05613933e 100644
--- a/csharp/test/Drivers/Apache/Spark/StatementTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/StatementTests.cs
@@ -17,6 +17,7 @@
 
 using System;
 using System.Collections.Generic;
+using System.Threading.Tasks;
 using Apache.Arrow.Adbc.Tests.Drivers.Apache.Common;
 using Xunit;
 using Xunit.Abstractions;
@@ -37,15 +38,22 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
             base.StatementTimeoutTest(statementWithExceptions);
         }
 
+        [SkippableTheory]
+        [InlineData(LongRunningStatementTimeoutTestData.LongRunningQuery)]
+        internal override async Task CanCancelStatementTest(string query)
+        {
+            await base.CanCancelStatementTest(query);
+        }
+
         internal class LongRunningStatementTimeoutTestData : 
ShortRunningStatementTimeoutTestData
         {
+            internal const string LongRunningQuery = "SELECT COUNT(*) AS 
total_count\nFROM (\n  SELECT t1.id AS id1, t2.id AS id2\n  FROM RANGE(1000000) 
t1\n  CROSS JOIN RANGE(100000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0";
             public LongRunningStatementTimeoutTestData()
             {
-                string longRunningQuery = "SELECT COUNT(*) AS 
total_count\nFROM (\n  SELECT t1.id AS id1, t2.id AS id2\n  FROM RANGE(1000000) 
t1\n  CROSS JOIN RANGE(100000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0";
 
-                Add(new(5, longRunningQuery, typeof(TimeoutException)));
-                Add(new(null, longRunningQuery, typeof(TimeoutException)));
-                Add(new(0, longRunningQuery, null));
+                Add(new(5, LongRunningQuery, typeof(TimeoutException)));
+                Add(new(null, LongRunningQuery, typeof(TimeoutException)));
+                Add(new(0, LongRunningQuery, null));
             }
         }
 
diff --git a/csharp/test/Drivers/Databricks/E2E/DatabricksTestConfiguration.cs 
b/csharp/test/Drivers/Databricks/E2E/DatabricksTestConfiguration.cs
index 2b17c6498..ab2ae5a42 100644
--- a/csharp/test/Drivers/Databricks/E2E/DatabricksTestConfiguration.cs
+++ b/csharp/test/Drivers/Databricks/E2E/DatabricksTestConfiguration.cs
@@ -51,5 +51,11 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
 
         [JsonPropertyName("isCITesting"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
         public bool IsCITesting { get; set; } = false;
+
+        [JsonPropertyName("enableRunAsyncInThriftOp"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
+        public string EnableRunAsyncInThriftOp { get; set; } = string.Empty;
+
+        [JsonPropertyName("enableDirectResults"), JsonIgnore(Condition = 
JsonIgnoreCondition.WhenWritingDefault)]
+        public string EnableDirectResults { get; set; } = string.Empty;
     }
 }
diff --git a/csharp/test/Drivers/Databricks/E2E/DatabricksTestEnvironment.cs 
b/csharp/test/Drivers/Databricks/E2E/DatabricksTestEnvironment.cs
index 6f9c08473..d93d7823a 100644
--- a/csharp/test/Drivers/Databricks/E2E/DatabricksTestEnvironment.cs
+++ b/csharp/test/Drivers/Databricks/E2E/DatabricksTestEnvironment.cs
@@ -157,6 +157,14 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
             {
                 parameters.Add(DatabricksParameters.TraceStateEnabled, 
testConfiguration.TraceStateEnabled!);
             }
+            if 
(!string.IsNullOrEmpty(testConfiguration.EnableRunAsyncInThriftOp))
+            {
+                parameters.Add(DatabricksParameters.EnableRunAsyncInThriftOp, 
testConfiguration.EnableRunAsyncInThriftOp!);
+            }
+            if (!string.IsNullOrEmpty(testConfiguration.EnableDirectResults))
+            {
+                parameters.Add(DatabricksParameters.EnableDirectResults, 
testConfiguration.EnableDirectResults!);
+            }
             if (testConfiguration.HttpOptions != null)
             {
                 if (testConfiguration.HttpOptions.Tls != null)
diff --git a/csharp/test/Drivers/Databricks/E2E/StatementTests.cs 
b/csharp/test/Drivers/Databricks/E2E/StatementTests.cs
index acbba677c..07520b3a6 100644
--- a/csharp/test/Drivers/Databricks/E2E/StatementTests.cs
+++ b/csharp/test/Drivers/Databricks/E2E/StatementTests.cs
@@ -90,16 +90,38 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
             base.StatementTimeoutTest(statementWithExceptions);
         }
 
-        internal class LongRunningStatementTimeoutTestData : 
ShortRunningStatementTimeoutTestData
+        [SkippableTheory]
+        [InlineData(LongRunningStatementTimeoutTestData.LongRunningQuery, 
"false", "true")]
+        [InlineData(LongRunningStatementTimeoutTestData.LongRunningQuery, 
"true", "true")]
+        [InlineData(LongRunningStatementTimeoutTestData.LongRunningQuery, 
"true", "false")]
+        internal async Task DatabricksCanCancelStatementTest(string query, 
string enableRunAsyncInThriftOp, string enableDirectResults)
         {
-            public LongRunningStatementTimeoutTestData() : base("SELECT 1")
+            string enableRunAsyncInThriftOpOrig = 
TestConfiguration.EnableRunAsyncInThriftOp;
+            string enableDirectResultsOrig = 
TestConfiguration.EnableDirectResults;
+            try
             {
-                string longRunningQuery = "SELECT COUNT(*) AS 
total_count\nFROM (\n  SELECT t1.id AS id1, t2.id AS id2\n  FROM RANGE(1000000) 
t1\n  CROSS JOIN RANGE(100000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0";
+                TestConfiguration.EnableRunAsyncInThriftOp = 
enableRunAsyncInThriftOp;
+                TestConfiguration.EnableDirectResults = enableDirectResults;
+                await base.CanCancelStatementTest(query);
+            }
+            finally
+            {
+                TestConfiguration.EnableRunAsyncInThriftOp = 
enableRunAsyncInThriftOpOrig;
+                TestConfiguration.EnableDirectResults = 
enableDirectResultsOrig;
+            }
+        }
 
+        internal class LongRunningStatementTimeoutTestData : 
ShortRunningStatementTimeoutTestData
+        {
+            internal const string LongRunningQuery = "SELECT COUNT(*) AS 
total_count\nFROM (\n  SELECT t1.id AS id1, t2.id AS id2\n  FROM RANGE(1000000) 
t1\n  CROSS JOIN RANGE(100000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0";
+            private const string DefaultQuery = "SELECT 1";
+
+            public LongRunningStatementTimeoutTestData() : base(DefaultQuery)
+            {
                 // Add Databricks-specific long-running query tests
-                Add(new(5, longRunningQuery, typeof(TimeoutException)));
-                Add(new(null, longRunningQuery, typeof(TimeoutException)));
-                Add(new(0, longRunningQuery, null));
+                Add(new(5, LongRunningQuery, typeof(TimeoutException)));
+                Add(new(null, LongRunningQuery, typeof(TimeoutException)));
+                Add(new(0, LongRunningQuery, null));
             }
         }
 

Reply via email to