This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 103865846 feat(csharp/src/Drivers/Databricks): Add option to enable 
using direct results for statements (#2737)
103865846 is described below

commit 103865846d78821d90500b9c402efddf484d3b59
Author: Alex Guo <133057192+alexguo...@users.noreply.github.com>
AuthorDate: Mon Apr 28 14:18:04 2025 -0700

    feat(csharp/src/Drivers/Databricks): Add option to enable using direct 
results for statements (#2737)
    
    - Add option to set EnableDirectResults, which sends getDirectResults in
    the Thrift execute statement request
    - If getDirectResults is set in the request, then directResults is set
    on the response containing initial results (equivalent to the server
    calling GetOperationStatus, GetResultSetMetadata, FetchResults, and
    CloseOperation)
    - If directResults is set on the response, don't poll for the operation
    status
    - We already set getDirectResults on requests for metadata commands,
    just not the execute statement request
    
    Tested E2E using `dotnet test --filter CloudFetchE2ETest`
    
    ```
    [xUnit.net 00:00:00.11]   Starting:    
Apache.Arrow.Adbc.Tests.Drivers.Databricks
    [xUnit.net 00:01:27.27]   Finished:    
Apache.Arrow.Adbc.Tests.Drivers.Databricks
      Apache.Arrow.Adbc.Tests.Drivers.Databricks test net8.0 succeeded (87.7s)
    
    Test summary: total: 8, failed: 0, succeeded: 8, skipped: 0, duration: 87.7s
    Build succeeded in 89.1s
    ```
---
 .../Drivers/Apache/Hive2/HiveServer2Connection.cs  | 26 +++++-----
 .../Drivers/Apache/Hive2/HiveServer2Statement.cs   | 37 ++++++++++++---
 csharp/src/Drivers/Apache/Spark/SparkConnection.cs |  2 +-
 .../Databricks/CloudFetch/CloudFetchReader.cs      |  2 +-
 .../CloudFetch/CloudFetchResultFetcher.cs          | 28 +++++++++++
 .../Databricks/CloudFetch/IHiveServer2Statement.cs | 11 +++++
 .../src/Drivers/Databricks/DatabricksConnection.cs | 40 ++++++++++++++++
 .../src/Drivers/Databricks/DatabricksParameters.cs |  6 +++
 csharp/src/Drivers/Databricks/DatabricksReader.cs  | 11 +++++
 .../src/Drivers/Databricks/DatabricksStatement.cs  | 16 +++++++
 .../test/Drivers/Databricks/CloudFetchE2ETest.cs   | 55 +++++++++++-----------
 .../Drivers/Databricks/DatabricksConnectionTest.cs |  1 +
 12 files changed, 186 insertions(+), 49 deletions(-)

diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
index 9a17eb6c1..f4f0f7978 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
@@ -377,7 +377,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                 if (depth == GetObjectsDepth.All || depth >= 
GetObjectsDepth.Catalogs)
                 {
                     TGetCatalogsReq getCatalogsReq = new 
TGetCatalogsReq(SessionHandle);
-                    if (AreResultsAvailableDirectly())
+                    if (AreResultsAvailableDirectly)
                     {
                         SetDirectResults(getCatalogsReq);
                     }
@@ -416,7 +416,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                     TGetSchemasReq getSchemasReq = new 
TGetSchemasReq(SessionHandle);
                     getSchemasReq.CatalogName = catalogPattern;
                     getSchemasReq.SchemaName = dbSchemaPattern;
-                    if (AreResultsAvailableDirectly())
+                    if (AreResultsAvailableDirectly)
                     {
                         SetDirectResults(getSchemasReq);
                     }
@@ -449,7 +449,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                     getTablesReq.CatalogName = catalogPattern;
                     getTablesReq.SchemaName = dbSchemaPattern;
                     getTablesReq.TableName = tableNamePattern;
-                    if (AreResultsAvailableDirectly())
+                    if (AreResultsAvailableDirectly)
                     {
                         SetDirectResults(getTablesReq);
                     }
@@ -486,7 +486,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                     columnsReq.CatalogName = catalogPattern;
                     columnsReq.SchemaName = dbSchemaPattern;
                     columnsReq.TableName = tableNamePattern;
-                    if (AreResultsAvailableDirectly())
+                    if (AreResultsAvailableDirectly)
                     {
                         SetDirectResults(columnsReq);
                     }
@@ -594,7 +594,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                 SessionHandle = SessionHandle ?? throw new 
InvalidOperationException("session not created"),
             };
 
-            if (AreResultsAvailableDirectly())
+            if (AreResultsAvailableDirectly)
             {
                 SetDirectResults(req);
             }
@@ -786,7 +786,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
         protected abstract Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken 
cancellationToken = default);
         protected internal abstract Task<TGetResultSetMetadataResp> 
GetResultSetMetadataAsync(TGetPrimaryKeysResp response, CancellationToken 
cancellationToken = default);
 
-        protected internal virtual bool AreResultsAvailableDirectly() => false;
+        protected internal virtual bool AreResultsAvailableDirectly => false;
 
         protected virtual void SetDirectResults(TGetColumnsReq request) => 
throw new System.NotImplementedException();
 
@@ -923,7 +923,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             }
 
             TGetCatalogsReq req = new TGetCatalogsReq(SessionHandle);
