This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 9eee98deb feat(csharp): Implement CloudFetch for Databricks Spark 
driver (#2634)
9eee98deb is described below

commit 9eee98deb5d902dbe5322e5a2a9b5c922a80b220
Author: Jade Wang <[email protected]>
AuthorDate: Mon Mar 31 13:21:27 2025 -0700

    feat(csharp): Implement CloudFetch for Databricks Spark driver (#2634)
    
    Initial implementation of adding CloudFetch feature in Databricks Spark
    Driver.
    
    - create a new CloudFetchReader to handle CloudFetch file download and
    decompress.
    - Test case for small and large result.
    
    Coming changes after this
    
    - Adding prefetch to the downloader
    - Adding renewal for expired presigned url
    - Retries
---
 .../Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj |   2 +
 .../Drivers/Apache/Hive2/HiveServer2Connection.cs  |   2 +-
 .../Apache/Hive2/HiveServer2HttpConnection.cs      |   2 +-
 .../Drivers/Apache/Hive2/HiveServer2Statement.cs   |  14 +-
 .../Drivers/Apache/Impala/ImpalaHttpConnection.cs  |   2 +-
 .../Apache/Impala/ImpalaStandardConnection.cs      |   2 +-
 .../Spark/CloudFetch/SparkCloudFetchReader.cs      | 318 +++++++++++++++++++++
 csharp/src/Drivers/Apache/Spark/SparkConnection.cs |   3 +-
 .../Apache/Spark/SparkDatabricksConnection.cs      |  33 ++-
 .../Drivers/Apache/Spark/SparkDatabricksReader.cs  |   2 +-
 .../Drivers/Apache/Spark/SparkHttpConnection.cs    |   2 +-
 csharp/src/Drivers/Apache/Spark/SparkParameters.cs |  19 ++
 csharp/src/Drivers/Apache/Spark/SparkStatement.cs  | 103 ++++++-
 .../Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj  |   1 +
 .../test/Drivers/Apache/Spark/CloudFetchE2ETest.cs |  94 ++++++
 15 files changed, 580 insertions(+), 19 deletions(-)

diff --git a/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj 
b/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj
index 7e4c7c096..2ad285410 100644
--- a/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj
+++ b/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj
@@ -6,6 +6,8 @@
 
   <ItemGroup>
     <PackageReference Include="ApacheThrift" Version="0.21.0" />
+    <PackageReference Include="K4os.Compression.LZ4" Version="1.3.8" />
+    <PackageReference Include="K4os.Compression.LZ4.Streams" Version="1.3.8" />
     <PackageReference Include="System.Net.Http" Version="4.3.4" />
     <PackageReference Include="System.Text.Json" Version="8.0.5" />
   </ItemGroup>
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
index ab3efea31..990cb4774 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
@@ -354,7 +354,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 
         internal abstract SchemaParser SchemaParser { get; }
 
-        internal abstract IArrowArrayStream NewReader<T>(T statement, Schema 
schema) where T : HiveServer2Statement;
+        internal abstract IArrowArrayStream NewReader<T>(T statement, Schema 
schema, TGetResultSetMetadataResp? metadataResp = null) where T : 
HiveServer2Statement;
 
         public override IArrowArrayStream GetObjects(GetObjectsDepth depth, 
string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, 
IReadOnlyList<string>? tableTypes, string? columnNamePattern)
         {
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs
index 187e5712c..6ebcb9267 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs
@@ -144,7 +144,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             return new HiveServer2Statement(this);
         }
 
-        internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema) => new HiveServer2Reader(
+        internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema, TGetResultSetMetadataResp? metadataResp = null) => new 
HiveServer2Reader(
             statement,
             schema,
             dataTypeConversion: statement.Connection.DataTypeConversion,
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
index c08f997ca..9042b4205 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
@@ -84,9 +84,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             // take QueryTimeoutSeconds (but this could be restricting)
             await ExecuteStatementAsync(cancellationToken); // --> get 
QueryTimeout +
             await HiveServer2Connection.PollForResponseAsync(OperationHandle!, 
Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to 
QueryTimeout
-            Schema schema = await GetResultSetSchemaAsync(OperationHandle!, 
Connection.Client, cancellationToken); // + get the result, up to QueryTimeout
+            TGetResultSetMetadataResp response = await 
HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, 
Connection.Client, cancellationToken);
+            Schema schema = 
Connection.SchemaParser.GetArrowSchema(response.Schema, 
Connection.DataTypeConversion);
 
-            return new QueryResult(-1, Connection.NewReader(this, schema));
+            // Store metadata for use in readers
+            return new QueryResult(-1, Connection.NewReader(this, schema, 
response));
         }
 
         public override async ValueTask<QueryResult> ExecuteQueryAsync()
@@ -108,12 +110,6 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             }
         }
 
