This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 8bc03ef6d feat(csharp/src/Drivers/BigQuery): implement 
AdbcStatement.Cancel on BigQuery (#3422)
8bc03ef6d is described below

commit 8bc03ef6d3eccaff60d4b11d7f07a32c61536eb5
Author: Bruce Irschick <[email protected]>
AuthorDate: Mon Oct 13 09:00:25 2025 -0700

    feat(csharp/src/Drivers/BigQuery): implement AdbcStatement.Cancel on 
BigQuery (#3422)
    
    Adds implementation of `AdbcStatement.Cancel` to the BigQuery driver.
    - A `CancellationRegistry` tracks `CancellationContext`
    - When Statment.Cancel call is made, it will cancel all the registered
    `CancellationContext`.
    - When a `JobCancellationContext` is registered, it will attempt to
    cancel the job on the server.
---
 csharp/src/Drivers/Apache/ApacheUtility.cs         |  48 +++--
 csharp/src/Drivers/BigQuery/BigQueryStatement.cs   | 226 +++++++++++++++++----
 csharp/src/Drivers/BigQuery/BigQueryUtils.cs       |  28 +++
 csharp/src/Drivers/BigQuery/RetryManager.cs        |   2 +-
 .../test/Drivers/Apache/Common/ApacheUtilsTests.cs | 151 ++++++++++++++
 csharp/test/Drivers/BigQuery/BigQueryUtilsTests.cs |  88 ++++++++
 csharp/test/Drivers/BigQuery/StatementTests.cs     | 115 ++++++++++-
 7 files changed, 595 insertions(+), 63 deletions(-)

diff --git a/csharp/src/Drivers/Apache/ApacheUtility.cs 
b/csharp/src/Drivers/Apache/ApacheUtility.cs
index 0ee596a71..6b14d5053 100644
--- a/csharp/src/Drivers/Apache/ApacheUtility.cs
+++ b/csharp/src/Drivers/Apache/ApacheUtility.cs
@@ -98,18 +98,6 @@ namespace Apache.Arrow.Adbc.Drivers.Apache
 
         public static bool ContainsException<T>(Exception exception, out T? 
containedException) where T : Exception
         {
-            if (exception is AggregateException aggregateException)
-            {
-                foreach (Exception? ex in aggregateException.InnerExceptions)
-                {
-                    if (ex is T ce)
-                    {
-                        containedException = ce;
-                        return true;
-                    }
-                }
-            }
-
             Exception? e = exception;
             while (e != null)
             {
@@ -118,6 +106,17 @@ namespace Apache.Arrow.Adbc.Drivers.Apache
                     containedException = ce;
                     return true;
                 }
+                else if (e is AggregateException aggregateException)
+                {
+                    foreach (Exception? ex in 
aggregateException.InnerExceptions)
+                    {
+                        if (ContainsException(ex, out T? inner))
+                        {
+                            containedException = inner;
+                            return true;
+                        }
+                    }
+                }
                 e = e.InnerException;
             }
 
@@ -127,24 +126,12 @@ namespace Apache.Arrow.Adbc.Drivers.Apache
 
         public static bool ContainsException(Exception exception, Type? 
exceptionType, out Exception? containedException)
         {
-            if (exception == null || exceptionType == null)
+            if (exceptionType == null)
             {
                 containedException = null;
                 return false;
             }
 
-            if (exception is AggregateException aggregateException)
-            {
-                foreach (Exception? ex in aggregateException.InnerExceptions)
-                {
-                    if (exceptionType.IsInstanceOfType(ex))
-                    {
-                        containedException = ex;
-                        return true;
-                    }
-                }
-            }
-
             Exception? e = exception;
             while (e != null)
             {
@@ -153,6 +140,17 @@ namespace Apache.Arrow.Adbc.Drivers.Apache
                     containedException = e;
                     return true;
                 }
+                else if (e is AggregateException aggregateException)
+                {
+                    foreach (Exception? ex in 
aggregateException.InnerExceptions)
+                    {
+                        if (ContainsException(ex, exceptionType, out 
Exception? inner))
+                        {
+                            containedException = inner;
+                            return true;
+                        }
+                    }
+                }
                 e = e.InnerException;
             }
 
diff --git a/csharp/src/Drivers/BigQuery/BigQueryStatement.cs 
b/csharp/src/Drivers/BigQuery/BigQueryStatement.cs
index 880d8b79f..cb6d7435a 100644
--- a/csharp/src/Drivers/BigQuery/BigQueryStatement.cs
+++ b/csharp/src/Drivers/BigQuery/BigQueryStatement.cs
@@ -16,8 +16,8 @@
 */
 
 using System;
+using System.Collections.Concurrent;
 using System.Collections.Generic;
-using System.Data;
 using System.Diagnostics;
 using System.IO;
 using System.Linq;
@@ -43,6 +43,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
     class BigQueryStatement : TracingStatement, ITokenProtectedResource, 
IDisposable
     {
         readonly BigQueryConnection bigQueryConnection;
+        readonly CancellationRegistry cancellationRegistry;
 
         public BigQueryStatement(BigQueryConnection bigQueryConnection) : 
base(bigQueryConnection)
         {
@@ -52,6 +53,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
             UpdateToken = bigQueryConnection.UpdateToken;
 
             this.bigQueryConnection = bigQueryConnection;
+            this.cancellationRegistry = new CancellationRegistry();
         }
 
         public Func<Task>? UpdateToken { get; set; }
@@ -105,16 +107,20 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
                     
activity?.AddBigQueryParameterTag(BigQueryParameters.GetQueryResultsOptionsTimeout,
 seconds);
                 }
 
+                JobCancellationContext cancellationContext = new 
JobCancellationContext(cancellationRegistry, job);
                 // We can't checkJobStatus, Otherwise, the timeout in 
QueryResultsOptions is meaningless.
                 // When encountering a long-running job, it should be 
controlled by the timeout in the Google SDK instead of blocking in a while loop.
                 Func<Task<BigQueryResults>> getJobResults = async () =>
                 {
-                    // if the authentication token was reset, then we need a 
new job with the latest token
-                    BigQueryJob completedJob = await 
Client.GetJobAsync(jobReference);
-                    return await 
completedJob.GetQueryResultsAsync(getQueryResultsOptions);
+                    return await 
ExecuteCancellableJobAsync(cancellationContext, activity, async (context) =>
+                    {
+                        // if the authentication token was reset, then we need 
a new job with the latest token
+                        context.Job = await Client.GetJobAsync(jobReference, 
cancellationToken: context.CancellationToken).ConfigureAwait(false);
+                        return await 
context.Job.GetQueryResultsAsync(getQueryResultsOptions, cancellationToken: 
context.CancellationToken).ConfigureAwait(false);
+                    }).ConfigureAwait(false);
                 };
 
-                BigQueryResults results = await 
ExecuteWithRetriesAsync(getJobResults, activity);
+                BigQueryResults results = await 
ExecuteWithRetriesAsync(getJobResults, activity).ConfigureAwait(false);
 
                 TokenProtectedReadClientManger clientMgr = new 
TokenProtectedReadClientManger(Credential);
                 clientMgr.UpdateToken = () => Task.Run(() =>
@@ -145,31 +151,36 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
                     }
 
                     Func<Task<BigQueryResults>> getMultiJobResults = async () 
=>
-                    {
-                        // To get the results of all statements in a 
multi-statement query, enumerate the child jobs. Related public docs: 
https://cloud.google.com/bigquery/docs/multi-statement-queries#get_all_executed_statements.
-                        // Can filter by StatementType and EvaluationKind. 
Related public docs: 
https://cloud.google.com/bigquery/docs/reference/rest/v2/Job#jobstatistics2, 
https://cloud.google.com/bigquery/docs/reference/rest/v2/Job#evaluationkind
-                        ListJobsOptions listJobsOptions = new 
ListJobsOptions();
-                        listJobsOptions.ParentJobId = 
results.JobReference.JobId;
-                        var joblist = Client.ListJobs(listJobsOptions)
-                            .Select(job => Client.GetJob(job.Reference))
-                            .Where(job => string.IsNullOrEmpty(evaluationKind) 
|| job.Statistics.ScriptStatistics.EvaluationKind.Equals(evaluationKind, 
StringComparison.OrdinalIgnoreCase))
-                            .Where(job => string.IsNullOrEmpty(statementType) 
|| job.Statistics.Query.StatementType.Equals(statementType, 
StringComparison.OrdinalIgnoreCase))
-                            .OrderBy(job => 
job.Resource.Statistics.CreationTime)
-                            .ToList();
-
-                        if (joblist.Count > 0)
                         {
-                            if (statementIndex < 1 || statementIndex > 
joblist.Count)
+                            // To get the results of all statements in a 
multi-statement query, enumerate the child jobs. Related public docs: 
https://cloud.google.com/bigquery/docs/multi-statement-queries#get_all_executed_statements.
+                            // Can filter by StatementType and EvaluationKind. 
Related public docs: 
https://cloud.google.com/bigquery/docs/reference/rest/v2/Job#jobstatistics2, 
https://cloud.google.com/bigquery/docs/reference/rest/v2/Job#evaluationkind
+                            ListJobsOptions listJobsOptions = new 
ListJobsOptions();
+                            listJobsOptions.ParentJobId = 
results.JobReference.JobId;
+                            var joblist = Client.ListJobs(listJobsOptions)
+                                .Select(job => Client.GetJob(job.Reference))
+                                .Where(job => 
string.IsNullOrEmpty(evaluationKind) || 
job.Statistics.ScriptStatistics.EvaluationKind.Equals(evaluationKind, 
StringComparison.OrdinalIgnoreCase))
+                                .Where(job => 
string.IsNullOrEmpty(statementType) || 
job.Statistics.Query.StatementType.Equals(statementType, 
StringComparison.OrdinalIgnoreCase))
+                                .OrderBy(job => 
job.Resource.Statistics.CreationTime)
+                                .ToList();
+
+                            if (joblist.Count > 0)
                             {
-                                throw new ArgumentOutOfRangeException($"The 
specified index {statementIndex} is out of range. There are {joblist.Count} 
jobs available.");
+                                if (statementIndex < 1 || statementIndex > 
joblist.Count)
+                                {
+                                    throw new 
ArgumentOutOfRangeException($"The specified index {statementIndex} is out of 
range. There are {joblist.Count} jobs available.");
+                                }
+                                BigQueryJob indexedJob = 
joblist[statementIndex - 1];
+                                cancellationContext.Job = indexedJob;
+                                return await 
ExecuteCancellableJobAsync(cancellationContext, activity, async (context) =>
+                                {
+                                    return await 
indexedJob.GetQueryResultsAsync(getQueryResultsOptions, cancellationToken: 
context.CancellationToken).ConfigureAwait(false);
+                                }).ConfigureAwait(false);
                             }
-                            return await joblist[statementIndex - 
1].GetQueryResultsAsync(getQueryResultsOptions);
-                        }
 
-                        throw new AdbcException($"Unable to obtain result from 
statement [{statementIndex}]", AdbcStatusCode.InvalidData);
-                    };
+                            throw new AdbcException($"Unable to obtain result 
from statement [{statementIndex}]", AdbcStatusCode.InvalidData);
+                        };
 
-                    results = await 
ExecuteWithRetriesAsync(getMultiJobResults, activity);
+                    results = await 
ExecuteWithRetriesAsync(getMultiJobResults, activity).ConfigureAwait(false);
                 }
 
                 if (results?.TableReference == null)
@@ -194,10 +205,18 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
 
                 long totalRows = results.TotalRows == null ? -1L : 
(long)results.TotalRows.Value;
 
-                Func<Task<IEnumerable<IArrowReader>>> func = () => 
GetArrowReaders(clientMgr, table, results.TableReference.ProjectId, 
maxStreamCount, activity);
-                IEnumerable<IArrowReader> readers = await 
ExecuteWithRetriesAsync<IEnumerable<IArrowReader>>(func, activity);
+                Func<Task<IEnumerable<IArrowReader>>> getArrowReadersFunc = 
async () =>
+                {
+                    return await 
ExecuteCancellableJobAsync(cancellationContext, activity, async (context) =>
+                    {
+                        // Cancelling this step may leave the server with 
unread streams.
+                        return await GetArrowReaders(clientMgr, table, 
results.TableReference.ProjectId, maxStreamCount, activity, 
context.CancellationToken).ConfigureAwait(false);
+                    }).ConfigureAwait(false);
+                };
+                IEnumerable<IArrowReader> readers = await 
ExecuteWithRetriesAsync(getArrowReadersFunc, activity).ConfigureAwait(false);
 
-                IArrowArrayStream stream = new MultiArrowReader(this, 
TranslateSchema(results.Schema), readers);
+                // Note: MultiArrowReader must dispose the cancellationContext.
+                IArrowArrayStream stream = new MultiArrowReader(this, 
TranslateSchema(results.Schema), readers, cancellationContext);
                 activity?.AddTag(SemanticConventions.Db.Response.ReturnedRows, 
totalRows);
                 return new QueryResult(totalRows, stream);
             });
@@ -208,14 +227,15 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
             string table,
             string projectId,
             int maxStreamCount,
-            Activity? activity)
+            Activity? activity,
+            CancellationToken cancellationToken = default)
         {
             ReadSession rs = new ReadSession { Table = table, DataFormat = 
DataFormat.Arrow };
             BigQueryReadClient bigQueryReadClient = clientMgr.ReadClient;
             ReadSession rrs = await 
bigQueryReadClient.CreateReadSessionAsync("projects/" + projectId, rs, 
maxStreamCount);
 
             var readers = rrs.Streams
-                             .Select(s => ReadChunk(bigQueryReadClient, 
s.Name, activity, this.bigQueryConnection.IsSafeToTrace))
+                             .Select(s => ReadChunk(bigQueryReadClient, 
s.Name, activity, this.bigQueryConnection.IsSafeToTrace, cancellationToken))
                              .Where(chunk => chunk != null)
                              .Cast<IArrowReader>();
 
@@ -227,6 +247,20 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
             return ExecuteUpdateInternalAsync().GetAwaiter().GetResult();
         }
 