-            if (AreResultsAvailableDirectly())
+            if (AreResultsAvailableDirectly)
             {
                 SetDirectResults(req);
             }
@@ -950,7 +950,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             }
 
             TGetSchemasReq req = new(SessionHandle);
-            if (AreResultsAvailableDirectly())
+            if (AreResultsAvailableDirectly)
             {
                 SetDirectResults(req);
             }
@@ -987,7 +987,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             }
 
             TGetTablesReq req = new(SessionHandle);
-            if (AreResultsAvailableDirectly())
+            if (AreResultsAvailableDirectly)
             {
                 SetDirectResults(req);
             }
@@ -1032,7 +1032,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             }
 
             TGetColumnsReq req = new(SessionHandle);
-            if (AreResultsAvailableDirectly())
+            if (AreResultsAvailableDirectly)
             {
                 SetDirectResults(req);
             }
@@ -1076,7 +1076,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             }
 
             TGetPrimaryKeysReq req = new(SessionHandle);
-            if (AreResultsAvailableDirectly())
+            if (AreResultsAvailableDirectly)
             {
                 SetDirectResults(req);
             }
@@ -1119,7 +1119,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             }
 
             TGetCrossReferenceReq req = new(SessionHandle);
-            if (AreResultsAvailableDirectly())
+            if (AreResultsAvailableDirectly)
             {
                 SetDirectResults(req);
             }
@@ -1255,7 +1255,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             getColumnsReq.CatalogName = catalog;
             getColumnsReq.SchemaName = dbSchema;
             getColumnsReq.TableName = tableName;
-            if (AreResultsAvailableDirectly())
+            if (AreResultsAvailableDirectly)
             {
                 SetDirectResults(getColumnsReq);
             }
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
index 564a28e9f..bff8ca225 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
@@ -99,17 +99,28 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                 return await ExecuteMetadataCommandQuery(cancellationToken);
             }
 
+            _directResults = null;
+
             // this could either:
             // take QueryTimeoutSeconds * 3
             // OR
             // take QueryTimeoutSeconds (but this could be restricting)
             await ExecuteStatementAsync(cancellationToken); // --> get 
QueryTimeout +
-            await HiveServer2Connection.PollForResponseAsync(OperationHandle!, 
Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to 
QueryTimeout
-            TGetResultSetMetadataResp response = await 
HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, 
Connection.Client, cancellationToken);
-            Schema schema = 
Connection.SchemaParser.GetArrowSchema(response.Schema, 
Connection.DataTypeConversion);
 
+            TGetResultSetMetadataResp metadata;
+            if (_directResults?.OperationStatus?.OperationState == 
TOperationState.FINISHED_STATE)
+            {
+                // The initial response has result data so we don't need to 
poll
+                metadata = _directResults.ResultSetMetadata;
+            }
+            else
+            {
+                await 
HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, 
PollTimeMilliseconds, cancellationToken); // + poll, up to QueryTimeout
+                metadata = await 
HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, 
Connection.Client, cancellationToken);
+            }
             // Store metadata for use in readers