-        private async Task<Schema> GetResultSetSchemaAsync(TOperationHandle 
operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken 
= default)
-        {
-            TGetResultSetMetadataResp response = await 
HiveServer2Connection.GetResultSetMetadataAsync(operationHandle, client, 
cancellationToken);
-            return Connection.SchemaParser.GetArrowSchema(response.Schema, 
Connection.DataTypeConversion);
-        }
-
         public async Task<UpdateResult> 
ExecuteUpdateAsyncInternal(CancellationToken cancellationToken = default)
         {
             const string NumberOfAffectedRowsColumnName = "num_affected_rows";
@@ -195,7 +191,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 
         protected async Task ExecuteStatementAsync(CancellationToken 
cancellationToken = default)
         {
-            TExecuteStatementReq executeRequest = new 
TExecuteStatementReq(Connection.SessionHandle, SqlQuery);
+            TExecuteStatementReq executeRequest = new 
TExecuteStatementReq(Connection.SessionHandle!, SqlQuery!);
             SetStatementProperties(executeRequest);
             TExecuteStatementResp executeResponse = await 
Connection.Client.ExecuteStatement(executeRequest, cancellationToken);
             if (executeResponse.Status.StatusCode == TStatusCode.ERROR_STATUS)
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs 
b/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
index 914ba9269..ef5c34166 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
@@ -123,7 +123,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
             TlsOptions = HiveServer2TlsImpl.GetHttpTlsOptions(Properties);
         }
 
-        internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: 
statement.Connection.DataTypeConversion);
+        internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema, TGetResultSetMetadataResp? metadataResp = null) => new 
HiveServer2Reader(statement, schema, dataTypeConversion: 
statement.Connection.DataTypeConversion);
 
         protected override TTransport CreateTransport()
         {
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs 
b/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
index 99ac368be..2665070bb 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
@@ -149,7 +149,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
             return request;
         }
 
-        internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: 
statement.Connection.DataTypeConversion);
+        internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema, TGetResultSetMetadataResp? metadataResp = null) => new 
HiveServer2Reader(statement, schema, dataTypeConversion: 
statement.Connection.DataTypeConversion);
 
         internal override ImpalaServerType ServerType => 
ImpalaServerType.Standard;
 
