This is an automated email from the ASF dual-hosted git repository. curth pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push: new 103865846 feat(csharp/src/Drivers/Databricks): Add option to enable using direct results for statements (#2737) 103865846 is described below commit 103865846d78821d90500b9c402efddf484d3b59 Author: Alex Guo <133057192+alexguo...@users.noreply.github.com> AuthorDate: Mon Apr 28 14:18:04 2025 -0700 feat(csharp/src/Drivers/Databricks): Add option to enable using direct results for statements (#2737) - Add option to set EnableDirectResults, which sends getDirectResults in the Thrift execute statement request - If getDirectResults is set in the request, then directResults is set on the response containing initial results (equivalent to the server calling GetOperationStatus, GetResultSetMetadata, FetchResults, and CloseOperation) - If directResults is set on the response, don't poll for the operation status - We already set getDirectResults on requests for metadata commands, just not the execute statement request Tested E2E using `dotnet test --filter CloudFetchE2ETest` ``` [xUnit.net 00:00:00.11] Starting: Apache.Arrow.Adbc.Tests.Drivers.Databricks [xUnit.net 00:01:27.27] Finished: Apache.Arrow.Adbc.Tests.Drivers.Databricks Apache.Arrow.Adbc.Tests.Drivers.Databricks test net8.0 succeeded (87.7s) Test summary: total: 8, failed: 0, succeeded: 8, skipped: 0, duration: 87.7s Build succeeded in 89.1s ``` --- .../Drivers/Apache/Hive2/HiveServer2Connection.cs | 26 +++++----- .../Drivers/Apache/Hive2/HiveServer2Statement.cs | 37 ++++++++++++--- csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 2 +- .../Databricks/CloudFetch/CloudFetchReader.cs | 2 +- .../CloudFetch/CloudFetchResultFetcher.cs | 28 +++++++++++ .../Databricks/CloudFetch/IHiveServer2Statement.cs | 11 +++++ .../src/Drivers/Databricks/DatabricksConnection.cs | 40 ++++++++++++++++ .../src/Drivers/Databricks/DatabricksParameters.cs | 6 +++ csharp/src/Drivers/Databricks/DatabricksReader.cs | 11 +++++ .../src/Drivers/Databricks/DatabricksStatement.cs | 16 +++++++ .../test/Drivers/Databricks/CloudFetchE2ETest.cs | 55 +++++++++++----------- .../Drivers/Databricks/DatabricksConnectionTest.cs | 1 + 12 files changed, 186 insertions(+), 49 deletions(-) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index 9a17eb6c1..f4f0f7978 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -377,7 +377,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) { TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(getCatalogsReq); } @@ -416,7 +416,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); getSchemasReq.CatalogName = catalogPattern; getSchemasReq.SchemaName = dbSchemaPattern; - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(getSchemasReq); } @@ -449,7 +449,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 getTablesReq.CatalogName = catalogPattern; getTablesReq.SchemaName = dbSchemaPattern; getTablesReq.TableName = tableNamePattern; - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(getTablesReq); } @@ -486,7 +486,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 columnsReq.CatalogName = catalogPattern; columnsReq.SchemaName = dbSchemaPattern; columnsReq.TableName = tableNamePattern; - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(columnsReq); } @@ -594,7 +594,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), }; - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(req); } @@ -786,7 +786,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default); protected internal abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default); - protected internal virtual bool AreResultsAvailableDirectly() => false; + protected internal virtual bool AreResultsAvailableDirectly => false; protected virtual void SetDirectResults(TGetColumnsReq request) => throw new System.NotImplementedException(); @@ -923,7 +923,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 } TGetCatalogsReq req = new TGetCatalogsReq(SessionHandle); - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(req); } @@ -950,7 +950,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 } TGetSchemasReq req = new(SessionHandle); - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(req); } @@ -987,7 +987,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 } TGetTablesReq req = new(SessionHandle); - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(req); } @@ -1032,7 +1032,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 } TGetColumnsReq req = new(SessionHandle); - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(req); } @@ -1076,7 +1076,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 } TGetPrimaryKeysReq req = new(SessionHandle); - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(req); } @@ -1119,7 +1119,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 } TGetCrossReferenceReq req = new(SessionHandle); - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(req); } @@ -1255,7 +1255,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 getColumnsReq.CatalogName = catalog; getColumnsReq.SchemaName = dbSchema; getColumnsReq.TableName = tableName; - if (AreResultsAvailableDirectly()) + if (AreResultsAvailableDirectly) { SetDirectResults(getColumnsReq); } diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 564a28e9f..bff8ca225 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -99,17 +99,28 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 return await ExecuteMetadataCommandQuery(cancellationToken); } + _directResults = null; + // this could either: // take QueryTimeoutSeconds * 3 // OR // take QueryTimeoutSeconds (but this could be restricting) await ExecuteStatementAsync(cancellationToken); // --> get QueryTimeout + - await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to QueryTimeout - TGetResultSetMetadataResp response = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken); - Schema schema = Connection.SchemaParser.GetArrowSchema(response.Schema, Connection.DataTypeConversion); + TGetResultSetMetadataResp metadata; + if (_directResults?.OperationStatus?.OperationState == TOperationState.FINISHED_STATE) + { + // The initial response has result data so we don't need to poll + metadata = _directResults.ResultSetMetadata; + } + else + { + await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to QueryTimeout + metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken); + } // Store metadata for use in readers - return new QueryResult(-1, Connection.NewReader(this, schema, response)); + Schema schema = Connection.SchemaParser.GetArrowSchema(metadata.Schema, Connection.DataTypeConversion); + return new QueryResult(-1, Connection.NewReader(this, schema, metadata)); } public override async ValueTask<QueryResult> ExecuteQueryAsync() @@ -257,6 +268,19 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 .SetNativeError(executeResponse.Status.ErrorCode); } OperationHandle = executeResponse.OperationHandle; + + // Capture direct results if they're available + if (executeResponse.DirectResults != null) + { + _directResults = executeResponse.DirectResults; + + if (!string.IsNullOrEmpty(_directResults.OperationStatus?.DisplayMessage)) + { + throw new HiveServer2Exception(_directResults.OperationStatus!.DisplayMessage) + .SetSqlState(_directResults.OperationStatus.SqlState) + .SetNativeError(_directResults.OperationStatus.ErrorCode); + } + } } protected internal int PollTimeMilliseconds { get; private set; } = HiveServer2Connection.PollTimeMillisecondsDefault; @@ -279,6 +303,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 protected internal string? ForeignCatalogName { get; set; } protected internal string? ForeignSchemaName { get; set; } protected internal string? ForeignTableName { get; set; } + protected internal TSparkDirectResults? _directResults { get; set; } public HiveServer2Connection Connection { get; private set; } @@ -416,7 +441,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 TRowSet rowSet; // For GetColumns, we need to enhance the result with BASE_TYPE_NAME - if (Connection.AreResultsAvailableDirectly() && resp.DirectResults?.ResultSet?.Results != null) + if (Connection.AreResultsAvailableDirectly && resp.DirectResults?.ResultSet?.Results != null) { // Get data from direct results metadata = resp.DirectResults.ResultSetMetadata; @@ -454,7 +479,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 private async Task<QueryResult> GetQueryResult(TSparkDirectResults? directResults, CancellationToken cancellationToken) { Schema schema; - if (Connection.AreResultsAvailableDirectly() && directResults?.ResultSet?.Results != null) + if (Connection.AreResultsAvailableDirectly && directResults?.ResultSet?.Results != null) { TGetResultSetMetadataResp resultSetMetadata = directResults.ResultSetMetadata; schema = Connection.SchemaParser.GetArrowSchema(resultSetMetadata.Schema, Connection.DataTypeConversion); diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index 925073e35..c7e25861e 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -117,7 +117,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark protected override bool IsColumnSizeValidForDecimal => false; - protected internal override bool AreResultsAvailableDirectly() => true; + protected internal override bool AreResultsAvailableDirectly => true; protected override void SetDirectResults(TGetColumnsReq request) => request.GetDirectResults = sparkGetDirectResults; diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs index abca66d37..1e3861833 100644 --- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs +++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs @@ -1,4 +1,4 @@ -/* +/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs index b0ad05a6a..3da5608ed 100644 --- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs +++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs @@ -129,6 +129,14 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch { try { + // Process direct results first, if available + if (_statement.HasDirectResults && _statement.DirectResults?.ResultSet?.Results?.ResultLinks?.Count > 0) + { + // Yield execution so the download queue doesn't get blocked before downloader is started + await Task.Yield(); + ProcessDirectResultsAsync(cancellationToken); + } + // Continue fetching as needed while (_hasMoreResults && !cancellationToken.IsCancellationRequested) { @@ -228,5 +236,25 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch _hasMoreResults = false; } } + + private void ProcessDirectResultsAsync(CancellationToken cancellationToken) + { + List<TSparkArrowResultLink> resultLinks = _statement.DirectResults!.ResultSet.Results.ResultLinks; + + foreach (var link in resultLinks) + { + var downloadResult = new DownloadResult(link, _memoryManager); + _downloadQueue.Add(downloadResult, cancellationToken); + } + + // Update the start offset for the next fetch + if (resultLinks.Count > 0) + { + var lastLink = resultLinks[resultLinks.Count - 1]; + _startOffset = lastLink.StartRowOffset + lastLink.RowCount; + } + + _hasMoreResults = _statement.DirectResults!.ResultSet.HasMoreRows; + } } } diff --git a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs b/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs index cfc92b98b..ee77dce9d 100644 --- a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs +++ b/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs @@ -33,5 +33,16 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch /// Gets the client. /// </summary> TCLIService.IAsync Client { get; } + + /// <summary> + /// Gets the direct results. + /// </summary> + TSparkDirectResults? DirectResults { get; } + + /// <summary> + /// Checks if direct results are available. + /// </summary> + /// <returns>True if direct results are available and contain result data, false otherwise.</returns> + bool HasDirectResults { get; } } } diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs b/csharp/src/Drivers/Databricks/DatabricksConnection.cs index aefe2df89..cd7fc02a0 100644 --- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs +++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs @@ -33,6 +33,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks internal class DatabricksConnection : SparkHttpConnection { private bool _applySSPWithQueries = false; + private bool _enableDirectResults = true; + + internal static TSparkGetDirectResults defaultGetDirectResults = new() + { + MaxRows = 2000000, + MaxBytes = 404857600 + }; // CloudFetch configuration private const long DefaultMaxBytesPerFile = 20 * 1024 * 1024; // 20MB @@ -62,6 +69,18 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks } } + if (Properties.TryGetValue(DatabricksParameters.EnableDirectResults, out string? enableDirectResultsStr)) + { + if (bool.TryParse(enableDirectResultsStr, out bool enableDirectResultsValue)) + { + _enableDirectResults = enableDirectResultsValue; + } + else + { + throw new ArgumentException($"Parameter '{DatabricksParameters.EnableDirectResults}' value '{enableDirectResultsStr}' could not be parsed. Valid values are 'true' and 'false'."); + } + } + // Parse CloudFetch options from connection properties if (Properties.TryGetValue(DatabricksParameters.UseCloudFetch, out string? useCloudFetchStr)) { @@ -110,6 +129,11 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks /// </summary> internal bool ApplySSPWithQueries => _applySSPWithQueries; + /// <summary> + /// Gets whether direct results are enabled. + /// </summary> + internal bool EnableDirectResults => _enableDirectResults; + /// <summary> /// Gets whether CloudFetch is enabled. /// </summary> @@ -145,6 +169,22 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks return baseHandler; } + protected internal override bool AreResultsAvailableDirectly => _enableDirectResults; + + protected override void SetDirectResults(TGetColumnsReq request) => request.GetDirectResults = defaultGetDirectResults; + + protected override void SetDirectResults(TGetCatalogsReq request) => request.GetDirectResults = defaultGetDirectResults; + + protected override void SetDirectResults(TGetSchemasReq request) => request.GetDirectResults = defaultGetDirectResults; + + protected override void SetDirectResults(TGetTablesReq request) => request.GetDirectResults = defaultGetDirectResults; + + protected override void SetDirectResults(TGetTableTypesReq request) => request.GetDirectResults = defaultGetDirectResults; + + protected override void SetDirectResults(TGetPrimaryKeysReq request) => request.GetDirectResults = defaultGetDirectResults; + + protected override void SetDirectResults(TGetCrossReferenceReq request) => request.GetDirectResults = defaultGetDirectResults; + internal override IArrowArrayStream NewReader<T>(T statement, Schema schema, TGetResultSetMetadataResp? metadataResp = null) { // Get result format from metadata response if available diff --git a/csharp/src/Drivers/Databricks/DatabricksParameters.cs b/csharp/src/Drivers/Databricks/DatabricksParameters.cs index f45350b72..7c6e9a69f 100644 --- a/csharp/src/Drivers/Databricks/DatabricksParameters.cs +++ b/csharp/src/Drivers/Databricks/DatabricksParameters.cs @@ -61,6 +61,12 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks /// </summary> public const string CloudFetchTimeoutMinutes = "adbc.databricks.cloudfetch.timeout_minutes"; + /// <summary> + /// Whether to enable the use of direct results when executing queries. + /// Default value is true if not specified. + /// </summary> + public const string EnableDirectResults = "adbc.databricks.enable_direct_results"; + /// <summary> /// Whether to apply service side properties (SSP) with queries. If false, SSP will be applied /// by setting the Thrift configuration when the session is opened. diff --git a/csharp/src/Drivers/Databricks/DatabricksReader.cs b/csharp/src/Drivers/Databricks/DatabricksReader.cs index 56abfbb20..cdd131111 100644 --- a/csharp/src/Drivers/Databricks/DatabricksReader.cs +++ b/csharp/src/Drivers/Databricks/DatabricksReader.cs @@ -39,6 +39,17 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks this.statement = statement; this.schema = schema; this.isLz4Compressed = isLz4Compressed; + + // If we have direct results, initialize the batches from them + if (statement.HasDirectResults) + { + this.batches = statement.DirectResults!.ResultSet.Results.ArrowBatches; + + if (!statement.DirectResults.ResultSet.HasMoreRows) + { + this.statement = null; + } + } } public Schema Schema { get { return schema; } } diff --git a/csharp/src/Drivers/Databricks/DatabricksStatement.cs b/csharp/src/Drivers/Databricks/DatabricksStatement.cs index 447689e51..72cdb8ac7 100644 --- a/csharp/src/Drivers/Databricks/DatabricksStatement.cs +++ b/csharp/src/Drivers/Databricks/DatabricksStatement.cs @@ -49,6 +49,22 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks statement.CanDownloadResult = useCloudFetch; statement.CanDecompressLZ4Result = canDecompressLz4; statement.MaxBytesPerFile = maxBytesPerFile; + + if (Connection.AreResultsAvailableDirectly) + { + statement.GetDirectResults = DatabricksConnection.defaultGetDirectResults; + } + } + + /// <summary> + /// Checks if direct results are available. + /// </summary> + /// <returns>True if direct results are available and contain result data, false otherwise.</returns> + public bool HasDirectResults => DirectResults?.ResultSet != null && DirectResults?.ResultSetMetadata != null; + + public TSparkDirectResults? DirectResults + { + get { return _directResults; } } // Cast the Client to IAsync for CloudFetch compatibility diff --git a/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs b/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs index 0d9bbfa90..96b040274 100644 --- a/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs +++ b/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs @@ -16,6 +16,7 @@ */ using System; +using System.Collections.Generic; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Databricks; using Xunit; @@ -35,42 +36,40 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); } - /// <summary> - /// Integration test for running a large query against a real Databricks cluster. - /// </summary> - [Fact] - public async Task TestRealDatabricksCloudFetchSmallResultSet() - { - await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM range(1000)", 1000); - } - - [Fact] - public async Task TestRealDatabricksCloudFetchLargeResultSet() + public static IEnumerable<object[]> TestCases() { - await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM main.tpcds_sf10_delta.catalog_sales LIMIT 1000000", 1000000); - } + // Test cases format: (query, expected row count, use cloud fetch, enable direct results) - [Fact] - public async Task TestRealDatabricksNoCloudFetchSmallResultSet() - { - await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM range(1000)", 1000, false); - } + string smallQuery = $"SELECT * FROM range(1000)"; + yield return new object[] { smallQuery, 1000, true, true }; + yield return new object[] { smallQuery, 1000, false, true }; + yield return new object[] { smallQuery, 1000, true, false }; + yield return new object[] { smallQuery, 1000, false, false }; - [Fact] - public async Task TestRealDatabricksNoCloudFetchLargeResultSet() - { - await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM main.tpcds_sf10_delta.catalog_sales LIMIT 1000000", 1000000, false); + string largeQuery = $"SELECT * FROM main.tpcds_sf10_delta.catalog_sales LIMIT 1000000"; + yield return new object[] { largeQuery, 1000000, true, true }; + yield return new object[] { largeQuery, 1000000, false, true }; + yield return new object[] { largeQuery, 1000000, true, false }; + yield return new object[] { largeQuery, 1000000, false, false }; } - private async Task TestRealDatabricksCloudFetchLargeQuery(string query, int rowCount, bool useCloudFetch = true) + /// <summary> + /// Integration test for running queries against a real Databricks cluster with different CloudFetch settings. + /// </summary> + [Theory] + [MemberData(nameof(TestCases))] + private async Task TestRealDatabricksCloudFetch(string query, int rowCount, bool useCloudFetch, bool enableDirectResults) { - // Create a statement with CloudFetch enabled - var statement = Connection.CreateStatement(); - statement.SetOption(DatabricksParameters.UseCloudFetch, useCloudFetch.ToString()); - statement.SetOption(DatabricksParameters.CanDecompressLz4, "true"); - statement.SetOption(DatabricksParameters.MaxBytesPerFile, "10485760"); // 10MB + var connection = NewConnection(TestConfiguration, new Dictionary<string, string> + { + [DatabricksParameters.UseCloudFetch] = useCloudFetch.ToString(), + [DatabricksParameters.EnableDirectResults] = enableDirectResults.ToString(), + [DatabricksParameters.CanDecompressLz4] = "true", + [DatabricksParameters.MaxBytesPerFile] = "10485760" // 10MB + }); // Execute a query that generates a large result set using range function + var statement = connection.CreateStatement(); statement.SqlQuery = query; // Execute the query and get the result diff --git a/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs b/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs index 859ee7e84..5c334957f 100644 --- a/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs +++ b/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs @@ -300,6 +300,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks Add(new(new() { [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [DatabricksParameters.CanDecompressLz4] = "notabool"}, typeof(ArgumentException))); Add(new(new() { [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [DatabricksParameters.MaxBytesPerFile] = "notanumber" }, typeof(ArgumentException))); Add(new(new() { [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [DatabricksParameters.MaxBytesPerFile] = "-100" }, typeof(ArgumentOutOfRangeException))); + Add(new(new() { [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [DatabricksParameters.EnableDirectResults] = "notabool" }, typeof(ArgumentException))); Add(new(new() { /*[SparkParameters.Type] = SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = "-1" }, typeof(ArgumentOutOfRangeException))); Add(new(new() { /*[SparkParameters.Type] = SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = IPEndPoint.MinPort.ToString(CultureInfo.InvariantCulture) }, typeof(ArgumentOutOfRangeException))); Add(new(new() { /*[SparkParameters.Type] = SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = (IPEndPoint.MaxPort + 1).ToString(CultureInfo.InvariantCulture) }, typeof(ArgumentOutOfRangeException)));