-            return new QueryResult(-1, Connection.NewReader(this, schema, 
response));
+            Schema schema = 
Connection.SchemaParser.GetArrowSchema(metadata.Schema, 
Connection.DataTypeConversion);
+            return new QueryResult(-1, Connection.NewReader(this, schema, 
metadata));
         }
 
         public override async ValueTask<QueryResult> ExecuteQueryAsync()
@@ -257,6 +268,19 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                     .SetNativeError(executeResponse.Status.ErrorCode);
             }
             OperationHandle = executeResponse.OperationHandle;
+
+            // Capture direct results if they're available
+            if (executeResponse.DirectResults != null)
+            {
+                _directResults = executeResponse.DirectResults;
+
+                if 
(!string.IsNullOrEmpty(_directResults.OperationStatus?.DisplayMessage))
+                {
+                    throw new 
HiveServer2Exception(_directResults.OperationStatus!.DisplayMessage)
+                        .SetSqlState(_directResults.OperationStatus.SqlState)
+                        
.SetNativeError(_directResults.OperationStatus.ErrorCode);
+                }
+            }
         }
 
         protected internal int PollTimeMilliseconds { get; private set; } = 
HiveServer2Connection.PollTimeMillisecondsDefault;
@@ -279,6 +303,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
         protected internal string? ForeignCatalogName { get; set; }
         protected internal string? ForeignSchemaName { get; set; }
         protected internal string? ForeignTableName { get; set; }
+        protected internal TSparkDirectResults? _directResults { get; set; }
 
         public HiveServer2Connection Connection { get; private set; }
 
@@ -416,7 +441,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             TRowSet rowSet;
 
             // For GetColumns, we need to enhance the result with 
