0AyanamiRei commented on code in PR #60044:
URL: https://github.com/apache/doris/pull/60044#discussion_r2791075032


##########
be/src/runtime/routine_load/aws_msk_iam_auth.cpp:
##########
@@ -0,0 +1,492 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "runtime/routine_load/aws_msk_iam_auth.h"
+
+#include <aws/core/auth/AWSCredentials.h>
+#include <aws/core/auth/AWSCredentialsProvider.h>
+#include <aws/core/auth/AWSCredentialsProviderChain.h>
+#include <aws/core/auth/STSCredentialsProvider.h>
+#include <aws/core/platform/Environment.h>
+#include <aws/identity-management/auth/STSAssumeRoleCredentialsProvider.h>
+#include <aws/sts/STSClient.h>
+#include <aws/sts/model/AssumeRoleRequest.h>
+#include <openssl/hmac.h>
+#include <openssl/sha.h>
+
+#include <algorithm>
+#include <chrono>
+#include <iomanip>
+#include <sstream>
+
+#include "common/logging.h"
+
+namespace doris {
+
+AwsMskIamAuth::AwsMskIamAuth(Config config) : _config(std::move(config)) {
+    _credentials_provider = _create_credentials_provider();
+}
+
+std::shared_ptr<Aws::Auth::AWSCredentialsProvider> 
AwsMskIamAuth::_create_credentials_provider() {

Review Comment:
   Currently, the AWS credentials are obtained in s3_util, but this is not 
suitable for use in the import section. Temporarily, a separate credential 
retrieval function will be added in aws_msk_iam_auth.
   
   获取aws凭证目前在s3_util中有, 但是并不适合在导入这边使用, 暂时在aws_msk_iam_auth中添加单独的凭证获取函数



##########
be/src/runtime/routine_load/aws_msk_iam_auth.cpp:
##########
@@ -0,0 +1,437 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "runtime/routine_load/aws_msk_iam_auth.h"
+
+#include <aws/core/auth/AWSCredentials.h>
+#include <aws/core/auth/AWSCredentialsProvider.h>
+#include <aws/core/auth/AWSCredentialsProviderChain.h>
+#include <aws/core/auth/STSCredentialsProvider.h>
+#include <aws/identity-management/auth/STSAssumeRoleCredentialsProvider.h>
+#include <aws/sts/STSClient.h>
+#include <aws/sts/model/AssumeRoleRequest.h>
+#include <openssl/hmac.h>
+#include <openssl/sha.h>
+
+#include <chrono>
+#include <iomanip>
+#include <sstream>
+
+#include "common/logging.h"
+
+namespace doris {
+
+AwsMskIamAuth::AwsMskIamAuth(Config config) : _config(std::move(config)) {
+    _credentials_provider = _create_credentials_provider();
+}
+
+std::shared_ptr<Aws::Auth::AWSCredentialsProvider> 
AwsMskIamAuth::_create_credentials_provider() {
+    // Only two authentication methods are supported:
+    // 1. Explicit AK/SK (if access_key and secret_key are provided)
+    // 2. Assume Role (if role_arn is specified)
+
+    // 1. Explicit AK/SK credentials
+    if (!_config.access_key.empty() && !_config.secret_key.empty()) {
+        LOG(INFO) << "Using explicit AWS credentials (Access Key ID: "
+                  << _config.access_key.substr(0, 4) << "****)";
+
+        Aws::Auth::AWSCredentials credentials(_config.access_key.c_str(),
+                                              _config.secret_key.c_str());
+
+        return 
std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(credentials);
+    }
+
+    // 2. Assume Role
+    if (!_config.role_arn.empty()) {
+        LOG(INFO) << "Using AWS STS Assume Role: " << _config.role_arn;
+
+        Aws::Client::ClientConfiguration client_config;
+        if (!_config.region.empty()) {
+            client_config.region = _config.region;
+        }
+
+        auto sts_client = std::make_shared<Aws::STS::STSClient>(
+                
std::make_shared<Aws::Auth::InstanceProfileCredentialsProvider>(), 
client_config);
+
+        return std::make_shared<Aws::Auth::STSAssumeRoleCredentialsProvider>(
+                _config.role_arn, Aws::String(), /* external_id */ 
Aws::String(),
+                Aws::Auth::DEFAULT_CREDS_LOAD_FREQ_SECONDS, sts_client);
+    }
+
+    // No valid credentials configuration found
+    LOG(ERROR) << "AWS MSK IAM authentication requires either: "
+               << "1) region + access_key + secret_key, or "
+               << "2) region + role_arn";
+    return nullptr;
+}
+
+Status AwsMskIamAuth::get_credentials(Aws::Auth::AWSCredentials* credentials) {
+    std::lock_guard<std::mutex> lock(_mutex);
+
+    // Refresh if needed
+    if (_should_refresh_credentials()) {
+        _cached_credentials = _credentials_provider->GetAWSCredentials();
+
+        if (_cached_credentials.GetAWSAccessKeyId().empty()) {
+            return Status::InternalError("Failed to get AWS credentials");
+        }
+
+        // Set expiry time (assume 1 hour for instance profile, or use the 
credentials expiration)
+        _credentials_expiry = std::chrono::system_clock::now() + 
std::chrono::hours(1);
+
+        LOG(INFO) << "Refreshed AWS credentials for MSK IAM authentication";
+    }
+
+    *credentials = _cached_credentials;
+    return Status::OK();
+}
+
+bool AwsMskIamAuth::_should_refresh_credentials() {
+    auto now = std::chrono::system_clock::now();
+    auto refresh_time =
+            _credentials_expiry - 
std::chrono::milliseconds(_config.token_refresh_margin_ms);
+    return now >= refresh_time || 
_cached_credentials.GetAWSAccessKeyId().empty();
+}
+
+Status AwsMskIamAuth::generate_token(const std::string& broker_hostname, 
std::string* token,
+                                     int64_t* token_lifetime_ms) {
+    Aws::Auth::AWSCredentials credentials;
+    RETURN_IF_ERROR(get_credentials(&credentials));
+
+    std::string timestamp = _get_timestamp();
+    std::string date_stamp = _get_date_stamp(timestamp);
+
+    // AWS MSK IAM token is a base64-encoded presigned URL
+    // Reference: https://github.com/aws/aws-msk-iam-sasl-signer-python
+
+    // Token expiry in seconds (900 seconds = 15 minutes, matching AWS MSK IAM 
signer reference)
+    static constexpr int TOKEN_EXPIRY_SECONDS = 900;
+
+    // Build the endpoint URL
+    std::string endpoint_url = "https://kafka."; + _config.region + 
".amazonaws.com/";
+
+    // Build credential scope
+    std::string credential_scope =
+            date_stamp + "/" + _config.region + "/kafka-cluster/aws4_request";
+
+    // Build the canonical query string (sorted alphabetically)
+    // IMPORTANT: All query parameters must be included in the signature 
calculation
+    // Session Token must be in canonical query string if using temporary 
credentials
+    std::stringstream canonical_query_ss;
+    canonical_query_ss << "Action=kafka-cluster%3AConnect"; // URL-encoded :
+
+    // Add Algorithm
+    canonical_query_ss << "&X-Amz-Algorithm=AWS4-HMAC-SHA256";
+
+    // Add Credential
+    std::string credential = std::string(credentials.GetAWSAccessKeyId()) + 
"/" + credential_scope;
+    canonical_query_ss << "&X-Amz-Credential=" << _url_encode(credential);
+
+    // Add Date
+    canonical_query_ss << "&X-Amz-Date=" << timestamp;
+
+    // Add Expires
+    canonical_query_ss << "&X-Amz-Expires=" << TOKEN_EXPIRY_SECONDS;
+
+    // Add Security Token if present (MUST be before signature calculation)
+    if (!credentials.GetSessionToken().empty()) {
+        canonical_query_ss << "&X-Amz-Security-Token="
+                           << 
_url_encode(std::string(credentials.GetSessionToken()));
+    }
+
+    // Add SignedHeaders
+    canonical_query_ss << "&X-Amz-SignedHeaders=host";
+
+    std::string canonical_query_string = canonical_query_ss.str();
+
+    // Build the canonical headers
+    std::string host = "kafka." + _config.region + ".amazonaws.com";
+    std::string canonical_headers = "host:" + host + "\n";
+    std::string signed_headers = "host";
+
+    // Build the canonical request
+    std::string method = "GET";
+    std::string uri = "/";
+    std::string payload_hash = _sha256("");
+
+    std::string canonical_request = method + "\n" + uri + "\n" + 
canonical_query_string + "\n" +
+                                    canonical_headers + "\n" + signed_headers 
+ "\n" + payload_hash;
+
+    // Build the string to sign
+    std::string algorithm = "AWS4-HMAC-SHA256";
+    std::string canonical_request_hash = _sha256(canonical_request);
+    std::string string_to_sign =
+            algorithm + "\n" + timestamp + "\n" + credential_scope + "\n" + 
canonical_request_hash;
+
+    // Calculate signature
+    std::string signing_key = 
_calculate_signing_key(std::string(credentials.GetAWSSecretKey()),
+                                                     date_stamp, 
_config.region, "kafka-cluster");
+    std::string signature = _hmac_sha256_hex(signing_key, string_to_sign);
+
+    // Build the final presigned URL
+    // All parameters are already in canonical_query_string, just add signature
+    // Then add User-Agent AFTER signature (not part of signed content, 
matching reference impl)
+    std::string signed_url = endpoint_url + "?" + canonical_query_string +
+                             "&X-Amz-Signature=" + signature +
+                             "&User-Agent=doris-msk-iam-auth%2F1.0";
+
+    // Base64url encode the signed URL (without padding)
+    *token = _base64url_encode(signed_url);
+
+    // Token lifetime in milliseconds
+    *token_lifetime_ms = TOKEN_EXPIRY_SECONDS * 1000;
+
+    LOG(INFO) << "Generated AWS MSK IAM token, presigned URL: " << signed_url;
+
+    return Status::OK();
+}
+
+std::string AwsMskIamAuth::_hmac_sha256_hex(const std::string& key, const 
std::string& data) {
+    std::string raw = _hmac_sha256(key, data);
+    std::stringstream ss;
+    for (unsigned char c : raw) {
+        ss << std::hex << std::setw(2) << std::setfill('0') << 
static_cast<int>(c);
+    }
+    return ss.str();
+}
+
+std::string AwsMskIamAuth::_url_encode(const std::string& value) {
+    std::ostringstream escaped;
+    escaped.fill('0');
+    escaped << std::hex;
+
+    for (char c : value) {
+        // Keep alphanumeric and other accepted characters intact
+        if (isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '_' || 
c == '.' ||
+            c == '~') {
+            escaped << c;
+        } else {
+            // Any other characters are percent-encoded
+            escaped << std::uppercase;
+            escaped << '%' << std::setw(2) << 
static_cast<int>(static_cast<unsigned char>(c));
+            escaped << std::nouppercase;
+        }
+    }
+
+    return escaped.str();
+}
+
+std::string AwsMskIamAuth::_base64url_encode(const std::string& input) {
+    // Standard base64 alphabet
+    static const char* base64_chars =
+            "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+
+    std::string result;
+    result.reserve(((input.size() + 2) / 3) * 4);
+
+    const unsigned char* bytes = reinterpret_cast<const unsigned 
char*>(input.c_str());
+    size_t len = input.size();
+
+    for (size_t i = 0; i < len; i += 3) {
+        uint32_t n = static_cast<uint32_t>(bytes[i]) << 16;
+        if (i + 1 < len) n |= static_cast<uint32_t>(bytes[i + 1]) << 8;
+        if (i + 2 < len) n |= static_cast<uint32_t>(bytes[i + 2]);
+
+        result += base64_chars[(n >> 18) & 0x3F];
+        result += base64_chars[(n >> 12) & 0x3F];
+        if (i + 1 < len) result += base64_chars[(n >> 6) & 0x3F];
+        if (i + 2 < len) result += base64_chars[n & 0x3F];
+    }
+
+    // Convert to URL-safe base64 (replace + with -, / with _)
+    // and remove padding (=)
+    for (char& c : result) {
+        if (c == '+')
+            c = '-';
+        else if (c == '/')
+            c = '_';
+    }
+
+    return result;
+}
+
+std::string AwsMskIamAuth::_calculate_signing_key(const std::string& 
secret_key,
+                                                  const std::string& 
date_stamp,
+                                                  const std::string& region,
+                                                  const std::string& service) {
+    std::string k_secret = "AWS4" + secret_key;
+    std::string k_date = _hmac_sha256(k_secret, date_stamp);
+    std::string k_region = _hmac_sha256(k_date, region);
+    std::string k_service = _hmac_sha256(k_region, service);
+    std::string k_signing = _hmac_sha256(k_service, "aws4_request");
+    return k_signing;
+}
+
+std::string AwsMskIamAuth::_hmac_sha256(const std::string& key, const 
std::string& data) {
+    unsigned char* digest;
+    digest = HMAC(EVP_sha256(), key.c_str(), static_cast<int>(key.length()),
+                  reinterpret_cast<const unsigned char*>(data.c_str()), 
data.length(), nullptr,
+                  nullptr);
+    return {reinterpret_cast<char*>(digest), SHA256_DIGEST_LENGTH};
+}
+
+std::string AwsMskIamAuth::_sha256(const std::string& data) {
+    unsigned char hash[SHA256_DIGEST_LENGTH];
+    SHA256(reinterpret_cast<const unsigned char*>(data.c_str()), 
data.length(), hash);
+
+    std::stringstream ss;
+    for (unsigned char i : hash) {
+        ss << std::hex << std::setw(2) << std::setfill('0') << (int)i;
+    }
+    return ss.str();
+}
+
+std::string AwsMskIamAuth::_get_timestamp() {
+    auto now = std::chrono::system_clock::now();
+    auto time_t_now = std::chrono::system_clock::to_time_t(now);
+    std::tm tm_now;
+    gmtime_r(&time_t_now, &tm_now);
+
+    std::stringstream ss;
+    ss << std::put_time(&tm_now, "%Y%m%dT%H%M%SZ");
+    return ss.str();
+}
+
+std::string AwsMskIamAuth::_get_date_stamp(const std::string& timestamp) {
+    // Extract YYYYMMDD from YYYYMMDDTHHMMSSz
+    return timestamp.substr(0, 8);
+}
+
+// AwsMskIamOAuthCallback implementation
+
+namespace {
+// Property keys for AWS MSK IAM authentication
+constexpr const char* PROP_SECURITY_PROTOCOL = "security.protocol";
+constexpr const char* PROP_SASL_MECHANISM = "sasl.mechanism";
+constexpr const char* PROP_AWS_REGION = "aws.region";
+constexpr const char* PROP_AWS_ACCESS_KEY = "aws.access.key";
+constexpr const char* PROP_AWS_SECRET_KEY = "aws.secret.key";
+constexpr const char* PROP_AWS_ROLE_ARN = "aws.msk.iam.role.arn";
+} // namespace
+
+std::unique_ptr<AwsMskIamOAuthCallback> 
AwsMskIamOAuthCallback::create_from_properties(
+        const std::unordered_map<std::string, std::string>& custom_properties,
+        const std::string& brokers) {
+    auto security_protocol_it = custom_properties.find(PROP_SECURITY_PROTOCOL);
+    auto sasl_mechanism_it = custom_properties.find(PROP_SASL_MECHANISM);
+
+    // Check if this is AWS MSK IAM authentication
+    // Conditions: security.protocol = SASL_SSL and sasl.mechanism = 
OAUTHBEARER
+    bool is_sasl_ssl = security_protocol_it != custom_properties.end() &&
+                       security_protocol_it->second == "SASL_SSL";
+    bool is_oauthbearer = sasl_mechanism_it != custom_properties.end() &&
+                          sasl_mechanism_it->second == "OAUTHBEARER";
+
+    if (!is_sasl_ssl || !is_oauthbearer) {
+        // Not AWS MSK IAM authentication
+        return nullptr;
+    }
+
+    // Extract broker hostname for token generation
+    std::string broker_hostname = brokers;

Review Comment:
   We can determine whether the host is a private endpoint or a public endpoint 
from the broker_hostname, thereby deciding whether AWS credentials should be 
obtained via InstanceProfile or through explicitly provided access key and 
secret key.
   
   
我们可以从broker_hostname中获取该host是私有端点还是公共端点,以此判断AWS凭证的获取是通过InstanceProfile还是通过显式传入的aksk。



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to