diff --git 
a/csharp/src/Drivers/Apache/Spark/CloudFetch/SparkCloudFetchReader.cs 
b/csharp/src/Drivers/Apache/Spark/CloudFetch/SparkCloudFetchReader.cs
new file mode 100644
index 000000000..343bb5a0d
--- /dev/null
+++ b/csharp/src/Drivers/Apache/Spark/CloudFetch/SparkCloudFetchReader.cs
@@ -0,0 +1,318 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
+using Apache.Arrow.Ipc;
+using Apache.Hive.Service.Rpc.Thrift;
+using K4os.Compression.LZ4.Streams;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache.Spark.CloudFetch
+{
+    /// <summary>
+    /// Reader for CloudFetch results from Databricks Spark Thrift server.
+    /// Handles downloading and processing URL-based result sets.
+    /// </summary>
+    internal sealed class SparkCloudFetchReader : IArrowArrayStream
+    {
+        // Default values used if not specified in connection properties
+        private const int DefaultMaxRetries = 3;
+        private const int DefaultRetryDelayMs = 500;
+        private const int DefaultTimeoutMinutes = 5;
+
+        private readonly int maxRetries;
+        private readonly int retryDelayMs;
+        private readonly int timeoutMinutes;
+
+        private HiveServer2Statement? statement;
+        private readonly Schema schema;
+        private List<TSparkArrowResultLink>? resultLinks;
+        private int linkIndex;
+        private ArrowStreamReader? currentReader;
+        private readonly bool isLz4Compressed;
+        private long startOffset;
+
+        // Lazy initialization of HttpClient
+        private readonly Lazy<HttpClient> httpClient;
+
+        /// <summary>
+        /// Initializes a new instance of the <see 
cref="SparkCloudFetchReader"/> class.
+        /// </summary>
+        /// <param name="statement">The HiveServer2 statement.</param>
+        /// <param name="schema">The Arrow schema.</param>
+        /// <param name="isLz4Compressed">Whether the results are LZ4 
compressed.</param>
+        public SparkCloudFetchReader(HiveServer2Statement statement, Schema 
schema, bool isLz4Compressed)
+        {
+            this.statement = statement;
+            this.schema = schema;
+            this.isLz4Compressed = isLz4Compressed;
+
+            // Get configuration values from connection properties or use 
defaults
+            var connectionProps = statement.Connection.Properties;
+
+            // Parse max retries
+            int parsedMaxRetries = DefaultMaxRetries;
+            if 
(connectionProps.TryGetValue(SparkParameters.CloudFetchMaxRetries, out string? 
maxRetriesStr) &&
+                int.TryParse(maxRetriesStr, out parsedMaxRetries) &&
+                parsedMaxRetries > 0)
+            {
+                // Value was successfully parsed
+            }
+            else
+            {
+                parsedMaxRetries = DefaultMaxRetries;
+            }
+            this.maxRetries = parsedMaxRetries;
+
+            // Parse retry delay
+            int parsedRetryDelay = DefaultRetryDelayMs;
+            if 
(connectionProps.TryGetValue(SparkParameters.CloudFetchRetryDelayMs, out 
string? retryDelayStr) &&
+                int.TryParse(retryDelayStr, out parsedRetryDelay) &&
+                parsedRetryDelay > 0)
+            {
+                // Value was successfully parsed
+            }
+            else
+            {
+                parsedRetryDelay = DefaultRetryDelayMs;
+            }
+            this.retryDelayMs = parsedRetryDelay;
+
+            // Parse timeout minutes
+            int parsedTimeout = DefaultTimeoutMinutes;
+            if 
(connectionProps.TryGetValue(SparkParameters.CloudFetchTimeoutMinutes, out 
string? timeoutStr) &&
+                int.TryParse(timeoutStr, out parsedTimeout) &&
+                parsedTimeout > 0)
+            {
+                // Value was successfully parsed
+            }
+            else
+            {
+                parsedTimeout = DefaultTimeoutMinutes;
+            }
+            this.timeoutMinutes = parsedTimeout;
+
+            // Initialize HttpClient with the configured timeout
+            this.httpClient = new Lazy<HttpClient>(() =>
+            {
+                var client = new HttpClient();
+                client.Timeout = TimeSpan.FromMinutes(this.timeoutMinutes);
+                return client;
+            });
+        }
+
+        /// <summary>
+        /// Gets the Arrow schema.
+        /// </summary>
+        public Schema Schema { get { return schema; } }
+
+        private HttpClient HttpClient
+        {
+            get { return httpClient.Value; }
+        }
+
+        /// <summary>
+        /// Reads the next record batch from the result set.
+        /// </summary>
+        /// <param name="cancellationToken">The cancellation token.</param>
+        /// <returns>The next record batch, or null if there are no more 
batches.</returns>
+        public async ValueTask<RecordBatch?> 
ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
+        {
+            while (true)
+            {
+                // If we have a current reader, try to read the next batch
+                if (this.currentReader != null)
+                {
+                    RecordBatch? next = await 
this.currentReader.ReadNextRecordBatchAsync(cancellationToken);
+                    if (next != null)
+                    {
+                        return next;
+                    }
+                    else
+                    {
+                        this.currentReader.Dispose();
+                        this.currentReader = null;
+                    }
+                }
+
+                // If we have more links to process, download and process the 
next one
+                if (this.resultLinks != null && this.linkIndex < 
this.resultLinks.Count)
+                {
+                    var link = this.resultLinks[this.linkIndex++];
+                    byte[]? fileData = null;
+
+                    // Retry logic for downloading files
+                    for (int retry = 0; retry < this.maxRetries; retry++)
+                    {
+                        try
+                        {
+                            fileData = await DownloadFileAsync(link.FileLink, 
cancellationToken);
+                            break; // Success, exit retry loop
+                        }
+                        catch (Exception ex) when (retry < this.maxRetries - 1)
+                        {
+                            // Log the error and retry
+                            Debug.WriteLine($"Error downloading file (attempt 
{retry + 1}/{this.maxRetries}): {ex.Message}");
+                            await Task.Delay(this.retryDelayMs * (retry + 1), 
cancellationToken);
+                        }
+                    }
+
+                    // Process the downloaded file data
+                    MemoryStream dataStream;
+
+                    // If the data is LZ4 compressed, decompress it
+                    if (this.isLz4Compressed)
+                    {
+                        try
+                        {
+                            dataStream = new MemoryStream();
+                            using (var inputStream = new 
MemoryStream(fileData!))
+                            using (var decompressor = 
LZ4Stream.Decode(inputStream))
+                            {
+                                await decompressor.CopyToAsync(dataStream);
+                            }
+                            dataStream.Position = 0;
+                        }
+                        catch (Exception ex)
+                        {
+                            Debug.WriteLine($"Error decompressing data: 
{ex.Message}");
+                            continue; // Skip this link and try the next one
+                        }
+                    }
+                    else
+                    {
+                        dataStream = new MemoryStream(fileData!);
+                    }
+
+                    try
+                    {
+                        this.currentReader = new ArrowStreamReader(dataStream);
+                        continue;
+                    }
+                    catch (Exception ex)
+                    {
+                        Debug.WriteLine($"Error creating Arrow reader: 
{ex.Message}");
+                        dataStream.Dispose();
+                        continue; // Skip this link and try the next one
+                    }
+                }
+
+                this.resultLinks = null;
+                this.linkIndex = 0;
+
+                // If there's no statement, we're done
+                if (this.statement == null)
+                {
+                    return null;
+                }
+
+                // Fetch more results from the server
+                TFetchResultsReq request = new 
TFetchResultsReq(this.statement.OperationHandle!, TFetchOrientation.FETCH_NEXT, 
this.statement.BatchSize);
+
+                // Set the start row offset if we have processed some links 
already
+                if (this.startOffset > 0)
+                {
+                    request.StartRowOffset = this.startOffset;
+                }
+
+                TFetchResultsResp response;
+                try
+                {
+                    response = await 
this.statement.Connection.Client!.FetchResults(request, cancellationToken);
+                }
+                catch (Exception ex)
+                {
+                    Debug.WriteLine($"Error fetching results from server: 
{ex.Message}");
+                    this.statement = null; // Mark as done due to error
+                    return null;
+                }
+
+                // Check if we have URL-based results
+                if (response.Results.__isset.resultLinks &&
+                    response.Results.ResultLinks != null &&
+                    response.Results.ResultLinks.Count > 0)
+                {
+                    this.resultLinks = response.Results.ResultLinks;
+
+                    // Update the start offset for the next fetch by 
calculating it from the links
+                    if (this.resultLinks.Count > 0)
+                    {
+                        var lastLink = this.resultLinks[this.resultLinks.Count 
- 1];
+                        this.startOffset = lastLink.StartRowOffset + 
lastLink.RowCount;
+                    }
+
+                    // If the server indicates there are no more rows, we can 
close the statement
+                    if (!response.HasMoreRows)
+                    {
+                        this.statement = null;
+                    }
+                }
+                else
+                {
+                    // If there are no more results, we're done
+                    this.statement = null;
+                    return null;
+                }
+            }
+        }
+
+        /// <summary>
+        /// Downloads a file from a URL.
+        /// </summary>
+        /// <param name="url">The URL to download from.</param>
+        /// <param name="cancellationToken">The cancellation token.</param>
+        /// <returns>The downloaded file data.</returns>
+        private async Task<byte[]> DownloadFileAsync(string url, 
CancellationToken cancellationToken)
+        {
+            using HttpResponseMessage response = await 
HttpClient.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, 
cancellationToken);
+            response.EnsureSuccessStatusCode();
+
+            // Get the content length if available
+            long? contentLength = response.Content.Headers.ContentLength;
+            if (contentLength.HasValue && contentLength.Value > 0)
+            {
+                Debug.WriteLine($"Downloading file of size: 
{contentLength.Value / 1024.0 / 1024.0:F2} MB");
+            }
+
+            return await response.Content.ReadAsByteArrayAsync();
+        }
+
+        /// <summary>
+        /// Disposes the reader.
+        /// </summary>
+        public void Dispose()
+        {
+            if (this.currentReader != null)
+            {
+                this.currentReader.Dispose();
+                this.currentReader = null;
+            }
+
+            // Dispose the HttpClient if it was created
+            if (httpClient.IsValueCreated)
+            {
+                httpClient.Value.Dispose();
+            }
+        }
+    }
+}
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs 
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index b9b40dfd1..2eb11e941 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -63,7 +63,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
 
         public override AdbcStatement CreateStatement()
         {
-            return new SparkStatement(this);
+            SparkStatement statement = new SparkStatement(this);
+            return statement;
         }
 
         protected internal override int PositionRequiredOffset => 1;
diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs 
b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
index 14d94acf3..a2413635b 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
@@ -18,6 +18,8 @@
 using System.Collections.Generic;
 using System.Threading;
 using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
+using Apache.Arrow.Adbc.Drivers.Apache.Spark.CloudFetch;
 using Apache.Arrow.Ipc;
 using Apache.Hive.Service.Rpc.Thrift;
 
@@ -29,7 +31,35 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
         {
         }
 
-        internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema) => new SparkDatabricksReader(statement, schema);
+        internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema, TGetResultSetMetadataResp? metadataResp = null)
+        {
+            // Get result format from metadata response if available
+            TSparkRowSetType resultFormat = TSparkRowSetType.ARROW_BASED_SET;
+            bool isLz4Compressed = false;
+
+            if (metadataResp != null)
+            {
+                if (metadataResp.__isset.resultFormat)
+                {
+                    resultFormat = metadataResp.ResultFormat;
+                }
+
+                if (metadataResp.__isset.lz4Compressed)
+                {
+                    isLz4Compressed = metadataResp.Lz4Compressed;
+                }
+            }
+
+            // Choose the appropriate reader based on the result format
+            if (resultFormat == TSparkRowSetType.URL_BASED_SET)
+            {
+                return new SparkCloudFetchReader(statement, schema, 
isLz4Compressed);
+            }
+            else
+            {
+                return new SparkDatabricksReader(statement, schema);
+            }
+        }
 
         internal override SchemaParser SchemaParser => new 
SparkDatabricksSchemaParser();
 
@@ -40,6 +70,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             var req = new TOpenSessionReq
             {
                 Client_protocol = 
TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7,
+                Client_protocol_i64 = 
(long)TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7,
                 CanUseMultipleCatalogs = true,
             };
             return req;
diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs 
b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
index 059ab1690..0e0166926 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
@@ -68,7 +68,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
                     return null;
                 }
 