BASE_TYPE_NAME
-            if (Connection.AreResultsAvailableDirectly() && 
resp.DirectResults?.ResultSet?.Results != null)
+            if (Connection.AreResultsAvailableDirectly && 
resp.DirectResults?.ResultSet?.Results != null)
             {
                 // Get data from direct results
                 metadata = resp.DirectResults.ResultSetMetadata;
@@ -454,7 +479,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
         private async Task<QueryResult> GetQueryResult(TSparkDirectResults? 
directResults, CancellationToken cancellationToken)
         {
             Schema schema;
-            if (Connection.AreResultsAvailableDirectly() && 
directResults?.ResultSet?.Results != null)
+            if (Connection.AreResultsAvailableDirectly && 
directResults?.ResultSet?.Results != null)
             {
                 TGetResultSetMetadataResp resultSetMetadata = 
directResults.ResultSetMetadata;
                 schema = 
Connection.SchemaParser.GetArrowSchema(resultSetMetadata.Schema, 
Connection.DataTypeConversion);
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs 
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index 925073e35..c7e25861e 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -117,7 +117,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
 
         protected override bool IsColumnSizeValidForDecimal => false;
 
-        protected internal override bool AreResultsAvailableDirectly() => true;
+        protected internal override bool AreResultsAvailableDirectly => true;
 
         protected override void SetDirectResults(TGetColumnsReq request) => 
request.GetDirectResults = sparkGetDirectResults;
 
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs 
b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
index abca66d37..1e3861833 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
@@ -1,4 +1,4 @@
-/*
+/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
diff --git 
a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs 
b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
index b0ad05a6a..3da5608ed 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
@@ -129,6 +129,14 @@ namespace 
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
         {
             try
             {
+                // Process direct results first, if available
+                if (_statement.HasDirectResults && 
_statement.DirectResults?.ResultSet?.Results?.ResultLinks?.Count > 0)
+                {
+                    // Yield execution so the download queue doesn't get 
blocked before downloader is started
+                    await Task.Yield();
+                    ProcessDirectResultsAsync(cancellationToken);
+                }
+
                 // Continue fetching as needed
                 while (_hasMoreResults && 
!cancellationToken.IsCancellationRequested)
                 {
@@ -228,5 +236,25 @@ namespace 
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
                 _hasMoreResults = false;
             }
         }
+
+        private void ProcessDirectResultsAsync(CancellationToken 
cancellationToken)
+        {
+            List<TSparkArrowResultLink> resultLinks = 
_statement.DirectResults!.ResultSet.Results.ResultLinks;
+
+            foreach (var link in resultLinks)
+            {
+                var downloadResult = new DownloadResult(link, _memoryManager);
+                _downloadQueue.Add(downloadResult, cancellationToken);
+            }
+
+            // Update the start offset for the next fetch
+            if (resultLinks.Count > 0)
+            {
+                var lastLink = resultLinks[resultLinks.Count - 1];
+                _startOffset = lastLink.StartRowOffset + lastLink.RowCount;
+            }
+
+            _hasMoreResults = _statement.DirectResults!.ResultSet.HasMoreRows;
+        }
     }
 }
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs 
b/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
index cfc92b98b..ee77dce9d 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
@@ -33,5 +33,16 @@ namespace 
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
         /// Gets the client.
         /// </summary>
         TCLIService.IAsync Client { get; }
+
+        /// <summary>
+        /// Gets the direct results.
+        /// </summary>
+        TSparkDirectResults? DirectResults { get; }
+
+        /// <summary>
+        /// Checks if direct results are available.
+        /// </summary>
+        /// <returns>True if direct results are available and contain result 
data, false otherwise.</returns>
+        bool HasDirectResults { get; }
     }
 }
diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs 
b/csharp/src/Drivers/Databricks/DatabricksConnection.cs
index aefe2df89..cd7fc02a0 100644
--- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs
@@ -33,6 +33,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
     internal class DatabricksConnection : SparkHttpConnection
     {
         private bool _applySSPWithQueries = false;
+        private bool _enableDirectResults = true;
+
+        internal static TSparkGetDirectResults defaultGetDirectResults = new()
+        {
+            MaxRows = 2000000,
+            MaxBytes = 404857600
+        };
 
         // CloudFetch configuration
         private const long DefaultMaxBytesPerFile = 20 * 1024 * 1024; // 20MB
@@ -62,6 +69,18 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
                 }
             }
 
+            if 
(Properties.TryGetValue(DatabricksParameters.EnableDirectResults, out string? 
enableDirectResultsStr))
+            {
+                if (bool.TryParse(enableDirectResultsStr, out bool 
enableDirectResultsValue))
+                {
+                    _enableDirectResults = enableDirectResultsValue;
+                }
+                else
+                {
+                    throw new ArgumentException($"Parameter 
'{DatabricksParameters.EnableDirectResults}' value '{enableDirectResultsStr}' 
could not be parsed. Valid values are 'true' and 'false'.");
+                }
+            }
+
             // Parse CloudFetch options from connection properties
             if (Properties.TryGetValue(DatabricksParameters.UseCloudFetch, out 
string? useCloudFetchStr))
             {
@@ -110,6 +129,11 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
         /// </summary>
         internal bool ApplySSPWithQueries => _applySSPWithQueries;
 
+        /// <summary>
+        /// Gets whether direct results are enabled.
+        /// </summary>
+        internal bool EnableDirectResults => _enableDirectResults;
+
         /// <summary>
         /// Gets whether CloudFetch is enabled.
         /// </summary>
@@ -145,6 +169,22 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
             return baseHandler;
         }
 
+        protected internal override bool AreResultsAvailableDirectly => 
_enableDirectResults;
+
+        protected override void SetDirectResults(TGetColumnsReq request) => 
request.GetDirectResults = defaultGetDirectResults;
+
+        protected override void SetDirectResults(TGetCatalogsReq request) => 
request.GetDirectResults = defaultGetDirectResults;
+
+        protected override void SetDirectResults(TGetSchemasReq request) => 
request.GetDirectResults = defaultGetDirectResults;
+
+        protected override void SetDirectResults(TGetTablesReq request) => 
request.GetDirectResults = defaultGetDirectResults;
+
+        protected override void SetDirectResults(TGetTableTypesReq request) => 
request.GetDirectResults = defaultGetDirectResults;
+
+        protected override void SetDirectResults(TGetPrimaryKeysReq request) 
=> request.GetDirectResults = defaultGetDirectResults;
+
+        protected override void SetDirectResults(TGetCrossReferenceReq 
request) => request.GetDirectResults = defaultGetDirectResults;
+
         internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema, TGetResultSetMetadataResp? metadataResp = null)
         {
             // Get result format from metadata response if available
diff --git a/csharp/src/Drivers/Databricks/DatabricksParameters.cs 
b/csharp/src/Drivers/Databricks/DatabricksParameters.cs
index f45350b72..7c6e9a69f 100644
--- a/csharp/src/Drivers/Databricks/DatabricksParameters.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksParameters.cs
@@ -61,6 +61,12 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
         /// </summary>
         public const string CloudFetchTimeoutMinutes = 
"adbc.databricks.cloudfetch.timeout_minutes";
 
+        /// <summary>
+        /// Whether to enable the use of direct results when executing queries.
+        /// Default value is true if not specified.
+        /// </summary>
+        public const string EnableDirectResults = 
"adbc.databricks.enable_direct_results";
+
         /// <summary>
         /// Whether to apply service side properties (SSP) with queries. If 
false, SSP will be applied
         /// by setting the Thrift configuration when the session is opened.
diff --git a/csharp/src/Drivers/Databricks/DatabricksReader.cs 
b/csharp/src/Drivers/Databricks/DatabricksReader.cs
index 56abfbb20..cdd131111 100644
--- a/csharp/src/Drivers/Databricks/DatabricksReader.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksReader.cs
@@ -39,6 +39,17 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
             this.statement = statement;
             this.schema = schema;
             this.isLz4Compressed = isLz4Compressed;
+
+            // If we have direct results, initialize the batches from them
+            if (statement.HasDirectResults)
+            {
+                this.batches = 
statement.DirectResults!.ResultSet.Results.ArrowBatches;
+
+                if (!statement.DirectResults.ResultSet.HasMoreRows)
+                {
+                    this.statement = null;
+                }
+            }
         }
 
         public Schema Schema { get { return schema; } }
diff --git a/csharp/src/Drivers/Databricks/DatabricksStatement.cs 
b/csharp/src/Drivers/Databricks/DatabricksStatement.cs
index 447689e51..72cdb8ac7 100644
--- a/csharp/src/Drivers/Databricks/DatabricksStatement.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksStatement.cs
@@ -49,6 +49,22 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
             statement.CanDownloadResult = useCloudFetch;
             statement.CanDecompressLZ4Result = canDecompressLz4;
             statement.MaxBytesPerFile = maxBytesPerFile;
+
+            if (Connection.AreResultsAvailableDirectly)
+            {
+                statement.GetDirectResults = 
DatabricksConnection.defaultGetDirectResults;
+            }
+        }
+
+        /// <summary>
+        /// Checks if direct results are available.
+        /// </summary>
+        /// <returns>True if direct results are available and contain result 
data, false otherwise.</returns>
+        public bool HasDirectResults => DirectResults?.ResultSet != null && 
DirectResults?.ResultSetMetadata != null;
+
+        public TSparkDirectResults? DirectResults
+        {
+            get { return _directResults; }
         }
 
         // Cast the Client to IAsync for CloudFetch compatibility
diff --git a/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs 
b/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
index 0d9bbfa90..96b040274 100644
--- a/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
+++ b/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
@@ -16,6 +16,7 @@
 */
 
 using System;
+using System.Collections.Generic;
 using System.Threading.Tasks;
 using Apache.Arrow.Adbc.Drivers.Databricks;
 using Xunit;
@@ -35,42 +36,40 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
             Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable));
         }
 
