kushaman commented on code in PR #2365: URL: https://github.com/apache/arrow-adbc/pull/2365#discussion_r1919793025
########## csharp/test/Drivers/Apache/Common/StringValueTests.cs: ########## @@ -90,15 +91,15 @@ await ValidateInsertSelectDeleteSingleValueAsync( [SkippableTheory] [InlineData(null)] [InlineData("")] - [InlineData("你好")] [InlineData(" Leading and trailing spaces ")] - protected virtual async Task TestCharData(string? value) + internal virtual async Task TestCharData(string? value) Review Comment: @birschick-bq , can you please check. ########## csharp/test/Drivers/Apache/Common/DriverTests.cs: ########## @@ -50,7 +50,7 @@ public abstract class DriverTests<TConfig, TEnv> : TestBase<TConfig, TEnv> /// <summary> /// Supported Spark data types as a subset of <see cref="ColumnTypeId"/> /// </summary> - private enum SupportedSparkDataType : short + internal enum SupportedSparkDataType : short Review Comment: Yes, It is common for Spark & Impala. I have changed it to SupportedDriverDataType. ########## csharp/test/Drivers/Apache/Impala/DriverTests.cs: ########## @@ -0,0 +1,132 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using Apache.Arrow.Adbc.Drivers.Apache.Impala; +using Apache.Arrow.Adbc.Tests.Metadata; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala +{ + public class DriverTests : Common.DriverTests<ApacheTestConfiguration, ImpalaTestEnvironment> + { + public DriverTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new ImpalaTestEnvironment.Factory()) + { + } + + [SkippableTheory] + [MemberData(nameof(CatalogNamePatternData))] + public override void CanGetObjectsCatalogs(string? pattern) + { + GetObjectsCatalogsTest(pattern); + } + + [SkippableTheory] + [MemberData(nameof(DbSchemasNamePatternData))] + public override void CanGetObjectsDbSchemas(string dbSchemaPattern) + { + GetObjectsDbSchemasTest(dbSchemaPattern); + } + + [SkippableTheory] + [MemberData(nameof(TableNamePatternData))] + public override void CanGetObjectsTables(string tableNamePattern) + { + GetObjectsTablesTest(tableNamePattern); + } + + public override void CanDetectInvalidServer() + { + AdbcDriver driver = NewDriver; + Assert.NotNull(driver); + Dictionary<string, string> parameters = GetDriverParameters(TestConfiguration); + + bool hasUri = parameters.TryGetValue(AdbcOptions.Uri, out var uri) && !string.IsNullOrEmpty(uri); + bool hasHostName = parameters.TryGetValue(ImpalaParameters.HostName, out var hostName) && !string.IsNullOrEmpty(hostName); + if (hasUri) + { + parameters[AdbcOptions.Uri] = "http://unknownhost.azure.com/cliservice"; + } + else if (hasHostName) + { + parameters[ImpalaParameters.HostName] = "unknownhost.azure.com"; + } + else + { + Assert.Fail($"Unexpected configuration. Must provide '{AdbcOptions.Uri}' or '{ImpalaParameters.HostName}'."); + } + + AdbcDatabase database = driver.Open(parameters); + AggregateException exception = Assert.ThrowsAny<AggregateException>(() => database.Connect(parameters)); + OutputHelper?.WriteLine(exception.Message); + } + + public override void CanDetectInvalidAuthentication() + { + AdbcDriver driver = NewDriver; + Assert.NotNull(driver); + Dictionary<string, string> parameters = GetDriverParameters(TestConfiguration); + + bool hasUsername = parameters.TryGetValue(AdbcOptions.Username, out var username) && !string.IsNullOrEmpty(username); + bool hasPassword = parameters.TryGetValue(AdbcOptions.Password, out var password) && !string.IsNullOrEmpty(password); + if (hasUsername && hasPassword) + { + parameters[AdbcOptions.Password] = "invalid-password"; + } + else + { + Assert.Fail($"Unexpected configuration. Must provide '{AdbcOptions.Username}' and '{AdbcOptions.Password}'."); + } + + AdbcDatabase database = driver.Open(parameters); + AggregateException exception = Assert.ThrowsAny<AggregateException>(() => database.Connect(parameters)); + OutputHelper?.WriteLine(exception.Message); + } + + protected override IReadOnlyList<int> GetUpdateExpectedResults() + { + int affectedRows = ValidateAffectedRows ? 1 : -1; + return ClientTests.GetUpdateExpectedResults(affectedRows); + } + + Review Comment: Removed. ########## csharp/test/Drivers/Apache/Impala/StatementTests.cs: ########## @@ -0,0 +1,51 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Common; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala +{ + public class StatementTests : Common.StatementTests<ApacheTestConfiguration, ImpalaTestEnvironment> + { + public StatementTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new ImpalaTestEnvironment.Factory()) + { + } + + [SkippableTheory] + [ClassData(typeof(LongRunningStatementTimeoutTestData))] + internal override void StatementTimeoutTest(StatementWithExceptions statementWithExceptions) + { + base.StatementTimeoutTest(statementWithExceptions); + } + + internal class LongRunningStatementTimeoutTestData : ShortRunningStatementTimeoutTestData + { + public LongRunningStatementTimeoutTestData() + { + // TODO: Determine if this long-running query will work as expected on Impala. Review Comment: @birschick-bq , can we remove these overrides. ########## csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs: ########## @@ -208,5 +707,616 @@ internal static async Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync( TGetResultSetMetadataResp response = await client.GetResultSetMetadata(request, cancellationToken); return response; } + + /// <summary> + /// Gets the data-source specific columns names for the GetColumns metadata result. + /// </summary> + /// <returns></returns> + protected abstract ColumnsMetadataColumnNames GetColumnsMetadataColumnNames(); + + /// <summary> + /// Gets the default product version + /// </summary> + /// <returns></returns> + protected abstract string GetProductVersionDefault(); + + /// <summary> + /// Gets the current product version. + /// </summary> + /// <returns></returns> + protected internal string GetProductVersion() + { + FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location); + return fileVersionInfo.ProductVersion ?? GetProductVersionDefault(); + } + + protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, string? port) + { + // Uri property takes precedent. + if (!string.IsNullOrWhiteSpace(uri)) + { + var uriValue = new Uri(uri); + if (uriValue.Scheme != Uri.UriSchemeHttp && uriValue.Scheme != Uri.UriSchemeHttps) + throw new ArgumentOutOfRangeException( + AdbcOptions.Uri, + uri, + $"Unsupported scheme '{uriValue.Scheme}'"); + return uriValue; + } + + bool isPortSet = !string.IsNullOrEmpty(port); + bool isValidPortNumber = int.TryParse(port, out int portNumber) && portNumber > 0; + bool isDefaultHttpsPort = !isPortSet || (isValidPortNumber && portNumber == 443); + string uriScheme = isDefaultHttpsPort ? Uri.UriSchemeHttps : Uri.UriSchemeHttp; + int uriPort; + if (!isPortSet) + uriPort = -1; + else if (isValidPortNumber) + uriPort = portNumber; + else + throw new ArgumentOutOfRangeException(nameof(port), portNumber, $"Port number is not in a valid range."); + + Uri baseAddress = new UriBuilder(uriScheme, hostName, uriPort, path).Uri; + return baseAddress; + } + + // Note data source's Position may be one-indexed or zero-indexed + protected IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns + .Select(t => new { Index = t.Position - PositionRequiredOffset, t.ColumnName }) + .ToDictionary(t => t.ColumnName, t => t.Index); + + protected abstract Task<TRowSet> GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetCatalogsResp getCatalogsResp, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetSchemasResp getSchemasResp, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + + protected abstract bool AreResultsAvailableDirectly(); + + protected abstract TSparkGetDirectResults GetDirectResults(); Review Comment: Done. Removed "TSparkGetDirectResults" from here. ########## csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs: ########## @@ -208,5 +707,616 @@ internal static async Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync( TGetResultSetMetadataResp response = await client.GetResultSetMetadata(request, cancellationToken); return response; } + + /// <summary> + /// Gets the data-source specific columns names for the GetColumns metadata result. + /// </summary> + /// <returns></returns> + protected abstract ColumnsMetadataColumnNames GetColumnsMetadataColumnNames(); + + /// <summary> + /// Gets the default product version + /// </summary> + /// <returns></returns> + protected abstract string GetProductVersionDefault(); + + /// <summary> + /// Gets the current product version. + /// </summary> + /// <returns></returns> + protected internal string GetProductVersion() + { + FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location); + return fileVersionInfo.ProductVersion ?? GetProductVersionDefault(); + } + + protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, string? port) + { + // Uri property takes precedent. + if (!string.IsNullOrWhiteSpace(uri)) + { + var uriValue = new Uri(uri); + if (uriValue.Scheme != Uri.UriSchemeHttp && uriValue.Scheme != Uri.UriSchemeHttps) + throw new ArgumentOutOfRangeException( + AdbcOptions.Uri, + uri, + $"Unsupported scheme '{uriValue.Scheme}'"); + return uriValue; + } + + bool isPortSet = !string.IsNullOrEmpty(port); + bool isValidPortNumber = int.TryParse(port, out int portNumber) && portNumber > 0; + bool isDefaultHttpsPort = !isPortSet || (isValidPortNumber && portNumber == 443); + string uriScheme = isDefaultHttpsPort ? Uri.UriSchemeHttps : Uri.UriSchemeHttp; + int uriPort; + if (!isPortSet) + uriPort = -1; + else if (isValidPortNumber) + uriPort = portNumber; + else + throw new ArgumentOutOfRangeException(nameof(port), portNumber, $"Port number is not in a valid range."); + + Uri baseAddress = new UriBuilder(uriScheme, hostName, uriPort, path).Uri; + return baseAddress; + } + + // Note data source's Position may be one-indexed or zero-indexed + protected IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns + .Select(t => new { Index = t.Position - PositionRequiredOffset, t.ColumnName }) + .ToDictionary(t => t.ColumnName, t => t.Index); + + protected abstract Task<TRowSet> GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetCatalogsResp getCatalogsResp, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetSchemasResp getSchemasResp, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + + protected abstract bool AreResultsAvailableDirectly(); + + protected abstract TSparkGetDirectResults GetDirectResults(); Review Comment: @birschick-bq , pls review ########## csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs: ########## @@ -15,22 +15,29 @@ * limitations under the License. */ +using System; using System.Collections.Generic; +using System.Linq; namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { public class ImpalaDatabase : AdbcDatabase { readonly IReadOnlyDictionary<string, string> properties; - internal ImpalaDatabase(IReadOnlyDictionary<string, string> properties) + public ImpalaDatabase(IReadOnlyDictionary<string, string> properties) { this.properties = properties; } - public override AdbcConnection Connect(IReadOnlyDictionary<string, string>? properties) + public override AdbcConnection Connect(IReadOnlyDictionary<string, string>? options) { - ImpalaConnection connection = new ImpalaConnection(this.properties); + IReadOnlyDictionary<string, string> mergedProperties = options == null Review Comment: @birschick-bq , Can you please check. ########## csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs: ########## @@ -128,12 +360,279 @@ internal async Task OpenAsync() public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList<string>? tableTypes, string? columnNamePattern) { - throw new NotImplementedException(); + Dictionary<string, Dictionary<string, Dictionary<string, TableInfo>>> catalogMap = new Dictionary<string, Dictionary<string, Dictionary<string, TableInfo>>>(); + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + if (GetObjectsPatternsRequireLowerCase) + { + catalogPattern = catalogPattern?.ToLower(); + dbSchemaPattern = dbSchemaPattern?.ToLower(); + tableNamePattern = tableNamePattern?.ToLower(); + columnNamePattern = columnNamePattern?.ToLower(); + } + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) + { + TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); + if (AreResultsAvailableDirectly()) + { + getCatalogsReq.GetDirectResults = GetDirectResults(); + } + + TGetCatalogsResp getCatalogsResp = Client.GetCatalogs(getCatalogsReq, cancellationToken).Result; + + if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getCatalogsResp.Status.ErrorMessage); + } + var catalogsMetadata = GetResultSetMetadataAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(catalogsMetadata.Schema.Columns); + + string catalogRegexp = PatternToRegEx(catalogPattern); + TRowSet rowSet = GetRowSetAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyList<string> list = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + for (int i = 0; i < list.Count; i++) + { + string col = list[i]; + string catalog = col; + + if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) + { + catalogMap.Add(catalog, new Dictionary<string, Dictionary<string, TableInfo>>()); + } + } + // Handle the case where server does not support 'catalog' in the namespace. + if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern)) + { + catalogMap.Add(string.Empty, []); + } + } + + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) + { + TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); + getSchemasReq.CatalogName = catalogPattern; + getSchemasReq.SchemaName = dbSchemaPattern; + if (AreResultsAvailableDirectly()) + { + getSchemasReq.GetDirectResults = GetDirectResults(); + } + + TGetSchemasResp getSchemasResp = Client.GetSchemas(getSchemasReq, cancellationToken).Result; + if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getSchemasResp.Status.ErrorMessage); + } + + TGetResultSetMetadataResp schemaMetadata = GetResultSetMetadataAsync(getSchemasResp, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(schemaMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getSchemasResp, cancellationToken).Result; + + IReadOnlyList<string> catalogList = rowSet.Columns[columnMap[TableCatalog]].StringVal.Values; + IReadOnlyList<string> schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). + catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary<string, TableInfo>()); + } + } + + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) + { + TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle); + getTablesReq.CatalogName = catalogPattern; + getTablesReq.SchemaName = dbSchemaPattern; + getTablesReq.TableName = tableNamePattern; + if (AreResultsAvailableDirectly()) + { + getTablesReq.GetDirectResults = GetDirectResults(); + } + + TGetTablesResp getTablesResp = Client.GetTables(getTablesReq, cancellationToken).Result; + if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getTablesResp.Status.ErrorMessage); + } + + TGetResultSetMetadataResp tableMetadata = GetResultSetMetadataAsync(getTablesResp, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(tableMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getTablesResp, cancellationToken).Result; + + IReadOnlyList<string> catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + IReadOnlyList<string> schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + IReadOnlyList<string> tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; + IReadOnlyList<string> tableTypeList = rowSet.Columns[columnMap[TableType]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string tableType = tableTypeList[i]; + TableInfo tableInfo = new(tableType); + catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); + } + } + + if (depth == GetObjectsDepth.All) + { + TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle); + columnsReq.CatalogName = catalogPattern; + columnsReq.SchemaName = dbSchemaPattern; + columnsReq.TableName = tableNamePattern; + if (AreResultsAvailableDirectly()) + { + columnsReq.GetDirectResults = GetDirectResults(); + } + + if (!string.IsNullOrEmpty(columnNamePattern)) + columnsReq.ColumnName = columnNamePattern; + + var columnsResponse = Client.GetColumns(columnsReq, cancellationToken).Result; + if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(columnsResponse.Status.ErrorMessage); + } + + TGetResultSetMetadataResp columnsMetadata = GetResultSetMetadataAsync(columnsResponse, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(columnsMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(columnsResponse, cancellationToken).Result; + + ColumnsMetadataColumnNames columnNames = GetColumnsMetadataColumnNames(); + IReadOnlyList<string> catalogList = rowSet.Columns[columnMap[columnNames.TableCatalog]].StringVal.Values; + IReadOnlyList<string> schemaList = rowSet.Columns[columnMap[columnNames.TableSchema]].StringVal.Values; + IReadOnlyList<string> tableList = rowSet.Columns[columnMap[columnNames.TableName]].StringVal.Values; + IReadOnlyList<string> columnNameList = rowSet.Columns[columnMap[columnNames.ColumnName]].StringVal.Values; + ReadOnlySpan<int> columnTypeList = rowSet.Columns[columnMap[columnNames.DataType]].I32Val.Values.Values; + IReadOnlyList<string> typeNameList = rowSet.Columns[columnMap[columnNames.TypeName]].StringVal.Values; + ReadOnlySpan<int> nullableList = rowSet.Columns[columnMap[columnNames.Nullable]].I32Val.Values.Values; + IReadOnlyList<string> columnDefaultList = rowSet.Columns[columnMap[columnNames.ColumnDef]].StringVal.Values; + ReadOnlySpan<int> ordinalPosList = rowSet.Columns[columnMap[columnNames.OrdinalPosition]].I32Val.Values.Values; + IReadOnlyList<string> isNullableList = rowSet.Columns[columnMap[columnNames.IsNullable]].StringVal.Values; + IReadOnlyList<string> isAutoIncrementList = rowSet.Columns[columnMap[columnNames.IsAutoIncrement]].StringVal.Values; + ReadOnlySpan<int> columnSizeList = rowSet.Columns[columnMap[columnNames.ColumnSize]].I32Val.Values.Values; + ReadOnlySpan<int> decimalDigitsList = rowSet.Columns[columnMap[columnNames.DecimalDigits]].I32Val.Values.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + // For systems that don't support 'catalog' in the namespace + string catalog = catalogList[i] ?? string.Empty; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string columnName = columnNameList[i]; + short colType = (short)columnTypeList[i]; + string typeName = typeNameList[i]; + short nullable = (short)nullableList[i]; + string? isAutoIncrementString = isAutoIncrementList[i]; + bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); + string isNullable = isNullableList[i] ?? "YES"; + string columnDefault = columnDefaultList[i] ?? ""; + // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed + int ordinalPos = ordinalPosList[i] + PositionRequiredOffset; + int columnSize = columnSizeList[i]; + int decimalDigits = decimalDigitsList[i]; + TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); + tableInfo?.ColumnName.Add(columnName); + tableInfo?.ColType.Add(colType); + tableInfo?.Nullable.Add(nullable); + tableInfo?.IsAutoIncrement.Add(isAutoIncrement); + tableInfo?.IsNullable.Add(isNullable); + tableInfo?.ColumnDefault.Add(columnDefault); + tableInfo?.OrdinalPosition.Add(ordinalPos); + SetPrecisionScaleAndTypeName(colType, typeName, tableInfo, columnSize, decimalDigits); + } + } + + StringArray.Builder catalogNameBuilder = new StringArray.Builder(); + List<IArrowArray?> catalogDbSchemasValues = new List<IArrowArray?>(); + + foreach (KeyValuePair<string, Dictionary<string, Dictionary<string, TableInfo>>> catalogEntry in catalogMap) + { + catalogNameBuilder.Append(catalogEntry.Key); + + if (depth == GetObjectsDepth.Catalogs) + { + catalogDbSchemasValues.Add(null); + } + else + { + catalogDbSchemasValues.Add(GetDbSchemas( + depth, catalogEntry.Value)); + } + } + + Schema schema = StandardSchemas.GetObjectsSchema; + IReadOnlyList<IArrowArray> dataArrays = schema.Validate( + new List<IArrowArray> + { + catalogNameBuilder.Build(), + catalogDbSchemasValues.BuildListArrayForType(new StructType(StandardSchemas.DbSchemaSchema)), + }); + + return new HiveInfoArrowStream(schema, dataArrays); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); + } } public override IArrowArrayStream GetTableTypes() { - throw new NotImplementedException(); + TGetTableTypesReq req = new() + { + SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), + }; + + if (AreResultsAvailableDirectly()) + { + req.GetDirectResults = GetDirectResults(); + } + + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + TGetTableTypesResp resp = Client.GetTableTypes(req, cancellationToken).Result; + + if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(resp.Status.ErrorMessage) + .SetNativeError(resp.Status.ErrorCode) + .SetSqlState(resp.Status.SqlState); + } + + TRowSet rowSet = GetRowSetAsync(resp, cancellationToken).Result; + StringArray tableTypes = rowSet.Columns[0].StringVal.Values; + + StringArray.Builder tableTypesBuilder = new StringArray.Builder(); + tableTypesBuilder.AppendRange(tableTypes); + + IArrowArray[] dataArrays = new IArrowArray[] + { + tableTypesBuilder.Build() + }; + + return new HiveInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays); + } + catch (Exception ex) Review Comment: Done. Created ExceptionHelper class for this. ########## csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs: ########## @@ -0,0 +1,219 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Security; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Ipc; +using Apache.Hive.Service.Rpc.Thrift; +using Thrift; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Impala +{ + internal class ImpalaHttpConnection : ImpalaConnection + { + private const string BasicAuthenticationScheme = "Basic"; + + public ImpalaHttpConnection(IReadOnlyDictionary<string, string> properties) : base(properties) + { + } + + protected override void ValidateAuthentication() + { + // Validate authentication parameters + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(ImpalaParameters.AuthType, out string? authType); + bool isValidAuthType = ImpalaAuthTypeParser.TryParse(authType, out ImpalaAuthType authTypeValue); + switch (authTypeValue) + { + case ImpalaAuthType.Basic: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameter '{ImpalaParameters.AuthType}' is set to '{ImpalaAuthTypeConstants.Basic}' but parameters '{AdbcOptions.Username}' or '{AdbcOptions.Password}' are not set. Please provide a values for these parameters.", + nameof(Properties)); + break; + case ImpalaAuthType.UsernameOnly: + if (string.IsNullOrWhiteSpace(username)) + throw new ArgumentException( + $"Parameter '{ImpalaParameters.AuthType}' is set to '{ImpalaAuthTypeConstants.UsernameOnly}' but parameter '{AdbcOptions.Username}' is not set. Please provide a values for this parameter.", + nameof(Properties)); + break; + case ImpalaAuthType.None: + break; + case ImpalaAuthType.Empty: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameters must include valid authentiation settings. Please provide '{AdbcOptions.Username}' and '{AdbcOptions.Password}'.", + nameof(Properties)); + break; + default: + throw new ArgumentOutOfRangeException(ImpalaParameters.AuthType, authType, $"Unsupported {ImpalaParameters.AuthType} value."); + } + } + + protected override void ValidateConnection() + { + // HostName or Uri is required parameter + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + Properties.TryGetValue(ImpalaParameters.HostName, out string? hostName); + if ((Uri.CheckHostName(hostName) == UriHostNameType.Unknown) + && (string.IsNullOrEmpty(uri) || !Uri.TryCreate(uri, UriKind.Absolute, out Uri? _))) + { + throw new ArgumentException( + $"Required parameter '{ImpalaParameters.HostName}' or '{AdbcOptions.Uri}' is missing or invalid. Please provide a valid hostname or URI for the data source.", + nameof(Properties)); + } + + // Validate port range + Properties.TryGetValue(ImpalaParameters.Port, out string? port); + if (int.TryParse(port, out int portNumber) && (portNumber <= IPEndPoint.MinPort || portNumber > IPEndPoint.MaxPort)) + throw new ArgumentOutOfRangeException( + nameof(Properties), + port, + $"Parameter '{ImpalaParameters.Port}' value is not in the valid range of 1 .. {IPEndPoint.MaxPort}."); + + // Ensure the parameters will produce a valid address + Properties.TryGetValue(ImpalaParameters.Path, out string? path); + _ = new HttpClient() + { + BaseAddress = GetBaseAddress(uri, hostName, path, port) + }; + } + + protected override void ValidateOptions() + { + Properties.TryGetValue(ImpalaParameters.DataTypeConv, out string? dataTypeConv); + DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv); + Properties.TryGetValue(ImpalaParameters.TLSOptions, out string? tlsOptions); + TlsOptions = TlsOptionsParser.Parse(tlsOptions); + Properties.TryGetValue(ImpalaParameters.ConnectTimeoutMilliseconds, out string? connectTimeoutMs); + if (connectTimeoutMs != null) + { + ConnectTimeoutMilliseconds = int.TryParse(connectTimeoutMs, NumberStyles.Integer, CultureInfo.InvariantCulture, out int connectTimeoutMsValue) && (connectTimeoutMsValue >= 0) + ? connectTimeoutMsValue + : throw new ArgumentOutOfRangeException(ImpalaParameters.ConnectTimeoutMilliseconds, connectTimeoutMs, $"must be a value of 0 (infinite) or between 1 .. {int.MaxValue}. default is 30000 milliseconds."); + } + } + + internal override IArrowArrayStream NewReader<T>(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); + + protected override TTransport CreateTransport() + { + // Assumption: parameters have already been validated. + Properties.TryGetValue(ImpalaParameters.HostName, out string? hostName); + Properties.TryGetValue(ImpalaParameters.Path, out string? path); + Properties.TryGetValue(ImpalaParameters.Port, out string? port); + Properties.TryGetValue(ImpalaParameters.AuthType, out string? authType); + bool isValidAuthType = ImpalaAuthTypeParser.TryParse(authType, out ImpalaAuthType authTypeValue); + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + + Uri baseAddress = GetBaseAddress(uri, hostName, path, port); + AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, username, password); + + HttpClientHandler httpClientHandler = NewHttpClientHandler(); + HttpClient httpClient = new(httpClientHandler); + httpClient.BaseAddress = baseAddress; + httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue; + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent); + httpClient.DefaultRequestHeaders.AcceptEncoding.Clear(); + httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity")); + httpClient.DefaultRequestHeaders.ExpectContinue = false; + + TConfiguration config = new(); + ThriftHttpTransport transport = new(httpClient, config) + { + // This value can only be set before the first call/request. So if a new value for query timeout + // is set, we won't be able to update the value. Setting to ~infinite and relying on cancellation token + // to ensure cancelled correctly. + ConnectTimeout = int.MaxValue, + }; + return transport; + } + + private HttpClientHandler NewHttpClientHandler() + { + HttpClientHandler httpClientHandler = new(); + if (TlsOptions != HiveServer2TlsOption.Empty) + { + httpClientHandler.ServerCertificateCustomValidationCallback = (request, certificate, chain, policyErrors) => + { + if (policyErrors == SslPolicyErrors.None) return true; + + return Review Comment: Thanks for pointing this out. Done. ########## csharp/test/Drivers/Apache/Common/ComplexTypesValueTests.cs: ########## @@ -28,7 +28,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common /// <summary> /// Validates that specific complex structured types can be inserted, retrieved and targeted correctly /// </summary> - public class ComplexTypesValueTests<TConfig, TEnv> : TestBase<TConfig, TEnv> + public abstract class ComplexTypesValueTests<TConfig, TEnv> : TestBase<TConfig, TEnv> Review Comment: I believe we have ComplexTypesValueTests in Spark which derives this. @birschick-bq , pls confirm. ########## csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs: ########## @@ -36,6 +45,229 @@ internal abstract class HiveServer2Connection : AdbcConnection private readonly Lazy<string> _vendorVersion; private readonly Lazy<string> _vendorName; + readonly AdbcInfoCode[] infoSupportedCodes = [ + AdbcInfoCode.DriverName, + AdbcInfoCode.DriverVersion, + AdbcInfoCode.DriverArrowVersion, + AdbcInfoCode.VendorName, + AdbcInfoCode.VendorSql, + AdbcInfoCode.VendorVersion, + ]; + + internal const string ColumnDef = "COLUMN_DEF"; + internal const string ColumnName = "COLUMN_NAME"; + internal const string DataType = "DATA_TYPE"; + internal const string IsAutoIncrement = "IS_AUTO_INCREMENT"; + internal const string IsNullable = "IS_NULLABLE"; + internal const string OrdinalPosition = "ORDINAL_POSITION"; + internal const string TableCat = "TABLE_CAT"; + internal const string TableCatalog = "TABLE_CATALOG"; + internal const string TableName = "TABLE_NAME"; + internal const string TableSchem = "TABLE_SCHEM"; + internal const string TableMd = "TABLE_MD"; + internal const string TableType = "TABLE_TYPE"; + internal const string TypeName = "TYPE_NAME"; + internal const string Nullable = "NULLABLE"; + internal const string ColumnSize = "COLUMN_SIZE"; + internal const string DecimalDigits = "DECIMAL_DIGITS"; + internal const string BufferLength = "BUFFER_LENGTH"; + + /// <summary> + /// The GetColumns metadata call returns a result with different column names + /// on different data sources. Populate this structure with the actual column names. + /// </summary> + internal struct ColumnsMetadataColumnNames + { + public string TableCatalog { get; internal set; } + public string TableSchema { get; internal set; } + public string TableName { get; internal set; } + public string ColumnName { get; internal set; } + public string DataType { get; internal set; } + public string TypeName { get; internal set; } + public string Nullable { get; internal set; } + public string ColumnDef { get; internal set; } + public string OrdinalPosition { get; internal set; } + public string IsNullable { get; internal set; } + public string IsAutoIncrement { get; internal set; } + public string ColumnSize { get; set; } + public string DecimalDigits { get; set; } + } + + /// <summary> + /// The data type definitions based on the <see href="https://docs.oracle.com/en%2Fjava%2Fjavase%2F21%2Fdocs%2Fapi%2F%2F/java.sql/java/sql/Types.html">JDBC Types</see> constants. + /// </summary> + /// <remarks> + /// This enumeration can be used to determine the drivers specific data types that are contained in fields <c>xdbc_data_type</c> and <c>xdbc_sql_data_type</c> + /// in the column metadata <see cref="StandardSchemas.ColumnSchema"/>. This column metadata is returned as a result of a call to + /// <see cref="AdbcConnection.GetObjects(GetObjectsDepth, string?, string?, string?, IReadOnlyList{string}?, string?)"/> + /// when <c>depth</c> is set to <see cref="AdbcConnection.GetObjectsDepth.All"/>. + /// </remarks> + internal enum ColumnTypeId Review Comment: Sure, will take this up as a followup. ########## csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs: ########## @@ -128,12 +360,279 @@ internal async Task OpenAsync() public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList<string>? tableTypes, string? columnNamePattern) { - throw new NotImplementedException(); + Dictionary<string, Dictionary<string, Dictionary<string, TableInfo>>> catalogMap = new Dictionary<string, Dictionary<string, Dictionary<string, TableInfo>>>(); + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + if (GetObjectsPatternsRequireLowerCase) + { + catalogPattern = catalogPattern?.ToLower(); + dbSchemaPattern = dbSchemaPattern?.ToLower(); + tableNamePattern = tableNamePattern?.ToLower(); + columnNamePattern = columnNamePattern?.ToLower(); + } + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) + { + TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); + if (AreResultsAvailableDirectly()) + { + getCatalogsReq.GetDirectResults = GetDirectResults(); + } + + TGetCatalogsResp getCatalogsResp = Client.GetCatalogs(getCatalogsReq, cancellationToken).Result; + + if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getCatalogsResp.Status.ErrorMessage); + } + var catalogsMetadata = GetResultSetMetadataAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(catalogsMetadata.Schema.Columns); + + string catalogRegexp = PatternToRegEx(catalogPattern); + TRowSet rowSet = GetRowSetAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyList<string> list = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + for (int i = 0; i < list.Count; i++) + { + string col = list[i]; + string catalog = col; + + if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) + { + catalogMap.Add(catalog, new Dictionary<string, Dictionary<string, TableInfo>>()); + } + } + // Handle the case where server does not support 'catalog' in the namespace. + if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern)) + { + catalogMap.Add(string.Empty, []); + } + } + + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) + { + TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); + getSchemasReq.CatalogName = catalogPattern; + getSchemasReq.SchemaName = dbSchemaPattern; + if (AreResultsAvailableDirectly()) + { + getSchemasReq.GetDirectResults = GetDirectResults(); + } + + TGetSchemasResp getSchemasResp = Client.GetSchemas(getSchemasReq, cancellationToken).Result; + if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getSchemasResp.Status.ErrorMessage); + } + + TGetResultSetMetadataResp schemaMetadata = GetResultSetMetadataAsync(getSchemasResp, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(schemaMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getSchemasResp, cancellationToken).Result; + + IReadOnlyList<string> catalogList = rowSet.Columns[columnMap[TableCatalog]].StringVal.Values; + IReadOnlyList<string> schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). + catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary<string, TableInfo>()); + } + } + + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) + { + TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle); + getTablesReq.CatalogName = catalogPattern; + getTablesReq.SchemaName = dbSchemaPattern; + getTablesReq.TableName = tableNamePattern; + if (AreResultsAvailableDirectly()) + { + getTablesReq.GetDirectResults = GetDirectResults(); + } + + TGetTablesResp getTablesResp = Client.GetTables(getTablesReq, cancellationToken).Result; + if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getTablesResp.Status.ErrorMessage); + } + + TGetResultSetMetadataResp tableMetadata = GetResultSetMetadataAsync(getTablesResp, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(tableMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getTablesResp, cancellationToken).Result; + + IReadOnlyList<string> catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + IReadOnlyList<string> schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + IReadOnlyList<string> tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; + IReadOnlyList<string> tableTypeList = rowSet.Columns[columnMap[TableType]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string tableType = tableTypeList[i]; + TableInfo tableInfo = new(tableType); + catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); + } + } + + if (depth == GetObjectsDepth.All) + { + TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle); + columnsReq.CatalogName = catalogPattern; + columnsReq.SchemaName = dbSchemaPattern; + columnsReq.TableName = tableNamePattern; + if (AreResultsAvailableDirectly()) + { + columnsReq.GetDirectResults = GetDirectResults(); + } + + if (!string.IsNullOrEmpty(columnNamePattern)) + columnsReq.ColumnName = columnNamePattern; + + var columnsResponse = Client.GetColumns(columnsReq, cancellationToken).Result; + if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(columnsResponse.Status.ErrorMessage); + } + + TGetResultSetMetadataResp columnsMetadata = GetResultSetMetadataAsync(columnsResponse, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(columnsMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(columnsResponse, cancellationToken).Result; + + ColumnsMetadataColumnNames columnNames = GetColumnsMetadataColumnNames(); + IReadOnlyList<string> catalogList = rowSet.Columns[columnMap[columnNames.TableCatalog]].StringVal.Values; + IReadOnlyList<string> schemaList = rowSet.Columns[columnMap[columnNames.TableSchema]].StringVal.Values; + IReadOnlyList<string> tableList = rowSet.Columns[columnMap[columnNames.TableName]].StringVal.Values; + IReadOnlyList<string> columnNameList = rowSet.Columns[columnMap[columnNames.ColumnName]].StringVal.Values; + ReadOnlySpan<int> columnTypeList = rowSet.Columns[columnMap[columnNames.DataType]].I32Val.Values.Values; + IReadOnlyList<string> typeNameList = rowSet.Columns[columnMap[columnNames.TypeName]].StringVal.Values; + ReadOnlySpan<int> nullableList = rowSet.Columns[columnMap[columnNames.Nullable]].I32Val.Values.Values; + IReadOnlyList<string> columnDefaultList = rowSet.Columns[columnMap[columnNames.ColumnDef]].StringVal.Values; + ReadOnlySpan<int> ordinalPosList = rowSet.Columns[columnMap[columnNames.OrdinalPosition]].I32Val.Values.Values; + IReadOnlyList<string> isNullableList = rowSet.Columns[columnMap[columnNames.IsNullable]].StringVal.Values; + IReadOnlyList<string> isAutoIncrementList = rowSet.Columns[columnMap[columnNames.IsAutoIncrement]].StringVal.Values; + ReadOnlySpan<int> columnSizeList = rowSet.Columns[columnMap[columnNames.ColumnSize]].I32Val.Values.Values; + ReadOnlySpan<int> decimalDigitsList = rowSet.Columns[columnMap[columnNames.DecimalDigits]].I32Val.Values.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + // For systems that don't support 'catalog' in the namespace + string catalog = catalogList[i] ?? string.Empty; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string columnName = columnNameList[i]; + short colType = (short)columnTypeList[i]; + string typeName = typeNameList[i]; + short nullable = (short)nullableList[i]; + string? isAutoIncrementString = isAutoIncrementList[i]; + bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); + string isNullable = isNullableList[i] ?? "YES"; + string columnDefault = columnDefaultList[i] ?? ""; + // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed + int ordinalPos = ordinalPosList[i] + PositionRequiredOffset; + int columnSize = columnSizeList[i]; + int decimalDigits = decimalDigitsList[i]; + TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); + tableInfo?.ColumnName.Add(columnName); + tableInfo?.ColType.Add(colType); + tableInfo?.Nullable.Add(nullable); + tableInfo?.IsAutoIncrement.Add(isAutoIncrement); + tableInfo?.IsNullable.Add(isNullable); + tableInfo?.ColumnDefault.Add(columnDefault); + tableInfo?.OrdinalPosition.Add(ordinalPos); + SetPrecisionScaleAndTypeName(colType, typeName, tableInfo, columnSize, decimalDigits); + } + } + + StringArray.Builder catalogNameBuilder = new StringArray.Builder(); + List<IArrowArray?> catalogDbSchemasValues = new List<IArrowArray?>(); + + foreach (KeyValuePair<string, Dictionary<string, Dictionary<string, TableInfo>>> catalogEntry in catalogMap) + { + catalogNameBuilder.Append(catalogEntry.Key); + + if (depth == GetObjectsDepth.Catalogs) + { + catalogDbSchemasValues.Add(null); + } + else + { + catalogDbSchemasValues.Add(GetDbSchemas( + depth, catalogEntry.Value)); + } + } + + Schema schema = StandardSchemas.GetObjectsSchema; + IReadOnlyList<IArrowArray> dataArrays = schema.Validate( + new List<IArrowArray> + { + catalogNameBuilder.Build(), + catalogDbSchemasValues.BuildListArrayForType(new StructType(StandardSchemas.DbSchemaSchema)), + }); + + return new HiveInfoArrowStream(schema, dataArrays); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); + } } public override IArrowArrayStream GetTableTypes() { - throw new NotImplementedException(); + TGetTableTypesReq req = new() + { + SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), + }; + + if (AreResultsAvailableDirectly()) + { + req.GetDirectResults = GetDirectResults(); + } + + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + TGetTableTypesResp resp = Client.GetTableTypes(req, cancellationToken).Result; + + if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(resp.Status.ErrorMessage) + .SetNativeError(resp.Status.ErrorCode) + .SetSqlState(resp.Status.SqlState); + } + + TRowSet rowSet = GetRowSetAsync(resp, cancellationToken).Result; + StringArray tableTypes = rowSet.Columns[0].StringVal.Values; + + StringArray.Builder tableTypesBuilder = new StringArray.Builder(); + tableTypesBuilder.AppendRange(tableTypes); + + IArrowArray[] dataArrays = new IArrowArray[] + { + tableTypesBuilder.Build() + }; + + return new HiveInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays); + } + catch (Exception ex) Review Comment: @birschick-bq pls review ########## csharp/test/Drivers/Apache/Spark/DriverTests.cs: ########## @@ -74,5 +128,27 @@ public static IEnumerable<object[]> TableNamePatternData() string? tableName = new DriverTests(null).TestConfiguration?.Metadata?.Table; return GetPatterns(tableName); } + + protected override bool TypeHasDecimalDigits(Metadata.AdbcColumn column) + { + HashSet<short> typesHaveDecimalDigits = new() + { + (short)SupportedSparkDataType.DECIMAL, + (short)SupportedSparkDataType.NUMERIC, + }; + return typesHaveDecimalDigits.Contains(column.XdbcDataType!.Value); Review Comment: Done. @birschick-bq , pls check once. ########## csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs: ########## @@ -0,0 +1,219 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Security; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Ipc; +using Apache.Hive.Service.Rpc.Thrift; +using Thrift; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Impala +{ + internal class ImpalaHttpConnection : ImpalaConnection + { + private const string BasicAuthenticationScheme = "Basic"; + + public ImpalaHttpConnection(IReadOnlyDictionary<string, string> properties) : base(properties) + { + } + + protected override void ValidateAuthentication() + { + // Validate authentication parameters + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(ImpalaParameters.AuthType, out string? authType); + bool isValidAuthType = ImpalaAuthTypeParser.TryParse(authType, out ImpalaAuthType authTypeValue); + switch (authTypeValue) + { + case ImpalaAuthType.Basic: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameter '{ImpalaParameters.AuthType}' is set to '{ImpalaAuthTypeConstants.Basic}' but parameters '{AdbcOptions.Username}' or '{AdbcOptions.Password}' are not set. Please provide a values for these parameters.", + nameof(Properties)); + break; + case ImpalaAuthType.UsernameOnly: + if (string.IsNullOrWhiteSpace(username)) + throw new ArgumentException( + $"Parameter '{ImpalaParameters.AuthType}' is set to '{ImpalaAuthTypeConstants.UsernameOnly}' but parameter '{AdbcOptions.Username}' is not set. Please provide a values for this parameter.", + nameof(Properties)); + break; + case ImpalaAuthType.None: + break; + case ImpalaAuthType.Empty: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameters must include valid authentiation settings. Please provide '{AdbcOptions.Username}' and '{AdbcOptions.Password}'.", + nameof(Properties)); + break; + default: + throw new ArgumentOutOfRangeException(ImpalaParameters.AuthType, authType, $"Unsupported {ImpalaParameters.AuthType} value."); + } + } + + protected override void ValidateConnection() + { + // HostName or Uri is required parameter + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + Properties.TryGetValue(ImpalaParameters.HostName, out string? hostName); + if ((Uri.CheckHostName(hostName) == UriHostNameType.Unknown) + && (string.IsNullOrEmpty(uri) || !Uri.TryCreate(uri, UriKind.Absolute, out Uri? _))) + { + throw new ArgumentException( + $"Required parameter '{ImpalaParameters.HostName}' or '{AdbcOptions.Uri}' is missing or invalid. Please provide a valid hostname or URI for the data source.", + nameof(Properties)); + } + + // Validate port range + Properties.TryGetValue(ImpalaParameters.Port, out string? port); + if (int.TryParse(port, out int portNumber) && (portNumber <= IPEndPoint.MinPort || portNumber > IPEndPoint.MaxPort)) + throw new ArgumentOutOfRangeException( + nameof(Properties), + port, + $"Parameter '{ImpalaParameters.Port}' value is not in the valid range of 1 .. {IPEndPoint.MaxPort}."); + + // Ensure the parameters will produce a valid address + Properties.TryGetValue(ImpalaParameters.Path, out string? path); + _ = new HttpClient() + { + BaseAddress = GetBaseAddress(uri, hostName, path, port) + }; + } + + protected override void ValidateOptions() + { + Properties.TryGetValue(ImpalaParameters.DataTypeConv, out string? dataTypeConv); + DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv); + Properties.TryGetValue(ImpalaParameters.TLSOptions, out string? tlsOptions); + TlsOptions = TlsOptionsParser.Parse(tlsOptions); + Properties.TryGetValue(ImpalaParameters.ConnectTimeoutMilliseconds, out string? connectTimeoutMs); + if (connectTimeoutMs != null) + { + ConnectTimeoutMilliseconds = int.TryParse(connectTimeoutMs, NumberStyles.Integer, CultureInfo.InvariantCulture, out int connectTimeoutMsValue) && (connectTimeoutMsValue >= 0) + ? connectTimeoutMsValue + : throw new ArgumentOutOfRangeException(ImpalaParameters.ConnectTimeoutMilliseconds, connectTimeoutMs, $"must be a value of 0 (infinite) or between 1 .. {int.MaxValue}. default is 30000 milliseconds."); + } + } + + internal override IArrowArrayStream NewReader<T>(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); + + protected override TTransport CreateTransport() + { + // Assumption: parameters have already been validated. + Properties.TryGetValue(ImpalaParameters.HostName, out string? hostName); + Properties.TryGetValue(ImpalaParameters.Path, out string? path); + Properties.TryGetValue(ImpalaParameters.Port, out string? port); + Properties.TryGetValue(ImpalaParameters.AuthType, out string? authType); + bool isValidAuthType = ImpalaAuthTypeParser.TryParse(authType, out ImpalaAuthType authTypeValue); + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + + Uri baseAddress = GetBaseAddress(uri, hostName, path, port); + AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, username, password); + + HttpClientHandler httpClientHandler = NewHttpClientHandler(); + HttpClient httpClient = new(httpClientHandler); + httpClient.BaseAddress = baseAddress; + httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue; + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent); + httpClient.DefaultRequestHeaders.AcceptEncoding.Clear(); + httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity")); + httpClient.DefaultRequestHeaders.ExpectContinue = false; + + TConfiguration config = new(); + ThriftHttpTransport transport = new(httpClient, config) + { + // This value can only be set before the first call/request. So if a new value for query timeout + // is set, we won't be able to update the value. Setting to ~infinite and relying on cancellation token + // to ensure cancelled correctly. + ConnectTimeout = int.MaxValue, + }; + return transport; + } + + private HttpClientHandler NewHttpClientHandler() + { + HttpClientHandler httpClientHandler = new(); + if (TlsOptions != HiveServer2TlsOption.Empty) + { + httpClientHandler.ServerCertificateCustomValidationCallback = (request, certificate, chain, policyErrors) => + { + if (policyErrors == SslPolicyErrors.None) return true; + + return + (!policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors) || TlsOptions.HasFlag(HiveServer2TlsOption.AllowSelfSigned)) + && (!policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch) || TlsOptions.HasFlag(HiveServer2TlsOption.AllowHostnameMismatch)); + }; + } + + return httpClientHandler; + } + + private static AuthenticationHeaderValue? GetAuthenticationHeaderValue(ImpalaAuthType authType, string? username, string? password) + { + if (!string.IsNullOrEmpty(username) && !string.IsNullOrEmpty(password) && (authType == ImpalaAuthType.Empty || authType == ImpalaAuthType.Basic)) + { + return new AuthenticationHeaderValue(BasicAuthenticationScheme, Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:{password}"))); + } + else if (!string.IsNullOrEmpty(username) && (authType == ImpalaAuthType.Empty || authType == ImpalaAuthType.UsernameOnly)) + { + return new AuthenticationHeaderValue(BasicAuthenticationScheme, Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:"))); + } + else if (authType == ImpalaAuthType.None) + { + return null; + } + else + { + throw new AdbcException("Missing connection properties. Must contain 'username' and 'password'"); + } + } + + protected override async Task<TProtocol> CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default) + { + Trace.TraceError($"create protocol with {Properties.Count} properties."); Review Comment: Agreed. Removed now. ########## csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs: ########## @@ -208,5 +707,616 @@ internal static async Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync( TGetResultSetMetadataResp response = await client.GetResultSetMetadata(request, cancellationToken); return response; } + + /// <summary> + /// Gets the data-source specific columns names for the GetColumns metadata result. + /// </summary> + /// <returns></returns> + protected abstract ColumnsMetadataColumnNames GetColumnsMetadataColumnNames(); + + /// <summary> + /// Gets the default product version + /// </summary> + /// <returns></returns> + protected abstract string GetProductVersionDefault(); + + /// <summary> + /// Gets the current product version. + /// </summary> + /// <returns></returns> + protected internal string GetProductVersion() + { + FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location); + return fileVersionInfo.ProductVersion ?? GetProductVersionDefault(); + } + + protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, string? port) + { + // Uri property takes precedent. + if (!string.IsNullOrWhiteSpace(uri)) + { + var uriValue = new Uri(uri); + if (uriValue.Scheme != Uri.UriSchemeHttp && uriValue.Scheme != Uri.UriSchemeHttps) + throw new ArgumentOutOfRangeException( + AdbcOptions.Uri, + uri, + $"Unsupported scheme '{uriValue.Scheme}'"); + return uriValue; + } + + bool isPortSet = !string.IsNullOrEmpty(port); + bool isValidPortNumber = int.TryParse(port, out int portNumber) && portNumber > 0; + bool isDefaultHttpsPort = !isPortSet || (isValidPortNumber && portNumber == 443); + string uriScheme = isDefaultHttpsPort ? Uri.UriSchemeHttps : Uri.UriSchemeHttp; + int uriPort; + if (!isPortSet) + uriPort = -1; + else if (isValidPortNumber) + uriPort = portNumber; + else + throw new ArgumentOutOfRangeException(nameof(port), portNumber, $"Port number is not in a valid range."); + + Uri baseAddress = new UriBuilder(uriScheme, hostName, uriPort, path).Uri; + return baseAddress; + } + + // Note data source's Position may be one-indexed or zero-indexed + protected IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns + .Select(t => new { Index = t.Position - PositionRequiredOffset, t.ColumnName }) + .ToDictionary(t => t.ColumnName, t => t.Index); + + protected abstract Task<TRowSet> GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetCatalogsResp getCatalogsResp, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetSchemasResp getSchemasResp, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + + protected abstract bool AreResultsAvailableDirectly(); + + protected abstract TSparkGetDirectResults GetDirectResults(); + + protected internal abstract int PositionRequiredOffset { get; } + + protected abstract string InfoDriverName { get; } + + protected abstract string InfoDriverArrowVersion { get; } + + protected abstract string ProductVersion { get; } + + protected abstract bool GetObjectsPatternsRequireLowerCase { get; } + + protected abstract bool IsColumnSizeValidForDecimal { get; } + + private static string PatternToRegEx(string? pattern) + { + if (pattern == null) + return ".*"; + + StringBuilder builder = new StringBuilder("(?i)^"); + string convertedPattern = pattern.Replace("_", ".").Replace("%", ".*"); + builder.Append(convertedPattern); + builder.Append('$'); + + return builder.ToString(); + } + + private static StructArray GetDbSchemas( + GetObjectsDepth depth, + Dictionary<string, Dictionary<string, TableInfo>> schemaMap) + { + StringArray.Builder dbSchemaNameBuilder = new StringArray.Builder(); + List<IArrowArray?> dbSchemaTablesValues = new List<IArrowArray?>(); + ArrowBuffer.BitmapBuilder nullBitmapBuffer = new ArrowBuffer.BitmapBuilder(); + int length = 0; + + Review Comment: Thanks for pointing this out. Removed the blank lines. ########## csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs: ########## @@ -15,22 +15,29 @@ * limitations under the License. */ +using System; using System.Collections.Generic; +using System.Linq; namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { public class ImpalaDatabase : AdbcDatabase { readonly IReadOnlyDictionary<string, string> properties; - internal ImpalaDatabase(IReadOnlyDictionary<string, string> properties) + public ImpalaDatabase(IReadOnlyDictionary<string, string> properties) Review Comment: I have changed it to internal. @birschick-bq, pls review. ########## csharp/test/Drivers/Apache/Common/BinaryBooleanValueTests.cs: ########## @@ -46,26 +47,47 @@ public static IEnumerable<object[]> ByteArrayData(int size) yield return new object[] { bytes }; } + public static IEnumerable<object[]> AsciiArrayData(int size) + { + const string values = "abcdefghijklmnopqrstuvwxyz0123456789"; + StringBuilder builder = new StringBuilder(); + Random rnd = new(); + byte[] bytes = new byte[size]; + for (int i = 0; i < size; i++) + { + Review Comment: Removed. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org