CurtHagenlocher commented on code in PR #2365: URL: https://github.com/apache/arrow-adbc/pull/2365#discussion_r1914888621
########## csharp/test/Drivers/Apache/Impala/StatementTests.cs: ########## @@ -0,0 +1,51 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using Apache.Arrow.Adbc.Tests.Drivers.Apache.Common; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala +{ + public class StatementTests : Common.StatementTests<ApacheTestConfiguration, ImpalaTestEnvironment> + { + public StatementTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new ImpalaTestEnvironment.Factory()) + { + } + + [SkippableTheory] + [ClassData(typeof(LongRunningStatementTimeoutTestData))] + internal override void StatementTimeoutTest(StatementWithExceptions statementWithExceptions) + { + base.StatementTimeoutTest(statementWithExceptions); + } + + internal class LongRunningStatementTimeoutTestData : ShortRunningStatementTimeoutTestData + { + public LongRunningStatementTimeoutTestData() + { + // TODO: Determine if this long-running query will work as expected on Impala. Review Comment: Should there be an issue filed to follow up? Is it worth even including these overrides given that they don't do anything extra? ########## csharp/test/Drivers/Apache/Common/StringValueTests.cs: ########## @@ -90,15 +91,15 @@ await ValidateInsertSelectDeleteSingleValueAsync( [SkippableTheory] [InlineData(null)] [InlineData("")] - [InlineData("你好")] [InlineData(" Leading and trailing spaces ")] - protected virtual async Task TestCharData(string? value) + internal virtual async Task TestCharData(string? value) Review Comment: I see this same pattern is used in multiple places. I'm a little suspicious of the interaction between `InlineData` and virtual functions. Have we validated that the override will incorporate all the `InlineData` from the base method in addition to its new data? (I tend to think that even if this does work, it's better to refactor into a helper method that just does the validation and multiple test methods that call the helper. It's easier on the reader of the code.) ########## csharp/test/Drivers/Apache/Common/BinaryBooleanValueTests.cs: ########## @@ -46,26 +47,47 @@ public static IEnumerable<object[]> ByteArrayData(int size) yield return new object[] { bytes }; } + public static IEnumerable<object[]> AsciiArrayData(int size) + { + const string values = "abcdefghijklmnopqrstuvwxyz0123456789"; + StringBuilder builder = new StringBuilder(); + Random rnd = new(); + byte[] bytes = new byte[size]; + for (int i = 0; i < size; i++) + { + Review Comment: nit: extra blank line ########## csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs: ########## @@ -208,5 +707,616 @@ internal static async Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync( TGetResultSetMetadataResp response = await client.GetResultSetMetadata(request, cancellationToken); return response; } + + /// <summary> + /// Gets the data-source specific columns names for the GetColumns metadata result. + /// </summary> + /// <returns></returns> + protected abstract ColumnsMetadataColumnNames GetColumnsMetadataColumnNames(); + + /// <summary> + /// Gets the default product version + /// </summary> + /// <returns></returns> + protected abstract string GetProductVersionDefault(); + + /// <summary> + /// Gets the current product version. + /// </summary> + /// <returns></returns> + protected internal string GetProductVersion() + { + FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location); + return fileVersionInfo.ProductVersion ?? GetProductVersionDefault(); + } + + protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, string? port) + { + // Uri property takes precedent. + if (!string.IsNullOrWhiteSpace(uri)) + { + var uriValue = new Uri(uri); + if (uriValue.Scheme != Uri.UriSchemeHttp && uriValue.Scheme != Uri.UriSchemeHttps) + throw new ArgumentOutOfRangeException( + AdbcOptions.Uri, + uri, + $"Unsupported scheme '{uriValue.Scheme}'"); + return uriValue; + } + + bool isPortSet = !string.IsNullOrEmpty(port); + bool isValidPortNumber = int.TryParse(port, out int portNumber) && portNumber > 0; + bool isDefaultHttpsPort = !isPortSet || (isValidPortNumber && portNumber == 443); + string uriScheme = isDefaultHttpsPort ? Uri.UriSchemeHttps : Uri.UriSchemeHttp; + int uriPort; + if (!isPortSet) + uriPort = -1; + else if (isValidPortNumber) + uriPort = portNumber; + else + throw new ArgumentOutOfRangeException(nameof(port), portNumber, $"Port number is not in a valid range."); + + Uri baseAddress = new UriBuilder(uriScheme, hostName, uriPort, path).Uri; + return baseAddress; + } + + // Note data source's Position may be one-indexed or zero-indexed + protected IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns + .Select(t => new { Index = t.Position - PositionRequiredOffset, t.ColumnName }) + .ToDictionary(t => t.ColumnName, t => t.Index); + + protected abstract Task<TRowSet> GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetCatalogsResp getCatalogsResp, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetSchemasResp getSchemasResp, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + + protected abstract bool AreResultsAvailableDirectly(); + + protected abstract TSparkGetDirectResults GetDirectResults(); Review Comment: It seems somewhat unlikely that a method returning `TSparkGetDirectResults` belongs in the base class implementation. Is there a better factoring possible here? ########## csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs: ########## @@ -15,22 +15,29 @@ * limitations under the License. */ +using System; using System.Collections.Generic; +using System.Linq; namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { public class ImpalaDatabase : AdbcDatabase { readonly IReadOnlyDictionary<string, string> properties; - internal ImpalaDatabase(IReadOnlyDictionary<string, string> properties) + public ImpalaDatabase(IReadOnlyDictionary<string, string> properties) Review Comment: I'm not sure how we ended up with so much inconsistency around which types are public. The database-specific Database classes are public while the Connection and Statement classes are not :(. ########## csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs: ########## @@ -0,0 +1,219 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Security; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Ipc; +using Apache.Hive.Service.Rpc.Thrift; +using Thrift; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Impala +{ + internal class ImpalaHttpConnection : ImpalaConnection + { + private const string BasicAuthenticationScheme = "Basic"; + + public ImpalaHttpConnection(IReadOnlyDictionary<string, string> properties) : base(properties) + { + } + + protected override void ValidateAuthentication() + { + // Validate authentication parameters + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(ImpalaParameters.AuthType, out string? authType); + bool isValidAuthType = ImpalaAuthTypeParser.TryParse(authType, out ImpalaAuthType authTypeValue); + switch (authTypeValue) + { + case ImpalaAuthType.Basic: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameter '{ImpalaParameters.AuthType}' is set to '{ImpalaAuthTypeConstants.Basic}' but parameters '{AdbcOptions.Username}' or '{AdbcOptions.Password}' are not set. Please provide a values for these parameters.", + nameof(Properties)); + break; + case ImpalaAuthType.UsernameOnly: + if (string.IsNullOrWhiteSpace(username)) + throw new ArgumentException( + $"Parameter '{ImpalaParameters.AuthType}' is set to '{ImpalaAuthTypeConstants.UsernameOnly}' but parameter '{AdbcOptions.Username}' is not set. Please provide a values for this parameter.", + nameof(Properties)); + break; + case ImpalaAuthType.None: + break; + case ImpalaAuthType.Empty: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameters must include valid authentiation settings. Please provide '{AdbcOptions.Username}' and '{AdbcOptions.Password}'.", + nameof(Properties)); + break; + default: + throw new ArgumentOutOfRangeException(ImpalaParameters.AuthType, authType, $"Unsupported {ImpalaParameters.AuthType} value."); + } + } + + protected override void ValidateConnection() + { + // HostName or Uri is required parameter + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + Properties.TryGetValue(ImpalaParameters.HostName, out string? hostName); + if ((Uri.CheckHostName(hostName) == UriHostNameType.Unknown) + && (string.IsNullOrEmpty(uri) || !Uri.TryCreate(uri, UriKind.Absolute, out Uri? _))) + { + throw new ArgumentException( + $"Required parameter '{ImpalaParameters.HostName}' or '{AdbcOptions.Uri}' is missing or invalid. Please provide a valid hostname or URI for the data source.", + nameof(Properties)); + } + + // Validate port range + Properties.TryGetValue(ImpalaParameters.Port, out string? port); + if (int.TryParse(port, out int portNumber) && (portNumber <= IPEndPoint.MinPort || portNumber > IPEndPoint.MaxPort)) + throw new ArgumentOutOfRangeException( + nameof(Properties), + port, + $"Parameter '{ImpalaParameters.Port}' value is not in the valid range of 1 .. {IPEndPoint.MaxPort}."); + + // Ensure the parameters will produce a valid address + Properties.TryGetValue(ImpalaParameters.Path, out string? path); + _ = new HttpClient() + { + BaseAddress = GetBaseAddress(uri, hostName, path, port) + }; + } + + protected override void ValidateOptions() + { + Properties.TryGetValue(ImpalaParameters.DataTypeConv, out string? dataTypeConv); + DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv); + Properties.TryGetValue(ImpalaParameters.TLSOptions, out string? tlsOptions); + TlsOptions = TlsOptionsParser.Parse(tlsOptions); + Properties.TryGetValue(ImpalaParameters.ConnectTimeoutMilliseconds, out string? connectTimeoutMs); + if (connectTimeoutMs != null) + { + ConnectTimeoutMilliseconds = int.TryParse(connectTimeoutMs, NumberStyles.Integer, CultureInfo.InvariantCulture, out int connectTimeoutMsValue) && (connectTimeoutMsValue >= 0) + ? connectTimeoutMsValue + : throw new ArgumentOutOfRangeException(ImpalaParameters.ConnectTimeoutMilliseconds, connectTimeoutMs, $"must be a value of 0 (infinite) or between 1 .. {int.MaxValue}. default is 30000 milliseconds."); + } + } + + internal override IArrowArrayStream NewReader<T>(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); + + protected override TTransport CreateTransport() + { + // Assumption: parameters have already been validated. + Properties.TryGetValue(ImpalaParameters.HostName, out string? hostName); + Properties.TryGetValue(ImpalaParameters.Path, out string? path); + Properties.TryGetValue(ImpalaParameters.Port, out string? port); + Properties.TryGetValue(ImpalaParameters.AuthType, out string? authType); + bool isValidAuthType = ImpalaAuthTypeParser.TryParse(authType, out ImpalaAuthType authTypeValue); + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + + Uri baseAddress = GetBaseAddress(uri, hostName, path, port); + AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, username, password); + + HttpClientHandler httpClientHandler = NewHttpClientHandler(); + HttpClient httpClient = new(httpClientHandler); + httpClient.BaseAddress = baseAddress; + httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue; + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent); + httpClient.DefaultRequestHeaders.AcceptEncoding.Clear(); + httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity")); + httpClient.DefaultRequestHeaders.ExpectContinue = false; + + TConfiguration config = new(); + ThriftHttpTransport transport = new(httpClient, config) + { + // This value can only be set before the first call/request. So if a new value for query timeout + // is set, we won't be able to update the value. Setting to ~infinite and relying on cancellation token + // to ensure cancelled correctly. + ConnectTimeout = int.MaxValue, + }; + return transport; + } + + private HttpClientHandler NewHttpClientHandler() + { + HttpClientHandler httpClientHandler = new(); + if (TlsOptions != HiveServer2TlsOption.Empty) + { + httpClientHandler.ServerCertificateCustomValidationCallback = (request, certificate, chain, policyErrors) => + { + if (policyErrors == SslPolicyErrors.None) return true; + + return Review Comment: Consider rewriting this logic to be less terse and more readable, e.g. ``` if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors) && !TlsOptions.HasFlag(HiveServer2TlsOption.AllowSelfSigned)) return false; if (policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch) && !TlsOptions.HasFlag(HiveServer2TlsOption.AllowHostnameMismatch)) return false; return true; ``` ########## csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs: ########## @@ -15,22 +15,29 @@ * limitations under the License. */ +using System; using System.Collections.Generic; +using System.Linq; namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { public class ImpalaDatabase : AdbcDatabase { readonly IReadOnlyDictionary<string, string> properties; - internal ImpalaDatabase(IReadOnlyDictionary<string, string> properties) + public ImpalaDatabase(IReadOnlyDictionary<string, string> properties) { this.properties = properties; } - public override AdbcConnection Connect(IReadOnlyDictionary<string, string>? properties) + public override AdbcConnection Connect(IReadOnlyDictionary<string, string>? options) { - ImpalaConnection connection = new ImpalaConnection(this.properties); + IReadOnlyDictionary<string, string> mergedProperties = options == null Review Comment: I know that `SparkDatabase` is doing this too, but it seems semantically wrong to me to mix the database-level properties with the connection options inside the same dictionary. @davidhcoe? ########## csharp/test/Drivers/Apache/Spark/DriverTests.cs: ########## @@ -74,5 +128,27 @@ public static IEnumerable<object[]> TableNamePatternData() string? tableName = new DriverTests(null).TestConfiguration?.Metadata?.Table; return GetPatterns(tableName); } + + protected override bool TypeHasDecimalDigits(Metadata.AdbcColumn column) + { + HashSet<short> typesHaveDecimalDigits = new() + { + (short)SupportedSparkDataType.DECIMAL, + (short)SupportedSparkDataType.NUMERIC, + }; + return typesHaveDecimalDigits.Contains(column.XdbcDataType!.Value); Review Comment: This is test code and so not performance critical. But for what it's worth, if the hash set is being built each time then it's probably a bit slower than a switch statement. ########## csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs: ########## @@ -36,6 +45,229 @@ internal abstract class HiveServer2Connection : AdbcConnection private readonly Lazy<string> _vendorVersion; private readonly Lazy<string> _vendorName; + readonly AdbcInfoCode[] infoSupportedCodes = [ + AdbcInfoCode.DriverName, + AdbcInfoCode.DriverVersion, + AdbcInfoCode.DriverArrowVersion, + AdbcInfoCode.VendorName, + AdbcInfoCode.VendorSql, + AdbcInfoCode.VendorVersion, + ]; + + internal const string ColumnDef = "COLUMN_DEF"; + internal const string ColumnName = "COLUMN_NAME"; + internal const string DataType = "DATA_TYPE"; + internal const string IsAutoIncrement = "IS_AUTO_INCREMENT"; + internal const string IsNullable = "IS_NULLABLE"; + internal const string OrdinalPosition = "ORDINAL_POSITION"; + internal const string TableCat = "TABLE_CAT"; + internal const string TableCatalog = "TABLE_CATALOG"; + internal const string TableName = "TABLE_NAME"; + internal const string TableSchem = "TABLE_SCHEM"; + internal const string TableMd = "TABLE_MD"; + internal const string TableType = "TABLE_TYPE"; + internal const string TypeName = "TYPE_NAME"; + internal const string Nullable = "NULLABLE"; + internal const string ColumnSize = "COLUMN_SIZE"; + internal const string DecimalDigits = "DECIMAL_DIGITS"; + internal const string BufferLength = "BUFFER_LENGTH"; + + /// <summary> + /// The GetColumns metadata call returns a result with different column names + /// on different data sources. Populate this structure with the actual column names. + /// </summary> + internal struct ColumnsMetadataColumnNames + { + public string TableCatalog { get; internal set; } + public string TableSchema { get; internal set; } + public string TableName { get; internal set; } + public string ColumnName { get; internal set; } + public string DataType { get; internal set; } + public string TypeName { get; internal set; } + public string Nullable { get; internal set; } + public string ColumnDef { get; internal set; } + public string OrdinalPosition { get; internal set; } + public string IsNullable { get; internal set; } + public string IsAutoIncrement { get; internal set; } + public string ColumnSize { get; set; } + public string DecimalDigits { get; set; } + } + + /// <summary> + /// The data type definitions based on the <see href="https://docs.oracle.com/en%2Fjava%2Fjavase%2F21%2Fdocs%2Fapi%2F%2F/java.sql/java/sql/Types.html">JDBC Types</see> constants. + /// </summary> + /// <remarks> + /// This enumeration can be used to determine the drivers specific data types that are contained in fields <c>xdbc_data_type</c> and <c>xdbc_sql_data_type</c> + /// in the column metadata <see cref="StandardSchemas.ColumnSchema"/>. This column metadata is returned as a result of a call to + /// <see cref="AdbcConnection.GetObjects(GetObjectsDepth, string?, string?, string?, IReadOnlyList{string}?, string?)"/> + /// when <c>depth</c> is set to <see cref="AdbcConnection.GetObjectsDepth.All"/>. + /// </remarks> + internal enum ColumnTypeId Review Comment: Consider factoring this into a top-level class and naming it something like `XdbcTypeCode`. It's more unwieldy to use as a nested class, and it's potentially useful from other drivers as well. (This could also be done as a followup.) ########## csharp/test/Drivers/Apache/Common/NumericValueTests.cs: ########## @@ -227,10 +228,7 @@ public async Task TestRoundingNumbers(decimal input, decimal output) [InlineData(1.234E+2)] [InlineData(double.NegativeInfinity)] [InlineData(double.PositiveInfinity)] - [InlineData(double.NaN)] - [InlineData(double.MinValue)] - [InlineData(double.MaxValue)] - public async Task TestDoubleValuesInsertSelectDelete(double value) + public virtual async Task TestDoubleValuesInsertSelectDelete(double value) Review Comment: Are the deleted test values simply not supported in Impala? That wouldn't be surprising for NaN or the infinities, but is somewhat surprising for `double.MinValue` and `double.MaxValue`. ########## csharp/test/Drivers/Apache/Common/DriverTests.cs: ########## @@ -50,7 +50,7 @@ public abstract class DriverTests<TConfig, TEnv> : TestBase<TConfig, TEnv> /// <summary> /// Supported Spark data types as a subset of <see cref="ColumnTypeId"/> /// </summary> - private enum SupportedSparkDataType : short + internal enum SupportedSparkDataType : short Review Comment: Should the comments and type names in the base class be referencing Spark explicitly when they're used for both Spark and Impala? ########## csharp/test/Drivers/Apache/Impala/Resources/ImpalaData.sql: ########## @@ -16,118 +16,122 @@ DROP TABLE IF EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE}; +-- Note: +-- Impala supports complext type (ARRAY, MAP, STRUCT), BUT, Review Comment: typo; consider `complex types` It's not possible to use `get_json_object` to load structured data? ########## csharp/test/Drivers/Apache/Impala/DriverTests.cs: ########## @@ -0,0 +1,132 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using Apache.Arrow.Adbc.Drivers.Apache.Impala; +using Apache.Arrow.Adbc.Tests.Metadata; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Impala +{ + public class DriverTests : Common.DriverTests<ApacheTestConfiguration, ImpalaTestEnvironment> + { + public DriverTests(ITestOutputHelper? outputHelper) + : base(outputHelper, new ImpalaTestEnvironment.Factory()) + { + } + + [SkippableTheory] + [MemberData(nameof(CatalogNamePatternData))] + public override void CanGetObjectsCatalogs(string? pattern) + { + GetObjectsCatalogsTest(pattern); + } + + [SkippableTheory] + [MemberData(nameof(DbSchemasNamePatternData))] + public override void CanGetObjectsDbSchemas(string dbSchemaPattern) + { + GetObjectsDbSchemasTest(dbSchemaPattern); + } + + [SkippableTheory] + [MemberData(nameof(TableNamePatternData))] + public override void CanGetObjectsTables(string tableNamePattern) + { + GetObjectsTablesTest(tableNamePattern); + } + + public override void CanDetectInvalidServer() + { + AdbcDriver driver = NewDriver; + Assert.NotNull(driver); + Dictionary<string, string> parameters = GetDriverParameters(TestConfiguration); + + bool hasUri = parameters.TryGetValue(AdbcOptions.Uri, out var uri) && !string.IsNullOrEmpty(uri); + bool hasHostName = parameters.TryGetValue(ImpalaParameters.HostName, out var hostName) && !string.IsNullOrEmpty(hostName); + if (hasUri) + { + parameters[AdbcOptions.Uri] = "http://unknownhost.azure.com/cliservice"; + } + else if (hasHostName) + { + parameters[ImpalaParameters.HostName] = "unknownhost.azure.com"; + } + else + { + Assert.Fail($"Unexpected configuration. Must provide '{AdbcOptions.Uri}' or '{ImpalaParameters.HostName}'."); + } + + AdbcDatabase database = driver.Open(parameters); + AggregateException exception = Assert.ThrowsAny<AggregateException>(() => database.Connect(parameters)); + OutputHelper?.WriteLine(exception.Message); + } + + public override void CanDetectInvalidAuthentication() + { + AdbcDriver driver = NewDriver; + Assert.NotNull(driver); + Dictionary<string, string> parameters = GetDriverParameters(TestConfiguration); + + bool hasUsername = parameters.TryGetValue(AdbcOptions.Username, out var username) && !string.IsNullOrEmpty(username); + bool hasPassword = parameters.TryGetValue(AdbcOptions.Password, out var password) && !string.IsNullOrEmpty(password); + if (hasUsername && hasPassword) + { + parameters[AdbcOptions.Password] = "invalid-password"; + } + else + { + Assert.Fail($"Unexpected configuration. Must provide '{AdbcOptions.Username}' and '{AdbcOptions.Password}'."); + } + + AdbcDatabase database = driver.Open(parameters); + AggregateException exception = Assert.ThrowsAny<AggregateException>(() => database.Connect(parameters)); + OutputHelper?.WriteLine(exception.Message); + } + + protected override IReadOnlyList<int> GetUpdateExpectedResults() + { + int affectedRows = ValidateAffectedRows ? 1 : -1; + return ClientTests.GetUpdateExpectedResults(affectedRows); + } + + Review Comment: nit: extra blank line ########## csharp/test/Drivers/Apache/Common/ComplexTypesValueTests.cs: ########## @@ -28,7 +28,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Common /// <summary> /// Validates that specific complex structured types can be inserted, retrieved and targeted correctly /// </summary> - public class ComplexTypesValueTests<TConfig, TEnv> : TestBase<TConfig, TEnv> + public abstract class ComplexTypesValueTests<TConfig, TEnv> : TestBase<TConfig, TEnv> Review Comment: There's nothing which derives from this. Won't that prevent this class from being instantiated and tested (now that it's abstract)? ########## csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs: ########## @@ -208,5 +707,616 @@ internal static async Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync( TGetResultSetMetadataResp response = await client.GetResultSetMetadata(request, cancellationToken); return response; } + + /// <summary> + /// Gets the data-source specific columns names for the GetColumns metadata result. + /// </summary> + /// <returns></returns> + protected abstract ColumnsMetadataColumnNames GetColumnsMetadataColumnNames(); + + /// <summary> + /// Gets the default product version + /// </summary> + /// <returns></returns> + protected abstract string GetProductVersionDefault(); + + /// <summary> + /// Gets the current product version. + /// </summary> + /// <returns></returns> + protected internal string GetProductVersion() + { + FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location); + return fileVersionInfo.ProductVersion ?? GetProductVersionDefault(); + } + + protected static Uri GetBaseAddress(string? uri, string? hostName, string? path, string? port) + { + // Uri property takes precedent. + if (!string.IsNullOrWhiteSpace(uri)) + { + var uriValue = new Uri(uri); + if (uriValue.Scheme != Uri.UriSchemeHttp && uriValue.Scheme != Uri.UriSchemeHttps) + throw new ArgumentOutOfRangeException( + AdbcOptions.Uri, + uri, + $"Unsupported scheme '{uriValue.Scheme}'"); + return uriValue; + } + + bool isPortSet = !string.IsNullOrEmpty(port); + bool isValidPortNumber = int.TryParse(port, out int portNumber) && portNumber > 0; + bool isDefaultHttpsPort = !isPortSet || (isValidPortNumber && portNumber == 443); + string uriScheme = isDefaultHttpsPort ? Uri.UriSchemeHttps : Uri.UriSchemeHttp; + int uriPort; + if (!isPortSet) + uriPort = -1; + else if (isValidPortNumber) + uriPort = portNumber; + else + throw new ArgumentOutOfRangeException(nameof(port), portNumber, $"Port number is not in a valid range."); + + Uri baseAddress = new UriBuilder(uriScheme, hostName, uriPort, path).Uri; + return baseAddress; + } + + // Note data source's Position may be one-indexed or zero-indexed + protected IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns + .Select(t => new { Index = t.Position - PositionRequiredOffset, t.ColumnName }) + .ToDictionary(t => t.ColumnName, t => t.Index); + + protected abstract Task<TRowSet> GetRowSetAsync(TGetTableTypesResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetCatalogsResp getCatalogsResp, CancellationToken cancellationToken = default); + protected abstract Task<TRowSet> GetRowSetAsync(TGetSchemasResp getSchemasResp, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken cancellationToken = default); + protected abstract Task<TGetResultSetMetadataResp> GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken cancellationToken = default); + + protected abstract bool AreResultsAvailableDirectly(); + + protected abstract TSparkGetDirectResults GetDirectResults(); + + protected internal abstract int PositionRequiredOffset { get; } + + protected abstract string InfoDriverName { get; } + + protected abstract string InfoDriverArrowVersion { get; } + + protected abstract string ProductVersion { get; } + + protected abstract bool GetObjectsPatternsRequireLowerCase { get; } + + protected abstract bool IsColumnSizeValidForDecimal { get; } + + private static string PatternToRegEx(string? pattern) + { + if (pattern == null) + return ".*"; + + StringBuilder builder = new StringBuilder("(?i)^"); + string convertedPattern = pattern.Replace("_", ".").Replace("%", ".*"); + builder.Append(convertedPattern); + builder.Append('$'); + + return builder.ToString(); + } + + private static StructArray GetDbSchemas( + GetObjectsDepth depth, + Dictionary<string, Dictionary<string, TableInfo>> schemaMap) + { + StringArray.Builder dbSchemaNameBuilder = new StringArray.Builder(); + List<IArrowArray?> dbSchemaTablesValues = new List<IArrowArray?>(); + ArrowBuffer.BitmapBuilder nullBitmapBuffer = new ArrowBuffer.BitmapBuilder(); + int length = 0; + + Review Comment: nit: extra blank line. These two methods (`GetDbSchemas` and `GetTableSchemas`) are full of duplicate and extra blank lines. Please clean them up. ########## csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs: ########## @@ -128,12 +360,279 @@ internal async Task OpenAsync() public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList<string>? tableTypes, string? columnNamePattern) { - throw new NotImplementedException(); + Dictionary<string, Dictionary<string, Dictionary<string, TableInfo>>> catalogMap = new Dictionary<string, Dictionary<string, Dictionary<string, TableInfo>>>(); + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + if (GetObjectsPatternsRequireLowerCase) + { + catalogPattern = catalogPattern?.ToLower(); + dbSchemaPattern = dbSchemaPattern?.ToLower(); + tableNamePattern = tableNamePattern?.ToLower(); + columnNamePattern = columnNamePattern?.ToLower(); + } + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) + { + TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); + if (AreResultsAvailableDirectly()) + { + getCatalogsReq.GetDirectResults = GetDirectResults(); + } + + TGetCatalogsResp getCatalogsResp = Client.GetCatalogs(getCatalogsReq, cancellationToken).Result; + + if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getCatalogsResp.Status.ErrorMessage); + } + var catalogsMetadata = GetResultSetMetadataAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(catalogsMetadata.Schema.Columns); + + string catalogRegexp = PatternToRegEx(catalogPattern); + TRowSet rowSet = GetRowSetAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyList<string> list = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + for (int i = 0; i < list.Count; i++) + { + string col = list[i]; + string catalog = col; + + if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) + { + catalogMap.Add(catalog, new Dictionary<string, Dictionary<string, TableInfo>>()); + } + } + // Handle the case where server does not support 'catalog' in the namespace. + if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern)) + { + catalogMap.Add(string.Empty, []); + } + } + + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) + { + TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); + getSchemasReq.CatalogName = catalogPattern; + getSchemasReq.SchemaName = dbSchemaPattern; + if (AreResultsAvailableDirectly()) + { + getSchemasReq.GetDirectResults = GetDirectResults(); + } + + TGetSchemasResp getSchemasResp = Client.GetSchemas(getSchemasReq, cancellationToken).Result; + if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getSchemasResp.Status.ErrorMessage); + } + + TGetResultSetMetadataResp schemaMetadata = GetResultSetMetadataAsync(getSchemasResp, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(schemaMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getSchemasResp, cancellationToken).Result; + + IReadOnlyList<string> catalogList = rowSet.Columns[columnMap[TableCatalog]].StringVal.Values; + IReadOnlyList<string> schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). + catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary<string, TableInfo>()); + } + } + + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) + { + TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle); + getTablesReq.CatalogName = catalogPattern; + getTablesReq.SchemaName = dbSchemaPattern; + getTablesReq.TableName = tableNamePattern; + if (AreResultsAvailableDirectly()) + { + getTablesReq.GetDirectResults = GetDirectResults(); + } + + TGetTablesResp getTablesResp = Client.GetTables(getTablesReq, cancellationToken).Result; + if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getTablesResp.Status.ErrorMessage); + } + + TGetResultSetMetadataResp tableMetadata = GetResultSetMetadataAsync(getTablesResp, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(tableMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getTablesResp, cancellationToken).Result; + + IReadOnlyList<string> catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + IReadOnlyList<string> schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + IReadOnlyList<string> tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; + IReadOnlyList<string> tableTypeList = rowSet.Columns[columnMap[TableType]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string tableType = tableTypeList[i]; + TableInfo tableInfo = new(tableType); + catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); + } + } + + if (depth == GetObjectsDepth.All) + { + TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle); + columnsReq.CatalogName = catalogPattern; + columnsReq.SchemaName = dbSchemaPattern; + columnsReq.TableName = tableNamePattern; + if (AreResultsAvailableDirectly()) + { + columnsReq.GetDirectResults = GetDirectResults(); + } + + if (!string.IsNullOrEmpty(columnNamePattern)) + columnsReq.ColumnName = columnNamePattern; + + var columnsResponse = Client.GetColumns(columnsReq, cancellationToken).Result; + if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(columnsResponse.Status.ErrorMessage); + } + + TGetResultSetMetadataResp columnsMetadata = GetResultSetMetadataAsync(columnsResponse, cancellationToken).Result; + IReadOnlyDictionary<string, int> columnMap = GetColumnIndexMap(columnsMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(columnsResponse, cancellationToken).Result; + + ColumnsMetadataColumnNames columnNames = GetColumnsMetadataColumnNames(); + IReadOnlyList<string> catalogList = rowSet.Columns[columnMap[columnNames.TableCatalog]].StringVal.Values; + IReadOnlyList<string> schemaList = rowSet.Columns[columnMap[columnNames.TableSchema]].StringVal.Values; + IReadOnlyList<string> tableList = rowSet.Columns[columnMap[columnNames.TableName]].StringVal.Values; + IReadOnlyList<string> columnNameList = rowSet.Columns[columnMap[columnNames.ColumnName]].StringVal.Values; + ReadOnlySpan<int> columnTypeList = rowSet.Columns[columnMap[columnNames.DataType]].I32Val.Values.Values; + IReadOnlyList<string> typeNameList = rowSet.Columns[columnMap[columnNames.TypeName]].StringVal.Values; + ReadOnlySpan<int> nullableList = rowSet.Columns[columnMap[columnNames.Nullable]].I32Val.Values.Values; + IReadOnlyList<string> columnDefaultList = rowSet.Columns[columnMap[columnNames.ColumnDef]].StringVal.Values; + ReadOnlySpan<int> ordinalPosList = rowSet.Columns[columnMap[columnNames.OrdinalPosition]].I32Val.Values.Values; + IReadOnlyList<string> isNullableList = rowSet.Columns[columnMap[columnNames.IsNullable]].StringVal.Values; + IReadOnlyList<string> isAutoIncrementList = rowSet.Columns[columnMap[columnNames.IsAutoIncrement]].StringVal.Values; + ReadOnlySpan<int> columnSizeList = rowSet.Columns[columnMap[columnNames.ColumnSize]].I32Val.Values.Values; + ReadOnlySpan<int> decimalDigitsList = rowSet.Columns[columnMap[columnNames.DecimalDigits]].I32Val.Values.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + // For systems that don't support 'catalog' in the namespace + string catalog = catalogList[i] ?? string.Empty; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string columnName = columnNameList[i]; + short colType = (short)columnTypeList[i]; + string typeName = typeNameList[i]; + short nullable = (short)nullableList[i]; + string? isAutoIncrementString = isAutoIncrementList[i]; + bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); + string isNullable = isNullableList[i] ?? "YES"; + string columnDefault = columnDefaultList[i] ?? ""; + // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed + int ordinalPos = ordinalPosList[i] + PositionRequiredOffset; + int columnSize = columnSizeList[i]; + int decimalDigits = decimalDigitsList[i]; + TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); + tableInfo?.ColumnName.Add(columnName); + tableInfo?.ColType.Add(colType); + tableInfo?.Nullable.Add(nullable); + tableInfo?.IsAutoIncrement.Add(isAutoIncrement); + tableInfo?.IsNullable.Add(isNullable); + tableInfo?.ColumnDefault.Add(columnDefault); + tableInfo?.OrdinalPosition.Add(ordinalPos); + SetPrecisionScaleAndTypeName(colType, typeName, tableInfo, columnSize, decimalDigits); + } + } + + StringArray.Builder catalogNameBuilder = new StringArray.Builder(); + List<IArrowArray?> catalogDbSchemasValues = new List<IArrowArray?>(); + + foreach (KeyValuePair<string, Dictionary<string, Dictionary<string, TableInfo>>> catalogEntry in catalogMap) + { + catalogNameBuilder.Append(catalogEntry.Key); + + if (depth == GetObjectsDepth.Catalogs) + { + catalogDbSchemasValues.Add(null); + } + else + { + catalogDbSchemasValues.Add(GetDbSchemas( + depth, catalogEntry.Value)); + } + } + + Schema schema = StandardSchemas.GetObjectsSchema; + IReadOnlyList<IArrowArray> dataArrays = schema.Validate( + new List<IArrowArray> + { + catalogNameBuilder.Build(), + catalogDbSchemasValues.BuildListArrayForType(new StructType(StandardSchemas.DbSchemaSchema)), + }); + + return new HiveInfoArrowStream(schema, dataArrays); + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) + { + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); + } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); + } } public override IArrowArrayStream GetTableTypes() { - throw new NotImplementedException(); + TGetTableTypesReq req = new() + { + SessionHandle = SessionHandle ?? throw new InvalidOperationException("session not created"), + }; + + if (AreResultsAvailableDirectly()) + { + req.GetDirectResults = GetDirectResults(); + } + + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try + { + TGetTableTypesResp resp = Client.GetTableTypes(req, cancellationToken).Result; + + if (resp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new HiveServer2Exception(resp.Status.ErrorMessage) + .SetNativeError(resp.Status.ErrorCode) + .SetSqlState(resp.Status.SqlState); + } + + TRowSet rowSet = GetRowSetAsync(resp, cancellationToken).Result; + StringArray tableTypes = rowSet.Columns[0].StringVal.Values; + + StringArray.Builder tableTypesBuilder = new StringArray.Builder(); + tableTypesBuilder.AppendRange(tableTypes); + + IArrowArray[] dataArrays = new IArrowArray[] + { + tableTypesBuilder.Build() + }; + + return new HiveInfoArrowStream(StandardSchemas.TableTypesSchema, dataArrays); + } + catch (Exception ex) Review Comment: It looks like the same exception logic appears four separate times. Can these be factored into a single helper that's called as an exception filter? ########## csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs: ########## @@ -0,0 +1,219 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Net.Security; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Ipc; +using Apache.Hive.Service.Rpc.Thrift; +using Thrift; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Impala +{ + internal class ImpalaHttpConnection : ImpalaConnection + { + private const string BasicAuthenticationScheme = "Basic"; + + public ImpalaHttpConnection(IReadOnlyDictionary<string, string> properties) : base(properties) + { + } + + protected override void ValidateAuthentication() + { + // Validate authentication parameters + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(ImpalaParameters.AuthType, out string? authType); + bool isValidAuthType = ImpalaAuthTypeParser.TryParse(authType, out ImpalaAuthType authTypeValue); + switch (authTypeValue) + { + case ImpalaAuthType.Basic: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameter '{ImpalaParameters.AuthType}' is set to '{ImpalaAuthTypeConstants.Basic}' but parameters '{AdbcOptions.Username}' or '{AdbcOptions.Password}' are not set. Please provide a values for these parameters.", + nameof(Properties)); + break; + case ImpalaAuthType.UsernameOnly: + if (string.IsNullOrWhiteSpace(username)) + throw new ArgumentException( + $"Parameter '{ImpalaParameters.AuthType}' is set to '{ImpalaAuthTypeConstants.UsernameOnly}' but parameter '{AdbcOptions.Username}' is not set. Please provide a values for this parameter.", + nameof(Properties)); + break; + case ImpalaAuthType.None: + break; + case ImpalaAuthType.Empty: + if (string.IsNullOrWhiteSpace(username) || string.IsNullOrWhiteSpace(password)) + throw new ArgumentException( + $"Parameters must include valid authentiation settings. Please provide '{AdbcOptions.Username}' and '{AdbcOptions.Password}'.", + nameof(Properties)); + break; + default: + throw new ArgumentOutOfRangeException(ImpalaParameters.AuthType, authType, $"Unsupported {ImpalaParameters.AuthType} value."); + } + } + + protected override void ValidateConnection() + { + // HostName or Uri is required parameter + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + Properties.TryGetValue(ImpalaParameters.HostName, out string? hostName); + if ((Uri.CheckHostName(hostName) == UriHostNameType.Unknown) + && (string.IsNullOrEmpty(uri) || !Uri.TryCreate(uri, UriKind.Absolute, out Uri? _))) + { + throw new ArgumentException( + $"Required parameter '{ImpalaParameters.HostName}' or '{AdbcOptions.Uri}' is missing or invalid. Please provide a valid hostname or URI for the data source.", + nameof(Properties)); + } + + // Validate port range + Properties.TryGetValue(ImpalaParameters.Port, out string? port); + if (int.TryParse(port, out int portNumber) && (portNumber <= IPEndPoint.MinPort || portNumber > IPEndPoint.MaxPort)) + throw new ArgumentOutOfRangeException( + nameof(Properties), + port, + $"Parameter '{ImpalaParameters.Port}' value is not in the valid range of 1 .. {IPEndPoint.MaxPort}."); + + // Ensure the parameters will produce a valid address + Properties.TryGetValue(ImpalaParameters.Path, out string? path); + _ = new HttpClient() + { + BaseAddress = GetBaseAddress(uri, hostName, path, port) + }; + } + + protected override void ValidateOptions() + { + Properties.TryGetValue(ImpalaParameters.DataTypeConv, out string? dataTypeConv); + DataTypeConversion = DataTypeConversionParser.Parse(dataTypeConv); + Properties.TryGetValue(ImpalaParameters.TLSOptions, out string? tlsOptions); + TlsOptions = TlsOptionsParser.Parse(tlsOptions); + Properties.TryGetValue(ImpalaParameters.ConnectTimeoutMilliseconds, out string? connectTimeoutMs); + if (connectTimeoutMs != null) + { + ConnectTimeoutMilliseconds = int.TryParse(connectTimeoutMs, NumberStyles.Integer, CultureInfo.InvariantCulture, out int connectTimeoutMsValue) && (connectTimeoutMsValue >= 0) + ? connectTimeoutMsValue + : throw new ArgumentOutOfRangeException(ImpalaParameters.ConnectTimeoutMilliseconds, connectTimeoutMs, $"must be a value of 0 (infinite) or between 1 .. {int.MaxValue}. default is 30000 milliseconds."); + } + } + + internal override IArrowArrayStream NewReader<T>(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); + + protected override TTransport CreateTransport() + { + // Assumption: parameters have already been validated. + Properties.TryGetValue(ImpalaParameters.HostName, out string? hostName); + Properties.TryGetValue(ImpalaParameters.Path, out string? path); + Properties.TryGetValue(ImpalaParameters.Port, out string? port); + Properties.TryGetValue(ImpalaParameters.AuthType, out string? authType); + bool isValidAuthType = ImpalaAuthTypeParser.TryParse(authType, out ImpalaAuthType authTypeValue); + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + + Uri baseAddress = GetBaseAddress(uri, hostName, path, port); + AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, username, password); + + HttpClientHandler httpClientHandler = NewHttpClientHandler(); + HttpClient httpClient = new(httpClientHandler); + httpClient.BaseAddress = baseAddress; + httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue; + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent); + httpClient.DefaultRequestHeaders.AcceptEncoding.Clear(); + httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity")); + httpClient.DefaultRequestHeaders.ExpectContinue = false; + + TConfiguration config = new(); + ThriftHttpTransport transport = new(httpClient, config) + { + // This value can only be set before the first call/request. So if a new value for query timeout + // is set, we won't be able to update the value. Setting to ~infinite and relying on cancellation token + // to ensure cancelled correctly. + ConnectTimeout = int.MaxValue, + }; + return transport; + } + + private HttpClientHandler NewHttpClientHandler() + { + HttpClientHandler httpClientHandler = new(); + if (TlsOptions != HiveServer2TlsOption.Empty) + { + httpClientHandler.ServerCertificateCustomValidationCallback = (request, certificate, chain, policyErrors) => + { + if (policyErrors == SslPolicyErrors.None) return true; + + return + (!policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors) || TlsOptions.HasFlag(HiveServer2TlsOption.AllowSelfSigned)) + && (!policyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch) || TlsOptions.HasFlag(HiveServer2TlsOption.AllowHostnameMismatch)); + }; + } + + return httpClientHandler; + } + + private static AuthenticationHeaderValue? GetAuthenticationHeaderValue(ImpalaAuthType authType, string? username, string? password) + { + if (!string.IsNullOrEmpty(username) && !string.IsNullOrEmpty(password) && (authType == ImpalaAuthType.Empty || authType == ImpalaAuthType.Basic)) + { + return new AuthenticationHeaderValue(BasicAuthenticationScheme, Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:{password}"))); + } + else if (!string.IsNullOrEmpty(username) && (authType == ImpalaAuthType.Empty || authType == ImpalaAuthType.UsernameOnly)) + { + return new AuthenticationHeaderValue(BasicAuthenticationScheme, Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:"))); + } + else if (authType == ImpalaAuthType.None) + { + return null; + } + else + { + throw new AdbcException("Missing connection properties. Must contain 'username' and 'password'"); + } + } + + protected override async Task<TProtocol> CreateProtocolAsync(TTransport transport, CancellationToken cancellationToken = default) + { + Trace.TraceError($"create protocol with {Properties.Count} properties."); Review Comment: The decision to trace here and in ImpalaStandardConnection seems a bit arbitrary. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
