toddmeng-db commented on code in PR #2743: URL: https://github.com/apache/arrow-adbc/pull/2743#discussion_r2062908355
########## csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsService.cs: ########## @@ -0,0 +1,231 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth +{ + /// <summary> + /// Service for obtaining OAuth access tokens using the client credentials grant type. + /// </summary> + internal class OAuthClientCredentialsService : IDisposable + { + private readonly Lazy<HttpClient> _httpClient; + private readonly string _clientId; + private readonly string _clientSecret; + private readonly Uri _baseUri; + private readonly string _tokenEndpoint; + private readonly int _timeoutMinutes; + private readonly SemaphoreSlim _tokenLock = new SemaphoreSlim(1, 1); + private TokenInfo? _cachedToken; + + private class TokenInfo + { + public string? AccessToken { get; set; } + public DateTime ExpiresAt { get; set; } + + public bool IsExpired => DateTime.UtcNow >= ExpiresAt; + + // Add buffer time to refresh token before actual expiration + public bool NeedsRefresh => DateTime.UtcNow >= ExpiresAt.AddMinutes(-5); + } + + /// <summary> + /// Initializes a new instance of the <see cref="OAuthClientCredentialsService"/> class. + /// </summary> + /// <param name="clientId">The OAuth client ID.</param> + /// <param name="clientSecret">The OAuth client secret.</param> + /// <param name="baseUri">The base URI of the Databricks workspace.</param> + public OAuthClientCredentialsService( + string clientId, + string clientSecret, + Uri baseUri, + int timeoutMinutes = 1, + HttpClient? httpClient = null) + { + _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret)); + _baseUri = baseUri ?? throw new ArgumentNullException(nameof(baseUri)); + _timeoutMinutes = timeoutMinutes; + _tokenEndpoint = DetermineTokenEndpoint(); + + _httpClient = httpClient != null + ? new Lazy<HttpClient>(() => httpClient) + : new Lazy<HttpClient>(() => + { + var client = new HttpClient(); + client.Timeout = TimeSpan.FromMinutes(_timeoutMinutes); + return client; + }); + } + + private HttpClient HttpClient => _httpClient.Value; + + private string DetermineTokenEndpoint() + { + // For workspace URLs, the token endpoint is always /oidc/v1/token + // TODO: Might be different for Azure AAD SPs + return $"{_baseUri.Scheme}://{_baseUri.Host}/oidc/v1/token"; + } + + private string? GetValidCachedToken() + { + return _cachedToken != null && !_cachedToken.NeedsRefresh && _cachedToken.AccessToken != null + ? _cachedToken.AccessToken + : null; + } + + + private async Task<string> RefreshTokenInternalAsync(CancellationToken cancellationToken) + { + var request = CreateTokenRequest(); + + HttpResponseMessage response; + try + { + response = await HttpClient.SendAsync(request, cancellationToken); + response.EnsureSuccessStatusCode(); + } + catch (HttpRequestException ex) + { + throw new DatabricksException($"Failed to acquire OAuth access token: {ex.Message}", ex); + } + + string content = await response.Content.ReadAsStringAsync(); + + try + { + _cachedToken = ParseTokenResponse(content); + return _cachedToken.AccessToken!; + } + catch (JsonException ex) + { + throw new DatabricksException($"Failed to parse OAuth response: {ex.Message}", ex); + } + } + + private HttpRequestMessage CreateTokenRequest() + { + var requestContent = new FormUrlEncodedContent(new[] + { + new KeyValuePair<string, string>("grant_type", "client_credentials"), + new KeyValuePair<string, string>("scope", "sql") + }); + + var request = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint) + { + Content = requestContent + }; + + // Use Basic Auth with client ID and secret + var authHeader = Convert.ToBase64String( + System.Text.Encoding.ASCII.GetBytes($"{_clientId}:{_clientSecret}")); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", authHeader); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + + return request; + } + + private TokenInfo ParseTokenResponse(string content) + { + using var jsonDoc = JsonDocument.Parse(content); + + if (!jsonDoc.RootElement.TryGetProperty("access_token", out var accessTokenElement)) + { + throw new DatabricksException("OAuth response did not contain an access_token"); + } + + string? accessToken = accessTokenElement.GetString(); + if (string.IsNullOrEmpty(accessToken)) + { + throw new DatabricksException("OAuth access_token was null or empty"); + } + + // Get expiration time from response + if (!jsonDoc.RootElement.TryGetProperty("expires_in", out var expiresInElement)) + { + throw new DatabricksException("OAuth response did not contain expires_in"); + } + + int expiresIn = expiresInElement.GetInt32(); + if (expiresIn <= 0) + { + throw new DatabricksException("OAuth expires_in value must be positive"); + } + + return new TokenInfo + { + AccessToken = accessToken!, + ExpiresAt = DateTime.UtcNow.AddSeconds(expiresIn) + }; + } + + /// <summary> + /// Gets an OAuth access token using the client credentials grant type. + /// </summary> + /// <param name="cancellationToken">A cancellation token to cancel the operation.</param> + /// <returns>The access token.</returns> + /// <exception cref="DatabricksException">Thrown when the token request fails or the response is invalid.</exception> + public async Task<string> GetAccessTokenAsync(CancellationToken cancellationToken = default) Review Comment: Seems like it's a bit awkward to implement a non-async method when there's an http call involved in c#? ``` public Credentials GetCredentials() { var response = Task.Run(() => _httpClient.GetAsync(...)) .Unwrap() .GetAwaiter() .GetResult(); // process response... return credentials; } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org