-        /// <summary>
-        /// Integration test for running a large query against a real 
Databricks cluster.
-        /// </summary>
-        [Fact]
-        public async Task TestRealDatabricksCloudFetchSmallResultSet()
-        {
-            await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM 
range(1000)", 1000);
-        }
-
-        [Fact]
-        public async Task TestRealDatabricksCloudFetchLargeResultSet()
+        public static IEnumerable<object[]> TestCases()
         {
-            await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM 
main.tpcds_sf10_delta.catalog_sales LIMIT 1000000", 1000000);
-        }
+            // Test cases format: (query, expected row count, use cloud fetch, 
enable direct results)
 
-        [Fact]
-        public async Task TestRealDatabricksNoCloudFetchSmallResultSet()
-        {
-            await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM 
range(1000)", 1000, false);
-        }
+            string smallQuery = $"SELECT * FROM range(1000)";
+            yield return new object[] { smallQuery, 1000, true, true };
+            yield return new object[] { smallQuery, 1000, false, true };
+            yield return new object[] { smallQuery, 1000, true, false };
+            yield return new object[] { smallQuery, 1000, false, false };
 
-        [Fact]
-        public async Task TestRealDatabricksNoCloudFetchLargeResultSet()
-        {
-            await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM 
main.tpcds_sf10_delta.catalog_sales LIMIT 1000000", 1000000, false);
+            string largeQuery = $"SELECT * FROM 
main.tpcds_sf10_delta.catalog_sales LIMIT 1000000";
+            yield return new object[] { largeQuery, 1000000, true, true };
+            yield return new object[] { largeQuery, 1000000, false, true };
+            yield return new object[] { largeQuery, 1000000, true, false };
+            yield return new object[] { largeQuery, 1000000, false, false };
         }
 