-                TFetchResultsReq request = new 
TFetchResultsReq(this.statement.OperationHandle, TFetchOrientation.FETCH_NEXT, 
this.statement.BatchSize);
+                TFetchResultsReq request = new 
TFetchResultsReq(this.statement.OperationHandle!, TFetchOrientation.FETCH_NEXT, 
this.statement.BatchSize);
                 TFetchResultsResp response = await 
this.statement.Connection.Client!.FetchResults(request, cancellationToken);
                 this.batches = response.Results.ArrowBatches;
 
diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs 
b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
index e28ab4632..fd7f18097 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
@@ -139,7 +139,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             TlsOptions = HiveServer2TlsImpl.GetHttpTlsOptions(Properties);
         }
 
-        internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: 
statement.Connection.DataTypeConversion);
+        internal override IArrowArrayStream NewReader<T>(T statement, Schema 
schema, TGetResultSetMetadataResp? metadataResp = null) => new 
HiveServer2Reader(statement, schema, dataTypeConversion: 
statement.Connection.DataTypeConversion);
 
         protected override TTransport CreateTransport()
         {
diff --git a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs 
b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
index 8e75ae3f5..b5587197d 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
@@ -33,6 +33,25 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
         public const string Type = "adbc.spark.type";
         public const string DataTypeConv = "adbc.spark.data_type_conv";
         public const string ConnectTimeoutMilliseconds = 
"adbc.spark.connect_timeout_ms";
+
+        // CloudFetch configuration parameters
+        /// <summary>
+        /// Maximum number of retry attempts for CloudFetch downloads.
+        /// Default value is 3 if not specified.
+        /// </summary>
+        public const string CloudFetchMaxRetries = 
"adbc.spark.cloudfetch.max_retries";
+
+        /// <summary>
+        /// Delay in milliseconds between CloudFetch retry attempts.
+        /// Default value is 500ms if not specified.
+        /// </summary>
+        public const string CloudFetchRetryDelayMs = 
"adbc.spark.cloudfetch.retry_delay_ms";
+
+        /// <summary>
+        /// Timeout in minutes for CloudFetch HTTP operations.
+        /// Default value is 5 minutes if not specified.
+        /// </summary>
+        public const string CloudFetchTimeoutMinutes = 
"adbc.spark.cloudfetch.timeout_minutes";
     }
 
     public static class SparkAuthTypeConstants
