This is an automated email from the ASF dual-hosted git repository. curth pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push: new 7f3d33b90 feat(csharp/src/Drivers/Apache): Add prefetch functionality to CloudFetch in Spark ADBC driver (#2678) 7f3d33b90 is described below commit 7f3d33b9022baec92caf4494d282ff4f7e286e90 Author: Jade Wang <111902719+jadewang...@users.noreply.github.com> AuthorDate: Thu Apr 24 08:56:46 2025 -0700 feat(csharp/src/Drivers/Apache): Add prefetch functionality to CloudFetch in Spark ADBC driver (#2678) # Add Prefetch Functionality to CloudFetch in Spark ADBC Driver This PR enhances the CloudFetch feature in the Spark ADBC driver by implementing prefetch functionality, which improves performance by fetching multiple batches of results ahead of time. ## Changes ### CloudFetchResultFetcher Enhancements - **Initial Prefetch**: Added code to perform an initial prefetch of multiple batches when the fetcher starts, ensuring data is available immediately when needed. - **State Management**: Added tracking for current batch offset and size, with proper state reset when starting the fetcher. ### Interface Updates - Added new methods to `ICloudFetchResultFetcher` interface: ### Testing Infrastructure - Created `ITestableHiveServer2Statement` interface to facilitate testing - Updated tests to account for prefetch behavior - Ensured all tests pass with the new prefetch functionality ## Benefits - **Improved Performance**: By prefetching multiple batches, data is available sooner, reducing wait times. - **Better Reliability**: Enhanced error handling and state management make the system more robust. - **More Efficient Resource Usage**: Link caching reduces unnecessary server requests. This implementation maintains backward compatibility while providing significant performance improvements for CloudFetch operations. --- .../Drivers/Apache/Hive2/HiveServer2Statement.cs | 3 + .../CloudFetch/CloudFetchDownloadManager.cs | 330 +++++++++++++ .../Databricks/CloudFetch/CloudFetchDownloader.cs | 402 ++++++++++++++++ .../CloudFetch/CloudFetchMemoryBufferManager.cs | 135 ++++++ .../Databricks/CloudFetch/CloudFetchReader.cs | 268 ++++------- .../CloudFetch/CloudFetchResultFetcher.cs | 232 +++++++++ .../Databricks/CloudFetch/DownloadResult.cs | 128 +++++ .../Databricks/CloudFetch/EndOfResultsGuard.cs | 69 +++ .../Databricks/CloudFetch/ICloudFetchInterfaces.cs | 217 +++++++++ .../Databricks/CloudFetch/IHiveServer2Statement.cs | 37 ++ .../CloudFetch/cloudfetch-pipeline-design.md | 72 +++ .../src/Drivers/Databricks/DatabricksParameters.cs | 25 + .../src/Drivers/Databricks/DatabricksStatement.cs | 6 +- .../CloudFetch/CloudFetchDownloaderTest.cs | 534 +++++++++++++++++++++ .../CloudFetch/CloudFetchResultFetcherTest.cs | 386 +++++++++++++++ 15 files changed, 2660 insertions(+), 184 deletions(-) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 607925333..564a28e9f 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -284,6 +284,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 public TOperationHandle? OperationHandle { get; private set; } + // Keep the original Client property for internal use + public TCLIService.Client Client => Connection.Client; + private void UpdatePollTimeIfValid(string key, string value) => PollTimeMilliseconds = !string.IsNullOrEmpty(key) && int.TryParse(value, result: out int pollTimeMilliseconds) && pollTimeMilliseconds >= 0 ? pollTimeMilliseconds : throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than or equal to 0."); diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs new file mode 100644 index 000000000..99f38fe82 --- /dev/null +++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Databricks; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Manages the CloudFetch download pipeline. + /// </summary> + internal sealed class CloudFetchDownloadManager : ICloudFetchDownloadManager + { + // Default values + private const int DefaultParallelDownloads = 3; + private const int DefaultPrefetchCount = 2; + private const int DefaultMemoryBufferSizeMB = 200; + private const bool DefaultPrefetchEnabled = true; + private const int DefaultFetchBatchSize = 2000000; + + private readonly DatabricksStatement _statement; + private readonly Schema _schema; + private readonly bool _isLz4Compressed; + private readonly ICloudFetchMemoryBufferManager _memoryManager; + private readonly BlockingCollection<IDownloadResult> _downloadQueue; + private readonly BlockingCollection<IDownloadResult> _resultQueue; + private readonly ICloudFetchResultFetcher _resultFetcher; + private readonly ICloudFetchDownloader _downloader; + private readonly HttpClient _httpClient; + private bool _isDisposed; + private bool _isStarted; + private CancellationTokenSource? _cancellationTokenSource; + + /// <summary> + /// Initializes a new instance of the <see cref="CloudFetchDownloadManager"/> class. + /// </summary> + /// <param name="statement">The HiveServer2 statement.</param> + /// <param name="schema">The Arrow schema.</param> + /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> + public CloudFetchDownloadManager(DatabricksStatement statement, Schema schema, bool isLz4Compressed) + { + _statement = statement ?? throw new ArgumentNullException(nameof(statement)); + _schema = schema ?? throw new ArgumentNullException(nameof(schema)); + _isLz4Compressed = isLz4Compressed; + + // Get configuration values from connection properties + var connectionProps = statement.Connection.Properties; + + // Parse parallel downloads + int parallelDownloads = DefaultParallelDownloads; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchParallelDownloads, out string? parallelDownloadsStr)) + { + if (int.TryParse(parallelDownloadsStr, out int parsedParallelDownloads) && parsedParallelDownloads > 0) + { + parallelDownloads = parsedParallelDownloads; + } + else + { + throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchParallelDownloads}: {parallelDownloadsStr}. Expected a positive integer."); + } + } + + // Parse prefetch count + int prefetchCount = DefaultPrefetchCount; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchPrefetchCount, out string? prefetchCountStr)) + { + if (int.TryParse(prefetchCountStr, out int parsedPrefetchCount) && parsedPrefetchCount > 0) + { + prefetchCount = parsedPrefetchCount; + } + else + { + throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchPrefetchCount}: {prefetchCountStr}. Expected a positive integer."); + } + } + + // Parse memory buffer size + int memoryBufferSizeMB = DefaultMemoryBufferSizeMB; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMemoryBufferSize, out string? memoryBufferSizeStr)) + { + if (int.TryParse(memoryBufferSizeStr, out int parsedMemoryBufferSize) && parsedMemoryBufferSize > 0) + { + memoryBufferSizeMB = parsedMemoryBufferSize; + } + else + { + throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchMemoryBufferSize}: {memoryBufferSizeStr}. Expected a positive integer."); + } + } + + // Parse max retries + int maxRetries = 3; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out string? maxRetriesStr)) + { + if (int.TryParse(maxRetriesStr, out int parsedMaxRetries) && parsedMaxRetries > 0) + { + maxRetries = parsedMaxRetries; + } + else + { + throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchMaxRetries}: {maxRetriesStr}. Expected a positive integer."); + } + } + + // Parse retry delay + int retryDelayMs = 500; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out string? retryDelayStr)) + { + if (int.TryParse(retryDelayStr, out int parsedRetryDelay) && parsedRetryDelay > 0) + { + retryDelayMs = parsedRetryDelay; + } + else + { + throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchRetryDelayMs}: {retryDelayStr}. Expected a positive integer."); + } + } + + // Parse timeout minutes + int timeoutMinutes = 5; + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchTimeoutMinutes, out string? timeoutStr)) + { + if (int.TryParse(timeoutStr, out int parsedTimeout) && parsedTimeout > 0) + { + timeoutMinutes = parsedTimeout; + } + else + { + throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchTimeoutMinutes}: {timeoutStr}. Expected a positive integer."); + } + } + + // Initialize the memory manager + _memoryManager = new CloudFetchMemoryBufferManager(memoryBufferSizeMB); + + // Initialize the queues with bounded capacity + _downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), prefetchCount * 2); + _resultQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), prefetchCount * 2); + + // Initialize the HTTP client + _httpClient = new HttpClient + { + Timeout = TimeSpan.FromMinutes(timeoutMinutes) + }; + + // Initialize the result fetcher + _resultFetcher = new CloudFetchResultFetcher( + _statement, + _memoryManager, + _downloadQueue, + DefaultFetchBatchSize); + + // Initialize the downloader + _downloader = new CloudFetchDownloader( + _downloadQueue, + _resultQueue, + _memoryManager, + _httpClient, + parallelDownloads, + _isLz4Compressed, + maxRetries, + retryDelayMs); + } + + /// <summary> + /// Initializes a new instance of the <see cref="CloudFetchDownloadManager"/> class. + /// This constructor is intended for testing purposes only. + /// </summary> + /// <param name="statement">The HiveServer2 statement.</param> + /// <param name="schema">The Arrow schema.</param> + /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> + /// <param name="resultFetcher">The result fetcher.</param> + /// <param name="downloader">The downloader.</param> + internal CloudFetchDownloadManager( + DatabricksStatement statement, + Schema schema, + bool isLz4Compressed, + ICloudFetchResultFetcher resultFetcher, + ICloudFetchDownloader downloader) + { + _statement = statement ?? throw new ArgumentNullException(nameof(statement)); + _schema = schema ?? throw new ArgumentNullException(nameof(schema)); + _isLz4Compressed = isLz4Compressed; + _resultFetcher = resultFetcher ?? throw new ArgumentNullException(nameof(resultFetcher)); + _downloader = downloader ?? throw new ArgumentNullException(nameof(downloader)); + + // Create empty collections for the test + _memoryManager = new CloudFetchMemoryBufferManager(DefaultMemoryBufferSizeMB); + _downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10); + _resultQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10); + _httpClient = new HttpClient(); + } + + /// <inheritdoc /> + public bool HasMoreResults => !_downloader.IsCompleted || !_resultQueue.IsCompleted; + + /// <inheritdoc /> + public async Task<IDownloadResult?> GetNextDownloadedFileAsync(CancellationToken cancellationToken) + { + ThrowIfDisposed(); + + if (!_isStarted) + { + throw new InvalidOperationException("Download manager has not been started."); + } + + try + { + return await _downloader.GetNextDownloadedFileAsync(cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) when (_resultFetcher.HasError) + { + throw new AggregateException("Errors in download pipeline", new[] { ex, _resultFetcher.Error! }); + } + } + + /// <inheritdoc /> + public async Task StartAsync() + { + ThrowIfDisposed(); + + if (_isStarted) + { + throw new InvalidOperationException("Download manager is already started."); + } + + // Create a new cancellation token source + _cancellationTokenSource = new CancellationTokenSource(); + + // Start the result fetcher + await _resultFetcher.StartAsync(_cancellationTokenSource.Token).ConfigureAwait(false); + + // Start the downloader + await _downloader.StartAsync(_cancellationTokenSource.Token).ConfigureAwait(false); + + _isStarted = true; + } + + /// <inheritdoc /> + public async Task StopAsync() + { + if (!_isStarted) + { + return; + } + + // Cancel the token to signal all operations to stop + _cancellationTokenSource?.Cancel(); + + // Stop the downloader + await _downloader.StopAsync().ConfigureAwait(false); + + // Stop the result fetcher + await _resultFetcher.StopAsync().ConfigureAwait(false); + + // Dispose the cancellation token source + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; + + _isStarted = false; + } + + /// <inheritdoc /> + public void Dispose() + { + if (_isDisposed) + { + return; + } + + // Stop the pipeline + StopAsync().GetAwaiter().GetResult(); + + // Dispose the HTTP client + _httpClient.Dispose(); + + // Dispose the cancellation token source if it hasn't been disposed yet + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; + + // Mark the queues as completed to release any waiting threads + _downloadQueue.CompleteAdding(); + _resultQueue.CompleteAdding(); + + // Dispose any remaining results + foreach (var result in _resultQueue.GetConsumingEnumerable(CancellationToken.None)) + { + result.Dispose(); + } + + foreach (var result in _downloadQueue.GetConsumingEnumerable(CancellationToken.None)) + { + result.Dispose(); + } + + _downloadQueue.Dispose(); + _resultQueue.Dispose(); + + _isDisposed = true; + } + + private void ThrowIfDisposed() + { + if (_isDisposed) + { + throw new ObjectDisposedException(nameof(CloudFetchDownloadManager)); + } + } + } +} diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloader.cs b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloader.cs new file mode 100644 index 000000000..e3cec4a2f --- /dev/null +++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloader.cs @@ -0,0 +1,402 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.IO; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using K4os.Compression.LZ4.Streams; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Downloads files from URLs. + /// </summary> + internal sealed class CloudFetchDownloader : ICloudFetchDownloader + { + private readonly BlockingCollection<IDownloadResult> _downloadQueue; + private readonly BlockingCollection<IDownloadResult> _resultQueue; + private readonly ICloudFetchMemoryBufferManager _memoryManager; + private readonly HttpClient _httpClient; + private readonly int _maxParallelDownloads; + private readonly bool _isLz4Compressed; + private readonly int _maxRetries; + private readonly int _retryDelayMs; + private readonly SemaphoreSlim _downloadSemaphore; + private Task? _downloadTask; + private CancellationTokenSource? _cancellationTokenSource; + private bool _isCompleted; + private Exception? _error; + private readonly object _errorLock = new object(); + + /// <summary> + /// Initializes a new instance of the <see cref="CloudFetchDownloader"/> class. + /// </summary> + /// <param name="downloadQueue">The queue of downloads to process.</param> + /// <param name="resultQueue">The queue to add completed downloads to.</param> + /// <param name="memoryManager">The memory buffer manager.</param> + /// <param name="httpClient">The HTTP client to use for downloads.</param> + /// <param name="maxParallelDownloads">The maximum number of parallel downloads.</param> + /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> + /// <param name="maxRetries">The maximum number of retry attempts.</param> + /// <param name="retryDelayMs">The delay between retry attempts in milliseconds.</param> + public CloudFetchDownloader( + BlockingCollection<IDownloadResult> downloadQueue, + BlockingCollection<IDownloadResult> resultQueue, + ICloudFetchMemoryBufferManager memoryManager, + HttpClient httpClient, + int maxParallelDownloads, + bool isLz4Compressed, + int maxRetries = 3, + int retryDelayMs = 500) + { + _downloadQueue = downloadQueue ?? throw new ArgumentNullException(nameof(downloadQueue)); + _resultQueue = resultQueue ?? throw new ArgumentNullException(nameof(resultQueue)); + _memoryManager = memoryManager ?? throw new ArgumentNullException(nameof(memoryManager)); + _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); + _maxParallelDownloads = maxParallelDownloads > 0 ? maxParallelDownloads : throw new ArgumentOutOfRangeException(nameof(maxParallelDownloads)); + _isLz4Compressed = isLz4Compressed; + _maxRetries = maxRetries > 0 ? maxRetries : throw new ArgumentOutOfRangeException(nameof(maxRetries)); + _retryDelayMs = retryDelayMs > 0 ? retryDelayMs : throw new ArgumentOutOfRangeException(nameof(retryDelayMs)); + _downloadSemaphore = new SemaphoreSlim(_maxParallelDownloads, _maxParallelDownloads); + _isCompleted = false; + } + + /// <inheritdoc /> + public bool IsCompleted => _isCompleted; + + /// <inheritdoc /> + public bool HasError => _error != null; + + /// <inheritdoc /> + public Exception? Error => _error; + + /// <inheritdoc /> + public async Task StartAsync(CancellationToken cancellationToken) + { + if (_downloadTask != null) + { + throw new InvalidOperationException("Downloader is already running."); + } + + _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _downloadTask = DownloadFilesAsync(_cancellationTokenSource.Token); + + // Wait for the download task to start + await Task.Yield(); + } + + /// <inheritdoc /> + public async Task StopAsync() + { + if (_downloadTask == null) + { + return; + } + + _cancellationTokenSource?.Cancel(); + + try + { + await _downloadTask.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Expected when cancellation is requested + } + catch (Exception ex) + { + Debug.WriteLine($"Error stopping downloader: {ex.Message}"); + } + finally + { + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; + _downloadTask = null; + } + } + + /// <inheritdoc /> + public async Task<IDownloadResult?> GetNextDownloadedFileAsync(CancellationToken cancellationToken) + { + try + { + // Check if there's an error before trying to take from the queue + if (HasError) + { + throw new AdbcException("Error in download process", _error ?? new Exception("Unknown error")); + } + + // Try to take the next result from the queue + IDownloadResult result = await Task.Run(() => _resultQueue.Take(cancellationToken), cancellationToken); + + // Check if this is the end of results guard + if (result == EndOfResultsGuard.Instance) + { + _isCompleted = true; + return null; + } + + return result; + } + catch (OperationCanceledException) + { + // Cancellation was requested + return null; + } + catch (InvalidOperationException) when (_resultQueue.IsCompleted) + { + // Queue is completed and empty + _isCompleted = true; + return null; + } + catch (AdbcException) + { + // Re-throw AdbcExceptions (these are our own errors) + throw; + } + catch (Exception ex) + { + // If there's an error, set the error state and propagate it + SetError(ex); + throw; + } + } + + private async Task DownloadFilesAsync(CancellationToken cancellationToken) + { + await Task.Yield(); + + try + { + // Keep track of active download tasks + var downloadTasks = new ConcurrentDictionary<Task, IDownloadResult>(); + var downloadTaskCompletionSource = new TaskCompletionSource<bool>(); + + // Process items from the download queue until it's completed + foreach (var downloadResult in _downloadQueue.GetConsumingEnumerable(cancellationToken)) + { + // Check if there's an error before processing more downloads + if (HasError) + { + // Add the failed download result to the queue to signal the error + // This will be caught by GetNextDownloadedFileAsync + break; + } + + // Check if this is the end of results guard + if (downloadResult == EndOfResultsGuard.Instance) + { + // Wait for all active downloads to complete + if (downloadTasks.Count > 0) + { + try + { + await Task.WhenAll(downloadTasks.Keys).ConfigureAwait(false); + } + catch (Exception ex) + { + Debug.WriteLine($"Error waiting for downloads to complete: {ex.Message}"); + // Don't set error here, as individual download tasks will handle their own errors + } + } + + // Only add the guard if there's no error + if (!HasError) + { + // Add the guard to the result queue to signal the end of results + _resultQueue.Add(EndOfResultsGuard.Instance, cancellationToken); + _isCompleted = true; + } + break; + } + + // Acquire a download slot + await _downloadSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + + // Start the download task + Task downloadTask = DownloadFileAsync(downloadResult, cancellationToken) + .ContinueWith(t => + { + // Release the download slot + _downloadSemaphore.Release(); + + // Remove the task from the dictionary + downloadTasks.TryRemove(t, out _); + + // Handle any exceptions + if (t.IsFaulted) + { + Exception ex = t.Exception?.InnerException ?? new Exception("Unknown error"); + Debug.WriteLine($"Download failed: {ex.Message}"); + + // Set the download as failed + downloadResult.SetFailed(ex); + + // Set the error state to stop the download process + SetError(ex); + + // Signal that we should stop processing downloads + downloadTaskCompletionSource.TrySetException(ex); + } + }, cancellationToken); + + // Add the task to the dictionary + downloadTasks[downloadTask] = downloadResult; + + // Add the result to the result queue add the result here to assure the download sequence. + _resultQueue.Add(downloadResult, cancellationToken); + + // If there's an error, stop processing more downloads + if (HasError) + { + break; + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // Expected when cancellation is requested + } + catch (Exception ex) + { + Debug.WriteLine($"Error in download loop: {ex.Message}"); + SetError(ex); + } + finally + { + // If there's an error, add the error to the result queue + if (HasError) + { + CompleteWithError(); + } + } + } + + private async Task DownloadFileAsync(IDownloadResult downloadResult, CancellationToken cancellationToken) + { + string url = downloadResult.Link.FileLink; + byte[]? fileData = null; + + // Use the size directly from the download result + long size = downloadResult.Size; + + // Acquire memory before downloading + await _memoryManager.AcquireMemoryAsync(size, cancellationToken).ConfigureAwait(false); + + // Retry logic for downloading files + for (int retry = 0; retry < _maxRetries; retry++) + { + try + { + // Download the file directly + using HttpResponseMessage response = await _httpClient.GetAsync( + url, + HttpCompletionOption.ResponseHeadersRead, + cancellationToken).ConfigureAwait(false); + + response.EnsureSuccessStatusCode(); + + // Log the download size if available from response headers + long? contentLength = response.Content.Headers.ContentLength; + if (contentLength.HasValue && contentLength.Value > 0) + { + Debug.WriteLine($"Downloading file of size: {contentLength.Value / 1024.0 / 1024.0:F2} MB"); + } + + // Read the file data + fileData = await response.Content.ReadAsByteArrayAsync().ConfigureAwait(false); + break; // Success, exit retry loop + } + catch (Exception ex) when (retry < _maxRetries - 1 && !cancellationToken.IsCancellationRequested) + { + // Log the error and retry + Debug.WriteLine($"Error downloading file (attempt {retry + 1}/{_maxRetries}): {ex.Message}"); + await Task.Delay(_retryDelayMs * (retry + 1), cancellationToken).ConfigureAwait(false); + } + } + + if (fileData == null) + { + // Release the memory we acquired + _memoryManager.ReleaseMemory(size); + throw new InvalidOperationException($"Failed to download file from {url} after {_maxRetries} attempts."); + } + + // Process the downloaded file data + MemoryStream dataStream; + + // If the data is LZ4 compressed, decompress it + if (_isLz4Compressed) + { + try + { + dataStream = new MemoryStream(); + using (var inputStream = new MemoryStream(fileData)) + using (var decompressor = LZ4Stream.Decode(inputStream)) + { + await decompressor.CopyToAsync(dataStream, 81920, cancellationToken).ConfigureAwait(false); + } + dataStream.Position = 0; + } + catch (Exception ex) + { + // Release the memory we acquired + _memoryManager.ReleaseMemory(size); + throw new InvalidOperationException($"Error decompressing data: {ex.Message}", ex); + } + } + else + { + dataStream = new MemoryStream(fileData); + } + + // Set the download as completed with the original size + downloadResult.SetCompleted(dataStream, size); + } + + private void SetError(Exception ex) + { + lock (_errorLock) + { + if (_error == null) + { + _error = ex; + } + } + } + + private void CompleteWithError() + { + try + { + // Mark the result queue as completed to prevent further additions + _resultQueue.CompleteAdding(); + + // Mark the download as completed with error + _isCompleted = true; + } + catch (Exception ex) + { + Debug.WriteLine($"Error completing with error: {ex.Message}"); + } + } + } +} diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchMemoryBufferManager.cs b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchMemoryBufferManager.cs new file mode 100644 index 000000000..7f5a13e10 --- /dev/null +++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchMemoryBufferManager.cs @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Manages memory allocation for prefetched files. + /// </summary> + internal sealed class CloudFetchMemoryBufferManager : ICloudFetchMemoryBufferManager + { + private const int DefaultMemoryBufferSizeMB = 200; + private readonly long _maxMemory; + private long _usedMemory; + private readonly SemaphoreSlim _memorySemaphore; + + /// <summary> + /// Initializes a new instance of the <see cref="CloudFetchMemoryBufferManager"/> class. + /// </summary> + /// <param name="maxMemoryMB">The maximum memory allowed for buffering in megabytes.</param> + public CloudFetchMemoryBufferManager(int? maxMemoryMB = null) + { + int memoryMB = maxMemoryMB ?? DefaultMemoryBufferSizeMB; + if (memoryMB <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxMemoryMB), "Memory buffer size must be positive."); + } + + // Convert MB to bytes + _maxMemory = memoryMB * 1024L * 1024L; + _usedMemory = 0; + _memorySemaphore = new SemaphoreSlim(1, 1); + } + + /// <inheritdoc /> + public long MaxMemory => _maxMemory; + + /// <inheritdoc /> + public long UsedMemory => Interlocked.Read(ref _usedMemory); + + /// <inheritdoc /> + public bool TryAcquireMemory(long size) + { + if (size <= 0) + { + throw new ArgumentOutOfRangeException(nameof(size), "Size must be positive."); + } + + // Try to acquire memory + long originalValue; + long newValue; + do + { + originalValue = Interlocked.Read(ref _usedMemory); + newValue = originalValue + size; + + // Check if we would exceed the maximum memory + if (newValue > _maxMemory) + { + return false; + } + } + while (Interlocked.CompareExchange(ref _usedMemory, newValue, originalValue) != originalValue); + + return true; + } + + /// <inheritdoc /> + public async Task AcquireMemoryAsync(long size, CancellationToken cancellationToken) + { + if (size <= 0) + { + throw new ArgumentOutOfRangeException(nameof(size), "Size must be positive."); + } + + // Special case: if size is greater than max memory, we'll never be able to acquire it + if (size > _maxMemory) + { + throw new ArgumentOutOfRangeException(nameof(size), $"Requested size ({size} bytes) exceeds maximum memory ({_maxMemory} bytes)."); + } + + while (!cancellationToken.IsCancellationRequested) + { + // Try to acquire memory without blocking + if (TryAcquireMemory(size)) + { + return; + } + + // If we couldn't acquire memory, wait for some to be released + await Task.Delay(10, cancellationToken).ConfigureAwait(false); + } + + // If we get here, cancellation was requested + cancellationToken.ThrowIfCancellationRequested(); + } + + /// <inheritdoc /> + public void ReleaseMemory(long size) + { + if (size <= 0) + { + throw new ArgumentOutOfRangeException(nameof(size), "Size must be positive."); + } + + // Release memory + long newValue = Interlocked.Add(ref _usedMemory, -size); + + // Ensure we don't go negative + if (newValue < 0) + { + // This should never happen if the code is correct + Interlocked.Exchange(ref _usedMemory, 0); + throw new InvalidOperationException("Memory buffer manager released more memory than was acquired."); + } + } + } +} diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs index 27b47194f..abca66d37 100644 --- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs +++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs @@ -1,4 +1,4 @@ -/* +/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -21,7 +21,8 @@ using System.Diagnostics; using System.Net.Http; using System.Threading; using System.Threading.Tasks; -using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch; +using Apache.Arrow.Adbc.Drivers.Databricks; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; @@ -33,25 +34,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch /// </summary> internal sealed class CloudFetchReader : IArrowArrayStream { - // Default values used if not specified in connection properties - private const int DefaultMaxRetries = 3; - private const int DefaultRetryDelayMs = 500; - private const int DefaultTimeoutMinutes = 5; - - private readonly int maxRetries; - private readonly int retryDelayMs; - private readonly int timeoutMinutes; - - private DatabricksStatement? statement; private readonly Schema schema; - private List<TSparkArrowResultLink>? resultLinks; - private int linkIndex; - private ArrowStreamReader? currentReader; private readonly bool isLz4Compressed; - private long startOffset; - - // Lazy initialization of HttpClient - private readonly Lazy<HttpClient> httpClient; + private readonly ICloudFetchDownloadManager downloadManager; + private ArrowStreamReader? currentReader; + private IDownloadResult? currentDownloadResult; + private bool isPrefetchEnabled; + private bool isDisposed; /// <summary> /// Initializes a new instance of the <see cref="CloudFetchReader"/> class. @@ -61,62 +50,37 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch /// <param name="isLz4Compressed">Whether the results are LZ4 compressed.</param> public CloudFetchReader(DatabricksStatement statement, Schema schema, bool isLz4Compressed) { - this.statement = statement; this.schema = schema; this.isLz4Compressed = isLz4Compressed; - // Get configuration values from connection properties or use defaults + // Check if prefetch is enabled var connectionProps = statement.Connection.Properties; - - // Parse max retries - int parsedMaxRetries = DefaultMaxRetries; - if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxRetries, out string? maxRetriesStr) && - int.TryParse(maxRetriesStr, out parsedMaxRetries) && - parsedMaxRetries > 0) - { - // Value was successfully parsed - } - else - { - parsedMaxRetries = DefaultMaxRetries; - } - this.maxRetries = parsedMaxRetries; - - // Parse retry delay - int parsedRetryDelay = DefaultRetryDelayMs; - if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchRetryDelayMs, out string? retryDelayStr) && - int.TryParse(retryDelayStr, out parsedRetryDelay) && - parsedRetryDelay > 0) - { - // Value was successfully parsed - } - else + isPrefetchEnabled = true; // Default to true + if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchPrefetchEnabled, out string? prefetchEnabledStr)) { - parsedRetryDelay = DefaultRetryDelayMs; + if (bool.TryParse(prefetchEnabledStr, out bool parsedPrefetchEnabled)) + { + isPrefetchEnabled = parsedPrefetchEnabled; + } + else + { + throw new ArgumentException($"Invalid value for {DatabricksParameters.CloudFetchPrefetchEnabled}: {prefetchEnabledStr}. Expected a boolean value."); + } } - this.retryDelayMs = parsedRetryDelay; - // Parse timeout minutes - int parsedTimeout = DefaultTimeoutMinutes; - if (connectionProps.TryGetValue(DatabricksParameters.CloudFetchTimeoutMinutes, out string? timeoutStr) && - int.TryParse(timeoutStr, out parsedTimeout) && - parsedTimeout > 0) + // Initialize the download manager + if (isPrefetchEnabled) { - // Value was successfully parsed + downloadManager = new CloudFetchDownloadManager(statement, schema, isLz4Compressed); + downloadManager.StartAsync().Wait(); } else { - parsedTimeout = DefaultTimeoutMinutes; + // For now, we only support the prefetch implementation + // This flag is reserved for future use if we need to support a non-prefetch mode + downloadManager = new CloudFetchDownloadManager(statement, schema, isLz4Compressed); + downloadManager.StartAsync().Wait(); } - this.timeoutMinutes = parsedTimeout; - - // Initialize HttpClient with the configured timeout - this.httpClient = new Lazy<HttpClient>(() => - { - var client = new HttpClient(); - client.Timeout = TimeSpan.FromMinutes(this.timeoutMinutes); - return client; - }); } /// <summary> @@ -124,11 +88,6 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch /// </summary> public Schema Schema { get { return schema; } } - private HttpClient HttpClient - { - get { return httpClient.Value; } - } - /// <summary> /// Reads the next record batch from the result set. /// </summary> @@ -136,6 +95,8 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch /// <returns>The next record batch, or null if there are no more batches.</returns> public async ValueTask<RecordBatch?> ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) { + ThrowIfDisposed(); + while (true) { // If we have a current reader, try to read the next batch @@ -148,159 +109,100 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch } else { + // Clean up the current reader and download result this.currentReader.Dispose(); this.currentReader = null; + + if (this.currentDownloadResult != null) + { + this.currentDownloadResult.Dispose(); + this.currentDownloadResult = null; + } } } - // If we have more links to process, download and process the next one - if (this.resultLinks != null && this.linkIndex < this.resultLinks.Count) + // If we don't have a current reader, get the next downloaded file + if (this.downloadManager != null) { - var link = this.resultLinks[this.linkIndex++]; - byte[]? fileData = null; + // Start the download manager if it's not already started + if (!this.isPrefetchEnabled) + { + throw new InvalidOperationException("Prefetch must be enabled."); + } try { - // Try to download with retry logic - for (int retry = 0; retry < this.maxRetries; retry++) + // Get the next downloaded file + this.currentDownloadResult = await this.downloadManager.GetNextDownloadedFileAsync(cancellationToken); + if (this.currentDownloadResult == null) { - try - { - fileData = await DownloadFileAsync(link.FileLink, cancellationToken); - break; // Success, exit retry loop - } - catch (Exception) when (retry < this.maxRetries - 1) - { - // Only delay and retry if we haven't reached max retries - await Task.Delay(this.retryDelayMs * (retry + 1), cancellationToken); - } + // No more files + return null; } - // If download still failed after all retries - if (fileData == null) + await this.currentDownloadResult.DownloadCompletedTask; + + // Create a new reader for the downloaded file + try { - throw new AdbcException($"Failed to download CloudFetch data from {link.FileLink} after {this.maxRetries} attempts"); + this.currentReader = new ArrowStreamReader(this.currentDownloadResult.DataStream); + continue; } - - ReadOnlyMemory<byte> dataToUse = new ReadOnlyMemory<byte>(fileData); - - // If the data is LZ4 compressed, decompress it - if (this.isLz4Compressed) + catch (Exception ex) { - dataToUse = Lz4Utilities.DecompressLz4(fileData); + Debug.WriteLine($"Error creating Arrow reader: {ex.Message}"); + this.currentDownloadResult.Dispose(); + this.currentDownloadResult = null; + throw; } - - // Use ChunkStream which supports ReadOnlyMemory<byte> directly - this.currentReader = new ArrowStreamReader(new ChunkStream(this.schema, dataToUse)); - continue; } catch (Exception ex) { - // Create concise error message based on exception type - string errorPrefix = $"CloudFetch link {this.linkIndex - 1}:"; - string errorMessage = ex switch - { - _ when ex.GetType().Name.Contains("LZ4") => $"{errorPrefix} LZ4 decompression failed - Data may be corrupted", - HttpRequestException or TaskCanceledException => $"{errorPrefix} Download failed - {ex.Message}", - _ => $"{errorPrefix} Processing failed - {ex.Message}" // Default case for any other exception - }; - throw new AdbcException(errorMessage, ex); + Debug.WriteLine($"Error getting next downloaded file: {ex.Message}"); + throw; } } - this.resultLinks = null; - this.linkIndex = 0; - - // If there's no statement, we're done - if (this.statement == null) - { - return null; - } - - // Fetch more results from the server - TFetchResultsReq request = new TFetchResultsReq(this.statement.OperationHandle!, TFetchOrientation.FETCH_NEXT, this.statement.BatchSize); - - // Set the start row offset if we have processed some links already - if (this.startOffset > 0) - { - request.StartRowOffset = this.startOffset; - } - - TFetchResultsResp response; - try - { - response = await this.statement.Connection.Client!.FetchResults(request, cancellationToken); - } - catch (Exception ex) - { - throw new AdbcException($"Server request failed - {ex.Message}", ex); - } - - // Check if we have URL-based results - if (response.Results.__isset.resultLinks && - response.Results.ResultLinks != null && - response.Results.ResultLinks.Count > 0) - { - this.resultLinks = response.Results.ResultLinks; - - // Update the start offset for the next fetch by calculating it from the links - if (this.resultLinks.Count > 0) - { - var lastLink = this.resultLinks[this.resultLinks.Count - 1]; - this.startOffset = lastLink.StartRowOffset + lastLink.RowCount; - } - - // If the server indicates there are no more rows, we can close the statement - if (!response.HasMoreRows) - { - this.statement = null; - } - } - else - { - // If there are no more results, we're done - this.statement = null; - return null; - } + // If we get here, there are no more files + return null; } } /// <summary> - /// Downloads a file from a URL. + /// Disposes the reader. /// </summary> - /// <param name="url">The URL to download from.</param> - /// <param name="cancellationToken">The cancellation token.</param> - /// <returns>The downloaded file data.</returns> - private async Task<byte[]> DownloadFileAsync(string url, CancellationToken cancellationToken) + public void Dispose() { - using HttpResponseMessage response = await HttpClient.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, cancellationToken); - response.EnsureSuccessStatusCode(); - - // Get the content length if available - long? contentLength = response.Content.Headers.ContentLength; - if (contentLength.HasValue && contentLength.Value > 0) + if (isDisposed) { - Debug.WriteLine($"Downloading file of size: {contentLength.Value / 1024.0 / 1024.0:F2} MB"); + return; } - return await response.Content.ReadAsByteArrayAsync(); - } - - /// <summary> - /// Disposes the reader. - /// </summary> - public void Dispose() - { if (this.currentReader != null) { this.currentReader.Dispose(); this.currentReader = null; } - // Dispose the HttpClient if it was created - if (httpClient.IsValueCreated) + if (this.currentDownloadResult != null) + { + this.currentDownloadResult.Dispose(); + this.currentDownloadResult = null; + } + + if (this.downloadManager != null) + { + this.downloadManager.Dispose(); + } + + isDisposed = true; + } + + private void ThrowIfDisposed() + { + if (isDisposed) { - httpClient.Value.Dispose(); + throw new ObjectDisposedException(nameof(CloudFetchReader)); } } } diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs new file mode 100644 index 000000000..b0ad05a6a --- /dev/null +++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Fetches result chunks from the Thrift server. + /// </summary> + internal sealed class CloudFetchResultFetcher : ICloudFetchResultFetcher + { + private readonly IHiveServer2Statement _statement; + private readonly ICloudFetchMemoryBufferManager _memoryManager; + private readonly BlockingCollection<IDownloadResult> _downloadQueue; + private long _startOffset; + private bool _hasMoreResults; + private bool _isCompleted; + private Task? _fetchTask; + private CancellationTokenSource? _cancellationTokenSource; + private Exception? _error; + private long _batchSize; + + /// <summary> + /// Initializes a new instance of the <see cref="CloudFetchResultFetcher"/> class. + /// </summary> + /// <param name="statement">The HiveServer2 statement interface.</param> + /// <param name="memoryManager">The memory buffer manager.</param> + /// <param name="downloadQueue">The queue to add download tasks to.</param> + /// <param name="prefetchCount">The number of result chunks to prefetch.</param> + public CloudFetchResultFetcher( + IHiveServer2Statement statement, + ICloudFetchMemoryBufferManager memoryManager, + BlockingCollection<IDownloadResult> downloadQueue, + long batchSize) + { + _statement = statement ?? throw new ArgumentNullException(nameof(statement)); + _memoryManager = memoryManager ?? throw new ArgumentNullException(nameof(memoryManager)); + _downloadQueue = downloadQueue ?? throw new ArgumentNullException(nameof(downloadQueue)); + _hasMoreResults = true; + _isCompleted = false; + _batchSize = batchSize; + } + + /// <inheritdoc /> + public bool HasMoreResults => _hasMoreResults; + + /// <inheritdoc /> + public bool IsCompleted => _isCompleted; + + /// <inheritdoc /> + public bool HasError => _error != null; + + /// <inheritdoc /> + public Exception? Error => _error; + + /// <inheritdoc /> + public async Task StartAsync(CancellationToken cancellationToken) + { + if (_fetchTask != null) + { + throw new InvalidOperationException("Fetcher is already running."); + } + + // Reset state + _startOffset = 0; + _hasMoreResults = true; + _isCompleted = false; + _error = null; + + _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _fetchTask = FetchResultsAsync(_cancellationTokenSource.Token); + + // Wait for the fetch task to start + await Task.Yield(); + } + + /// <inheritdoc /> + public async Task StopAsync() + { + if (_fetchTask == null) + { + return; + } + + _cancellationTokenSource?.Cancel(); + + try + { + await _fetchTask.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Expected when cancellation is requested + } + catch (Exception ex) + { + Debug.WriteLine($"Error stopping fetcher: {ex.Message}"); + } + finally + { + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; + _fetchTask = null; + } + } + + private async Task FetchResultsAsync(CancellationToken cancellationToken) + { + try + { + // Continue fetching as needed + while (_hasMoreResults && !cancellationToken.IsCancellationRequested) + { + try + { + // Fetch more results from the server + await FetchNextResultBatchAsync(cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // Expected when cancellation is requested + break; + } + catch (Exception ex) + { + Debug.WriteLine($"Error fetching results: {ex.Message}"); + _error = ex; + _hasMoreResults = false; + break; + } + } + + // Add the end of results guard to the queue + _downloadQueue.Add(EndOfResultsGuard.Instance, cancellationToken); + _isCompleted = true; + } + catch (Exception ex) + { + Debug.WriteLine($"Unhandled error in fetcher: {ex.Message}"); + _error = ex; + _hasMoreResults = false; + _isCompleted = true; + + // Add the end of results guard to the queue even in case of error + try + { + _downloadQueue.Add(EndOfResultsGuard.Instance, CancellationToken.None); + } + catch (Exception) + { + // Ignore any errors when adding the guard in case of error + } + } + } + + private async Task FetchNextResultBatchAsync(CancellationToken cancellationToken) + { + // Create fetch request + TFetchResultsReq request = new TFetchResultsReq(_statement.OperationHandle!, TFetchOrientation.FETCH_NEXT, _batchSize); + + // Set the start row offset if we have processed some links already + if (_startOffset > 0) + { + request.StartRowOffset = _startOffset; + } + + // Fetch results + TFetchResultsResp response; + try + { + response = await _statement.Client.FetchResults(request, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + Debug.WriteLine($"Error fetching results from server: {ex.Message}"); + _hasMoreResults = false; + throw; + } + + // Check if we have URL-based results + if (response.Results.__isset.resultLinks && + response.Results.ResultLinks != null && + response.Results.ResultLinks.Count > 0) + { + List<TSparkArrowResultLink> resultLinks = response.Results.ResultLinks; + + // Add each link to the download queue + foreach (var link in resultLinks) + { + var downloadResult = new DownloadResult(link, _memoryManager); + _downloadQueue.Add(downloadResult, cancellationToken); + } + + // Update the start offset for the next fetch + if (resultLinks.Count > 0) + { + var lastLink = resultLinks[resultLinks.Count - 1]; + _startOffset = lastLink.StartRowOffset + lastLink.RowCount; + } + + // Update whether there are more results + _hasMoreResults = response.HasMoreRows; + } + else + { + // No more results + _hasMoreResults = false; + } + } + } +} diff --git a/csharp/src/Drivers/Databricks/CloudFetch/DownloadResult.cs b/csharp/src/Drivers/Databricks/CloudFetch/DownloadResult.cs new file mode 100644 index 000000000..eb2736c33 --- /dev/null +++ b/csharp/src/Drivers/Databricks/CloudFetch/DownloadResult.cs @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.IO; +using System.Threading.Tasks; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Represents a downloaded result file with its associated metadata. + /// </summary> + internal sealed class DownloadResult : IDownloadResult + { + private readonly TaskCompletionSource<bool> _downloadCompletionSource; + private readonly ICloudFetchMemoryBufferManager _memoryManager; + private Stream? _dataStream; + private bool _isDisposed; + private long _size; + + /// <summary> + /// Initializes a new instance of the <see cref="DownloadResult"/> class. + /// </summary> + /// <param name="link">The link information for this result.</param> + /// <param name="memoryManager">The memory buffer manager.</param> + public DownloadResult(TSparkArrowResultLink link, ICloudFetchMemoryBufferManager memoryManager) + { + Link = link ?? throw new ArgumentNullException(nameof(link)); + _memoryManager = memoryManager ?? throw new ArgumentNullException(nameof(memoryManager)); + _downloadCompletionSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously); + _size = link.BytesNum; + } + + /// <inheritdoc /> + public TSparkArrowResultLink Link { get; } + + /// <inheritdoc /> + public Stream DataStream + { + get + { + ThrowIfDisposed(); + if (!IsCompleted) + { + throw new InvalidOperationException("Download has not completed yet."); + } + return _dataStream!; + } + } + + /// <inheritdoc /> + public long Size => _size; + + /// <inheritdoc /> + public Task DownloadCompletedTask => _downloadCompletionSource.Task; + + /// <inheritdoc /> + public bool IsCompleted => _downloadCompletionSource.Task.IsCompleted && !_downloadCompletionSource.Task.IsFaulted; + + /// <inheritdoc /> + public void SetCompleted(Stream dataStream, long size) + { + ThrowIfDisposed(); + _dataStream = dataStream ?? throw new ArgumentNullException(nameof(dataStream)); + _downloadCompletionSource.TrySetResult(true); + _size = size; + } + + /// <inheritdoc /> + public void SetFailed(Exception exception) + { + ThrowIfDisposed(); + _downloadCompletionSource.TrySetException(exception ?? throw new ArgumentNullException(nameof(exception))); + } + + /// <inheritdoc /> + public void Dispose() + { + if (_isDisposed) + { + return; + } + + if (_dataStream != null) + { + _dataStream.Dispose(); + _dataStream = null; + + // Release memory back to the manager + if (_size > 0) + { + _memoryManager.ReleaseMemory(_size); + } + } + + // Ensure any waiting tasks are completed if not already + if (!_downloadCompletionSource.Task.IsCompleted) + { + _downloadCompletionSource.TrySetCanceled(); + } + + _isDisposed = true; + } + + private void ThrowIfDisposed() + { + if (_isDisposed) + { + throw new ObjectDisposedException(nameof(DownloadResult)); + } + } + } +} diff --git a/csharp/src/Drivers/Databricks/CloudFetch/EndOfResultsGuard.cs b/csharp/src/Drivers/Databricks/CloudFetch/EndOfResultsGuard.cs new file mode 100644 index 000000000..d305082cf --- /dev/null +++ b/csharp/src/Drivers/Databricks/CloudFetch/EndOfResultsGuard.cs @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.IO; +using System.Threading.Tasks; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Special marker class that indicates the end of results in the download queue. + /// </summary> + internal sealed class EndOfResultsGuard : IDownloadResult + { + private static readonly Task CompletedTask = Task.CompletedTask; + + /// <summary> + /// Gets the singleton instance of the <see cref="EndOfResultsGuard"/> class. + /// </summary> + public static EndOfResultsGuard Instance { get; } = new EndOfResultsGuard(); + + private EndOfResultsGuard() + { + // Private constructor to enforce singleton pattern + } + + /// <inheritdoc /> + public TSparkArrowResultLink Link => throw new NotSupportedException("EndOfResultsGuard does not have a link."); + + /// <inheritdoc /> + public Stream DataStream => throw new NotSupportedException("EndOfResultsGuard does not have a data stream."); + + /// <inheritdoc /> + public long Size => 0; + + /// <inheritdoc /> + public Task DownloadCompletedTask => CompletedTask; + + /// <inheritdoc /> + public bool IsCompleted => true; + + /// <inheritdoc /> + public void SetCompleted(Stream dataStream, long size) => throw new NotSupportedException("EndOfResultsGuard cannot be completed."); + + /// <inheritdoc /> + public void SetFailed(Exception exception) => throw new NotSupportedException("EndOfResultsGuard cannot fail."); + + /// <inheritdoc /> + public void Dispose() + { + // Nothing to dispose + } + } +} diff --git a/csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs b/csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs new file mode 100644 index 000000000..444213087 --- /dev/null +++ b/csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Represents a downloaded result file with its associated metadata. + /// </summary> + internal interface IDownloadResult : IDisposable + { + /// <summary> + /// Gets the link information for this result. + /// </summary> + TSparkArrowResultLink Link { get; } + + /// <summary> + /// Gets the stream containing the downloaded data. + /// </summary> + Stream DataStream { get; } + + /// <summary> + /// Gets the size of the downloaded data in bytes. + /// </summary> + long Size { get; } + + /// <summary> + /// Gets a task that completes when the download is finished. + /// </summary> + Task DownloadCompletedTask { get; } + + /// <summary> + /// Gets a value indicating whether the download has completed. + /// </summary> + bool IsCompleted { get; } + + /// <summary> + /// Sets the download as completed with the provided data stream. + /// </summary> + /// <param name="dataStream">The stream containing the downloaded data.</param> + /// <param name="size">The size of the downloaded data in bytes.</param> + void SetCompleted(Stream dataStream, long size); + + /// <summary> + /// Sets the download as failed with the specified exception. + /// </summary> + /// <param name="exception">The exception that caused the failure.</param> + void SetFailed(Exception exception); + } + + /// <summary> + /// Manages memory allocation for prefetched files. + /// </summary> + internal interface ICloudFetchMemoryBufferManager + { + /// <summary> + /// Gets the maximum memory allowed for buffering in bytes. + /// </summary> + long MaxMemory { get; } + + /// <summary> + /// Gets the currently used memory in bytes. + /// </summary> + long UsedMemory { get; } + + /// <summary> + /// Tries to acquire memory for a download without blocking. + /// </summary> + /// <param name="size">The size in bytes to acquire.</param> + /// <returns>True if memory was successfully acquired, false otherwise.</returns> + bool TryAcquireMemory(long size); + + /// <summary> + /// Acquires memory for a download, blocking until memory is available. + /// </summary> + /// <param name="size">The size in bytes to acquire.</param> + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns>A task representing the asynchronous operation.</returns> + Task AcquireMemoryAsync(long size, CancellationToken cancellationToken); + + /// <summary> + /// Releases previously acquired memory. + /// </summary> + /// <param name="size">The size in bytes to release.</param> + void ReleaseMemory(long size); + } + + /// <summary> + /// Fetches result chunks from the Thrift server. + /// </summary> + internal interface ICloudFetchResultFetcher + { + /// <summary> + /// Starts the result fetcher. + /// </summary> + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns>A task representing the asynchronous operation.</returns> + Task StartAsync(CancellationToken cancellationToken); + + /// <summary> + /// Stops the result fetcher. + /// </summary> + /// <returns>A task representing the asynchronous operation.</returns> + Task StopAsync(); + + /// <summary> + /// Gets a value indicating whether there are more results available. + /// </summary> + bool HasMoreResults { get; } + + /// <summary> + /// Gets a value indicating whether the fetcher has completed fetching all results. + /// </summary> + bool IsCompleted { get; } + + /// <summary> + /// Gets a value indicating whether the fetcher encountered an error. + /// </summary> + bool HasError { get; } + + /// <summary> + /// Gets the error encountered by the fetcher, if any. + /// </summary> + Exception? Error { get; } + } + + /// <summary> + /// Downloads files from URLs. + /// </summary> + internal interface ICloudFetchDownloader + { + /// <summary> + /// Starts the downloader. + /// </summary> + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns>A task representing the asynchronous operation.</returns> + Task StartAsync(CancellationToken cancellationToken); + + /// <summary> + /// Stops the downloader. + /// </summary> + /// <returns>A task representing the asynchronous operation.</returns> + Task StopAsync(); + + /// <summary> + /// Gets the next downloaded file. + /// </summary> + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns>The next downloaded file, or null if there are no more files.</returns> + Task<IDownloadResult?> GetNextDownloadedFileAsync(CancellationToken cancellationToken); + + /// <summary> + /// Gets a value indicating whether the downloader has completed all downloads. + /// </summary> + bool IsCompleted { get; } + + /// <summary> + /// Gets a value indicating whether the downloader encountered an error. + /// </summary> + bool HasError { get; } + + /// <summary> + /// Gets the error encountered by the downloader, if any. + /// </summary> + Exception? Error { get; } + } + + /// <summary> + /// Manages the CloudFetch download pipeline. + /// </summary> + internal interface ICloudFetchDownloadManager : IDisposable + { + /// <summary> + /// Gets the next downloaded file. + /// </summary> + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns>The next downloaded file, or null if there are no more files.</returns> + Task<IDownloadResult?> GetNextDownloadedFileAsync(CancellationToken cancellationToken); + + /// <summary> + /// Starts the download manager. + /// </summary> + /// <returns>A task representing the asynchronous operation.</returns> + Task StartAsync(); + + /// <summary> + /// Stops the download manager. + /// </summary> + /// <returns>A task representing the asynchronous operation.</returns> + Task StopAsync(); + + /// <summary> + /// Gets a value indicating whether there are more results available. + /// </summary> + bool HasMoreResults { get; } + } +} diff --git a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs b/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs new file mode 100644 index 000000000..cfc92b98b --- /dev/null +++ b/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Interface for accessing HiveServer2Statement properties needed by CloudFetchResultFetcher. + /// </summary> + internal interface IHiveServer2Statement + { + /// <summary> + /// Gets the operation handle. + /// </summary> + TOperationHandle? OperationHandle { get; } + + /// <summary> + /// Gets the client. + /// </summary> + TCLIService.IAsync Client { get; } + } +} diff --git a/csharp/src/Drivers/Databricks/CloudFetch/cloudfetch-pipeline-design.md b/csharp/src/Drivers/Databricks/CloudFetch/cloudfetch-pipeline-design.md new file mode 100644 index 000000000..43a70c071 --- /dev/null +++ b/csharp/src/Drivers/Databricks/CloudFetch/cloudfetch-pipeline-design.md @@ -0,0 +1,72 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> + +Current cloudfetch implementation downloads the cloud result file inline with the reader, which generates performance problems, as it slows down the reader when needing to download the next result file. + +We need to add prefetch functionality to the cloudfetch downloader, e.g., we shouldn't block the reader because of the file download. Instead, we should have a separate downloader class to handle the parallel downloading of the result files. + +If the current batch of downloads is finished, the downloader should be able to asynchronously fetch the next batch of files and start prefetching. + +There is some file download code currently in the SparkCloudFetchReader.cs; please remove or refactor it into the new design. We will just use prefetch logic for downloading to simplify the code. + +Also, the logic of the FetchResults call should be included in the prefetch logic. + +Additionally, we need to guarantee that the order of the read by the reader is the same as that of the TSparkArrowResultLink in TFetchResultsResp. + +We need to add the following configuration items: +1. How many parallel downloads are allowed (default value: 3) +2. How many files we want to prefetch (default value: 2) +3. How much memory we want to use to buffer the prefetched files (default value: 200MB) +4. Whether prefetch is enabled (default: true) + +Here are some high level class designs for this work: + +- The current SparkCloudFetchReader class will be simplified; all file download and fetch-next-result-set logic will be moved out of this class. + +- A new DownloadResult class is central for monitoring download and usage status of each file: + - Contains the link to the file. + - Holds the memory stream of the file. + - Tracks the size of the downloaded file. + - Includes a TaskSource so that SparkCloudFetchReader can wait if a download is not finished. + - Acts as the event for the pipeline and should be disposable, holding a reference to the CloudFetchMemoryBufferManager and returning the memory upon disposal. + +- Pipeline Design: + - Uses a concurrent queue to build the pipeline, with DownloadResult as the event for each stage. + - Two workers process the pipeline: + - The result chunk fetcher worker: + - Continuously fetches results from the Thrift server via a background task and appends events to the download queue. + - Monitors the pending download queue and continues fetching new results only if the event count is below a configurable threshold. + - The file download worker: + - Polls events from the download queue, performs file downloads, and appends them to the result queue for consumption by the SparkCloudFetchReader. + - Adheres to the concurrent download and memory limit configurations. + - Runs as a background task. + +- A new class, CloudFetchDownloader, will maintain a concurrent queue of DownloadResult objects: + - When getNextDownloadedFileAsync is called, it will pop and return a DownloadResult to the SparkCloudFetchReader. + - If the queue is empty but there are still results being fetched by CloudFetchResultFetcher, it should wait for new results before returning them. + - If the queue is empty and no more results are forthcoming, it should return null. + +- The CloudFetchMemoryBufferManager class will restrict how many files can be buffered in memory: + - Memory must be acquired before scheduling a download. + - Memory is released once SparkCloudFetchReader has finished reading a file. + - DownloadResult can be made disposable to ensure safe operation. + +- The CloudFetchDownloadManager class will manage all the above components: + - SparkCloudFetchReader will obtain new download results in batches from this manager. + - It monitors both fetch and download statuses, returning null to the SparkCloudFetchReader when there are no more files (i.e., no more results fetched from the Thrift server and no pending events in any queues). + +- Interfaces should be used to decouple the implementations of each class. diff --git a/csharp/src/Drivers/Databricks/DatabricksParameters.cs b/csharp/src/Drivers/Databricks/DatabricksParameters.cs index c99fbde3e..f45350b72 100644 --- a/csharp/src/Drivers/Databricks/DatabricksParameters.cs +++ b/csharp/src/Drivers/Databricks/DatabricksParameters.cs @@ -68,6 +68,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks /// </summary> public const string ApplySSPWithQueries = "adbc.databricks.apply_ssp_with_queries"; + /// <summary> /// Prefix for server-side properties. Properties with this prefix will be passed to the server /// by executing a "set key=value" query when opening a session. @@ -85,6 +86,30 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks /// Default value is 900 seconds (15 minutes). Set to 0 to retry indefinitely. /// </summary> public const string TemporarilyUnavailableRetryTimeout = "adbc.spark.temporarily_unavailable_retry_timeout"; + + /// <summary> + /// Maximum number of parallel downloads for CloudFetch operations. + /// Default value is 3 if not specified. + /// </summary> + public const string CloudFetchParallelDownloads = "adbc.databricks.cloudfetch.parallel_downloads"; + + /// <summary> + /// Number of files to prefetch in CloudFetch operations. + /// Default value is 2 if not specified. + /// </summary> + public const string CloudFetchPrefetchCount = "adbc.databricks.cloudfetch.prefetch_count"; + + /// <summary> + /// Maximum memory buffer size in MB for CloudFetch prefetched files. + /// Default value is 200MB if not specified. + /// </summary> + public const string CloudFetchMemoryBufferSize = "adbc.databricks.cloudfetch.memory_buffer_size_mb"; + + /// <summary> + /// Whether CloudFetch prefetch functionality is enabled. + /// Default value is true if not specified. + /// </summary> + public const string CloudFetchPrefetchEnabled = "adbc.databricks.cloudfetch.prefetch_enabled"; } /// <summary> diff --git a/csharp/src/Drivers/Databricks/DatabricksStatement.cs b/csharp/src/Drivers/Databricks/DatabricksStatement.cs index f2c9e643a..447689e51 100644 --- a/csharp/src/Drivers/Databricks/DatabricksStatement.cs +++ b/csharp/src/Drivers/Databricks/DatabricksStatement.cs @@ -18,6 +18,7 @@ using System; using Apache.Arrow.Adbc.Drivers.Apache; using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch; using Apache.Hive.Service.Rpc.Thrift; namespace Apache.Arrow.Adbc.Drivers.Databricks @@ -25,7 +26,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks /// <summary> /// Databricks-specific implementation of <see cref="AdbcStatement"/> /// </summary> - internal class DatabricksStatement : SparkStatement + internal class DatabricksStatement : SparkStatement, IHiveServer2Statement { private bool useCloudFetch; private bool canDecompressLz4; @@ -50,6 +51,9 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks statement.MaxBytesPerFile = maxBytesPerFile; } + // Cast the Client to IAsync for CloudFetch compatibility + TCLIService.IAsync IHiveServer2Statement.Client => Connection.Client; + public override void SetOption(string key, string value) { switch (key) diff --git a/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchDownloaderTest.cs b/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchDownloaderTest.cs new file mode 100644 index 000000000..350ddf219 --- /dev/null +++ b/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchDownloaderTest.cs @@ -0,0 +1,534 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch; +using Apache.Hive.Service.Rpc.Thrift; +using Moq; +using Moq.Protected; +using Xunit; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch +{ + public class CloudFetchDownloaderTest + { + private readonly BlockingCollection<IDownloadResult> _downloadQueue; + private readonly BlockingCollection<IDownloadResult> _resultQueue; + private readonly Mock<ICloudFetchMemoryBufferManager> _mockMemoryManager; + + public CloudFetchDownloaderTest() + { + _downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10); + _resultQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10); + _mockMemoryManager = new Mock<ICloudFetchMemoryBufferManager>(); + + // Set up memory manager defaults + _mockMemoryManager.Setup(m => m.TryAcquireMemory(It.IsAny<long>())).Returns(true); + _mockMemoryManager.Setup(m => m.AcquireMemoryAsync(It.IsAny<long>(), It.IsAny<CancellationToken>())) + .Returns(Task.CompletedTask); + } + + [Fact] + public async Task StartAsync_CalledTwice_ThrowsException() + { + // Arrange + var mockDownloader = new Mock<ICloudFetchDownloader>(); + + // Setup first call to succeed and second call to throw + mockDownloader.SetupSequence(d => d.StartAsync(It.IsAny<CancellationToken>())) + .Returns(Task.CompletedTask) + .Throws(new InvalidOperationException("Downloader is already running.")); + + // Act & Assert + await mockDownloader.Object.StartAsync(CancellationToken.None); + await Assert.ThrowsAsync<InvalidOperationException>(() => mockDownloader.Object.StartAsync(CancellationToken.None)); + } + + [Fact] + public async Task GetNextDownloadedFileAsync_ReturnsNull_WhenEndOfResultsGuardReceived() + { + // Arrange + var mockHttpMessageHandler = new Mock<HttpMessageHandler>(); + var httpClient = new HttpClient(mockHttpMessageHandler.Object); + var downloader = new CloudFetchDownloader( + _downloadQueue, + _resultQueue, + _mockMemoryManager.Object, + httpClient, + 3, // maxParallelDownloads + false); // isLz4Compressed + + // Add the end of results guard to the result queue + _resultQueue.Add(EndOfResultsGuard.Instance); + + // Act + await downloader.StartAsync(CancellationToken.None); + var result = await downloader.GetNextDownloadedFileAsync(CancellationToken.None); + + // Assert + Assert.Null(result); + Assert.True(downloader.IsCompleted); + + // Cleanup + await downloader.StopAsync(); + } + + [Fact] + public async Task DownloadFileAsync_ProcessesFile_AndAddsToResultQueue() + { + // Arrange + string testContent = "Test file content"; + byte[] testContentBytes = Encoding.UTF8.GetBytes(testContent); + + // Create a mock HTTP handler that returns our test content + var mockHttpMessageHandler = CreateMockHttpMessageHandler(testContentBytes); + var httpClient = new HttpClient(mockHttpMessageHandler.Object); + + // Create a test download result + var mockDownloadResult = new Mock<IDownloadResult>(); + var resultLink = new TSparkArrowResultLink { FileLink = "http://test.com/file1" }; + mockDownloadResult.Setup(r => r.Link).Returns(resultLink); + mockDownloadResult.Setup(r => r.Size).Returns(testContentBytes.Length); + + // Capture the stream and size passed to SetCompleted + Stream? capturedStream = null; + long capturedSize = 0; + mockDownloadResult.Setup(r => r.SetCompleted(It.IsAny<Stream>(), It.IsAny<long>())) + .Callback<Stream, long>((stream, size) => + { + capturedStream = stream; + capturedSize = size; + }); + + // Create the downloader and add the download to the queue + var downloader = new CloudFetchDownloader( + _downloadQueue, + _resultQueue, + _mockMemoryManager.Object, + httpClient, + 1, // maxParallelDownloads + false, // isLz4Compressed + 1, // maxRetries + 10); // retryDelayMs + + // Act + await downloader.StartAsync(CancellationToken.None); + _downloadQueue.Add(mockDownloadResult.Object); + + // Wait for the download to be processed + await Task.Delay(100); + + // Add the end of results guard to complete the downloader + _downloadQueue.Add(EndOfResultsGuard.Instance); + + // Wait for the result to be available + var result = await downloader.GetNextDownloadedFileAsync(CancellationToken.None); + + // Assert + Assert.Same(mockDownloadResult.Object, result); + + // Verify SetCompleted was called + mockDownloadResult.Verify(r => r.SetCompleted(It.IsAny<Stream>(), It.IsAny<long>()), Times.Once); + + // Verify the content of the stream + Assert.NotNull(capturedStream); + using (var reader = new StreamReader(capturedStream)) + { + string content = reader.ReadToEnd(); + Assert.Equal(testContent, content); + } + + // Verify memory was acquired + _mockMemoryManager.Verify(m => m.AcquireMemoryAsync(It.IsAny<long>(), It.IsAny<CancellationToken>()), Times.Once); + + // Cleanup + await downloader.StopAsync(); + } + + [Fact] + public async Task DownloadFileAsync_HandlesHttpError_AndSetsFailedOnDownloadResult() + { + // Arrange + // Create a mock HTTP handler that returns a 404 error + var mockHttpMessageHandler = new Mock<HttpMessageHandler>(); + mockHttpMessageHandler + .Protected() + .Setup<Task<HttpResponseMessage>>( + "SendAsync", + ItExpr.IsAny<HttpRequestMessage>(), + ItExpr.IsAny<CancellationToken>()) + .Returns<HttpRequestMessage, CancellationToken>(async (request, token) => + { + await Task.Delay(1, token); // Small delay to simulate network + return new HttpResponseMessage(HttpStatusCode.NotFound); + }); + + var httpClient = new HttpClient(mockHttpMessageHandler.Object); + + // Create a test download result + var mockDownloadResult = new Mock<IDownloadResult>(); + var resultLink = new TSparkArrowResultLink { FileLink = "http://test.com/file1" }; + mockDownloadResult.Setup(r => r.Link).Returns(resultLink); + mockDownloadResult.Setup(r => r.Size).Returns(1000); // Some arbitrary size + + // Capture when SetFailed is called + Exception? capturedException = null; + mockDownloadResult.Setup(r => r.SetFailed(It.IsAny<Exception>())) + .Callback<Exception>(ex => capturedException = ex); + + // Create the downloader and add the download to the queue + var downloader = new CloudFetchDownloader( + _downloadQueue, + _resultQueue, + _mockMemoryManager.Object, + httpClient, + 1, // maxParallelDownloads + false, // isLz4Compressed + 1, // maxRetries + 10); // retryDelayMs + + // Act + await downloader.StartAsync(CancellationToken.None); + _downloadQueue.Add(mockDownloadResult.Object); + + // Wait for the download to be processed + await Task.Delay(100); + + // Add the end of results guard to complete the downloader + _downloadQueue.Add(EndOfResultsGuard.Instance); + + // Assert + // Verify SetFailed was called + mockDownloadResult.Verify(r => r.SetFailed(It.IsAny<Exception>()), Times.Once); + Assert.NotNull(capturedException); + Assert.IsType<HttpRequestException>(capturedException); + + // Verify the downloader has an error + Assert.True(downloader.HasError); + Assert.NotNull(downloader.Error); + + // Verify GetNextDownloadedFileAsync throws an exception + await Assert.ThrowsAsync<AdbcException>(() => downloader.GetNextDownloadedFileAsync(CancellationToken.None)); + + // Cleanup + await downloader.StopAsync(); + } + + [Fact] + public async Task DownloadFileAsync_WithError_StopsProcessingRemainingFiles() + { + // Arrange + // Create a mock HTTP handler that returns success for the first request and error for the second + var mockHttpMessageHandler = new Mock<HttpMessageHandler>(); + + // Use a simpler approach - just make all requests fail + mockHttpMessageHandler + .Protected() + .Setup<Task<HttpResponseMessage>>( + "SendAsync", + ItExpr.IsAny<HttpRequestMessage>(), + ItExpr.IsAny<CancellationToken>()) + .ReturnsAsync(new HttpResponseMessage(HttpStatusCode.NotFound)); + + var httpClient = new HttpClient(mockHttpMessageHandler.Object); + + // Create test download results + var mockDownloadResult = new Mock<IDownloadResult>(); + var resultLink = new TSparkArrowResultLink { FileLink = "http://test.com/file1" }; + mockDownloadResult.Setup(r => r.Link).Returns(resultLink); + mockDownloadResult.Setup(r => r.Size).Returns(100); + + // Capture when SetFailed is called + Exception? capturedException = null; + mockDownloadResult.Setup(r => r.SetFailed(It.IsAny<Exception>())) + .Callback<Exception>(ex => capturedException = ex); + + // Create the downloader + var downloader = new CloudFetchDownloader( + _downloadQueue, + _resultQueue, + _mockMemoryManager.Object, + httpClient, + 1, // maxParallelDownloads + false, // isLz4Compressed + 1, // maxRetries + 10); // retryDelayMs + + // Act + await downloader.StartAsync(CancellationToken.None); + _downloadQueue.Add(mockDownloadResult.Object); + + // Wait for the download to be processed and fail + await Task.Delay(200); + + // Add the end of results guard + _downloadQueue.Add(EndOfResultsGuard.Instance); + + // Wait for all processing to complete + await Task.Delay(200); + + // Assert + // Verify the download failed + mockDownloadResult.Verify(r => r.SetFailed(It.IsAny<Exception>()), Times.Once); + + // Verify the downloader has an error + Assert.True(downloader.HasError); + Assert.NotNull(downloader.Error); + + // Verify GetNextDownloadedFileAsync throws an exception + await Assert.ThrowsAsync<AdbcException>(() => downloader.GetNextDownloadedFileAsync(CancellationToken.None)); + + // Cleanup + await downloader.StopAsync(); + } + + [Fact] + public async Task StopAsync_CancelsOngoingDownloads() + { + // Arrange + var cancellationTokenSource = new CancellationTokenSource(); + var downloadStarted = new TaskCompletionSource<bool>(); + var downloadCancelled = new TaskCompletionSource<bool>(); + + // Create a mock HTTP handler with a delay to simulate a long download + var mockHttpMessageHandler = new Mock<HttpMessageHandler>(); + mockHttpMessageHandler + .Protected() + .Setup<Task<HttpResponseMessage>>( + "SendAsync", + ItExpr.IsAny<HttpRequestMessage>(), + ItExpr.IsAny<CancellationToken>()) + .Returns<HttpRequestMessage, CancellationToken>(async (request, token) => + { + downloadStarted.TrySetResult(true); + + try + { + // Wait for a long time or until cancellation + await Task.Delay(10000, token); + } + catch (OperationCanceledException) + { + downloadCancelled.TrySetResult(true); + throw; + } + + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent("Test content") + }; + }); + + var httpClient = new HttpClient(mockHttpMessageHandler.Object); + + // Create a test download result + var mockDownloadResult = new Mock<IDownloadResult>(); + var resultLink = new TSparkArrowResultLink { FileLink = "http://test.com/file1" }; + mockDownloadResult.Setup(r => r.Link).Returns(resultLink); + mockDownloadResult.Setup(r => r.Size).Returns(100); + + // Create the downloader and add the download to the queue + var downloader = new CloudFetchDownloader( + _downloadQueue, + _resultQueue, + _mockMemoryManager.Object, + httpClient, + 1, // maxParallelDownloads + false); // isLz4Compressed + + // Act + await downloader.StartAsync(CancellationToken.None); + _downloadQueue.Add(mockDownloadResult.Object); + + // Wait for the download to start + await downloadStarted.Task; + + // Stop the downloader + await downloader.StopAsync(); + + // Assert + // Wait a short time for cancellation to propagate + var cancelled = await Task.WhenAny(downloadCancelled.Task, Task.Delay(1000)) == downloadCancelled.Task; + Assert.True(cancelled, "Download should have been cancelled"); + } + + [Fact] + public async Task GetNextDownloadedFileAsync_RespectsMaxParallelDownloads() + { + // Arrange + int totalDownloads = 3; + int maxParallelDownloads = 2; + var downloadStartedEvents = new TaskCompletionSource<bool>[totalDownloads]; + var downloadCompletedEvents = new TaskCompletionSource<bool>[totalDownloads]; + + for (int i = 0; i < totalDownloads; i++) + { + downloadStartedEvents[i] = new TaskCompletionSource<bool>(); + downloadCompletedEvents[i] = new TaskCompletionSource<bool>(); + } + + // Create a mock HTTP handler that signals when downloads start and waits for completion signal + var mockHttpMessageHandler = new Mock<HttpMessageHandler>(); + mockHttpMessageHandler + .Protected() + .Setup<Task<HttpResponseMessage>>( + "SendAsync", + ItExpr.IsAny<HttpRequestMessage>(), + ItExpr.IsAny<CancellationToken>()) + .Returns<HttpRequestMessage, CancellationToken>(async (request, token) => + { + // Extract the index from the URL + string url = request.RequestUri?.ToString() ?? ""; + if (url.Contains("file")) + { + int index = int.Parse(url.Substring(url.Length - 1)); + + if (request.Method == HttpMethod.Get) + { + // Signal that this download has started + downloadStartedEvents[index].TrySetResult(true); + + // Wait for the signal to complete this download + await downloadCompletedEvents[index].Task; + } + } + + // Return a success response + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent(Encoding.UTF8.GetBytes("Test content")) + }; + }); + + var httpClient = new HttpClient(mockHttpMessageHandler.Object); + + // Create test download results + var downloadResults = new IDownloadResult[totalDownloads]; + for (int i = 0; i < totalDownloads; i++) + { + var mockDownloadResult = new Mock<IDownloadResult>(); + var resultLink = new TSparkArrowResultLink { FileLink = $"http://test.com/file{i}" }; + mockDownloadResult.Setup(r => r.Link).Returns(resultLink); + mockDownloadResult.Setup(r => r.Size).Returns(100); + mockDownloadResult.Setup(r => r.SetCompleted(It.IsAny<Stream>(), It.IsAny<long>())) + .Callback<Stream, long>((_, _) => { }); + downloadResults[i] = mockDownloadResult.Object; + } + + // Create the downloader + var downloader = new CloudFetchDownloader( + _downloadQueue, + _resultQueue, + _mockMemoryManager.Object, + httpClient, + maxParallelDownloads, + false); // isLz4Compressed + + // Act + await downloader.StartAsync(CancellationToken.None); + + // Add all downloads to the queue + foreach (var result in downloadResults) + { + _downloadQueue.Add(result); + } + + // Wait for the first two downloads to start + await Task.WhenAll( + downloadStartedEvents[0].Task, + downloadStartedEvents[1].Task); + + // At this point, two downloads should be in progress + // Wait a bit to ensure the third download has had a chance to start if it's going to + await Task.Delay(100); + + // The third download should not have started yet + Assert.False(downloadStartedEvents[2].Task.IsCompleted, "The third download should not have started yet"); + + // Complete the first download + downloadCompletedEvents[0].SetResult(true); + + // Wait for the third download to start + await downloadStartedEvents[2].Task; + + // Complete the remaining downloads + downloadCompletedEvents[1].SetResult(true); + downloadCompletedEvents[2].SetResult(true); + + // Add the end of results guard to complete the downloader + _downloadQueue.Add(EndOfResultsGuard.Instance); + + // Cleanup + await downloader.StopAsync(); + } + + private static Mock<HttpMessageHandler> CreateMockHttpMessageHandler( + byte[]? content, + HttpStatusCode statusCode = HttpStatusCode.OK, + TimeSpan? delay = null) + { + var mockHandler = new Mock<HttpMessageHandler>(); + + mockHandler + .Protected() + .Setup<Task<HttpResponseMessage>>( + "SendAsync", + ItExpr.IsAny<HttpRequestMessage>(), + ItExpr.IsAny<CancellationToken>()) + .Returns<HttpRequestMessage, CancellationToken>(async (request, token) => + { + // If a delay is specified, wait for that duration + if (delay.HasValue) + { + await Task.Delay(delay.Value, token); + } + + // If the request is a HEAD request, return a response with content length + if (request.Method == HttpMethod.Head) + { + var response = new HttpResponseMessage(statusCode); + if (content != null) + { + response.Content = new ByteArrayContent(new byte[0]); + response.Content.Headers.ContentLength = content.Length; + } + return response; + } + + // For GET requests, return the actual content + var responseMessage = new HttpResponseMessage(statusCode); + if (content != null && statusCode == HttpStatusCode.OK) + { + responseMessage.Content = new ByteArrayContent(content); + responseMessage.Content.Headers.ContentLength = content.Length; + } + + return responseMessage; + }); + + return mockHandler; + } + } +} diff --git a/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchResultFetcherTest.cs b/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchResultFetcherTest.cs new file mode 100644 index 000000000..32f9a1a81 --- /dev/null +++ b/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchResultFetcherTest.cs @@ -0,0 +1,386 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch; +using Apache.Hive.Service.Rpc.Thrift; +using Moq; +using Xunit; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch +{ + /// <summary> + /// Tests for CloudFetchResultFetcher + /// </summary> + public class CloudFetchResultFetcherTest + { + private readonly Mock<ICloudFetchMemoryBufferManager> _mockMemoryManager; + private readonly BlockingCollection<IDownloadResult> _downloadQueue; + + public CloudFetchResultFetcherTest() + { + _mockMemoryManager = new Mock<ICloudFetchMemoryBufferManager>(); + _downloadQueue = new BlockingCollection<IDownloadResult>(new ConcurrentQueue<IDownloadResult>(), 10); + } + + [Fact] + public async Task StartAsync_CalledTwice_ThrowsException() + { + // Arrange + var mockClient = new Mock<TCLIService.IAsync>(); + mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>())) + .ReturnsAsync(CreateFetchResultsResponse(new List<TSparkArrowResultLink>(), false)); + + var mockStatement = new Mock<IHiveServer2Statement>(); + mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle()); + mockStatement.Setup(s => s.Client).Returns(mockClient.Object); + + var fetcher = new CloudFetchResultFetcher( + mockStatement.Object, + _mockMemoryManager.Object, + _downloadQueue, + 5); // batchSize + + // Act & Assert + await fetcher.StartAsync(CancellationToken.None); + await Assert.ThrowsAsync<InvalidOperationException>(() => fetcher.StartAsync(CancellationToken.None)); + + // Cleanup + await fetcher.StopAsync(); + } + + [Fact] + public async Task FetchResultsAsync_SuccessfullyFetchesResults() + { + // Arrange + var resultLinks = new List<TSparkArrowResultLink> + { + CreateTestResultLink(0, 100, "http://test.com/file1"), + CreateTestResultLink(100, 100, "http://test.com/file2"), + CreateTestResultLink(200, 100, "http://test.com/file3") + }; + + var mockClient = new Mock<TCLIService.IAsync>(); + mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>())) + .ReturnsAsync(CreateFetchResultsResponse(resultLinks, false)); + + var mockStatement = new Mock<IHiveServer2Statement>(); + mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle()); + mockStatement.Setup(s => s.Client).Returns(mockClient.Object); + + var fetcher = new CloudFetchResultFetcher( + mockStatement.Object, + _mockMemoryManager.Object, + _downloadQueue, + 5); // batchSize + + // Act + await fetcher.StartAsync(CancellationToken.None); + + // Wait for the fetcher to process the results + await Task.Delay(100); + + // Assert + // The download queue should contain our result links + // Note: With prefetch, there might be more items in the queue than just our result links + Assert.True(_downloadQueue.Count >= resultLinks.Count, + $"Expected at least {resultLinks.Count} items in queue, but found {_downloadQueue.Count}"); + + // Take all items from the queue and verify they match our result links + var downloadResults = new List<IDownloadResult>(); + while (_downloadQueue.TryTake(out var result)) + { + // Skip the end of results guard + if (result == EndOfResultsGuard.Instance) + { + continue; + } + downloadResults.Add(result); + } + + Assert.Equal(resultLinks.Count, downloadResults.Count); + + // Verify each download result has the correct link + for (int i = 0; i < resultLinks.Count; i++) + { + Assert.Equal(resultLinks[i].FileLink, downloadResults[i].Link.FileLink); + Assert.Equal(resultLinks[i].StartRowOffset, downloadResults[i].Link.StartRowOffset); + Assert.Equal(resultLinks[i].RowCount, downloadResults[i].Link.RowCount); + } + + // Verify the fetcher state + Assert.False(fetcher.HasMoreResults); + Assert.True(fetcher.IsCompleted); + Assert.False(fetcher.HasError); + Assert.Null(fetcher.Error); + + // Cleanup + await fetcher.StopAsync(); + } + + [Fact] + public async Task FetchResultsAsync_WithMultipleBatches_FetchesAllResults() + { + // Arrange + var firstBatchLinks = new List<TSparkArrowResultLink> + { + CreateTestResultLink(0, 100, "http://test.com/file1"), + CreateTestResultLink(100, 100, "http://test.com/file2") + }; + + var secondBatchLinks = new List<TSparkArrowResultLink> + { + CreateTestResultLink(200, 100, "http://test.com/file3"), + CreateTestResultLink(300, 100, "http://test.com/file4") + }; + + var mockClient = new Mock<TCLIService.IAsync>(); + mockClient.SetupSequence(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>())) + .ReturnsAsync(CreateFetchResultsResponse(firstBatchLinks, true)) + .ReturnsAsync(CreateFetchResultsResponse(secondBatchLinks, false)); + + var mockStatement = new Mock<IHiveServer2Statement>(); + mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle()); + mockStatement.Setup(s => s.Client).Returns(mockClient.Object); + + var fetcher = new CloudFetchResultFetcher( + mockStatement.Object, + _mockMemoryManager.Object, + _downloadQueue, + 5); // batchSize + + // Act + await fetcher.StartAsync(CancellationToken.None); + + // Wait for the fetcher to process all results + await Task.Delay(200); + + // Assert + // The download queue should contain all result links (both batches) + // Note: With prefetch, there might be more items in the queue than just our result links + Assert.True(_downloadQueue.Count >= firstBatchLinks.Count + secondBatchLinks.Count, + $"Expected at least {firstBatchLinks.Count + secondBatchLinks.Count} items in queue, but found {_downloadQueue.Count}"); + + // Take all items from the queue + var downloadResults = new List<IDownloadResult>(); + while (_downloadQueue.TryTake(out var result)) + { + // Skip the end of results guard + if (result == EndOfResultsGuard.Instance) + { + continue; + } + downloadResults.Add(result); + } + + Assert.Equal(firstBatchLinks.Count + secondBatchLinks.Count, downloadResults.Count); + + // Verify the fetcher state + Assert.False(fetcher.HasMoreResults); + Assert.True(fetcher.IsCompleted); + Assert.False(fetcher.HasError); + + // Cleanup + await fetcher.StopAsync(); + } + + [Fact] + public async Task FetchResultsAsync_WithEmptyResults_CompletesGracefully() + { + // Arrange + var mockClient = new Mock<TCLIService.IAsync>(); + mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>())) + .ReturnsAsync(CreateFetchResultsResponse(new List<TSparkArrowResultLink>(), false)); + + var mockStatement = new Mock<IHiveServer2Statement>(); + mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle()); + mockStatement.Setup(s => s.Client).Returns(mockClient.Object); + + var fetcher = new CloudFetchResultFetcher( + mockStatement.Object, + _mockMemoryManager.Object, + _downloadQueue, + 5); // batchSize + + // Act + await fetcher.StartAsync(CancellationToken.None); + + // Wait for the fetcher to process the results + await Task.Delay(100); + + // Assert + // The download queue should be empty except for the end guard + var nonGuardItems = new List<IDownloadResult>(); + while (_downloadQueue.TryTake(out var result)) + { + if (result != EndOfResultsGuard.Instance) + { + nonGuardItems.Add(result); + } + } + Assert.Empty(nonGuardItems); + + // Verify the fetcher state + Assert.False(fetcher.HasMoreResults); + Assert.True(fetcher.IsCompleted); + Assert.False(fetcher.HasError); + + // Cleanup + await fetcher.StopAsync(); + } + + [Fact] + public async Task FetchResultsAsync_WithServerError_SetsErrorState() + { + // Arrange + var mockClient = new Mock<TCLIService.IAsync>(); + mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>())) + .ThrowsAsync(new InvalidOperationException("Test server error")); + + var mockStatement = new Mock<IHiveServer2Statement>(); + mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle()); + mockStatement.Setup(s => s.Client).Returns(mockClient.Object); + + var fetcher = new CloudFetchResultFetcher( + mockStatement.Object, + _mockMemoryManager.Object, + _downloadQueue, + 5); // batchSize + + // Act + await fetcher.StartAsync(CancellationToken.None); + + // Wait for the fetcher to process the error + await Task.Delay(100); + + // Assert + // Verify the fetcher state + Assert.False(fetcher.HasMoreResults); + Assert.True(fetcher.IsCompleted); + Assert.True(fetcher.HasError); + Assert.NotNull(fetcher.Error); + Assert.IsType<InvalidOperationException>(fetcher.Error); + + // The download queue should have the end guard + Assert.Single(_downloadQueue); + var result = _downloadQueue.Take(); + Assert.Same(EndOfResultsGuard.Instance, result); + + // Cleanup + await fetcher.StopAsync(); + } + + [Fact] + public async Task StopAsync_CancelsFetching() + { + // Arrange + var fetchStarted = new TaskCompletionSource<bool>(); + var fetchCancelled = new TaskCompletionSource<bool>(); + + var mockClient = new Mock<TCLIService.IAsync>(); + mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>())) + .Returns(async (TFetchResultsReq req, CancellationToken token) => + { + fetchStarted.TrySetResult(true); + + try + { + // Wait for a long time or until cancellation + await Task.Delay(10000, token); + } + catch (OperationCanceledException) + { + fetchCancelled.TrySetResult(true); + throw; + } + + // Return empty results if not cancelled + return CreateFetchResultsResponse(new List<TSparkArrowResultLink>(), false); + }); + + var mockStatement = new Mock<IHiveServer2Statement>(); + mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle()); + mockStatement.Setup(s => s.Client).Returns(mockClient.Object); + + var fetcher = new CloudFetchResultFetcher( + mockStatement.Object, + _mockMemoryManager.Object, + _downloadQueue, + 5); // batchSize + + // Act + await fetcher.StartAsync(CancellationToken.None); + + // Wait for the fetch to start + await fetchStarted.Task; + + // Stop the fetcher + await fetcher.StopAsync(); + + // Assert + // Wait a short time for cancellation to propagate + var cancelled = await Task.WhenAny(fetchCancelled.Task, Task.Delay(1000)) == fetchCancelled.Task; + Assert.True(cancelled, "Fetch operation should have been cancelled"); + + // Verify the fetcher state + Assert.True(fetcher.IsCompleted); + } + + private TOperationHandle CreateOperationHandle() + { + return new TOperationHandle + { + OperationId = new THandleIdentifier + { + Guid = new byte[16], + Secret = new byte[16] + }, + OperationType = TOperationType.EXECUTE_STATEMENT, + HasResultSet = true + }; + } + + private TFetchResultsResp CreateFetchResultsResponse(List<TSparkArrowResultLink> resultLinks, bool hasMoreRows) + { + var results = new TRowSet(); + results.__isset.resultLinks = true; + results.ResultLinks = resultLinks; + + return new TFetchResultsResp + { + Status = new TStatus { StatusCode = TStatusCode.SUCCESS_STATUS }, + HasMoreRows = hasMoreRows, + Results = results, + __isset = { results = true, hasMoreRows = true } + }; + } + + private TSparkArrowResultLink CreateTestResultLink(long startRowOffset, int rowCount, string fileLink) + { + return new TSparkArrowResultLink + { + StartRowOffset = startRowOffset, + RowCount = rowCount, + FileLink = fileLink + }; + } + } +}