-        private async Task TestRealDatabricksCloudFetchLargeQuery(string 
query, int rowCount, bool useCloudFetch = true)
+        /// <summary>
+        /// Integration test for running queries against a real Databricks 
cluster with different CloudFetch settings.
+        /// </summary>
+        [Theory]
+        [MemberData(nameof(TestCases))]
+        private async Task TestRealDatabricksCloudFetch(string query, int 
rowCount, bool useCloudFetch, bool enableDirectResults)
         {
-            // Create a statement with CloudFetch enabled
-            var statement = Connection.CreateStatement();
-            statement.SetOption(DatabricksParameters.UseCloudFetch, 
useCloudFetch.ToString());
-            statement.SetOption(DatabricksParameters.CanDecompressLz4, "true");
-            statement.SetOption(DatabricksParameters.MaxBytesPerFile, 
"10485760"); // 10MB
+            var connection = NewConnection(TestConfiguration, new 
Dictionary<string, string>
+            {
+                [DatabricksParameters.UseCloudFetch] = 
useCloudFetch.ToString(),
+                [DatabricksParameters.EnableDirectResults] = 
enableDirectResults.ToString(),
+                [DatabricksParameters.CanDecompressLz4] = "true",
+                [DatabricksParameters.MaxBytesPerFile] = "10485760" // 10MB
+            });
 
             // Execute a query that generates a large result set using range 
function
+            var statement = connection.CreateStatement();
             statement.SqlQuery = query;
 
             // Execute the query and get the result
diff --git a/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs 
b/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs
index 859ee7e84..5c334957f 100644
--- a/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs
+++ b/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs
@@ -300,6 +300,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
                 Add(new(new() { [SparkParameters.HostName] = 
"valid.server.com", [SparkParameters.Token] = "abcdef", 
[DatabricksParameters.CanDecompressLz4] = "notabool"}, 
typeof(ArgumentException)));
                 Add(new(new() { [SparkParameters.HostName] = 
"valid.server.com", [SparkParameters.Token] = "abcdef", 
[DatabricksParameters.MaxBytesPerFile] = "notanumber" }, 
typeof(ArgumentException)));
                 Add(new(new() { [SparkParameters.HostName] = 
"valid.server.com", [SparkParameters.Token] = "abcdef", 
[DatabricksParameters.MaxBytesPerFile] = "-100" }, 
typeof(ArgumentOutOfRangeException)));
+                Add(new(new() { [SparkParameters.HostName] = 
"valid.server.com", [SparkParameters.Token] = "abcdef", 
[DatabricksParameters.EnableDirectResults] = "notabool" }, 
typeof(ArgumentException)));
                 Add(new(new() { /*[SparkParameters.Type] = 
SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] = 
"valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] 
= "-1" }, typeof(ArgumentOutOfRangeException)));
                 Add(new(new() { /*[SparkParameters.Type] = 
SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] = 
"valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] 
= IPEndPoint.MinPort.ToString(CultureInfo.InvariantCulture) }, 
typeof(ArgumentOutOfRangeException)));
                 Add(new(new() { /*[SparkParameters.Type] = 
SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] = 
"valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] 
= (IPEndPoint.MaxPort + 1).ToString(CultureInfo.InvariantCulture) }, 
typeof(ArgumentOutOfRangeException)));

Reply via email to