This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new d9a69ed74 chore(csharp/src/Apache.Arrow.Adbc/Tracing): throw runtime 
exception if TraceActivity is called from async context (#3600)
d9a69ed74 is described below

commit d9a69ed7409c2edcb4ad22c78773cbb332531574
Author: Bruce Irschick <[email protected]>
AuthorDate: Mon Oct 20 15:13:48 2025 -0700

    chore(csharp/src/Apache.Arrow.Adbc/Tracing): throw runtime exception if 
TraceActivity is called from async context (#3600)
    
    Throws runtime exception if `TraceActivity` is called from async
    context.
---
 .../Tracing/IActivityTracerExtensions.cs           |   6 ++
 .../Tracing/TracingTests.cs                        | 101 ++++++++++++++++++++-
 2 files changed, 106 insertions(+), 1 deletion(-)

diff --git a/csharp/src/Apache.Arrow.Adbc/Tracing/IActivityTracerExtensions.cs 
b/csharp/src/Apache.Arrow.Adbc/Tracing/IActivityTracerExtensions.cs
index 5504c7e35..a71b709e1 100644
--- a/csharp/src/Apache.Arrow.Adbc/Tracing/IActivityTracerExtensions.cs
+++ b/csharp/src/Apache.Arrow.Adbc/Tracing/IActivityTracerExtensions.cs
@@ -58,6 +58,12 @@ namespace Apache.Arrow.Adbc.Tracing
         /// </remarks>
         public static T TraceActivity<T>(this IActivityTracer tracer, 
Func<Activity?, T> call, [CallerMemberName] string? activityName = null, 
string? traceParent = null)
         {
+            Type type = typeof(T);
+            if (type == typeof(Task) || (type.IsGenericType && 
type.GetGenericTypeDefinition() == typeof(Task<>)))
+            {
+                throw new InvalidOperationException($"Invalid return type 
('{type.Name}') for synchronous method call. Please use 
{nameof(TraceActivityAsync)}");
+            }
+
             return tracer.Trace.TraceActivity(call, activityName, traceParent 
?? tracer.TraceParent);
         }
 
diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Tracing/TracingTests.cs 
b/csharp/test/Apache.Arrow.Adbc.Tests/Tracing/TracingTests.cs
index a61582f06..b1134c998 100644
--- a/csharp/test/Apache.Arrow.Adbc.Tests/Tracing/TracingTests.cs
+++ b/csharp/test/Apache.Arrow.Adbc.Tests/Tracing/TracingTests.cs
@@ -19,6 +19,8 @@ using System;
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Threading.Tasks;
 using Apache.Arrow.Adbc.Tracing;
 using Apache.Arrow.Ipc;
 using OpenTelemetry;
@@ -118,7 +120,7 @@ namespace Apache.Arrow.Adbc.Tests.Tracing
             
testClass.MethodWithActivityRecursive(nameof(TraceProducer.MethodWithActivityRecursive),
 recurseCount);
 
             int lineCount = 0;
-            foreach(var exportedActivity in exportedActivities)
+            foreach (var exportedActivity in exportedActivities)
             {
                 lineCount++;
                 Assert.NotNull(exportedActivity);
@@ -227,6 +229,32 @@ namespace Apache.Arrow.Adbc.Tests.Tracing
             Assert.Single(exportedActivities);
         }
 
+        [Fact]
+        internal async Task CanDetectInvalidAsyncCall()
+        {
+            string activitySourceName = NewName();
+            Queue<Activity> exportedActivities = new();
+            var testClass = new MyTracingConnection(new Dictionary<string, 
string>(), activitySourceName);
+            using (ActivityListener activityListener = new()
+            {
+                ShouldListenTo = source =>
+                {
+                    return source.Name == testClass.ActivitySourceName
+                        && source.Tags?.Any(t => t.Key == SourceTagName && 
t.Value?.Equals(SourceTagValue) == true) == true;
+                },
+                Sample = (ref ActivityCreationOptions<ActivityContext> 
options) => ActivitySamplingResult.AllDataAndRecorded,
+                ActivityStopped = activity => 
exportedActivities.Enqueue(activity)
+            })
+            {
+                ActivitySource.AddActivityListener(activityListener);
+                await 
Assert.ThrowsAnyAsync<InvalidOperationException>(testClass.MethodWithInvalidAsyncTraceActivity1);
+                await 
Assert.ThrowsAnyAsync<InvalidOperationException>(testClass.MethodWithInvalidAsyncTraceActivity2);
+                await Assert.ThrowsAnyAsync<InvalidOperationException>(async 
() => await testClass.MethodWithInvalidAsyncTraceActivity3());
+                await Assert.ThrowsAnyAsync<InvalidOperationException>(async 
() => await testClass.MethodWithInvalidAsyncTraceActivity4());
+                await 
Assert.ThrowsAnyAsync<InvalidOperationException>(testClass.MethodWithInvalidAsyncTraceActivity5);
+            }
+        }
+
         internal static string NewName() => 
Guid.NewGuid().ToString().Replace("-", "").ToLower();
 
         protected virtual void Dispose(bool disposing)
@@ -346,6 +374,77 @@ namespace Apache.Arrow.Adbc.Tests.Tracing
                 });
             }
 
