nsivarajan commented on code in PR #61329: URL: https://github.com/apache/doris/pull/61329#discussion_r2935553143
########## common/cpp/oss_credential_provider.cpp: ########## @@ -0,0 +1,456 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "cpp/oss_credential_provider.h" + +#ifdef USE_OSS + +#include <alibabacloud/oss/auth/Credentials.h> +#include <alibabacloud/oss/auth/CredentialsProvider.h> +#include <curl/curl.h> +#include <rapidjson/document.h> +#include <time.h> + +#include <alibabacloud/Sts20150401.hpp> +#include <alibabacloud/credentials/Client.hpp> +#include <alibabacloud/utils/models/Config.hpp> +#include <darabonba/Runtime.hpp> +#include <iomanip> +#include <sstream> +#include <stdexcept> + +#include "common/logging.h" + +namespace { +std::string mask_credential(const std::string& cred) { + if (cred.empty()) return ""; + size_t len = cred.length(); + if (len <= 8) { + if (len <= 4) return std::string(len, '*'); + return cred.substr(0, 2) + std::string(len - 4, '*') + cred.substr(len - 2); + } + return cred.substr(0, 4) + std::string(len - 8, '*') + cred.substr(len - 4); +} +} // namespace + +namespace doris { + +static size_t curl_write_callback(void* contents, size_t size, size_t nmemb, std::string* userp) { + size_t total_size = size * nmemb; + userp->append(static_cast<char*>(contents), total_size); + return total_size; +} + +// ---- ECSMetadataCredentialsProvider ---- + +ECSMetadataCredentialsProvider::ECSMetadataCredentialsProvider() + : _cached_credentials(nullptr), _expiration(std::chrono::system_clock::now()) { + LOG(INFO) << "ECSMetadataCredentialsProvider initialized"; +} + +bool ECSMetadataCredentialsProvider::_is_expired() const { + auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast<std::chrono::seconds>(_expiration - now).count() <= + REFRESH_BEFORE_EXPIRY_SECONDS; +} + +AlibabaCloud::OSS::Credentials ECSMetadataCredentialsProvider::getCredentials() { + { + std::lock_guard<std::mutex> lock(_mtx); + if (_cached_credentials && !_is_expired()) { + VLOG(2) << "Returning cached ECS metadata credentials"; + return *_cached_credentials; + } + if (_cached_credentials) { + auto t = std::chrono::system_clock::to_time_t(_expiration); + struct tm tm_buf; + LOG(INFO) << "ECS metadata credentials expiring (" + << std::put_time(localtime_r(&t, &tm_buf), "%Y-%m-%d %H:%M:%S") + << "), refreshing"; + } else { + LOG(INFO) << "Fetching ECS metadata credentials (first time)"; + } + } + + std::unique_ptr<AlibabaCloud::OSS::Credentials> new_credentials; + std::chrono::system_clock::time_point new_expiration; + + if (_fetch_credentials_outside_lock(new_credentials, new_expiration) != 0) { + std::lock_guard<std::mutex> lock(_mtx); + if (_cached_credentials) { + LOG(WARNING) << "Using expired ECS metadata credentials as fallback"; + return *_cached_credentials; + } + LOG(ERROR) << "Failed to fetch credentials from ECS metadata service and no cached " + "fallback available"; + return AlibabaCloud::OSS::Credentials("", "", ""); + } + + { + std::lock_guard<std::mutex> lock(_mtx); + if (_cached_credentials && !_is_expired()) { + return *_cached_credentials; + } + _cached_credentials = std::move(new_credentials); + _expiration = new_expiration; + auto t = std::chrono::system_clock::to_time_t(_expiration); + struct tm tm_buf; + LOG(INFO) << "ECS metadata credentials refreshed, expiry: " + << std::put_time(localtime_r(&t, &tm_buf), "%Y-%m-%d %H:%M:%S"); + return *_cached_credentials; + } +} + +int ECSMetadataCredentialsProvider::_http_get(const std::string& url, std::string& response) { + CURL* curl = curl_easy_init(); + if (!curl) { + LOG(ERROR) << "Failed to initialize CURL"; + return -1; + } + + response.clear(); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curl_write_callback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); + curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, METADATA_SERVICE_TIMEOUT_MS); + curl_easy_setopt(curl, CURLOPT_NOSIGNAL, 1L); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + LOG(ERROR) << "ECS metadata HTTP GET failed: " << curl_easy_strerror(res); + curl_easy_cleanup(curl); + return -1; + } + + long http_code = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + curl_easy_cleanup(curl); + + if (http_code != 200) { + LOG(ERROR) << "ECS metadata service returned HTTP " << http_code; + return -1; + } + return 0; +} + +int ECSMetadataCredentialsProvider::_get_role_name(std::string& role_name) { + std::string url = std::string("http://") + METADATA_SERVICE_HOST + METADATA_SERVICE_PATH; + std::string response; + if (_http_get(url, response) != 0) { + return -1; + } + + role_name = response; + role_name.erase(role_name.begin(), + std::find_if(role_name.begin(), role_name.end(), + [](unsigned char ch) { return !std::isspace(ch); })); + role_name.erase(std::find_if(role_name.rbegin(), role_name.rend(), + [](unsigned char ch) { return !std::isspace(ch); }) + .base(), + role_name.end()); + + if (role_name.empty()) { + LOG(ERROR) << "No RAM role attached to this ECS instance"; + return -1; + } + + size_t newline_pos = role_name.find('\n'); + if (newline_pos != std::string::npos) { + std::string all_roles = role_name; + role_name = role_name.substr(0, newline_pos); + LOG(WARNING) << "Multiple RAM roles found, using first: " << role_name + << " (all: " << all_roles << ")"; + } + + LOG(INFO) << "ECS RAM role: " << role_name; + return 0; +} + +int ECSMetadataCredentialsProvider::_fetch_credentials_outside_lock( + std::unique_ptr<AlibabaCloud::OSS::Credentials>& out_credentials, + std::chrono::system_clock::time_point& out_expiration) { + std::string role_name; + if (_get_role_name(role_name) != 0) { + return -1; + } + + std::string url = + std::string("http://") + METADATA_SERVICE_HOST + METADATA_SERVICE_PATH + role_name; + std::string response; + if (_http_get(url, response) != 0) { + return -1; + } + + rapidjson::Document doc; + doc.Parse(response.c_str()); + + if (doc.HasParseError()) { + LOG(ERROR) << "Failed to parse ECS metadata JSON response"; + return -1; + } + if (!doc.HasMember("Code") || std::string(doc["Code"].GetString()) != "Success") { + LOG(ERROR) << "ECS metadata error: " + << (doc.HasMember("Message") ? doc["Message"].GetString() : "unknown"); + return -1; + } + if (!doc.HasMember("AccessKeyId") || !doc.HasMember("AccessKeySecret") || + !doc.HasMember("SecurityToken") || !doc.HasMember("Expiration")) { + LOG(ERROR) << "ECS metadata response missing required fields"; + return -1; + } + + std::string ak = doc["AccessKeyId"].GetString(); + std::string sk = doc["AccessKeySecret"].GetString(); + std::string token = doc["SecurityToken"].GetString(); + std::string expiry_str = doc["Expiration"].GetString(); + + if (ak.empty() || sk.empty() || token.empty()) { + LOG(ERROR) << "ECS metadata returned empty credentials"; + return -1; + } + + std::tm tm = {}; + std::istringstream ss(expiry_str); + ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ"); + if (ss.fail()) { + LOG(ERROR) << "Failed to parse expiration from ECS metadata: " << expiry_str; + return -1; + } + + out_expiration = std::chrono::system_clock::from_time_t(timegm(&tm)); + out_credentials = std::make_unique<AlibabaCloud::OSS::Credentials>(ak, sk, token); + VLOG(1) << "ECS metadata credentials: ak=" << mask_credential(ak) << ", expiry=" << expiry_str; + return 0; +} + +// ---- OSSSTSCredentialProvider ---- + +OSSSTSCredentialProvider::OSSSTSCredentialProvider(const std::string& role_arn, + const std::string& region, + const std::string& external_id) + : _cached_credentials(nullptr), + _expiration(std::chrono::system_clock::now()), + _role_arn(role_arn), + _region(region), + _external_id(external_id) { + if (_role_arn.empty()) { + throw std::invalid_argument("RAM role ARN cannot be empty for STS AssumeRole"); + } + LOG(INFO) << "OSSSTSCredentialProvider: role_arn=" << _role_arn << ", region=" << _region + << ", external_id=" + << (_external_id.empty() ? "(none)" : mask_credential(_external_id)); +} + +bool OSSSTSCredentialProvider::_is_expired() const { + auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast<std::chrono::seconds>(_expiration - now).count() <= + REFRESH_BEFORE_EXPIRY_SECONDS; +} + +AlibabaCloud::OSS::Credentials OSSSTSCredentialProvider::getCredentials() { + { + std::lock_guard<std::mutex> lock(_mtx); + if (_cached_credentials && !_is_expired()) { + VLOG(2) << "Returning cached STS AssumeRole credentials"; + return *_cached_credentials; + } + if (_cached_credentials) { + auto t = std::chrono::system_clock::to_time_t(_expiration); + struct tm tm_buf; + LOG(INFO) << "STS credentials expiring (" + << std::put_time(localtime_r(&t, &tm_buf), "%Y-%m-%d %H:%M:%S") + << "), refreshing"; + } else { + LOG(INFO) << "Fetching STS AssumeRole credentials (first time)"; + } + } + + std::unique_ptr<AlibabaCloud::OSS::Credentials> new_credentials; + std::chrono::system_clock::time_point new_expiration; + + if (_fetch_credentials_from_sts(new_credentials, new_expiration) != 0) { + std::lock_guard<std::mutex> lock(_mtx); + if (_cached_credentials) { + LOG(WARNING) << "Using expired STS credentials as fallback"; + return *_cached_credentials; + } + LOG(ERROR) << "Failed to fetch STS AssumeRole credentials and no cached fallback available"; + return AlibabaCloud::OSS::Credentials("", "", ""); + } + + { + std::lock_guard<std::mutex> lock(_mtx); + if (_cached_credentials && !_is_expired()) { + return *_cached_credentials; + } + _cached_credentials = std::move(new_credentials); + _expiration = new_expiration; + auto t = std::chrono::system_clock::to_time_t(_expiration); + struct tm tm_buf; + LOG(INFO) << "STS AssumeRole credentials refreshed, expiry: " + << std::put_time(localtime_r(&t, &tm_buf), "%Y-%m-%d %H:%M:%S"); + return *_cached_credentials; + } +} + +int OSSSTSCredentialProvider::_fetch_credentials_from_sts( + std::unique_ptr<AlibabaCloud::OSS::Credentials>& out_credentials, + std::chrono::system_clock::time_point& out_expiration) { + try { + AlibabaCloud::Credentials::Models::Config cred_config; + cred_config.setType("ecs_ram_role"); + AlibabaCloud::Credentials::Client cred_client(cred_config); + AlibabaCloud::Credentials::Models::CredentialModel base_cred = cred_client.getCredential(); + LOG(INFO) << "STS AssumeRole base credentials from provider: " + << base_cred.getProviderName(); + + AlibabaCloud::OpenApi::Utils::Models::Config config; + config.setAccessKeyId(base_cred.getAccessKeyId()); + config.setAccessKeySecret(base_cred.getAccessKeySecret()); + if (!base_cred.getSecurityToken().empty()) { + config.setSecurityToken(base_cred.getSecurityToken()); + } + config.setRegionId(_region); + config.setEndpoint("sts." + _region + ".aliyuncs.com"); + + AlibabaCloud::Sts20150401::Client client(config); + + AlibabaCloud::Sts20150401::Models::AssumeRoleRequest request; + request.setRoleArn(_role_arn); + request.setRoleSessionName("doris-oss-session"); + request.setDurationSeconds(SESSION_DURATION_SECONDS); + if (!_external_id.empty()) { + request.setExternalId(_external_id); + } + + Darabonba::RuntimeOptions runtime; + runtime.setIgnoreSSL(true); + Review Comment: Nice Suggestion, fixed now ########## be/src/io/fs/oss_file_writer.cpp: ########## @@ -0,0 +1,503 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "io/fs/oss_file_writer.h" + +#include <alibabacloud/oss/OssClient.h> +#include <bvar/reducer.h> +#include <fmt/format.h> +#include <glog/logging.h> + +#include <algorithm> +#include <sstream> + +#include "common/config.h" +#include "common/status.h" +#include "io/cache/file_cache_common.h" +#include "io/fs/file_writer.h" +#include "io/fs/path.h" +#include "io/fs/s3_file_bufferpool.h" +#include "runtime/exec_env.h" +#include "util/debug_points.h" + +namespace doris::io { + +// Existing metrics +bvar::Adder<uint64_t> oss_file_writer_total("oss_file_writer_total_num"); +bvar::Adder<uint64_t> oss_bytes_written_total("oss_file_writer_bytes_written"); +bvar::Adder<uint64_t> oss_file_created_total("oss_file_writer_file_created"); +bvar::Adder<uint64_t> oss_file_being_written("oss_file_writer_file_being_written"); + +// New async metrics +bvar::Adder<int64_t> oss_file_writer_async_close_queuing("oss_file_writer_async_close_queuing"); +bvar::Adder<int64_t> oss_file_writer_async_close_processing( + "oss_file_writer_async_close_processing"); + +OSSFileWriter::OSSFileWriter(std::shared_ptr<OSSClientHolder> client, std::string bucket, + std::string key, const FileWriterOptions* opts) + : _path(fmt::format("oss://{}/{}", bucket, key)), + _bucket(std::move(bucket)), + _key(std::move(key)), + _client(std::move(client)), + _buffer_size(config::s3_write_buffer_size), + _used_by_oss_committer(opts ? opts->used_by_s3_committer : false) { + oss_file_writer_total << 1; + oss_file_being_written << 1; + + _completed_parts.reserve(100); + + init_cache_builder(opts, _path); +} + +OSSFileWriter::~OSSFileWriter() { + // Wait for any pending async operations to complete + _wait_until_finish("~OSSFileWriter"); + + // Abort multipart upload if not completed + if (state() != State::CLOSED && !_upload_id.empty()) { + LOG(WARNING) << "OSSFileWriter destroyed without close(), aborting multipart upload: " + << _path.native() << " upload_id: " << _upload_id; + static_cast<void>(_abort_multipart_upload()); + } + + if (state() == State::OPENED && !_failed) { + oss_bytes_written_total << _bytes_appended; + } + oss_file_being_written << -1; +} + +Status OSSFileWriter::appendv(const Slice* data, size_t data_cnt) { + if (state() != State::OPENED) { + return Status::InternalError("append to closed file: {}", _path.native()); + } + + for (size_t i = 0; i < data_cnt; i++) { + size_t data_size = data[i].size; + const char* data_ptr = data[i].data; + size_t pos = 0; + + while (pos < data_size) { + if (_failed) { + return _st; + } + + // Create new buffer if needed + if (!_pending_buf) { + RETURN_IF_ERROR(_build_upload_buffer()); + } + + size_t remaining = data_size - pos; + size_t buffer_remaining = _buffer_size - _pending_buf->get_size(); + size_t to_append = std::min(remaining, buffer_remaining); + + Slice s(data_ptr + pos, to_append); + RETURN_IF_ERROR(_pending_buf->append_data(s)); + + pos += to_append; + _bytes_appended += to_append; + + if (_pending_buf->get_size() == _buffer_size) { + // Create multipart upload on first buffer flush + if (_cur_part_num == 1) { + RETURN_IF_ERROR(_create_multipart_upload()); + } + + _cur_part_num++; + _countdown_event.add_count(); + RETURN_IF_ERROR(FileBuffer::submit(std::move(_pending_buf))); + _pending_buf = nullptr; + } + } + } + + return Status::OK(); +} + +Status OSSFileWriter::_build_upload_buffer() { + auto builder = FileBufferBuilder(); + builder.set_type(BufferType::UPLOAD) + .set_upload_callback([part_num = _cur_part_num, this](UploadFileBuffer& buf) { + _upload_one_part(part_num, buf); + }) + .set_file_offset(_bytes_appended) + .set_sync_after_complete_task([this](auto&& s) { + return _complete_part_task_callback(std::forward<decltype(s)>(s)); + }) + .set_is_cancelled([this]() { return _failed.load(); }); + + if (cache_builder() != nullptr) { + int64_t tablet_id = get_tablet_id(_path.native()).value_or(0); + builder.set_allocate_file_blocks_holder([builder = *cache_builder(), + offset = _bytes_appended, + tablet_id = tablet_id]() -> FileBlocksHolderPtr { + return builder.allocate_cache_holder(offset, config::s3_write_buffer_size, tablet_id); + }); + } + + RETURN_IF_ERROR(builder.build(&_pending_buf)); + return Status::OK(); +} + +void OSSFileWriter::_upload_one_part(int64_t part_num, UploadFileBuffer& buf) { + if (buf.is_cancelled()) { + LOG(INFO) << "OSS file " << _path.native() << " skip part " << part_num + << " because previous failure"; + return; + } + + // Debug point: Simulate upload failure + DBUG_EXECUTE_IF("OSSFileWriter::_upload_one_part.upload_error", { + auto fail_part = dp->param<int64_t>("fail_part_num", 0); + if (fail_part == 0 || fail_part == part_num) { + LOG(WARNING) << "Debug point: Simulating OSS upload failure for part " << part_num; + buf.set_status(Status::IOError("Debug OSS upload error for part {}", part_num)); + return; + } + }); + + // Debug point: Simulate slow upload + DBUG_EXECUTE_IF("OSSFileWriter::_upload_one_part.slow_upload", { + auto sleep_ms = dp->param<int>("sleep_ms", 1000); + LOG(INFO) << "Debug point: Simulating slow OSS upload, sleeping " << sleep_ms << "ms"; + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); + }); + + auto client = _client->get(); + if (nullptr == client) { + buf.set_status(Status::InternalError("OSS client not initialized")); + return; + } + + auto stream = buf.get_stream(); + if (!stream) { + buf.set_status(Status::InternalError("Failed to get stream from upload buffer for part {}", + part_num)); + return; + } + + AlibabaCloud::OSS::UploadPartRequest request(_bucket, _key, stream); + request.setPartNumber(static_cast<int32_t>(part_num)); + request.setUploadId(_upload_id); + request.setContentLength(buf.get_size()); + + auto outcome = client->UploadPart(request); + if (!outcome.isSuccess()) { + std::string err = fmt::format("OSS UploadPart {} failed: {} - {}", part_num, + outcome.error().Code(), outcome.error().Message()); + LOG(WARNING) << err << ", path: " << _path.native(); + buf.set_status(Status::IOError(err)); + return; + } + + oss_bytes_written_total << buf.get_size(); + + AlibabaCloud::OSS::Part part(static_cast<int32_t>(part_num), outcome.result().ETag()); + + { + std::lock_guard<std::mutex> lock(_completed_lock); + _completed_parts.push_back(part); + } + + VLOG_DEBUG << "OSS UploadPart " << part_num << " completed: " << _path.native() + << " size: " << buf.get_size(); +} + +bool OSSFileWriter::_complete_part_task_callback(Status s) { + if (!s.ok()) { + _failed = true; + _st = std::move(s); + LOG(WARNING) << "OSS async upload failed: " << _path.native() << " error: " << _st; Review Comment: implemented -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
