This is an automated email from the ASF dual-hosted git repository. curth pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push: new 3c7bbd5b3 feat(csharp/src/Drivers/Apache): add support for Statement.Cancel (#3302) 3c7bbd5b3 is described below commit 3c7bbd5b3a83e6346b75dccbc593c63969164c0d Author: Bruce Irschick <bruce.irsch...@improving.com> AuthorDate: Thu Aug 21 13:06:31 2025 -0700 feat(csharp/src/Drivers/Apache): add support for Statement.Cancel (#3302) Add support for `AdbcStatement.Cancel`. - If a `CancellationTokenSource` exists, it will be set to Cancel. - If an operation is in progress, a `CancelOperation` will be sent using the current operation handle. - If no operation handle or cancellation source is available, then `Cancel` is a no-op Once the query result is created and returned, it is up to consumer of the QueryResult to close the Stream. Note to reviewers: Use the "Hide Whitespace" option. closes #3287 --- csharp/src/Drivers/Apache/ApacheUtility.cs | 21 ++- .../Drivers/Apache/Hive2/HiveServer2Statement.cs | 147 +++++++++++++++++---- .../src/Drivers/Databricks/DatabricksConnection.cs | 10 ++ .../test/Drivers/Apache/Common/StatementTests.cs | 32 +++++ .../test/Drivers/Apache/Impala/StatementTests.cs | 9 ++ csharp/test/Drivers/Apache/Spark/StatementTests.cs | 16 ++- .../Databricks/E2E/DatabricksTestConfiguration.cs | 6 + .../Databricks/E2E/DatabricksTestEnvironment.cs | 8 ++ .../test/Drivers/Databricks/E2E/StatementTests.cs | 34 ++++- 9 files changed, 239 insertions(+), 44 deletions(-) diff --git a/csharp/src/Drivers/Apache/ApacheUtility.cs b/csharp/src/Drivers/Apache/ApacheUtility.cs index 104d498fd..0ee596a71 100644 --- a/csharp/src/Drivers/Apache/ApacheUtility.cs +++ b/csharp/src/Drivers/Apache/ApacheUtility.cs @@ -33,7 +33,20 @@ namespace Apache.Arrow.Adbc.Drivers.Apache Milliseconds } + public static CancellationTokenSource GetCancellationTokenSource(int timeout, TimeUnit timeUnit) + { + TimeSpan timeSpan = CalculateTimeSpan(timeout, timeUnit); + return new CancellationTokenSource(timeSpan); + } + public static CancellationToken GetCancellationToken(int timeout, TimeUnit timeUnit) + { + TimeSpan timeSpan = CalculateTimeSpan(timeout, timeUnit); + var cts = new CancellationTokenSource(timeSpan); + return cts.Token; + } + + private static TimeSpan CalculateTimeSpan(int timeout, TimeUnit timeUnit) { TimeSpan span; @@ -55,13 +68,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache } } - return GetCancellationToken(span); - } - - private static CancellationToken GetCancellationToken(TimeSpan timeSpan) - { - var cts = new CancellationTokenSource(timeSpan); - return cts.Token; + return span; } public static bool QueryTimeoutIsValid(string key, string value, out int queryTimeoutSeconds) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 0ed673ea3..10b8ec2d6 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -52,6 +53,10 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 protected const string PrimaryKeyPrefix = "PK_"; protected const string ForeignKeyPrefix = "FK_"; + // Lock to ensure consistent access to TokenSource + private readonly object _tokenSourceLock = new(); + private CancellationTokenSource? _executeTokenSource; + internal HiveServer2Statement(HiveServer2Connection connection) : base(connection) { @@ -77,45 +82,49 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 public override QueryResult ExecuteQuery() { - CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + CancellationTokenSource ts = SetTokenSource(); try { - return ExecuteQueryAsyncInternal(cancellationToken).Result; + return ExecuteQueryAsyncInternal(ts.Token).Result; } - catch (Exception ex) - when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || - (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + catch (Exception ex) when (IsCancellation(ex, ts.Token)) { - throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + throw new TimeoutException("The query execution timed out or was cancelled. Consider increasing the query timeout value.", ex); } catch (Exception ex) when (ex is not HiveServer2Exception) { throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ApacheUtility.FormatExceptionMessage(ex)}'", ex); } + finally + { + DisposeTokenSource(); + } } public override UpdateResult ExecuteUpdate() { - CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + CancellationTokenSource ts = SetTokenSource(); try { - return ExecuteUpdateAsyncInternal(cancellationToken).Result; + return ExecuteUpdateAsyncInternal(ts.Token).Result; } - catch (Exception ex) - when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || - (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + catch (Exception ex) when (IsCancellation(ex, ts.Token)) { - throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + throw new TimeoutException("The query execution timed out or was cancelled. Consider increasing the query timeout value.", ex); } catch (Exception ex) when (ex is not HiveServer2Exception) { throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ApacheUtility.FormatExceptionMessage(ex)}'", ex); } + finally + { + DisposeTokenSource(); + } } private async Task<QueryResult> ExecuteQueryAsyncInternal(CancellationToken cancellationToken = default) { - return await this.TraceActivityAsync(async _ => + return await this.TraceActivityAsync(async activity => { if (IsMetadataCommand) { @@ -136,8 +145,17 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 } else { - await HiveServer2Connection.PollForResponseAsync(response.OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to QueryTimeout - metadata = await HiveServer2Connection.GetResultSetMetadataAsync(response.OperationHandle!, Connection.Client, cancellationToken); + try + { + await HiveServer2Connection.PollForResponseAsync(response.OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to QueryTimeout + metadata = await HiveServer2Connection.GetResultSetMetadataAsync(response.OperationHandle!, Connection.Client, cancellationToken); + } + catch (Exception ex) when (IsCancellation(ex, cancellationToken)) + { + // If the operation was cancelled, we need to cancel the operation on the server + await CancelOperationAsync(activity, response.OperationHandle); + throw; + } } Schema schema = GetSchemaFromMetadata(metadata); return new QueryResult(-1, Connection.NewReader(this, schema, response, metadata)); @@ -146,21 +164,23 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 public override async ValueTask<QueryResult> ExecuteQueryAsync() { - CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + CancellationTokenSource ts = SetTokenSource(); try { - return await ExecuteQueryAsyncInternal(cancellationToken); + return await ExecuteQueryAsyncInternal(ts.Token); } - catch (Exception ex) - when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || - (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + catch (Exception ex) when (IsCancellation(ex, ts.Token)) { - throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + throw new TimeoutException("The query execution timed out or was cancelled. Consider increasing the query timeout value.", ex); } catch (Exception ex) when (ex is not HiveServer2Exception) { throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ApacheUtility.FormatExceptionMessage(ex)}'", ex); } + finally + { + DisposeTokenSource(); + } } private async Task<UpdateResult> ExecuteUpdateAsyncInternal(CancellationToken cancellationToken = default) @@ -221,21 +241,23 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { return await this.TraceActivityAsync(async _ => { - CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + CancellationTokenSource ts = SetTokenSource(); try { - return await ExecuteUpdateAsyncInternal(cancellationToken); + return await ExecuteUpdateAsyncInternal(ts.Token); } - catch (Exception ex) - when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || - (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + catch (Exception ex) when (IsCancellation(ex, ts.Token)) { - throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); + throw new TimeoutException("The query execution timed out or was cancelled. Consider increasing the query timeout value.", ex); } catch (Exception ex) when (ex is not HiveServer2Exception) { throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ApacheUtility.FormatExceptionMessage(ex)}'", ex); } + finally + { + DisposeTokenSource(); + } }); } @@ -1006,5 +1028,76 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 directResults = null; return false; } + + /// <inheritdoc/> + public override void Cancel() + { + this.TraceActivity(_ => + { + // This will cancel any operation using the current token source + CancelTokenSource(); + }); + } + + private async Task CancelOperationAsync(Activity? activity, TOperationHandle? operationHandle) + { + if (operationHandle == null) + { + return; + } + using CancellationTokenSource cancellationTokenSource = ApacheUtility.GetCancellationTokenSource(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + activity?.AddEvent( + "db.operation.cancel_operation.starting", + [new(SemanticConventions.Db.Operation.OperationId, new Guid(operationHandle.OperationId.Guid).ToString("N"))]); + TCancelOperationReq req = new(operationHandle); + TCancelOperationResp resp = await Client.CancelOperation(req, cancellationTokenSource.Token); + HiveServer2Connection.HandleThriftResponse(resp.Status, activity); + activity?.AddEvent( + "db.operation.cancel_operation.completed", + [new(SemanticConventions.Db.Response.StatusCode, resp.Status.StatusCode.ToString())]); + } + catch (Exception ex) + { + activity?.AddException(ex); + } + } + + private CancellationTokenSource SetTokenSource() + { + lock (_tokenSourceLock) + { + if (_executeTokenSource != null) + { + throw new InvalidOperationException("Simultaneous query or update execution is not allowed. Ensure to complete the query or update before starting a new one."); + } + + _executeTokenSource = ApacheUtility.GetCancellationTokenSource(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + return _executeTokenSource; + } + } + + private void CancelTokenSource() + { + lock (_tokenSourceLock) + { + // Cancel any running execution + _executeTokenSource?.Cancel(); + } + } + + private void DisposeTokenSource() + { + lock (_tokenSourceLock) + { + _executeTokenSource?.Dispose(); + _executeTokenSource = null; + } + } + + private static bool IsCancellation(Exception ex, CancellationToken cancellationToken) => + ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested); } } diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs b/csharp/src/Drivers/Databricks/DatabricksConnection.cs index 2739197e8..9ffe7b657 100644 --- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs +++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs @@ -294,6 +294,16 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks /// </summary> internal bool EnableDirectResults => _enableDirectResults; + /// <inheritdoc/> + protected internal override bool TrySetGetDirectResults(IRequest request) + { + if (EnableDirectResults) + { + return base.TrySetGetDirectResults(request); + } + return false; + } + /// <summary> /// Gets whether CloudFetch is enabled. /// </summary> diff --git a/csharp/test/Drivers/Apache/Common/StatementTests.cs b/csharp/test/Drivers/Apache/Common/StatementTests.cs index 24fb017f5..dd9b2d074 100644 --- a/csharp/test/Drivers/Apache/Common/StatementTests.cs +++ b/csharp/test/Drivers/Apache/Common/StatementTests.cs @@ -171,6 +171,38 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common } } + internal virtual async Task CanCancelStatementTest(string query) + { + const int millisecondsDelay = 500; + + TConfig testConfiguration = (TConfig)TestConfiguration.Clone(); + testConfiguration.QueryTimeoutSeconds = "0"; // no timeout + testConfiguration.Query = query; + + AdbcStatement st = NewConnection(testConfiguration).CreateStatement(); + // Note: for this test to be valid, the query needs to run for more time than the delay value! + st.SqlQuery = testConfiguration.Query; + for (int i = 0; i < 10; i++) + { + // Reuse the statement to check for issue that might arise from using the Statement multiple times. + try + { + Task<QueryResult> queryTask = Task.Run(st.ExecuteQuery); + + await Task.Delay(millisecondsDelay); + st.Cancel(); + + QueryResult queryResult = await queryTask; + OutputHelper?.WriteLine($"QueryResultRowCount: {queryResult.RowCount}"); + Assert.Fail("Expecting query to timeout, but it did not."); + } + catch (Exception ex) when (ApacheUtility.ContainsException(ex, typeof(TimeoutException), out Exception? containedException)) + { + Assert.IsType<TimeoutException>(containedException!); + } + } + } + /// <summary> /// Validates if the driver can execute update statements. /// </summary> diff --git a/csharp/test/Drivers/Apache/Impala/StatementTests.cs b/csharp/test/Drivers/Apache/Impala/StatementTests.cs index 0b943fc29..6f7612fc7 100644 --- a/csharp/test/Drivers/Apache/Impala/StatementTests.cs +++ b/csharp/test/Drivers/Apache/Impala/StatementTests.cs @@ -24,6 +24,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala { public class StatementTests : Common.StatementTests<ApacheTestConfiguration, ImpalaTestEnvironment> { + private const string LongRunningQuery = "SELECT COUNT(*) AS total_count\nFROM (\n SELECT t1.id AS id1, t2.id AS id2\n FROM RANGE(1000000) t1\n CROSS JOIN RANGE(100000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0"; + public StatementTests(ITestOutputHelper? outputHelper) : base(outputHelper, new ImpalaTestEnvironment.Factory()) { @@ -47,6 +49,13 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala await base.CanGetCrossReferenceFromChildTable(TestConfiguration.Metadata.Catalog, TestConfiguration.Metadata.Schema); } + [SkippableTheory(Skip = "Untested")] + [InlineData(LongRunningQuery)] + internal override async Task CanCancelStatementTest(string query) + { + await base.CanCancelStatementTest(query); + } + protected override void PrepareCreateTableWithForeignKeys(string fullTableNameParent, out string sqlUpdate, out string tableNameChild, out string fullTableNameChild, out IReadOnlyList<string> foreignKeys) { CreateNewTableName(out tableNameChild, out fullTableNameChild); diff --git a/csharp/test/Drivers/Apache/Spark/StatementTests.cs b/csharp/test/Drivers/Apache/Spark/StatementTests.cs index e9c28c76d..05613933e 100644 --- a/csharp/test/Drivers/Apache/Spark/StatementTests.cs +++ b/csharp/test/Drivers/Apache/Spark/StatementTests.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; +using System.Threading.Tasks; using Apache.Arrow.Adbc.Tests.Drivers.Apache.Common; using Xunit; using Xunit.Abstractions; @@ -37,15 +38,22 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark base.StatementTimeoutTest(statementWithExceptions); } + [SkippableTheory] + [InlineData(LongRunningStatementTimeoutTestData.LongRunningQuery)] + internal override async Task CanCancelStatementTest(string query) + { + await base.CanCancelStatementTest(query); + } + internal class LongRunningStatementTimeoutTestData : ShortRunningStatementTimeoutTestData { + internal const string LongRunningQuery = "SELECT COUNT(*) AS total_count\nFROM (\n SELECT t1.id AS id1, t2.id AS id2\n FROM RANGE(1000000) t1\n CROSS JOIN RANGE(100000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0"; public LongRunningStatementTimeoutTestData() { - string longRunningQuery = "SELECT COUNT(*) AS total_count\nFROM (\n SELECT t1.id AS id1, t2.id AS id2\n FROM RANGE(1000000) t1\n CROSS JOIN RANGE(100000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0"; - Add(new(5, longRunningQuery, typeof(TimeoutException))); - Add(new(null, longRunningQuery, typeof(TimeoutException))); - Add(new(0, longRunningQuery, null)); + Add(new(5, LongRunningQuery, typeof(TimeoutException))); + Add(new(null, LongRunningQuery, typeof(TimeoutException))); + Add(new(0, LongRunningQuery, null)); } } diff --git a/csharp/test/Drivers/Databricks/E2E/DatabricksTestConfiguration.cs b/csharp/test/Drivers/Databricks/E2E/DatabricksTestConfiguration.cs index 2b17c6498..ab2ae5a42 100644 --- a/csharp/test/Drivers/Databricks/E2E/DatabricksTestConfiguration.cs +++ b/csharp/test/Drivers/Databricks/E2E/DatabricksTestConfiguration.cs @@ -51,5 +51,11 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks [JsonPropertyName("isCITesting"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public bool IsCITesting { get; set; } = false; + + [JsonPropertyName("enableRunAsyncInThriftOp"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string EnableRunAsyncInThriftOp { get; set; } = string.Empty; + + [JsonPropertyName("enableDirectResults"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string EnableDirectResults { get; set; } = string.Empty; } } diff --git a/csharp/test/Drivers/Databricks/E2E/DatabricksTestEnvironment.cs b/csharp/test/Drivers/Databricks/E2E/DatabricksTestEnvironment.cs index 6f9c08473..d93d7823a 100644 --- a/csharp/test/Drivers/Databricks/E2E/DatabricksTestEnvironment.cs +++ b/csharp/test/Drivers/Databricks/E2E/DatabricksTestEnvironment.cs @@ -157,6 +157,14 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks { parameters.Add(DatabricksParameters.TraceStateEnabled, testConfiguration.TraceStateEnabled!); } + if (!string.IsNullOrEmpty(testConfiguration.EnableRunAsyncInThriftOp)) + { + parameters.Add(DatabricksParameters.EnableRunAsyncInThriftOp, testConfiguration.EnableRunAsyncInThriftOp!); + } + if (!string.IsNullOrEmpty(testConfiguration.EnableDirectResults)) + { + parameters.Add(DatabricksParameters.EnableDirectResults, testConfiguration.EnableDirectResults!); + } if (testConfiguration.HttpOptions != null) { if (testConfiguration.HttpOptions.Tls != null) diff --git a/csharp/test/Drivers/Databricks/E2E/StatementTests.cs b/csharp/test/Drivers/Databricks/E2E/StatementTests.cs index acbba677c..07520b3a6 100644 --- a/csharp/test/Drivers/Databricks/E2E/StatementTests.cs +++ b/csharp/test/Drivers/Databricks/E2E/StatementTests.cs @@ -90,16 +90,38 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks base.StatementTimeoutTest(statementWithExceptions); } - internal class LongRunningStatementTimeoutTestData : ShortRunningStatementTimeoutTestData + [SkippableTheory] + [InlineData(LongRunningStatementTimeoutTestData.LongRunningQuery, "false", "true")] + [InlineData(LongRunningStatementTimeoutTestData.LongRunningQuery, "true", "true")] + [InlineData(LongRunningStatementTimeoutTestData.LongRunningQuery, "true", "false")] + internal async Task DatabricksCanCancelStatementTest(string query, string enableRunAsyncInThriftOp, string enableDirectResults) { - public LongRunningStatementTimeoutTestData() : base("SELECT 1") + string enableRunAsyncInThriftOpOrig = TestConfiguration.EnableRunAsyncInThriftOp; + string enableDirectResultsOrig = TestConfiguration.EnableDirectResults; + try { - string longRunningQuery = "SELECT COUNT(*) AS total_count\nFROM (\n SELECT t1.id AS id1, t2.id AS id2\n FROM RANGE(1000000) t1\n CROSS JOIN RANGE(100000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0"; + TestConfiguration.EnableRunAsyncInThriftOp = enableRunAsyncInThriftOp; + TestConfiguration.EnableDirectResults = enableDirectResults; + await base.CanCancelStatementTest(query); + } + finally + { + TestConfiguration.EnableRunAsyncInThriftOp = enableRunAsyncInThriftOpOrig; + TestConfiguration.EnableDirectResults = enableDirectResultsOrig; + } + } + internal class LongRunningStatementTimeoutTestData : ShortRunningStatementTimeoutTestData + { + internal const string LongRunningQuery = "SELECT COUNT(*) AS total_count\nFROM (\n SELECT t1.id AS id1, t2.id AS id2\n FROM RANGE(1000000) t1\n CROSS JOIN RANGE(100000) t2\n) subquery\nWHERE MOD(id1 + id2, 2) = 0"; + private const string DefaultQuery = "SELECT 1"; + + public LongRunningStatementTimeoutTestData() : base(DefaultQuery) + { // Add Databricks-specific long-running query tests - Add(new(5, longRunningQuery, typeof(TimeoutException))); - Add(new(null, longRunningQuery, typeof(TimeoutException))); - Add(new(0, longRunningQuery, null)); + Add(new(5, LongRunningQuery, typeof(TimeoutException))); + Add(new(null, LongRunningQuery, typeof(TimeoutException))); + Add(new(0, LongRunningQuery, null)); } }