+        public override void Cancel()
+        {
+            this.TraceActivity(_ =>
+            {
+                this.cancellationRegistry.CancelAll();
+            });
+        }
+
+        public override void Dispose()
+        {
+            this.cancellationRegistry.Dispose();
+            base.Dispose();
+        }
+
         private async Task<UpdateResult> ExecuteUpdateInternalAsync()
         {
             return await this.TraceActivity(async activity =>
@@ -243,9 +277,17 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
 
                 activity?.AddConditionalTag(SemanticConventions.Db.Query.Text, 
SqlQuery, this.bigQueryConnection.IsSafeToTrace);
 
+                using JobCancellationContext context = 
new(cancellationRegistry);
                 // Cannot set destination table in jobs with DDL statements, 
otherwise an error will be prompted
-                Func<Task<BigQueryResults?>> func = () => 
Client.ExecuteQueryAsync(SqlQuery, null, null, getQueryResultsOptions);
-                BigQueryResults? result = await 
ExecuteWithRetriesAsync<BigQueryResults?>(func, activity);
+                Func<Task<BigQueryResults?>> getQueryResultsAsyncFunc = async 
() =>
+                {
+                    return await ExecuteCancellableJobAsync(context, activity, 
async (context) =>
+                    {
+                        context.Job = await 
this.Client.CreateQueryJobAsync(SqlQuery, null, null, 
context.CancellationToken).ConfigureAwait(false);
+                        return await 
context.Job.GetQueryResultsAsync(getQueryResultsOptions, 
context.CancellationToken).ConfigureAwait(false);
+                    }).ConfigureAwait(false);
+                };
+                BigQueryResults? result = await 
ExecuteWithRetriesAsync(getQueryResultsAsyncFunc, activity);
                 long updatedRows = result?.NumDmlAffectedRows.HasValue == true 
? result.NumDmlAffectedRows.Value : -1L;
 
                 activity?.AddTag(SemanticConventions.Db.Response.ReturnedRows, 
updatedRows);
@@ -339,13 +381,13 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
             return type;
         }
 
-        private static IArrowReader? ReadChunk(BigQueryReadClient client, 
string streamName, Activity? activity, bool isSafeToTrace)
+        private static IArrowReader? ReadChunk(BigQueryReadClient client, 
string streamName, Activity? activity, bool isSafeToTrace, CancellationToken 
cancellationToken = default)
         {
             // Ideally we wouldn't need to indirect through a stream, but the 
necessary APIs in Arrow
             // are internal. (TODO: consider changing Arrow).
             activity?.AddConditionalBigQueryTag("read_stream", streamName, 
isSafeToTrace);
             BigQueryReadClient.ReadRowsStream readRowsStream = 
client.ReadRows(new ReadRowsRequest { ReadStream = streamName });
-            IAsyncEnumerator<ReadRowsResponse> enumerator = 
readRowsStream.GetResponseStream().GetAsyncEnumerator();
+            IAsyncEnumerator<ReadRowsResponse> enumerator = 
readRowsStream.GetResponseStream().GetAsyncEnumerator(cancellationToken);
 
             ReadRowsStream stream = new ReadRowsStream(enumerator);
             activity?.AddBigQueryTag("read_stream.has_rows", stream.HasRows);
@@ -529,19 +571,129 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
 
         private async Task<T> ExecuteWithRetriesAsync<T>(Func<Task<T>> action, 
Activity? activity) => await RetryManager.ExecuteWithRetriesAsync<T>(this, 
action, activity, MaxRetryAttempts, RetryDelayMs);
 
+        private async Task<T> ExecuteCancellableJobAsync<T>(
+            JobCancellationContext context,
+            Activity? activity,
+            Func<JobCancellationContext, Task<T>> func)
+        {
+            try
+            {
+                return await func(context).ConfigureAwait(false);
+            }
+            catch (Exception ex) when (BigQueryUtils.ContainsException(ex, out 
OperationCanceledException? cancelledEx))
+            {
+                activity?.AddException(cancelledEx!);
+                try
+                {
+                    if (context.Job != null)
+                    {
+                        activity?.AddBigQueryTag("job.cancel", 
context.Job.Reference.JobId);
+                        await context.Job.CancelAsync().ConfigureAwait(false);
+                    }
+                }
+                catch (Exception e)
+                {
+                    activity?.AddException(e);
+                }
+                throw;
+            }
+            finally
+            {
+                // Job is no longer in context after completion or cancellation
+                context.Job = null;
+            }
+        }
+
+        private class CancellationContext : IDisposable
+        {
+            private readonly CancellationRegistry cancellationRegistry;
+            private readonly CancellationTokenSource cancellationTokenSource;
+            private bool disposed;
+
+            public CancellationContext(CancellationRegistry 
cancellationRegistry)
+            {
+                cancellationTokenSource = new CancellationTokenSource();
+                this.cancellationRegistry = cancellationRegistry;
+                this.cancellationRegistry.Register(this);
+            }
+
+            public CancellationToken CancellationToken => 
cancellationTokenSource.Token;
+
+            public void Cancel()
+            {
+                cancellationTokenSource.Cancel();
+            }
+
+            public virtual void Dispose()
+            {
+                if (!disposed)
+                {
+                    cancellationRegistry.Unregister(this);
+                    cancellationTokenSource.Dispose();
+                    disposed = true;
+                }
+            }
+        }
+
+        private class JobCancellationContext : CancellationContext
+        {
+            public JobCancellationContext(CancellationRegistry 
cancellationRegistry, BigQueryJob? job = default)
+                : base(cancellationRegistry)
+            {
+                Job = job;
+            }
+
+            public BigQueryJob? Job { get; set; }
+        }
+
+        private sealed class CancellationRegistry : IDisposable
+        {
+            private readonly ConcurrentDictionary<CancellationContext, byte> 
contexts = new();
+
+            public CancellationContext Register(CancellationContext context)
+            {
+                contexts.TryAdd(context, 0);
+                return context;
+            }
+
+            public bool Unregister(CancellationContext context)
+            {
+                return contexts.TryRemove(context, out _);
+            }
+
+            public void CancelAll()
+            {
+                foreach (CancellationContext context in contexts.Keys)
+                {
+                    context.Cancel();
+                }
+            }
+
+            public void Dispose()
+            {
+                foreach (CancellationContext context in contexts.Keys)
+                {
+                    context.Dispose();
+                }
+                contexts.Clear();
+            }
+        }
+
         private class MultiArrowReader : TracingReader
         {
             private static readonly string s_assemblyName = 
BigQueryUtils.GetAssemblyName(typeof(BigQueryStatement));
             private static readonly string s_assemblyVersion = 
BigQueryUtils.GetAssemblyVersion(typeof(BigQueryStatement));
 
             readonly Schema schema;
+            readonly CancellationContext cancellationContext;
             IEnumerator<IArrowReader>? readers;
             IArrowReader? reader;
 
-            public MultiArrowReader(BigQueryStatement statement, Schema 
schema, IEnumerable<IArrowReader> readers) : base(statement)
+            public MultiArrowReader(BigQueryStatement statement, Schema 
schema, IEnumerable<IArrowReader> readers, CancellationContext 
cancellationContext) : base(statement)
             {
                 this.schema = schema;
                 this.readers = readers.GetEnumerator();
+                this.cancellationContext = cancellationContext;
             }
 
             public override Schema Schema { get { return this.schema; } }
@@ -554,6 +706,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
             {
                 return await this.TraceActivityAsync(async activity =>
                 {
+                    using CancellationTokenSource linkedCts = 
CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, 
this.cancellationContext.CancellationToken);
                     if (this.readers == null)
                     {
                         return null;
@@ -571,7 +724,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
                             this.reader = this.readers.Current;
                         }
 
-                        RecordBatch result = await 
this.reader.ReadNextRecordBatchAsync(cancellationToken);
+                        RecordBatch result = await 
this.reader.ReadNextRecordBatchAsync(linkedCts.Token).ConfigureAwait(false);
 
                         if (result != null)
                         {
@@ -591,6 +744,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
                     {
                         this.readers.Dispose();
                         this.readers = null;
+                        this.cancellationContext.Dispose();
                     }
                 }
 
diff --git a/csharp/src/Drivers/BigQuery/BigQueryUtils.cs 
b/csharp/src/Drivers/BigQuery/BigQueryUtils.cs
index 956486d2a..6cd1e5031 100644
--- a/csharp/src/Drivers/BigQuery/BigQueryUtils.cs
+++ b/csharp/src/Drivers/BigQuery/BigQueryUtils.cs
@@ -42,5 +42,33 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
         internal static string GetAssemblyName(Type type) => 
type.Assembly.GetName().Name!;
 
         internal static string GetAssemblyVersion(Type type) => 
FileVersionInfo.GetVersionInfo(type.Assembly.Location).ProductVersion ?? 
string.Empty;
+
+        public static bool ContainsException<T>(Exception exception, out T? 
containedException) where T : Exception
+        {
+            Exception? e = exception;
+            while (e != null)
+            {
+                if (e is T ce)
+                {
+                    containedException = ce;
+                    return true;
+                }
+                else if (e is AggregateException aggregateException)
+                {
+                    foreach (Exception? ex in 
aggregateException.InnerExceptions)
+                    {
+                        if (ContainsException(ex, out T? inner))
+                        {
+                            containedException = inner;
+                            return true;
+                        }
+                    }
+                }
+                e = e.InnerException;
+            }
+
+            containedException = null;
+            return false;
+        }
     }
 }
diff --git a/csharp/src/Drivers/BigQuery/RetryManager.cs 
b/csharp/src/Drivers/BigQuery/RetryManager.cs
index f3c4e8eb0..09e3621c7 100644
--- a/csharp/src/Drivers/BigQuery/RetryManager.cs
+++ b/csharp/src/Drivers/BigQuery/RetryManager.cs
@@ -49,7 +49,7 @@ namespace Apache.Arrow.Adbc.Drivers.BigQuery
                     T result = await action();
                     return result;
                 }
-                catch (Exception ex)
+                catch (Exception ex) when 
(!BigQueryUtils.ContainsException(ex, out OperationCanceledException? _))
                 {
                     activity?.AddBigQueryTag("retry_attempt", retryCount);
                     activity?.AddException(ex);
diff --git a/csharp/test/Drivers/Apache/Common/ApacheUtilsTests.cs 
b/csharp/test/Drivers/Apache/Common/ApacheUtilsTests.cs
new file mode 100644
index 000000000..f80af09af
--- /dev/null
+++ b/csharp/test/Drivers/Apache/Common/ApacheUtilsTests.cs
@@ -0,0 +1,151 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+
+using System;
+using Apache.Arrow.Adbc.Drivers.Apache;
+using Xunit;
+
+namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common
+{
+    public class ApacheUtilsTests
+    {
+        [Fact]
+        public void TestContainsExceptionWithAggregateInMiddle()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle = new AggregateException("Middle exception", 
innerMost);
+            Exception outer = new Exception("Outer exception", middle);
+            bool found = ApacheUtility.ContainsException(outer, out 
InvalidOperationException? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost, containedException);
+        }
+
+        [Fact]
+        public void TestContainsExceptionWithAggregateOnTop()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle = new Exception("Middle exception", innerMost);
+            Exception outer = new AggregateException("Outer exception", 
middle);
+            bool found = ApacheUtility.ContainsException(outer, out 
InvalidOperationException? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost, containedException);
+        }
+
+        [Fact]
+        public void TestContainsExceptionMultipleAggregate()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle1 = new AggregateException("Middle exception 1", 
innerMost);
+            Exception middle2 = new AggregateException("Middle exception 2", 
middle1);
+            Exception outer = new Exception("Outer exception", middle2);
+            bool found = ApacheUtility.ContainsException(outer, out 
InvalidOperationException? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost, containedException);
+        }
+
+        [Fact]
+        public void TestContainsAggregateException()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle = new AggregateException("Middle exception 1", 
innerMost);
+            Exception outer = new Exception("Outer exception", middle);
+            bool found = ApacheUtility.ContainsException(outer, out 
AggregateException? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(middle, containedException);
+        }
+
+        [Fact]
+        public void TestContainsMultipleInAggregate()
+        {
+            Exception innerMost1 = new InvalidOperationException("Innermost 
exception 1");
+            Exception innerMost2 = new NotImplementedException("Innermost 
exception 2");
+            Exception middle = new AggregateException("Middle exception", 
[innerMost1, innerMost2]);
+            Exception outer = new Exception("Outer exception", middle);
+            bool found = ApacheUtility.ContainsException(outer, out 
NotImplementedException? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost2, containedException);
+        }
+
+        [Fact]
+        public void TestContainsExceptionWithAggregateInMiddleByType()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle = new AggregateException("Middle exception", 
innerMost);
+            Exception outer = new Exception("Outer exception", middle);
+            bool found = ApacheUtility.ContainsException(outer, 
innerMost.GetType(), out Exception? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost, containedException);
+        }
+
+        [Fact]
+        public void TestContainsExceptionWithAggregateOnTopByType()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle = new Exception("Middle exception", innerMost);
+            Exception outer = new AggregateException("Outer exception", 
middle);
+            bool found = ApacheUtility.ContainsException(outer, 
innerMost.GetType(), out Exception? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost, containedException);
+        }
+
+        [Fact]
+        public void TestContainsExceptionMultipleAggregateByType()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle1 = new AggregateException("Middle exception 1", 
innerMost);
+            Exception middle2 = new AggregateException("Middle exception 2", 
middle1);
+            Exception outer = new Exception("Outer exception", middle2);
+            bool found = ApacheUtility.ContainsException(outer, 
innerMost.GetType(), out Exception? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost, containedException);
+        }
+
+        [Fact]
+        public void TestContainsAggregateExceptionByType()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle = new AggregateException("Middle exception 1", 
innerMost);
+            Exception outer = new Exception("Outer exception", middle);
+            bool found = ApacheUtility.ContainsException(outer, 
middle.GetType(), out Exception? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(middle, containedException);
+        }
+
+        [Fact]
+        public void TestContainsMultipleInAggregateByType()
+        {
+            Exception innerMost1 = new InvalidOperationException("Innermost 
exception 1");
+            Exception innerMost2 = new NotImplementedException("Innermost 
exception 2");
+            Exception middle = new AggregateException("Middle exception", 
[innerMost1, innerMost2]);
+            Exception outer = new Exception("Outer exception", middle);
+            bool found = ApacheUtility.ContainsException(outer, 
innerMost2.GetType(), out Exception? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost2, containedException);
+        }
+    }
+}
diff --git a/csharp/test/Drivers/BigQuery/BigQueryUtilsTests.cs 
b/csharp/test/Drivers/BigQuery/BigQueryUtilsTests.cs
new file mode 100644
index 000000000..7ee19bc46
--- /dev/null
+++ b/csharp/test/Drivers/BigQuery/BigQueryUtilsTests.cs
@@ -0,0 +1,88 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+
+using System;
+using Apache.Arrow.Adbc.Drivers.BigQuery;
+using Xunit;
+
+namespace Apache.Arrow.Adbc.Tests.Drivers.BigQuery
+{
+    public class BigQueryUtilsTests
+    {
+        [Fact]
+        public void TestContainsExceptionWithAggregateInMiddle()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle = new AggregateException("Middle exception", 
innerMost);
+            Exception outer = new Exception("Outer exception", middle);
+            bool found = BigQueryUtils.ContainsException(outer, out 
InvalidOperationException? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost, containedException);
+        }
+
+        [Fact]
+        public void TestContainsExceptionWithAggregateOnTop()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle = new Exception("Middle exception", innerMost);
+            Exception outer = new AggregateException("Outer exception", 
middle);
+            bool found = BigQueryUtils.ContainsException(outer, out 
InvalidOperationException? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost, containedException);
+        }
+
+        [Fact]
+        public void TestContainsExceptionMultipleAggregate()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle1 = new AggregateException("Middle exception 1", 
innerMost);
+            Exception middle2 = new AggregateException("Middle exception 2", 
middle1);
+            Exception outer = new Exception("Outer exception", middle2);
+            bool found = BigQueryUtils.ContainsException(outer, out 
InvalidOperationException? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost, containedException);
+        }
+
+        [Fact]
+        public void TestContainsAggregateException()
+        {
+            Exception innerMost = new InvalidOperationException("Innermost 
exception");
+            Exception middle = new AggregateException("Middle exception 1", 
innerMost);
+            Exception outer = new Exception("Outer exception", middle);
+            bool found = BigQueryUtils.ContainsException(outer, out 
AggregateException? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(middle, containedException);
+        }
+        [Fact]
+        public void TestContainsMultipleInAggregate()
+        {
+            Exception innerMost1 = new InvalidOperationException("Innermost 
exception 1");
+            Exception innerMost2 = new NotImplementedException("Innermost 
exception 2");
+            Exception middle = new AggregateException("Middle exception", 
[innerMost1, innerMost2]);
+            Exception outer = new Exception("Outer exception", middle);
+            bool found = BigQueryUtils.ContainsException(outer, out 
NotImplementedException? containedException);
+            Assert.True(found);
+            Assert.NotNull(containedException);
+            Assert.Equal(innerMost2, containedException);
+        }
+    }
+}
diff --git a/csharp/test/Drivers/BigQuery/StatementTests.cs 
b/csharp/test/Drivers/BigQuery/StatementTests.cs
index f7ab55c20..8b3a2552e 100644
--- a/csharp/test/Drivers/BigQuery/StatementTests.cs
+++ b/csharp/test/Drivers/BigQuery/StatementTests.cs
@@ -15,10 +15,15 @@
 * limitations under the License.
 */
 
+using System;
 using System.Collections.Generic;
+using System.Reflection;
+using System.Threading.Tasks;
 using Apache.Arrow.Adbc.Drivers.BigQuery;
+using Apache.Arrow.Ipc;
 using Xunit;
 using Xunit.Abstractions;
+using Xunit.Sdk;
 
 namespace Apache.Arrow.Adbc.Tests.Drivers.BigQuery
 {
@@ -37,6 +42,11 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.BigQuery
             _testConfiguration = 
MultiEnvironmentTestUtils.LoadMultiEnvironmentTestConfiguration<BigQueryTestConfiguration>(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE);
             _environments = 
MultiEnvironmentTestUtils.GetTestEnvironments<BigQueryTestEnvironment>(_testConfiguration);
             _outputHelper = outputHelper;
+            foreach (BigQueryTestEnvironment environment in _environments)
+            {
+                AdbcConnection connection = 
BigQueryTestingUtils.GetBigQueryAdbcConnection(environment);
+                _configuredConnections.Add(environment.Name!, connection);
+            }
         }
 
         [Fact]
@@ -52,10 +62,113 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.BigQuery
                 statement.SetOption(BigQueryParameters.AllowLargeResults, 
"true");
 
                 // BigQuery is currently on ADBC 1.0, so it doesn't have the 
GetOption interface. Therefore, use reflection to validate the value is set 
correctly.
-                IReadOnlyDictionary<string, string>? options = 
statement.GetType().GetProperty("Options")!.GetValue(statement) as 
IReadOnlyDictionary<string, string>;
+                const BindingFlags bindingAttr = BindingFlags.NonPublic | 
BindingFlags.Public | BindingFlags.Instance;
+                IReadOnlyDictionary<string, string>? options = 
statement.GetType().GetProperty("Options", bindingAttr)!.GetValue(statement) as 
IReadOnlyDictionary<string, string>;
                 Assert.True(options != null);
                 Assert.True(options[BigQueryParameters.AllowLargeResults] == 
"true");
             }
         }