+            public async Task<bool> MethodWithInvalidAsyncTraceActivity1()
+            {
+                // This method is intended to demonstrate incorrect usage of 
TraceActivity with async methods.
+                return await this.TraceActivity(async activity =>
+                {
+                    await Task.Delay(1);
+                    return true;
+                });
+            }
+
+            public async Task MethodWithInvalidAsyncTraceActivity2()
+            {
+                // This method is intended to demonstrate incorrect usage of 
TraceActivity with async methods.
+                await this.TraceActivity(async activity =>
+                {
+                    await Task.Delay(1);
+                    return;
+                });
+            }
+
+            public async ValueTask<bool> MethodWithInvalidAsyncTraceActivity3()
+            {
+                // This method is intended to demonstrate incorrect usage of 
TraceActivity with async methods.
+                return await this.TraceActivity(async activity =>
+                {
+                    await Task.Delay(1);
+                    return true;
+                });
+            }
+
+            public async ValueTask MethodWithInvalidAsyncTraceActivity4()
+            {
+                // This method is intended to demonstrate incorrect usage of 
TraceActivity with async methods.
+                await this.TraceActivity(async activity =>
+                {
+                    await Task.Delay(1);
+                    return;
+                });
+            }
+
+            public async Task<bool> MethodWithInvalidAsyncTraceActivity5()
+            {
+                // This method is intended to demonstrate incorrect usage of 
TraceActivity with async methods.
+                return await this.TraceActivity(async activity =>
+                {
+                    await Task.Delay(1);
+                    return await new AwaitableBool();
+                });
+            }
+
+            public class AwaitableBool
+            {
+                public BoolAwaiter GetAwaiter()
+                {
+                    return new BoolAwaiter();
+                }
+
+                public class BoolAwaiter : INotifyCompletion
+                {
+                    public bool IsCompleted => throw new 
NotImplementedException();
+                    public bool GetResult()
+                    {
+                        throw new NotImplementedException();
+                    }
+                    public void OnCompleted(Action continuation)
+                    {
+                        throw new NotImplementedException();
+                    }
+                }
+            }
+
             public override AdbcStatement CreateStatement() => throw new 
NotImplementedException();
             public override IArrowArrayStream GetObjects(GetObjectsDepth 
depth, string? catalogPattern, string? dbSchemaPattern, string? 
tableNamePattern, IReadOnlyList<string>? tableTypes, string? columnNamePattern) 
=> throw new NotImplementedException();
             public override Schema GetTableSchema(string? catalog, string? 
dbSchema, string tableName) => throw new NotImplementedException();

Reply via email to