diff --git a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs 
b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
index ffe491e72..4c4e61562 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
@@ -15,6 +15,7 @@
 * limitations under the License.
 */
 
+using System;
 using System.Collections.Generic;
 using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
 using Apache.Hive.Service.Rpc.Thrift;
@@ -23,6 +24,14 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
 {
     internal class SparkStatement : HiveServer2Statement
     {
+        // Default maximum bytes per file for CloudFetch
+        private const long DefaultMaxBytesPerFile = 20 * 1024 * 1024; // 20MB
+
+        // CloudFetch configuration
+        private bool useCloudFetch = true;
+        private bool canDecompressLz4 = true;
+        private long maxBytesPerFile = DefaultMaxBytesPerFile;
+
         internal SparkStatement(SparkConnection connection)
             : base(connection)
         {
@@ -37,7 +46,12 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             // Set in combination with a CancellationToken.
             statement.QueryTimeout = QueryTimeoutSeconds;
             statement.CanReadArrowResult = true;
-            statement.CanDownloadResult = true;
+
+            // Set CloudFetch capabilities
+            statement.CanDownloadResult = useCloudFetch;
+            statement.CanDecompressLZ4Result = canDecompressLz4;
+            statement.MaxBytesPerFile = maxBytesPerFile;
+
 #pragma warning disable CS0618 // Type or member is obsolete
             statement.ConfOverlay = SparkConnection.timestampConfig;
 #pragma warning restore CS0618 // Type or member is obsolete
@@ -54,12 +68,97 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
             };
         }
 
+        public override void SetOption(string key, string value)
+        {
+            switch (key)
+            {
+                case Options.UseCloudFetch:
+                    if (bool.TryParse(value, out bool useCloudFetchValue))
+                    {
+                        this.useCloudFetch = useCloudFetchValue;
+                    }
+                    else
+                    {
+                        throw new ArgumentException($"Invalid value for {key}: 
{value}. Expected a boolean value.");
+                    }
+                    break;
+                case Options.CanDecompressLz4:
+                    if (bool.TryParse(value, out bool canDecompressLz4Value))
+                    {
+                        this.canDecompressLz4 = canDecompressLz4Value;
+                    }
+                    else
+                    {
+                        throw new ArgumentException($"Invalid value for {key}: 
{value}. Expected a boolean value.");
+                    }
+                    break;
+                case Options.MaxBytesPerFile:
+                    if (long.TryParse(value, out long maxBytesPerFileValue))
+                    {
+                        this.maxBytesPerFile = maxBytesPerFileValue;
+                    }
+                    else
+                    {
+                        throw new ArgumentException($"Invalid value for {key}: 
{value}. Expected a long value.");
+                    }
+                    break;
+                default:
+                    base.SetOption(key, value);
+                    break;
+            }
+        }
+
+        /// <summary>
+        /// Sets whether to use CloudFetch for retrieving results.
+        /// </summary>
+        /// <param name="useCloudFetch">Whether to use CloudFetch.</param>
+        internal void SetUseCloudFetch(bool useCloudFetch)
+        {
+            this.useCloudFetch = useCloudFetch;
+        }
+
+        /// <summary>
+        /// Gets whether CloudFetch is enabled.
+        /// </summary>
+        public bool UseCloudFetch => useCloudFetch;
+
+        /// <summary>
+        /// Sets whether the client can decompress LZ4 compressed results.
+        /// </summary>
+        /// <param name="canDecompressLz4">Whether the client can decompress 
LZ4.</param>
+        internal void SetCanDecompressLz4(bool canDecompressLz4)
+        {
+            this.canDecompressLz4 = canDecompressLz4;
+        }
+
+        /// <summary>
+        /// Gets whether LZ4 decompression is enabled.
+        /// </summary>
+        public bool CanDecompressLz4 => canDecompressLz4;
+
+        /// <summary>
+        /// Sets the maximum bytes per file for CloudFetch.
+        /// </summary>
+        /// <param name="maxBytesPerFile">The maximum bytes per file.</param>
+        internal void SetMaxBytesPerFile(long maxBytesPerFile)
+        {
+            this.maxBytesPerFile = maxBytesPerFile;
+        }
+
+        /// <summary>
+        /// Gets the maximum bytes per file for CloudFetch.
+        /// </summary>
+        public long MaxBytesPerFile => maxBytesPerFile;
+
         /// <summary>
         /// Provides the constant string key values to the <see 
cref="AdbcStatement.SetOption(string, string)" /> method.
         /// </summary>
         public sealed class Options : ApacheParameters
         {
-            // options specific to Spark go here
+            // CloudFetch options
+            public const string UseCloudFetch = 
"adbc.spark.cloudfetch.enabled";
+            public const string CanDecompressLz4 = 
"adbc.spark.cloudfetch.lz4.enabled";
+            public const string MaxBytesPerFile = 
"adbc.spark.cloudfetch.max_bytes_per_file";
         }
     }
 }
diff --git 
a/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj 
b/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj
index 63365312e..af779a699 100644
--- a/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj
+++ b/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj
@@ -14,6 +14,7 @@
     </PackageReference>
     <PackageReference Include="Xunit.SkippableFact" Version="1.5.23" />
     <PackageReference Include="System.Net.Http" Version="4.3.4" />
+    <PackageReference Include="K4os.Compression.LZ4" Version="1.3.8" />
   </ItemGroup>
 
   <ItemGroup>
diff --git a/csharp/test/Drivers/Apache/Spark/CloudFetchE2ETest.cs 
b/csharp/test/Drivers/Apache/Spark/CloudFetchE2ETest.cs
new file mode 100644
index 000000000..0325c3e98
--- /dev/null
+++ b/csharp/test/Drivers/Apache/Spark/CloudFetchE2ETest.cs
@@ -0,0 +1,94 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Reflection;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Spark;
+using Apache.Arrow.Adbc.Drivers.Apache.Spark.CloudFetch;
+using Apache.Arrow.Types;
+using Xunit;
+using Xunit.Abstractions;
+using Apache.Arrow.Adbc.Client;
+using Apache.Arrow.Adbc.Tests.Drivers.Apache.Common;
+
+namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
+{
+    /// <summary>
+    /// End-to-end tests for the CloudFetch feature in the Spark ADBC driver.
+    /// </summary>
+    public class CloudFetchE2ETest : TestBase<SparkTestConfiguration, 
SparkTestEnvironment>
+    {
+        public CloudFetchE2ETest(ITestOutputHelper? outputHelper)
+            : base(outputHelper, new SparkTestEnvironment.Factory())
+        {
+            // Skip the test if the SPARK_TEST_CONFIG_FILE environment 
variable is not set
+            Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable));
+        }
+
+        /// <summary>
+        /// Integration test for running a large query against a real 
Databricks cluster.
+        /// </summary>
+        [Fact]
+        public async Task TestRealDatabricksCloudFetchSmallResultSet()
+        {
+            await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM 
range(1000)", 1000);
+        }
+
+        [Fact]
+        public async Task TestRealDatabricksCloudFetchLargeResultSet()
+        {
+            await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM 
main.tpcds_sf10_delta.catalog_sales LIMIT 1000000", 1000000);
+        }
+
+        private async Task TestRealDatabricksCloudFetchLargeQuery(string 
query, int rowCount)
+        {
+            // Create a statement with CloudFetch enabled
+            var statement = Connection.CreateStatement();
+            statement.SetOption(SparkStatement.Options.UseCloudFetch, "true");
+            statement.SetOption(SparkStatement.Options.CanDecompressLz4, 
"true");
+            statement.SetOption(SparkStatement.Options.MaxBytesPerFile, 
"10485760"); // 10MB
+
+
+            // Execute a query that generates a large result set using range 
function
+            statement.SqlQuery = query;
+
+            // Execute the query and get the result
+            var result = await statement.ExecuteQueryAsync();
+
+
+            if (result.Stream == null)
+            {
+                throw new InvalidOperationException("Result stream is null");
+            }
+
+            // Read all the data and count rows
+            long totalRows = 0;
+            RecordBatch? batch;
+            while ((batch = await result.Stream.ReadNextRecordBatchAsync()) != 
null)
+            {
+                totalRows += batch.Length;
+            }
+
+            Assert.True(totalRows >= rowCount);
+
+            // Also log to the test output helper if available
+            OutputHelper?.WriteLine($"Read {totalRows} rows from range 
function");
+        }
+    }
+}


Reply via email to