+
+        [Fact]
+        public async Task CanCancelStatement()
+        {
+            foreach (BigQueryTestEnvironment environment in _environments)
+            {
+                AdbcConnection adbcConnection = 
GetAdbcConnection(environment.Name);
+
+                AdbcStatement statement = adbcConnection.CreateStatement();
+
+                // Execute the query/cancel multiple times to validate 
consistent behavior
+                const int iterations = 3;
+                for (int i = 0; i < iterations; i++)
+                {
+                    _outputHelper?.WriteLine($"Iteration {i + 1} of 
{iterations}");
+                    // Generate unique column names so query will not be 
served from cache
+                    string columnName1 = Guid.NewGuid().ToString("N");
+                    string columnName2 = Guid.NewGuid().ToString("N");
+                    statement.SqlQuery = $"SELECT 
GENERATE_ARRAY(`{columnName2}`, 10000) AS `{columnName1}` FROM 
UNNEST(GENERATE_ARRAY(0, 100000)) AS `{columnName2}`";
+                    _outputHelper?.WriteLine($"Query: {statement.SqlQuery}");
+
+                    // Expect this to take about 10 seconds without 
cancellation
+                    Task<QueryResult> queryTask = 
Task.Run(statement.ExecuteQuery);
+
+                    await Task.Yield();
+                    await Task.Delay(3000);
+                    statement.Cancel();
+
+                    try
+                    {
+                        QueryResult queryResult = await queryTask;
+                        Assert.Fail("Expecting OperationCanceledException to 
be thrown.");
+                    }
+                    catch (Exception ex) when (ex is 
OperationCanceledException)
+                    {
+                        _outputHelper?.WriteLine($"Received expected 
OperationCanceledException: {ex.Message}");
+                    }
+                    catch (Exception ex) when (ex is not FailException)
+                    {
+                        Assert.Fail($"Expecting OperationCanceledException to 
be thrown. Instead, received {ex.GetType().Name}: {ex.Message}");
+                    }
+                }
+            }
+        }
+
+        [Fact]
+        public async Task CanCancelStreamFromStatement()
+        {
+            foreach (BigQueryTestEnvironment environment in _environments)
+            {
+                AdbcConnection adbcConnection = 
GetAdbcConnection(environment.Name);
+
+                AdbcStatement statement = adbcConnection.CreateStatement();
+
+                // Execute the query/cancel multiple times to validate 
consistent behavior
+                const int iterations = 3;
+                QueryResult[] results = new QueryResult[iterations];
+                for (int i = 0; i < iterations; i++)
+                {
+                    _outputHelper?.WriteLine($"Iteration {i + 1} of 
{iterations}");
+                    // Generate unique column names so query will not be 
served from cache
+                    string columnName1 = Guid.NewGuid().ToString("N");
+                    string columnName2 = Guid.NewGuid().ToString("N");
+                    statement.SqlQuery = $"SELECT `{columnName2}` AS 
`{columnName1}` FROM UNNEST(GENERATE_ARRAY(1, 100)) AS `{columnName2}`";
+                    _outputHelper?.WriteLine($"Query: {statement.SqlQuery}");
+
+                    // Expect this to take about 10 seconds without 
cancellation
+                    results[i] = statement.ExecuteQuery();
+                }
+                statement.Cancel();
+                for (int index = 0; index < iterations; index++)
+                {
+                    try
+                    {
+                        QueryResult queryResult = results[index];
+                        IArrowArrayStream? stream = queryResult.Stream;
+                        Assert.NotNull(stream);
+                        RecordBatch batch = await 
stream.ReadNextRecordBatchAsync();
+
+                        Assert.Fail("Expecting OperationCanceledException to 
be thrown.");
+                    }
+                    catch (Exception ex) when 
(BigQueryUtils.ContainsException(ex, out OperationCanceledException? _))
+                    {
+                        _outputHelper?.WriteLine($"Received expected 
OperationCanceledException: {ex.Message}");
+                    }
+                    catch (Exception ex) when (ex is not FailException)
+                    {
+                        Assert.Fail($"Expecting OperationCanceledException to 
be thrown. Instead, received {ex.GetType().Name}: {ex.Message}");
+                    }
+                }
+            }
+        }
+
+        private AdbcConnection GetAdbcConnection(string? environmentName)
+        {
+            if (string.IsNullOrEmpty(environmentName))
+            {
+                throw new ArgumentNullException(nameof(environmentName));
+            }
+
+            return _configuredConnections[environmentName!];
+        }
     }
 }


Reply via email to