lidavidm commented on a change in pull request #10955: URL: https://github.com/apache/arrow/pull/10955#discussion_r716994853
########## File path: cpp/src/arrow/dataset/file_base.cc ########## @@ -115,7 +110,7 @@ Result<std::shared_ptr<FileFragment>> FileFormat::MakeFragment( } // TODO(ARROW-12355[CSV], ARROW-11772[IPC], ARROW-11843[Parquet]) The following Review comment: Side note, but all three JIRAs here are finished, maybe we should make the base method pure virtual? ########## File path: cpp/src/arrow/dataset/dataset_writer.cc ########## @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/dataset_writer.h" + +#include <list> +#include <mutex> +#include <unordered_map> + +#include "arrow/filesystem/path_util.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/map.h" +#include "arrow/util/string.h" + +namespace arrow { +namespace dataset { + +namespace { + +constexpr util::string_view kIntegerToken = "{i}"; + +class Throttle { + public: + explicit Throttle(uint64_t max_value) : max_value_(max_value) {} + + bool Unthrottled() const { return max_value_ <= 0; } + + Future<> Acquire(uint64_t values) { + if (Unthrottled()) { + return Future<>::MakeFinished(); + } + std::lock_guard<std::mutex> lg(mutex_); + if (values + current_value_ > max_value_) { + in_waiting_ = values; + backpressure_ = Future<>::Make(); + } else { + current_value_ += values; + } + return backpressure_; + } + + void Release(uint64_t values) { + if (Unthrottled()) { + return; + } + Future<> to_complete; + { + std::lock_guard<std::mutex> lg(mutex_); + current_value_ -= values; + if (in_waiting_ > 0 && in_waiting_ + current_value_ <= max_value_) { + in_waiting_ = 0; + to_complete = backpressure_; + } + } + if (to_complete.is_valid()) { + to_complete.MarkFinished(); + } + } + + private: + Future<> backpressure_ = Future<>::MakeFinished(); + uint64_t max_value_; + uint64_t in_waiting_ = 0; + uint64_t current_value_ = 0; + std::mutex mutex_; +}; + +class DatasetWriterFileQueue : public util::AsyncDestroyable { + public: + explicit DatasetWriterFileQueue(const Future<std::shared_ptr<FileWriter>>& writer_fut, + const FileSystemDatasetWriteOptions& options, + std::mutex* visitors_mutex) + : options_(options), visitors_mutex_(visitors_mutex) { + running_task_ = Future<>::Make(); + writer_fut.AddCallback( + [this](const Result<std::shared_ptr<FileWriter>>& maybe_writer) { + if (maybe_writer.ok()) { + writer_ = *maybe_writer; + Flush(); + } else { + Abort(maybe_writer.status()); + } + }); + } + + Future<uint64_t> Push(std::shared_ptr<RecordBatch> batch) { + std::unique_lock<std::mutex> lk(mutex); + write_queue_.push_back(std::move(batch)); + Future<uint64_t> write_future = Future<uint64_t>::Make(); + write_futures_.push_back(write_future); + if (!running_task_.is_valid()) { + running_task_ = Future<>::Make(); + FlushUnlocked(std::move(lk)); + } + return write_future; + } + + Future<> DoDestroy() override { + std::lock_guard<std::mutex> lg(mutex); + if (!running_task_.is_valid()) { + RETURN_NOT_OK(DoFinish()); + return Future<>::MakeFinished(); + } + return running_task_.Then([this] { return DoFinish(); }); + } + + private: + Future<uint64_t> WriteNext() { + // May want to prototype / measure someday pushing the async write down further + return DeferNotOk( + io::default_io_context().executor()->Submit([this]() -> Result<uint64_t> { + DCHECK(running_task_.is_valid()); + std::unique_lock<std::mutex> lk(mutex); + const std::shared_ptr<RecordBatch>& to_write = write_queue_.front(); + Future<uint64_t> on_complete = write_futures_.front(); + uint64_t rows_to_write = to_write->num_rows(); + lk.unlock(); + Status status = writer_->Write(to_write); + lk.lock(); + write_queue_.pop_front(); + write_futures_.pop_front(); + lk.unlock(); + if (!status.ok()) { + on_complete.MarkFinished(status); + } else { + on_complete.MarkFinished(rows_to_write); + } + return rows_to_write; + })); + } + + Status DoFinish() { + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + RETURN_NOT_OK(options_.writer_pre_finish(writer_.get())); + } + RETURN_NOT_OK(writer_->Finish()); + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + return options_.writer_post_finish(writer_.get()); + } + } + + void Abort(Status err) { + std::vector<Future<uint64_t>> futures_to_abort; + Future<> old_running_task = running_task_; + { + std::lock_guard<std::mutex> lg(mutex); + write_queue_.clear(); + futures_to_abort = + std::vector<Future<uint64_t>>(write_futures_.begin(), write_futures_.end()); + write_futures_.clear(); + running_task_ = Future<>(); + } + for (auto& fut : futures_to_abort) { + fut.MarkFinished(err); + } + old_running_task.MarkFinished(std::move(err)); + } + + void Flush() { + std::unique_lock<std::mutex> lk(mutex); + FlushUnlocked(std::move(lk)); + } + + void FlushUnlocked(std::unique_lock<std::mutex> lk) { + if (write_queue_.empty()) { + Future<> old_running_task = running_task_; + running_task_ = Future<>(); + lk.unlock(); + old_running_task.MarkFinished(); + return; + } + WriteNext().AddCallback([this](const Result<uint64_t>& res) { + if (res.ok()) { + Flush(); + } else { + Abort(res.status()); + } + }); + } + + const FileSystemDatasetWriteOptions& options_; + std::mutex* visitors_mutex_; + std::shared_ptr<FileWriter> writer_; + std::mutex mutex; + std::list<std::shared_ptr<RecordBatch>> write_queue_; + std::list<Future<uint64_t>> write_futures_; + Future<> running_task_; +}; + +struct WriteTask { + std::string filename; + uint64_t num_rows; +}; + +class DatasetWriterDirectoryQueue : public util::AsyncDestroyable { + public: + DatasetWriterDirectoryQueue(std::string directory, std::shared_ptr<Schema> schema, + const FileSystemDatasetWriteOptions& write_options, + Throttle* open_files_throttle, std::mutex* visitors_mutex) + : directory_(std::move(directory)), + schema_(std::move(schema)), + write_options_(write_options), + open_files_throttle_(open_files_throttle), + visitors_mutex_(visitors_mutex) {} + + Result<std::shared_ptr<RecordBatch>> NextWritableChunk( + std::shared_ptr<RecordBatch> batch, std::shared_ptr<RecordBatch>* remainder, + bool* will_open_file) const { + DCHECK_GT(batch->num_rows(), 0); + uint64_t rows_available = std::numeric_limits<uint64_t>::max(); + *will_open_file = rows_written_ == 0; + if (write_options_.max_rows_per_file > 0) { + rows_available = write_options_.max_rows_per_file - rows_written_; + } + + std::shared_ptr<RecordBatch> to_queue; + if (rows_available < static_cast<uint64_t>(batch->num_rows())) { + to_queue = batch->Slice(0, static_cast<int64_t>(rows_available)); + *remainder = batch->Slice(static_cast<int64_t>(rows_available)); + } else { + to_queue = std::move(batch); + } + return to_queue; + } + + Future<WriteTask> StartWrite(const std::shared_ptr<RecordBatch>& batch) { + rows_written_ += batch->num_rows(); + WriteTask task{current_filename_, static_cast<uint64_t>(batch->num_rows())}; + if (!latest_open_file_) { + ARROW_ASSIGN_OR_RAISE(latest_open_file_, OpenFileQueue(current_filename_)); + } + return latest_open_file_->Push(batch).Then([task] { return task; }); + } + + Result<std::string> GetNextFilename() { + auto basename = ::arrow::internal::Replace( + write_options_.basename_template, kIntegerToken, std::to_string(file_counter_++)); + if (!basename) { + return Status::Invalid("string interpolation of basename template failed"); + } + + return fs::internal::ConcatAbstractPath(directory_, *basename); + } + + Status FinishCurrentFile() { + if (latest_open_file_) { + latest_open_file_ = nullptr; + } + rows_written_ = 0; + return GetNextFilename().Value(¤t_filename_); + } + + Result<std::shared_ptr<FileWriter>> OpenWriter(const std::string& filename) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<io::OutputStream> out_stream, + write_options_.filesystem->OpenOutputStream(filename)); + return write_options_.format()->MakeWriter(std::move(out_stream), schema_, + write_options_.file_write_options, + {write_options_.filesystem, filename}); + } + + Result<std::shared_ptr<DatasetWriterFileQueue>> OpenFileQueue( + const std::string& filename) { + Future<std::shared_ptr<FileWriter>> file_writer_fut = + init_future_.Then([this, filename] { + ::arrow::internal::Executor* io_executor = + write_options_.filesystem->io_context().executor(); + return DeferNotOk( + io_executor->Submit([this, filename]() { return OpenWriter(filename); })); + }); + auto file_queue = util::MakeSharedAsync<DatasetWriterFileQueue>( + file_writer_fut, write_options_, visitors_mutex_); + RETURN_NOT_OK(task_group_.AddTask( + file_queue->on_closed().Then([this] { open_files_throttle_->Release(1); }))); + return file_queue; + } + + uint64_t rows_written() const { return rows_written_; } + + void PrepareDirectory() { + init_future_ = + DeferNotOk(write_options_.filesystem->io_context().executor()->Submit([this] { + RETURN_NOT_OK(write_options_.filesystem->CreateDir(directory_)); + if (write_options_.existing_data_behavior == kDeleteMatchingPartitions) { + fs::FileSelector selector; + selector.base_dir = directory_; + selector.recursive = true; + return write_options_.filesystem->DeleteFiles(selector); + } + return Status::OK(); + })); + } + + static Result<std::unique_ptr<DatasetWriterDirectoryQueue, + util::DestroyingDeleter<DatasetWriterDirectoryQueue>>> + Make(util::AsyncTaskGroup* task_group, + const FileSystemDatasetWriteOptions& write_options, Throttle* open_files_throttle, + std::shared_ptr<Schema> schema, std::string dir, std::mutex* visitors_mutex) { + auto dir_queue = util::MakeUniqueAsync<DatasetWriterDirectoryQueue>( + std::move(dir), std::move(schema), write_options, open_files_throttle, + visitors_mutex); + RETURN_NOT_OK(task_group->AddTask(dir_queue->on_closed())); + dir_queue->PrepareDirectory(); + ARROW_ASSIGN_OR_RAISE(dir_queue->current_filename_, dir_queue->GetNextFilename()); + // std::move required to make RTools 3.5 mingw compiler happy + return std::move(dir_queue); + } + + Future<> DoDestroy() override { + latest_open_file_.reset(); + return task_group_.WaitForTasksToFinish(); + } + + private: + util::AsyncTaskGroup task_group_; + std::string directory_; + std::shared_ptr<Schema> schema_; + const FileSystemDatasetWriteOptions& write_options_; + Throttle* open_files_throttle_; + std::mutex* visitors_mutex_; + Future<> init_future_; + std::string current_filename_; + std::shared_ptr<DatasetWriterFileQueue> latest_open_file_; + uint64_t rows_written_ = 0; + uint32_t file_counter_ = 0; +}; + +Status ValidateBasenameTemplate(util::string_view basename_template) { + if (basename_template.find(fs::internal::kSep) != util::string_view::npos) { + return Status::Invalid("basename_template contained '/'"); + } + size_t token_start = basename_template.find(kIntegerToken); + if (token_start == util::string_view::npos) { + return Status::Invalid("basename_template did not contain '", kIntegerToken, "'"); + } + return Status::OK(); +} + +Status EnsureDestinationValid(const FileSystemDatasetWriteOptions& options) { + if (options.existing_data_behavior == kError) { + fs::FileSelector selector; + selector.base_dir = options.base_dir; + selector.recursive = true; + Result<std::vector<fs::FileInfo>> maybe_files = + options.filesystem->GetFileInfo(selector); + if (!maybe_files.ok()) { + // If the path doesn't exist then continue + return Status::OK(); + } + if (maybe_files->size() > 1) { + return Status::Invalid( + "Could not write to ", options.base_dir, + " as the directory is not empty and existing_data_behavior is kError"); + } + } + return Status::OK(); +} + +} // namespace + +class DatasetWriter::DatasetWriterImpl : public util::AsyncDestroyable { + public: + DatasetWriterImpl(FileSystemDatasetWriteOptions write_options, uint64_t max_rows_queued) + : write_options_(std::move(write_options)), + rows_in_flight_throttle_(max_rows_queued), + open_files_throttle_(write_options.max_open_files) {} + + Future<> WriteRecordBatch(std::shared_ptr<RecordBatch> batch, + const std::string& directory) { + RETURN_NOT_OK(CheckError()); + if (batch->num_rows() == 0) { + return Future<>::MakeFinished(); + } + if (!directory.empty()) { + auto full_path = + fs::internal::ConcatAbstractPath(write_options_.base_dir, directory); + return DoWriteRecordBatch(std::move(batch), full_path); + } else { + return DoWriteRecordBatch(std::move(batch), write_options_.base_dir); + } + } + + protected: + Status CloseLargestFile() { + std::shared_ptr<DatasetWriterDirectoryQueue> largest = nullptr; + uint64_t largest_num_rows = 0; + for (auto& dir_queue : directory_queues_) { + if (dir_queue.second->rows_written() > largest_num_rows) { + largest_num_rows = dir_queue.second->rows_written(); + largest = dir_queue.second; + } + } + DCHECK_NE(largest, nullptr); + return largest->FinishCurrentFile(); + } + + Future<> DoWriteRecordBatch(std::shared_ptr<RecordBatch> batch, + const std::string& directory) { + ARROW_ASSIGN_OR_RAISE( + auto dir_queue_itr, + ::arrow::internal::GetOrInsertGenerated( + &directory_queues_, directory, [this, &batch](const std::string& dir) { + return DatasetWriterDirectoryQueue::Make( + &task_group_, write_options_, &open_files_throttle_, batch->schema(), + dir, &visitors_mutex_); + })); + std::shared_ptr<DatasetWriterDirectoryQueue> dir_queue = dir_queue_itr->second; + std::vector<Future<WriteTask>> scheduled_writes; + Future<> backpressure; + while (batch) { + // Keep opening new files until batch is done. + std::shared_ptr<RecordBatch> remainder; + bool will_open_file = false; + ARROW_ASSIGN_OR_RAISE(auto next_chunk, dir_queue->NextWritableChunk( + batch, &remainder, &will_open_file)); + + backpressure = rows_in_flight_throttle_.Acquire(next_chunk->num_rows()); + if (!backpressure.is_finished()) { + break; Review comment: Why don't we overwrite batch with remainder here? ########## File path: cpp/src/arrow/util/counting_semaphore.cc ########## @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/counting_semaphore.h" + +#include <chrono> +#include <condition_variable> +#include <cstdint> +#include <iostream> +#include <mutex> + +#include "arrow/status.h" + +namespace arrow { +namespace util { + +class CountingSemaphore::Impl { + public: + Impl(uint32_t initial_avail, double timeout_seconds) + : num_permits_(initial_avail), timeout_seconds_(timeout_seconds) {} + + Status Acquire(uint32_t num_permits) { + std::unique_lock<std::mutex> lk(mutex_); + RETURN_NOT_OK(CheckClosed()); + num_waiters_ += num_permits; + waiter_cv_.notify_all(); + bool timed_out = !acquirer_cv_.wait_for( + lk, std::chrono::nanoseconds(static_cast<int64_t>(timeout_seconds_ * 1e9)), + [&] { return closed_ || num_permits <= num_permits_; }); + num_waiters_ -= num_permits; Review comment: Shouldn't we notify waiter_cv_ here too? ########## File path: cpp/src/arrow/dataset/file_base.cc ########## @@ -115,7 +110,7 @@ Result<std::shared_ptr<FileFragment>> FileFormat::MakeFragment( } // TODO(ARROW-12355[CSV], ARROW-11772[IPC], ARROW-11843[Parquet]) The following Review comment: Or else update the comment at least. ########## File path: cpp/src/arrow/dataset/dataset_writer.cc ########## @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/dataset_writer.h" + +#include <list> +#include <mutex> +#include <unordered_map> + +#include "arrow/filesystem/path_util.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/map.h" +#include "arrow/util/string.h" + +namespace arrow { +namespace dataset { + +namespace { + +constexpr util::string_view kIntegerToken = "{i}"; + +class Throttle { + public: + explicit Throttle(uint64_t max_value) : max_value_(max_value) {} + + bool Unthrottled() const { return max_value_ <= 0; } + + Future<> Acquire(uint64_t values) { + if (Unthrottled()) { + return Future<>::MakeFinished(); + } + std::lock_guard<std::mutex> lg(mutex_); + if (values + current_value_ > max_value_) { + in_waiting_ = values; + backpressure_ = Future<>::Make(); + } else { + current_value_ += values; + } + return backpressure_; + } + + void Release(uint64_t values) { + if (Unthrottled()) { + return; + } + Future<> to_complete; + { + std::lock_guard<std::mutex> lg(mutex_); + current_value_ -= values; + if (in_waiting_ > 0 && in_waiting_ + current_value_ <= max_value_) { + in_waiting_ = 0; + to_complete = backpressure_; + } + } + if (to_complete.is_valid()) { + to_complete.MarkFinished(); + } + } + + private: + Future<> backpressure_ = Future<>::MakeFinished(); + uint64_t max_value_; + uint64_t in_waiting_ = 0; + uint64_t current_value_ = 0; + std::mutex mutex_; +}; + +class DatasetWriterFileQueue : public util::AsyncDestroyable { + public: + explicit DatasetWriterFileQueue(const Future<std::shared_ptr<FileWriter>>& writer_fut, + const FileSystemDatasetWriteOptions& options, + std::mutex* visitors_mutex) + : options_(options), visitors_mutex_(visitors_mutex) { + running_task_ = Future<>::Make(); + writer_fut.AddCallback( + [this](const Result<std::shared_ptr<FileWriter>>& maybe_writer) { + if (maybe_writer.ok()) { + writer_ = *maybe_writer; + Flush(); + } else { + Abort(maybe_writer.status()); + } + }); + } + + Future<uint64_t> Push(std::shared_ptr<RecordBatch> batch) { + std::unique_lock<std::mutex> lk(mutex); + write_queue_.push_back(std::move(batch)); + Future<uint64_t> write_future = Future<uint64_t>::Make(); + write_futures_.push_back(write_future); + if (!running_task_.is_valid()) { + running_task_ = Future<>::Make(); + FlushUnlocked(std::move(lk)); + } + return write_future; + } + + Future<> DoDestroy() override { + std::lock_guard<std::mutex> lg(mutex); + if (!running_task_.is_valid()) { + RETURN_NOT_OK(DoFinish()); + return Future<>::MakeFinished(); + } + return running_task_.Then([this] { return DoFinish(); }); + } + + private: + Future<uint64_t> WriteNext() { + // May want to prototype / measure someday pushing the async write down further + return DeferNotOk( + io::default_io_context().executor()->Submit([this]() -> Result<uint64_t> { + DCHECK(running_task_.is_valid()); + std::unique_lock<std::mutex> lk(mutex); + const std::shared_ptr<RecordBatch>& to_write = write_queue_.front(); + Future<uint64_t> on_complete = write_futures_.front(); + uint64_t rows_to_write = to_write->num_rows(); + lk.unlock(); + Status status = writer_->Write(to_write); + lk.lock(); + write_queue_.pop_front(); + write_futures_.pop_front(); + lk.unlock(); + if (!status.ok()) { + on_complete.MarkFinished(status); + } else { + on_complete.MarkFinished(rows_to_write); + } + return rows_to_write; + })); + } + + Status DoFinish() { + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + RETURN_NOT_OK(options_.writer_pre_finish(writer_.get())); + } + RETURN_NOT_OK(writer_->Finish()); + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + return options_.writer_post_finish(writer_.get()); + } + } + + void Abort(Status err) { + std::vector<Future<uint64_t>> futures_to_abort; + Future<> old_running_task = running_task_; + { + std::lock_guard<std::mutex> lg(mutex); + write_queue_.clear(); + futures_to_abort = + std::vector<Future<uint64_t>>(write_futures_.begin(), write_futures_.end()); + write_futures_.clear(); + running_task_ = Future<>(); + } + for (auto& fut : futures_to_abort) { + fut.MarkFinished(err); + } + old_running_task.MarkFinished(std::move(err)); + } + + void Flush() { + std::unique_lock<std::mutex> lk(mutex); + FlushUnlocked(std::move(lk)); + } + + void FlushUnlocked(std::unique_lock<std::mutex> lk) { + if (write_queue_.empty()) { + Future<> old_running_task = running_task_; + running_task_ = Future<>(); + lk.unlock(); + old_running_task.MarkFinished(); + return; + } + WriteNext().AddCallback([this](const Result<uint64_t>& res) { + if (res.ok()) { + Flush(); + } else { + Abort(res.status()); + } + }); + } + + const FileSystemDatasetWriteOptions& options_; + std::mutex* visitors_mutex_; + std::shared_ptr<FileWriter> writer_; + std::mutex mutex; + std::list<std::shared_ptr<RecordBatch>> write_queue_; + std::list<Future<uint64_t>> write_futures_; + Future<> running_task_; +}; + +struct WriteTask { + std::string filename; + uint64_t num_rows; +}; + +class DatasetWriterDirectoryQueue : public util::AsyncDestroyable { + public: + DatasetWriterDirectoryQueue(std::string directory, std::shared_ptr<Schema> schema, + const FileSystemDatasetWriteOptions& write_options, + Throttle* open_files_throttle, std::mutex* visitors_mutex) + : directory_(std::move(directory)), + schema_(std::move(schema)), + write_options_(write_options), + open_files_throttle_(open_files_throttle), + visitors_mutex_(visitors_mutex) {} + + Result<std::shared_ptr<RecordBatch>> NextWritableChunk( + std::shared_ptr<RecordBatch> batch, std::shared_ptr<RecordBatch>* remainder, + bool* will_open_file) const { + DCHECK_GT(batch->num_rows(), 0); + uint64_t rows_available = std::numeric_limits<uint64_t>::max(); + *will_open_file = rows_written_ == 0; + if (write_options_.max_rows_per_file > 0) { + rows_available = write_options_.max_rows_per_file - rows_written_; + } + + std::shared_ptr<RecordBatch> to_queue; + if (rows_available < static_cast<uint64_t>(batch->num_rows())) { + to_queue = batch->Slice(0, static_cast<int64_t>(rows_available)); + *remainder = batch->Slice(static_cast<int64_t>(rows_available)); + } else { + to_queue = std::move(batch); + } + return to_queue; + } + + Future<WriteTask> StartWrite(const std::shared_ptr<RecordBatch>& batch) { + rows_written_ += batch->num_rows(); + WriteTask task{current_filename_, static_cast<uint64_t>(batch->num_rows())}; + if (!latest_open_file_) { + ARROW_ASSIGN_OR_RAISE(latest_open_file_, OpenFileQueue(current_filename_)); + } + return latest_open_file_->Push(batch).Then([task] { return task; }); + } + + Result<std::string> GetNextFilename() { + auto basename = ::arrow::internal::Replace( + write_options_.basename_template, kIntegerToken, std::to_string(file_counter_++)); + if (!basename) { + return Status::Invalid("string interpolation of basename template failed"); + } + + return fs::internal::ConcatAbstractPath(directory_, *basename); + } + + Status FinishCurrentFile() { + if (latest_open_file_) { + latest_open_file_ = nullptr; + } + rows_written_ = 0; + return GetNextFilename().Value(¤t_filename_); + } + + Result<std::shared_ptr<FileWriter>> OpenWriter(const std::string& filename) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<io::OutputStream> out_stream, + write_options_.filesystem->OpenOutputStream(filename)); + return write_options_.format()->MakeWriter(std::move(out_stream), schema_, + write_options_.file_write_options, + {write_options_.filesystem, filename}); + } + + Result<std::shared_ptr<DatasetWriterFileQueue>> OpenFileQueue( + const std::string& filename) { + Future<std::shared_ptr<FileWriter>> file_writer_fut = + init_future_.Then([this, filename] { + ::arrow::internal::Executor* io_executor = + write_options_.filesystem->io_context().executor(); + return DeferNotOk( + io_executor->Submit([this, filename]() { return OpenWriter(filename); })); + }); + auto file_queue = util::MakeSharedAsync<DatasetWriterFileQueue>( + file_writer_fut, write_options_, visitors_mutex_); + RETURN_NOT_OK(task_group_.AddTask( + file_queue->on_closed().Then([this] { open_files_throttle_->Release(1); }))); + return file_queue; + } + + uint64_t rows_written() const { return rows_written_; } + + void PrepareDirectory() { + init_future_ = + DeferNotOk(write_options_.filesystem->io_context().executor()->Submit([this] { + RETURN_NOT_OK(write_options_.filesystem->CreateDir(directory_)); + if (write_options_.existing_data_behavior == kDeleteMatchingPartitions) { + fs::FileSelector selector; + selector.base_dir = directory_; + selector.recursive = true; + return write_options_.filesystem->DeleteFiles(selector); + } + return Status::OK(); + })); + } + + static Result<std::unique_ptr<DatasetWriterDirectoryQueue, + util::DestroyingDeleter<DatasetWriterDirectoryQueue>>> + Make(util::AsyncTaskGroup* task_group, + const FileSystemDatasetWriteOptions& write_options, Throttle* open_files_throttle, + std::shared_ptr<Schema> schema, std::string dir, std::mutex* visitors_mutex) { + auto dir_queue = util::MakeUniqueAsync<DatasetWriterDirectoryQueue>( + std::move(dir), std::move(schema), write_options, open_files_throttle, + visitors_mutex); + RETURN_NOT_OK(task_group->AddTask(dir_queue->on_closed())); + dir_queue->PrepareDirectory(); + ARROW_ASSIGN_OR_RAISE(dir_queue->current_filename_, dir_queue->GetNextFilename()); + // std::move required to make RTools 3.5 mingw compiler happy + return std::move(dir_queue); + } + + Future<> DoDestroy() override { + latest_open_file_.reset(); + return task_group_.WaitForTasksToFinish(); + } + + private: + util::AsyncTaskGroup task_group_; + std::string directory_; + std::shared_ptr<Schema> schema_; + const FileSystemDatasetWriteOptions& write_options_; + Throttle* open_files_throttle_; + std::mutex* visitors_mutex_; + Future<> init_future_; + std::string current_filename_; + std::shared_ptr<DatasetWriterFileQueue> latest_open_file_; + uint64_t rows_written_ = 0; + uint32_t file_counter_ = 0; +}; + +Status ValidateBasenameTemplate(util::string_view basename_template) { + if (basename_template.find(fs::internal::kSep) != util::string_view::npos) { + return Status::Invalid("basename_template contained '/'"); + } + size_t token_start = basename_template.find(kIntegerToken); + if (token_start == util::string_view::npos) { + return Status::Invalid("basename_template did not contain '", kIntegerToken, "'"); + } + return Status::OK(); +} + +Status EnsureDestinationValid(const FileSystemDatasetWriteOptions& options) { + if (options.existing_data_behavior == kError) { + fs::FileSelector selector; + selector.base_dir = options.base_dir; + selector.recursive = true; + Result<std::vector<fs::FileInfo>> maybe_files = + options.filesystem->GetFileInfo(selector); + if (!maybe_files.ok()) { + // If the path doesn't exist then continue + return Status::OK(); + } + if (maybe_files->size() > 1) { + return Status::Invalid( + "Could not write to ", options.base_dir, + " as the directory is not empty and existing_data_behavior is kError"); + } + } + return Status::OK(); +} + +} // namespace + +class DatasetWriter::DatasetWriterImpl : public util::AsyncDestroyable { + public: + DatasetWriterImpl(FileSystemDatasetWriteOptions write_options, uint64_t max_rows_queued) + : write_options_(std::move(write_options)), + rows_in_flight_throttle_(max_rows_queued), + open_files_throttle_(write_options.max_open_files) {} + + Future<> WriteRecordBatch(std::shared_ptr<RecordBatch> batch, + const std::string& directory) { + RETURN_NOT_OK(CheckError()); + if (batch->num_rows() == 0) { + return Future<>::MakeFinished(); + } + if (!directory.empty()) { + auto full_path = + fs::internal::ConcatAbstractPath(write_options_.base_dir, directory); + return DoWriteRecordBatch(std::move(batch), full_path); + } else { + return DoWriteRecordBatch(std::move(batch), write_options_.base_dir); + } + } + + protected: + Status CloseLargestFile() { + std::shared_ptr<DatasetWriterDirectoryQueue> largest = nullptr; + uint64_t largest_num_rows = 0; + for (auto& dir_queue : directory_queues_) { + if (dir_queue.second->rows_written() > largest_num_rows) { + largest_num_rows = dir_queue.second->rows_written(); + largest = dir_queue.second; + } + } + DCHECK_NE(largest, nullptr); + return largest->FinishCurrentFile(); + } + + Future<> DoWriteRecordBatch(std::shared_ptr<RecordBatch> batch, + const std::string& directory) { + ARROW_ASSIGN_OR_RAISE( + auto dir_queue_itr, + ::arrow::internal::GetOrInsertGenerated( + &directory_queues_, directory, [this, &batch](const std::string& dir) { + return DatasetWriterDirectoryQueue::Make( + &task_group_, write_options_, &open_files_throttle_, batch->schema(), + dir, &visitors_mutex_); + })); + std::shared_ptr<DatasetWriterDirectoryQueue> dir_queue = dir_queue_itr->second; + std::vector<Future<WriteTask>> scheduled_writes; + Future<> backpressure; + while (batch) { + // Keep opening new files until batch is done. + std::shared_ptr<RecordBatch> remainder; + bool will_open_file = false; + ARROW_ASSIGN_OR_RAISE(auto next_chunk, dir_queue->NextWritableChunk( + batch, &remainder, &will_open_file)); + + backpressure = rows_in_flight_throttle_.Acquire(next_chunk->num_rows()); + if (!backpressure.is_finished()) { + break; Review comment: And also, we could just pass &batch to avoid having to consider this. ########## File path: cpp/src/arrow/dataset/dataset_writer.cc ########## @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/dataset_writer.h" + +#include <list> +#include <mutex> +#include <unordered_map> + +#include "arrow/filesystem/path_util.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/map.h" +#include "arrow/util/string.h" + +namespace arrow { +namespace dataset { + +namespace { + +constexpr util::string_view kIntegerToken = "{i}"; + +class Throttle { + public: + explicit Throttle(uint64_t max_value) : max_value_(max_value) {} + + bool Unthrottled() const { return max_value_ <= 0; } + + Future<> Acquire(uint64_t values) { + if (Unthrottled()) { + return Future<>::MakeFinished(); + } + std::lock_guard<std::mutex> lg(mutex_); + if (values + current_value_ > max_value_) { + in_waiting_ = values; + backpressure_ = Future<>::Make(); + } else { + current_value_ += values; + } + return backpressure_; + } + + void Release(uint64_t values) { + if (Unthrottled()) { + return; + } + Future<> to_complete; + { + std::lock_guard<std::mutex> lg(mutex_); + current_value_ -= values; + if (in_waiting_ > 0 && in_waiting_ + current_value_ <= max_value_) { + in_waiting_ = 0; + to_complete = backpressure_; + } + } + if (to_complete.is_valid()) { + to_complete.MarkFinished(); + } + } + + private: + Future<> backpressure_ = Future<>::MakeFinished(); + uint64_t max_value_; + uint64_t in_waiting_ = 0; + uint64_t current_value_ = 0; + std::mutex mutex_; +}; + +class DatasetWriterFileQueue : public util::AsyncDestroyable { + public: + explicit DatasetWriterFileQueue(const Future<std::shared_ptr<FileWriter>>& writer_fut, + const FileSystemDatasetWriteOptions& options, + std::mutex* visitors_mutex) + : options_(options), visitors_mutex_(visitors_mutex) { + running_task_ = Future<>::Make(); + writer_fut.AddCallback( + [this](const Result<std::shared_ptr<FileWriter>>& maybe_writer) { + if (maybe_writer.ok()) { + writer_ = *maybe_writer; + Flush(); + } else { + Abort(maybe_writer.status()); + } + }); + } + + Future<uint64_t> Push(std::shared_ptr<RecordBatch> batch) { + std::unique_lock<std::mutex> lk(mutex); + write_queue_.push_back(std::move(batch)); + Future<uint64_t> write_future = Future<uint64_t>::Make(); + write_futures_.push_back(write_future); + if (!running_task_.is_valid()) { + running_task_ = Future<>::Make(); + FlushUnlocked(std::move(lk)); + } + return write_future; + } + + Future<> DoDestroy() override { + std::lock_guard<std::mutex> lg(mutex); + if (!running_task_.is_valid()) { + RETURN_NOT_OK(DoFinish()); + return Future<>::MakeFinished(); + } + return running_task_.Then([this] { return DoFinish(); }); + } + + private: + Future<uint64_t> WriteNext() { + // May want to prototype / measure someday pushing the async write down further + return DeferNotOk( + io::default_io_context().executor()->Submit([this]() -> Result<uint64_t> { + DCHECK(running_task_.is_valid()); + std::unique_lock<std::mutex> lk(mutex); + const std::shared_ptr<RecordBatch>& to_write = write_queue_.front(); + Future<uint64_t> on_complete = write_futures_.front(); + uint64_t rows_to_write = to_write->num_rows(); + lk.unlock(); + Status status = writer_->Write(to_write); + lk.lock(); + write_queue_.pop_front(); + write_futures_.pop_front(); + lk.unlock(); + if (!status.ok()) { + on_complete.MarkFinished(status); + } else { + on_complete.MarkFinished(rows_to_write); + } + return rows_to_write; + })); + } + + Status DoFinish() { + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + RETURN_NOT_OK(options_.writer_pre_finish(writer_.get())); + } + RETURN_NOT_OK(writer_->Finish()); + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + return options_.writer_post_finish(writer_.get()); + } + } + + void Abort(Status err) { + std::vector<Future<uint64_t>> futures_to_abort; + Future<> old_running_task = running_task_; + { + std::lock_guard<std::mutex> lg(mutex); + write_queue_.clear(); + futures_to_abort = + std::vector<Future<uint64_t>>(write_futures_.begin(), write_futures_.end()); + write_futures_.clear(); + running_task_ = Future<>(); + } + for (auto& fut : futures_to_abort) { + fut.MarkFinished(err); + } + old_running_task.MarkFinished(std::move(err)); + } + + void Flush() { + std::unique_lock<std::mutex> lk(mutex); + FlushUnlocked(std::move(lk)); + } + + void FlushUnlocked(std::unique_lock<std::mutex> lk) { + if (write_queue_.empty()) { + Future<> old_running_task = running_task_; + running_task_ = Future<>(); + lk.unlock(); + old_running_task.MarkFinished(); + return; + } + WriteNext().AddCallback([this](const Result<uint64_t>& res) { + if (res.ok()) { + Flush(); + } else { + Abort(res.status()); + } + }); + } + + const FileSystemDatasetWriteOptions& options_; + std::mutex* visitors_mutex_; + std::shared_ptr<FileWriter> writer_; + std::mutex mutex; + std::list<std::shared_ptr<RecordBatch>> write_queue_; + std::list<Future<uint64_t>> write_futures_; + Future<> running_task_; +}; + +struct WriteTask { + std::string filename; + uint64_t num_rows; +}; + +class DatasetWriterDirectoryQueue : public util::AsyncDestroyable { + public: + DatasetWriterDirectoryQueue(std::string directory, std::shared_ptr<Schema> schema, + const FileSystemDatasetWriteOptions& write_options, + Throttle* open_files_throttle, std::mutex* visitors_mutex) + : directory_(std::move(directory)), + schema_(std::move(schema)), + write_options_(write_options), + open_files_throttle_(open_files_throttle), + visitors_mutex_(visitors_mutex) {} + + Result<std::shared_ptr<RecordBatch>> NextWritableChunk( + std::shared_ptr<RecordBatch> batch, std::shared_ptr<RecordBatch>* remainder, + bool* will_open_file) const { + DCHECK_GT(batch->num_rows(), 0); + uint64_t rows_available = std::numeric_limits<uint64_t>::max(); + *will_open_file = rows_written_ == 0; + if (write_options_.max_rows_per_file > 0) { + rows_available = write_options_.max_rows_per_file - rows_written_; + } + + std::shared_ptr<RecordBatch> to_queue; + if (rows_available < static_cast<uint64_t>(batch->num_rows())) { + to_queue = batch->Slice(0, static_cast<int64_t>(rows_available)); + *remainder = batch->Slice(static_cast<int64_t>(rows_available)); + } else { + to_queue = std::move(batch); + } + return to_queue; + } + + Future<WriteTask> StartWrite(const std::shared_ptr<RecordBatch>& batch) { + rows_written_ += batch->num_rows(); + WriteTask task{current_filename_, static_cast<uint64_t>(batch->num_rows())}; + if (!latest_open_file_) { + ARROW_ASSIGN_OR_RAISE(latest_open_file_, OpenFileQueue(current_filename_)); + } + return latest_open_file_->Push(batch).Then([task] { return task; }); + } + + Result<std::string> GetNextFilename() { + auto basename = ::arrow::internal::Replace( + write_options_.basename_template, kIntegerToken, std::to_string(file_counter_++)); + if (!basename) { + return Status::Invalid("string interpolation of basename template failed"); + } + + return fs::internal::ConcatAbstractPath(directory_, *basename); + } + + Status FinishCurrentFile() { + if (latest_open_file_) { + latest_open_file_ = nullptr; + } + rows_written_ = 0; + return GetNextFilename().Value(¤t_filename_); + } + + Result<std::shared_ptr<FileWriter>> OpenWriter(const std::string& filename) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<io::OutputStream> out_stream, + write_options_.filesystem->OpenOutputStream(filename)); + return write_options_.format()->MakeWriter(std::move(out_stream), schema_, + write_options_.file_write_options, + {write_options_.filesystem, filename}); + } + + Result<std::shared_ptr<DatasetWriterFileQueue>> OpenFileQueue( + const std::string& filename) { + Future<std::shared_ptr<FileWriter>> file_writer_fut = + init_future_.Then([this, filename] { + ::arrow::internal::Executor* io_executor = + write_options_.filesystem->io_context().executor(); + return DeferNotOk( + io_executor->Submit([this, filename]() { return OpenWriter(filename); })); + }); + auto file_queue = util::MakeSharedAsync<DatasetWriterFileQueue>( + file_writer_fut, write_options_, visitors_mutex_); + RETURN_NOT_OK(task_group_.AddTask( + file_queue->on_closed().Then([this] { open_files_throttle_->Release(1); }))); + return file_queue; + } + + uint64_t rows_written() const { return rows_written_; } + + void PrepareDirectory() { + init_future_ = + DeferNotOk(write_options_.filesystem->io_context().executor()->Submit([this] { + RETURN_NOT_OK(write_options_.filesystem->CreateDir(directory_)); + if (write_options_.existing_data_behavior == kDeleteMatchingPartitions) { + fs::FileSelector selector; + selector.base_dir = directory_; + selector.recursive = true; + return write_options_.filesystem->DeleteFiles(selector); + } + return Status::OK(); + })); + } + + static Result<std::unique_ptr<DatasetWriterDirectoryQueue, + util::DestroyingDeleter<DatasetWriterDirectoryQueue>>> + Make(util::AsyncTaskGroup* task_group, + const FileSystemDatasetWriteOptions& write_options, Throttle* open_files_throttle, + std::shared_ptr<Schema> schema, std::string dir, std::mutex* visitors_mutex) { + auto dir_queue = util::MakeUniqueAsync<DatasetWriterDirectoryQueue>( + std::move(dir), std::move(schema), write_options, open_files_throttle, + visitors_mutex); + RETURN_NOT_OK(task_group->AddTask(dir_queue->on_closed())); + dir_queue->PrepareDirectory(); + ARROW_ASSIGN_OR_RAISE(dir_queue->current_filename_, dir_queue->GetNextFilename()); + // std::move required to make RTools 3.5 mingw compiler happy + return std::move(dir_queue); + } + + Future<> DoDestroy() override { + latest_open_file_.reset(); + return task_group_.WaitForTasksToFinish(); + } + + private: + util::AsyncTaskGroup task_group_; + std::string directory_; + std::shared_ptr<Schema> schema_; + const FileSystemDatasetWriteOptions& write_options_; + Throttle* open_files_throttle_; + std::mutex* visitors_mutex_; + Future<> init_future_; + std::string current_filename_; + std::shared_ptr<DatasetWriterFileQueue> latest_open_file_; + uint64_t rows_written_ = 0; + uint32_t file_counter_ = 0; +}; + +Status ValidateBasenameTemplate(util::string_view basename_template) { + if (basename_template.find(fs::internal::kSep) != util::string_view::npos) { + return Status::Invalid("basename_template contained '/'"); + } + size_t token_start = basename_template.find(kIntegerToken); + if (token_start == util::string_view::npos) { + return Status::Invalid("basename_template did not contain '", kIntegerToken, "'"); + } + return Status::OK(); +} + +Status EnsureDestinationValid(const FileSystemDatasetWriteOptions& options) { + if (options.existing_data_behavior == kError) { + fs::FileSelector selector; + selector.base_dir = options.base_dir; + selector.recursive = true; + Result<std::vector<fs::FileInfo>> maybe_files = + options.filesystem->GetFileInfo(selector); + if (!maybe_files.ok()) { + // If the path doesn't exist then continue + return Status::OK(); + } + if (maybe_files->size() > 1) { + return Status::Invalid( + "Could not write to ", options.base_dir, + " as the directory is not empty and existing_data_behavior is kError"); + } + } + return Status::OK(); +} + +} // namespace + +class DatasetWriter::DatasetWriterImpl : public util::AsyncDestroyable { + public: + DatasetWriterImpl(FileSystemDatasetWriteOptions write_options, uint64_t max_rows_queued) + : write_options_(std::move(write_options)), + rows_in_flight_throttle_(max_rows_queued), + open_files_throttle_(write_options.max_open_files) {} Review comment: This should technically be write_options_ since we moved out of the parameter right? ########## File path: cpp/src/arrow/dataset/file_base.cc ########## @@ -327,222 +322,70 @@ Status FileWriter::Finish() { namespace { -constexpr util::string_view kIntegerToken = "{i}"; +Future<> WriteNextBatch(DatasetWriter* dataset_writer, TaggedRecordBatch batch, + const FileSystemDatasetWriteOptions& write_options) { + ARROW_ASSIGN_OR_RAISE(auto groups, + write_options.partitioning->Partition(batch.record_batch)); + batch.record_batch.reset(); // drop to hopefully conserve memory -Status ValidateBasenameTemplate(util::string_view basename_template) { - if (basename_template.find(fs::internal::kSep) != util::string_view::npos) { - return Status::Invalid("basename_template contained '/'"); - } - size_t token_start = basename_template.find(kIntegerToken); - if (token_start == util::string_view::npos) { - return Status::Invalid("basename_template did not contain '", kIntegerToken, "'"); - } - return Status::OK(); -} - -/// WriteQueue allows batches to be pushed from multiple threads while another thread -/// flushes some to disk. -class WriteQueue { - public: - WriteQueue(std::string partition_expression, size_t index, - std::shared_ptr<Schema> schema) - : partition_expression_(std::move(partition_expression)), - index_(index), - schema_(std::move(schema)) {} - - // Push a batch into the writer's queue of pending writes. - void Push(std::shared_ptr<RecordBatch> batch) { - auto push_lock = push_mutex_.Lock(); - pending_.push_back(std::move(batch)); - } - - // Flush all pending batches, or return immediately if another thread is already - // flushing this queue. - Status Flush(const FileSystemDatasetWriteOptions& write_options) { - if (auto writer_lock = writer_mutex_.TryLock()) { - if (writer_ == nullptr) { - // FileWriters are opened lazily to avoid blocking access to a scan-wide queue set - RETURN_NOT_OK(OpenWriter(write_options)); - } - - while (true) { - std::shared_ptr<RecordBatch> batch; - { - auto push_lock = push_mutex_.Lock(); - if (pending_.empty()) { - // Ensure the writer_lock is released before the push_lock. Otherwise another - // thread might successfully Push() a batch but then fail to Flush() it since - // the writer_lock is still held, leaving an unflushed batch in pending_. - writer_lock.Unlock(); - break; - } - batch = std::move(pending_.front()); - pending_.pop_front(); - } - RETURN_NOT_OK(writer_->Write(batch)); - } - } - return Status::OK(); - } - - const std::shared_ptr<FileWriter>& writer() const { return writer_; } - - private: - Status OpenWriter(const FileSystemDatasetWriteOptions& write_options) { - auto dir = - fs::internal::EnsureTrailingSlash(write_options.base_dir) + partition_expression_; - - auto basename = ::arrow::internal::Replace(write_options.basename_template, - kIntegerToken, std::to_string(index_)); - if (!basename) { - return Status::Invalid("string interpolation of basename template failed"); - } - - auto path = fs::internal::ConcatAbstractPath(dir, *basename); - - RETURN_NOT_OK(write_options.filesystem->CreateDir(dir)); - ARROW_ASSIGN_OR_RAISE(auto destination, - write_options.filesystem->OpenOutputStream(path)); - - ARROW_ASSIGN_OR_RAISE( - writer_, write_options.format()->MakeWriter(std::move(destination), schema_, - write_options.file_write_options, - {write_options.filesystem, path})); - return Status::OK(); - } - - util::Mutex writer_mutex_; - std::shared_ptr<FileWriter> writer_; - - util::Mutex push_mutex_; - std::deque<std::shared_ptr<RecordBatch>> pending_; - - // The (formatted) partition expression to which this queue corresponds - std::string partition_expression_; - - size_t index_; - - std::shared_ptr<Schema> schema_; -}; - -struct WriteState { - explicit WriteState(FileSystemDatasetWriteOptions write_options) - : write_options(std::move(write_options)) {} - - FileSystemDatasetWriteOptions write_options; - util::Mutex mutex; - std::unordered_map<std::string, std::unique_ptr<WriteQueue>> queues; -}; - -Status WriteNextBatch(WriteState* state, const std::shared_ptr<Fragment>& fragment, - std::shared_ptr<RecordBatch> batch) { - ARROW_ASSIGN_OR_RAISE(auto groups, state->write_options.partitioning->Partition(batch)); - batch.reset(); // drop to hopefully conserve memory - - if (groups.batches.size() > static_cast<size_t>(state->write_options.max_partitions)) { + if (groups.batches.size() > static_cast<size_t>(write_options.max_partitions)) { return Status::Invalid("Fragment would be written into ", groups.batches.size(), " partitions. This exceeds the maximum of ", - state->write_options.max_partitions); + write_options.max_partitions); } - std::unordered_set<WriteQueue*> need_flushed; - for (size_t i = 0; i < groups.batches.size(); ++i) { - auto partition_expression = - and_(std::move(groups.expressions[i]), fragment->partition_expression()); - auto batch = std::move(groups.batches[i]); - - ARROW_ASSIGN_OR_RAISE( - auto part, state->write_options.partitioning->Format(partition_expression)); - - WriteQueue* queue; - { - // lookup the queue to which batch should be appended - auto queues_lock = state->mutex.Lock(); - - queue = ::arrow::internal::GetOrInsertGenerated( - &state->queues, std::move(part), - [&](const std::string& emplaced_part) { - // lookup in `queues` also failed, - // generate a new WriteQueue - size_t queue_index = state->queues.size() - 1; - - return ::arrow::internal::make_unique<WriteQueue>( - emplaced_part, queue_index, batch->schema()); - }) - ->second.get(); - } - - queue->Push(std::move(batch)); - need_flushed.insert(queue); - } - - // flush all touched WriteQueues - for (auto queue : need_flushed) { - RETURN_NOT_OK(queue->Flush(state->write_options)); - } - return Status::OK(); -} + std::shared_ptr<size_t> counter = std::make_shared<size_t>(0); + std::shared_ptr<Fragment> fragment = std::move(batch.fragment); -Status WriteInternal(const ScanOptions& scan_options, WriteState* state, - ScanTaskVector scan_tasks) { - // Store a mapping from partitions (represened by their formatted partition expressions) - // to a WriteQueue which flushes batches into that partition's output file. In principle - // any thread could produce a batch for any partition, so each task alternates between - // pushing batches and flushing them to disk. - auto task_group = scan_options.TaskGroup(); - - for (const auto& scan_task : scan_tasks) { - task_group->Append([&, scan_task] { - std::function<Status(std::shared_ptr<RecordBatch>)> visitor = - [&](std::shared_ptr<RecordBatch> batch) { - return WriteNextBatch(state, scan_task->fragment(), std::move(batch)); - }; - return ::arrow::internal::RunSynchronously<Future<>>( - [&](Executor* executor) { return scan_task->SafeVisit(executor, visitor); }, - /*use_threads=*/false); + AsyncGenerator<std::shared_ptr<RecordBatch>> partitioned_batch_gen = + [groups, counter, fragment, &write_options, + dataset_writer]() -> Future<std::shared_ptr<RecordBatch>> { + auto index = *counter; + if (index >= groups.batches.size()) { + return AsyncGeneratorEnd<std::shared_ptr<RecordBatch>>(); + } + auto partition_expression = + and_(groups.expressions[index], fragment->partition_expression()); + auto next_batch = groups.batches[index]; + ARROW_ASSIGN_OR_RAISE(std::string destination, + write_options.partitioning->Format(partition_expression)); + (*counter)++; + return dataset_writer->WriteRecordBatch(next_batch, destination).Then([next_batch] { + return next_batch; }); - } - return task_group->Finish(); + }; + + return VisitAsyncGenerator( + std::move(partitioned_batch_gen), + [](const std::shared_ptr<RecordBatch>&) -> Status { return Status::OK(); }); } } // namespace Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_options, std::shared_ptr<Scanner> scanner) { - RETURN_NOT_OK(ValidateBasenameTemplate(write_options.basename_template)); - - // Things we'll un-lazy for the sake of simplicity, with the tradeoff they represent: - // - // - Fragment iteration. Keeping this lazy would allow us to start partitioning/writing - // any fragments we have before waiting for discovery to complete. This isn't - // currently implemented for FileSystemDataset anyway: ARROW-8613 - // - // - ScanTask iteration. Keeping this lazy would save some unnecessary blocking when - // writing Fragments which produce scan tasks slowly. No Fragments do this. - // - // NB: neither of these will have any impact whatsoever on the common case of writing - // an in-memory table to disk. - - ARROW_SUPPRESS_DEPRECATION_WARNING - - // TODO(ARROW-11782/ARROW-12288) Remove calls to Scan() - ARROW_ASSIGN_OR_RAISE(auto scan_task_it, scanner->Scan()); - ARROW_ASSIGN_OR_RAISE(ScanTaskVector scan_tasks, scan_task_it.ToVector()); - - ARROW_UNSUPPRESS_DEPRECATION_WARNING - - WriteState state(write_options); - RETURN_NOT_OK(WriteInternal(*scanner->options(), &state, std::move(scan_tasks))); - - auto task_group = scanner->options()->TaskGroup(); - for (const auto& part_queue : state.queues) { - task_group->Append([&] { - RETURN_NOT_OK(write_options.writer_pre_finish(part_queue.second->writer().get())); - RETURN_NOT_OK(part_queue.second->writer()->Finish()); - return write_options.writer_post_finish(part_queue.second->writer().get()); - }); - } - return task_group->Finish(); + ARROW_ASSIGN_OR_RAISE(auto batch_gen, scanner->ScanBatchesAsync()); + ARROW_ASSIGN_OR_RAISE(auto dataset_writer, DatasetWriter::Make(write_options)); + + AsyncGenerator<std::shared_ptr<int>> queued_batch_gen = + [batch_gen, &dataset_writer, &write_options]() -> Future<std::shared_ptr<int>> { + Future<TaggedRecordBatch> next_batch_fut = batch_gen(); + return next_batch_fut.Then( + [&dataset_writer, &write_options](const TaggedRecordBatch& batch) { + if (IsIterationEnd(batch)) { + return AsyncGeneratorEnd<std::shared_ptr<int>>(); + } + return WriteNextBatch(dataset_writer.get(), batch, write_options).Then([] { + return std::make_shared<int>(0); Review comment: Just to make sure, the shared_ptr<int> is being used purely for the end-of-stream properties here? ########## File path: cpp/src/arrow/dataset/dataset_writer.cc ########## @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/dataset_writer.h" + +#include <list> +#include <mutex> +#include <unordered_map> + +#include "arrow/filesystem/path_util.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/map.h" +#include "arrow/util/string.h" + +namespace arrow { +namespace dataset { + +namespace { + +constexpr util::string_view kIntegerToken = "{i}"; + +class Throttle { + public: + explicit Throttle(uint64_t max_value) : max_value_(max_value) {} + + bool Unthrottled() const { return max_value_ <= 0; } + + Future<> Acquire(uint64_t values) { + if (Unthrottled()) { + return Future<>::MakeFinished(); + } + std::lock_guard<std::mutex> lg(mutex_); + if (values + current_value_ > max_value_) { + in_waiting_ = values; + backpressure_ = Future<>::Make(); + } else { + current_value_ += values; + } + return backpressure_; + } + + void Release(uint64_t values) { + if (Unthrottled()) { + return; + } + Future<> to_complete; + { + std::lock_guard<std::mutex> lg(mutex_); + current_value_ -= values; + if (in_waiting_ > 0 && in_waiting_ + current_value_ <= max_value_) { + in_waiting_ = 0; + to_complete = backpressure_; + } + } + if (to_complete.is_valid()) { + to_complete.MarkFinished(); + } + } + + private: + Future<> backpressure_ = Future<>::MakeFinished(); + uint64_t max_value_; + uint64_t in_waiting_ = 0; + uint64_t current_value_ = 0; + std::mutex mutex_; +}; + +class DatasetWriterFileQueue : public util::AsyncDestroyable { + public: + explicit DatasetWriterFileQueue(const Future<std::shared_ptr<FileWriter>>& writer_fut, + const FileSystemDatasetWriteOptions& options, + std::mutex* visitors_mutex) + : options_(options), visitors_mutex_(visitors_mutex) { + running_task_ = Future<>::Make(); + writer_fut.AddCallback( + [this](const Result<std::shared_ptr<FileWriter>>& maybe_writer) { + if (maybe_writer.ok()) { + writer_ = *maybe_writer; + Flush(); + } else { + Abort(maybe_writer.status()); + } + }); + } + + Future<uint64_t> Push(std::shared_ptr<RecordBatch> batch) { + std::unique_lock<std::mutex> lk(mutex); + write_queue_.push_back(std::move(batch)); + Future<uint64_t> write_future = Future<uint64_t>::Make(); + write_futures_.push_back(write_future); + if (!running_task_.is_valid()) { + running_task_ = Future<>::Make(); + FlushUnlocked(std::move(lk)); + } + return write_future; + } + + Future<> DoDestroy() override { + std::lock_guard<std::mutex> lg(mutex); + if (!running_task_.is_valid()) { + RETURN_NOT_OK(DoFinish()); + return Future<>::MakeFinished(); + } + return running_task_.Then([this] { return DoFinish(); }); + } + + private: + Future<uint64_t> WriteNext() { + // May want to prototype / measure someday pushing the async write down further + return DeferNotOk( + io::default_io_context().executor()->Submit([this]() -> Result<uint64_t> { + DCHECK(running_task_.is_valid()); + std::unique_lock<std::mutex> lk(mutex); + const std::shared_ptr<RecordBatch>& to_write = write_queue_.front(); + Future<uint64_t> on_complete = write_futures_.front(); + uint64_t rows_to_write = to_write->num_rows(); + lk.unlock(); + Status status = writer_->Write(to_write); + lk.lock(); + write_queue_.pop_front(); + write_futures_.pop_front(); + lk.unlock(); + if (!status.ok()) { + on_complete.MarkFinished(status); + } else { + on_complete.MarkFinished(rows_to_write); + } + return rows_to_write; + })); + } + + Status DoFinish() { + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + RETURN_NOT_OK(options_.writer_pre_finish(writer_.get())); + } + RETURN_NOT_OK(writer_->Finish()); + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + return options_.writer_post_finish(writer_.get()); + } + } + + void Abort(Status err) { + std::vector<Future<uint64_t>> futures_to_abort; + Future<> old_running_task = running_task_; + { + std::lock_guard<std::mutex> lg(mutex); + write_queue_.clear(); + futures_to_abort = + std::vector<Future<uint64_t>>(write_futures_.begin(), write_futures_.end()); + write_futures_.clear(); + running_task_ = Future<>(); + } + for (auto& fut : futures_to_abort) { + fut.MarkFinished(err); + } + old_running_task.MarkFinished(std::move(err)); + } + + void Flush() { + std::unique_lock<std::mutex> lk(mutex); + FlushUnlocked(std::move(lk)); + } + + void FlushUnlocked(std::unique_lock<std::mutex> lk) { + if (write_queue_.empty()) { + Future<> old_running_task = running_task_; + running_task_ = Future<>(); + lk.unlock(); + old_running_task.MarkFinished(); + return; + } + WriteNext().AddCallback([this](const Result<uint64_t>& res) { + if (res.ok()) { + Flush(); + } else { + Abort(res.status()); + } + }); + } + + const FileSystemDatasetWriteOptions& options_; + std::mutex* visitors_mutex_; + std::shared_ptr<FileWriter> writer_; + std::mutex mutex; + std::list<std::shared_ptr<RecordBatch>> write_queue_; + std::list<Future<uint64_t>> write_futures_; + Future<> running_task_; +}; + +struct WriteTask { + std::string filename; + uint64_t num_rows; +}; + +class DatasetWriterDirectoryQueue : public util::AsyncDestroyable { + public: + DatasetWriterDirectoryQueue(std::string directory, std::shared_ptr<Schema> schema, + const FileSystemDatasetWriteOptions& write_options, + Throttle* open_files_throttle, std::mutex* visitors_mutex) + : directory_(std::move(directory)), + schema_(std::move(schema)), + write_options_(write_options), + open_files_throttle_(open_files_throttle), + visitors_mutex_(visitors_mutex) {} + + Result<std::shared_ptr<RecordBatch>> NextWritableChunk( + std::shared_ptr<RecordBatch> batch, std::shared_ptr<RecordBatch>* remainder, + bool* will_open_file) const { + DCHECK_GT(batch->num_rows(), 0); + uint64_t rows_available = std::numeric_limits<uint64_t>::max(); + *will_open_file = rows_written_ == 0; + if (write_options_.max_rows_per_file > 0) { + rows_available = write_options_.max_rows_per_file - rows_written_; + } + + std::shared_ptr<RecordBatch> to_queue; + if (rows_available < static_cast<uint64_t>(batch->num_rows())) { + to_queue = batch->Slice(0, static_cast<int64_t>(rows_available)); + *remainder = batch->Slice(static_cast<int64_t>(rows_available)); + } else { + to_queue = std::move(batch); + } + return to_queue; + } + + Future<WriteTask> StartWrite(const std::shared_ptr<RecordBatch>& batch) { + rows_written_ += batch->num_rows(); + WriteTask task{current_filename_, static_cast<uint64_t>(batch->num_rows())}; + if (!latest_open_file_) { + ARROW_ASSIGN_OR_RAISE(latest_open_file_, OpenFileQueue(current_filename_)); + } + return latest_open_file_->Push(batch).Then([task] { return task; }); + } + + Result<std::string> GetNextFilename() { + auto basename = ::arrow::internal::Replace( + write_options_.basename_template, kIntegerToken, std::to_string(file_counter_++)); + if (!basename) { + return Status::Invalid("string interpolation of basename template failed"); + } + + return fs::internal::ConcatAbstractPath(directory_, *basename); + } + + Status FinishCurrentFile() { + if (latest_open_file_) { + latest_open_file_ = nullptr; + } + rows_written_ = 0; + return GetNextFilename().Value(¤t_filename_); + } + + Result<std::shared_ptr<FileWriter>> OpenWriter(const std::string& filename) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<io::OutputStream> out_stream, + write_options_.filesystem->OpenOutputStream(filename)); + return write_options_.format()->MakeWriter(std::move(out_stream), schema_, + write_options_.file_write_options, + {write_options_.filesystem, filename}); + } + + Result<std::shared_ptr<DatasetWriterFileQueue>> OpenFileQueue( + const std::string& filename) { + Future<std::shared_ptr<FileWriter>> file_writer_fut = + init_future_.Then([this, filename] { + ::arrow::internal::Executor* io_executor = + write_options_.filesystem->io_context().executor(); + return DeferNotOk( + io_executor->Submit([this, filename]() { return OpenWriter(filename); })); + }); + auto file_queue = util::MakeSharedAsync<DatasetWriterFileQueue>( + file_writer_fut, write_options_, visitors_mutex_); + RETURN_NOT_OK(task_group_.AddTask( + file_queue->on_closed().Then([this] { open_files_throttle_->Release(1); }))); + return file_queue; + } + + uint64_t rows_written() const { return rows_written_; } + + void PrepareDirectory() { + init_future_ = + DeferNotOk(write_options_.filesystem->io_context().executor()->Submit([this] { + RETURN_NOT_OK(write_options_.filesystem->CreateDir(directory_)); + if (write_options_.existing_data_behavior == kDeleteMatchingPartitions) { + fs::FileSelector selector; + selector.base_dir = directory_; + selector.recursive = true; + return write_options_.filesystem->DeleteFiles(selector); + } + return Status::OK(); + })); + } + + static Result<std::unique_ptr<DatasetWriterDirectoryQueue, + util::DestroyingDeleter<DatasetWriterDirectoryQueue>>> + Make(util::AsyncTaskGroup* task_group, + const FileSystemDatasetWriteOptions& write_options, Throttle* open_files_throttle, + std::shared_ptr<Schema> schema, std::string dir, std::mutex* visitors_mutex) { + auto dir_queue = util::MakeUniqueAsync<DatasetWriterDirectoryQueue>( + std::move(dir), std::move(schema), write_options, open_files_throttle, + visitors_mutex); + RETURN_NOT_OK(task_group->AddTask(dir_queue->on_closed())); + dir_queue->PrepareDirectory(); + ARROW_ASSIGN_OR_RAISE(dir_queue->current_filename_, dir_queue->GetNextFilename()); + // std::move required to make RTools 3.5 mingw compiler happy + return std::move(dir_queue); + } + + Future<> DoDestroy() override { + latest_open_file_.reset(); + return task_group_.WaitForTasksToFinish(); + } + + private: + util::AsyncTaskGroup task_group_; + std::string directory_; + std::shared_ptr<Schema> schema_; + const FileSystemDatasetWriteOptions& write_options_; + Throttle* open_files_throttle_; + std::mutex* visitors_mutex_; + Future<> init_future_; + std::string current_filename_; + std::shared_ptr<DatasetWriterFileQueue> latest_open_file_; + uint64_t rows_written_ = 0; + uint32_t file_counter_ = 0; +}; + +Status ValidateBasenameTemplate(util::string_view basename_template) { + if (basename_template.find(fs::internal::kSep) != util::string_view::npos) { + return Status::Invalid("basename_template contained '/'"); + } + size_t token_start = basename_template.find(kIntegerToken); + if (token_start == util::string_view::npos) { + return Status::Invalid("basename_template did not contain '", kIntegerToken, "'"); + } Review comment: Should we validate that the token doesn't occur multiple times? ########## File path: cpp/src/arrow/dataset/dataset_writer.cc ########## @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/dataset_writer.h" + +#include <list> +#include <mutex> +#include <unordered_map> + +#include "arrow/filesystem/path_util.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/map.h" +#include "arrow/util/string.h" + +namespace arrow { +namespace dataset { + +namespace { + +constexpr util::string_view kIntegerToken = "{i}"; + +class Throttle { + public: + explicit Throttle(uint64_t max_value) : max_value_(max_value) {} + + bool Unthrottled() const { return max_value_ <= 0; } + + Future<> Acquire(uint64_t values) { + if (Unthrottled()) { + return Future<>::MakeFinished(); + } + std::lock_guard<std::mutex> lg(mutex_); + if (values + current_value_ > max_value_) { + in_waiting_ = values; + backpressure_ = Future<>::Make(); + } else { + current_value_ += values; + } + return backpressure_; + } + + void Release(uint64_t values) { + if (Unthrottled()) { + return; + } + Future<> to_complete; + { + std::lock_guard<std::mutex> lg(mutex_); + current_value_ -= values; + if (in_waiting_ > 0 && in_waiting_ + current_value_ <= max_value_) { + in_waiting_ = 0; + to_complete = backpressure_; + } + } + if (to_complete.is_valid()) { + to_complete.MarkFinished(); + } + } + + private: + Future<> backpressure_ = Future<>::MakeFinished(); + uint64_t max_value_; + uint64_t in_waiting_ = 0; + uint64_t current_value_ = 0; + std::mutex mutex_; +}; + +class DatasetWriterFileQueue : public util::AsyncDestroyable { + public: + explicit DatasetWriterFileQueue(const Future<std::shared_ptr<FileWriter>>& writer_fut, + const FileSystemDatasetWriteOptions& options, + std::mutex* visitors_mutex) + : options_(options), visitors_mutex_(visitors_mutex) { + running_task_ = Future<>::Make(); + writer_fut.AddCallback( + [this](const Result<std::shared_ptr<FileWriter>>& maybe_writer) { + if (maybe_writer.ok()) { + writer_ = *maybe_writer; + Flush(); + } else { + Abort(maybe_writer.status()); + } + }); + } + + Future<uint64_t> Push(std::shared_ptr<RecordBatch> batch) { + std::unique_lock<std::mutex> lk(mutex); + write_queue_.push_back(std::move(batch)); + Future<uint64_t> write_future = Future<uint64_t>::Make(); + write_futures_.push_back(write_future); + if (!running_task_.is_valid()) { + running_task_ = Future<>::Make(); + FlushUnlocked(std::move(lk)); + } + return write_future; + } + + Future<> DoDestroy() override { + std::lock_guard<std::mutex> lg(mutex); + if (!running_task_.is_valid()) { + RETURN_NOT_OK(DoFinish()); + return Future<>::MakeFinished(); + } + return running_task_.Then([this] { return DoFinish(); }); + } + + private: + Future<uint64_t> WriteNext() { + // May want to prototype / measure someday pushing the async write down further + return DeferNotOk( + io::default_io_context().executor()->Submit([this]() -> Result<uint64_t> { + DCHECK(running_task_.is_valid()); + std::unique_lock<std::mutex> lk(mutex); + const std::shared_ptr<RecordBatch>& to_write = write_queue_.front(); + Future<uint64_t> on_complete = write_futures_.front(); + uint64_t rows_to_write = to_write->num_rows(); + lk.unlock(); + Status status = writer_->Write(to_write); + lk.lock(); + write_queue_.pop_front(); + write_futures_.pop_front(); + lk.unlock(); + if (!status.ok()) { + on_complete.MarkFinished(status); + } else { + on_complete.MarkFinished(rows_to_write); + } + return rows_to_write; + })); + } + + Status DoFinish() { + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + RETURN_NOT_OK(options_.writer_pre_finish(writer_.get())); + } + RETURN_NOT_OK(writer_->Finish()); + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + return options_.writer_post_finish(writer_.get()); + } + } + + void Abort(Status err) { + std::vector<Future<uint64_t>> futures_to_abort; + Future<> old_running_task = running_task_; + { + std::lock_guard<std::mutex> lg(mutex); + write_queue_.clear(); + futures_to_abort = + std::vector<Future<uint64_t>>(write_futures_.begin(), write_futures_.end()); + write_futures_.clear(); + running_task_ = Future<>(); + } + for (auto& fut : futures_to_abort) { + fut.MarkFinished(err); + } + old_running_task.MarkFinished(std::move(err)); + } + + void Flush() { + std::unique_lock<std::mutex> lk(mutex); + FlushUnlocked(std::move(lk)); + } + + void FlushUnlocked(std::unique_lock<std::mutex> lk) { + if (write_queue_.empty()) { + Future<> old_running_task = running_task_; + running_task_ = Future<>(); + lk.unlock(); + old_running_task.MarkFinished(); + return; + } + WriteNext().AddCallback([this](const Result<uint64_t>& res) { + if (res.ok()) { + Flush(); + } else { + Abort(res.status()); + } + }); + } + + const FileSystemDatasetWriteOptions& options_; + std::mutex* visitors_mutex_; + std::shared_ptr<FileWriter> writer_; + std::mutex mutex; + std::list<std::shared_ptr<RecordBatch>> write_queue_; + std::list<Future<uint64_t>> write_futures_; + Future<> running_task_; +}; + +struct WriteTask { + std::string filename; + uint64_t num_rows; +}; + +class DatasetWriterDirectoryQueue : public util::AsyncDestroyable { + public: + DatasetWriterDirectoryQueue(std::string directory, std::shared_ptr<Schema> schema, + const FileSystemDatasetWriteOptions& write_options, + Throttle* open_files_throttle, std::mutex* visitors_mutex) + : directory_(std::move(directory)), + schema_(std::move(schema)), + write_options_(write_options), + open_files_throttle_(open_files_throttle), + visitors_mutex_(visitors_mutex) {} + + Result<std::shared_ptr<RecordBatch>> NextWritableChunk( + std::shared_ptr<RecordBatch> batch, std::shared_ptr<RecordBatch>* remainder, + bool* will_open_file) const { + DCHECK_GT(batch->num_rows(), 0); + uint64_t rows_available = std::numeric_limits<uint64_t>::max(); + *will_open_file = rows_written_ == 0; + if (write_options_.max_rows_per_file > 0) { + rows_available = write_options_.max_rows_per_file - rows_written_; + } + + std::shared_ptr<RecordBatch> to_queue; + if (rows_available < static_cast<uint64_t>(batch->num_rows())) { + to_queue = batch->Slice(0, static_cast<int64_t>(rows_available)); + *remainder = batch->Slice(static_cast<int64_t>(rows_available)); + } else { + to_queue = std::move(batch); + } + return to_queue; + } + + Future<WriteTask> StartWrite(const std::shared_ptr<RecordBatch>& batch) { + rows_written_ += batch->num_rows(); + WriteTask task{current_filename_, static_cast<uint64_t>(batch->num_rows())}; + if (!latest_open_file_) { + ARROW_ASSIGN_OR_RAISE(latest_open_file_, OpenFileQueue(current_filename_)); + } + return latest_open_file_->Push(batch).Then([task] { return task; }); + } + + Result<std::string> GetNextFilename() { + auto basename = ::arrow::internal::Replace( + write_options_.basename_template, kIntegerToken, std::to_string(file_counter_++)); + if (!basename) { + return Status::Invalid("string interpolation of basename template failed"); + } + + return fs::internal::ConcatAbstractPath(directory_, *basename); + } + + Status FinishCurrentFile() { + if (latest_open_file_) { + latest_open_file_ = nullptr; + } + rows_written_ = 0; + return GetNextFilename().Value(¤t_filename_); + } + + Result<std::shared_ptr<FileWriter>> OpenWriter(const std::string& filename) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<io::OutputStream> out_stream, + write_options_.filesystem->OpenOutputStream(filename)); + return write_options_.format()->MakeWriter(std::move(out_stream), schema_, + write_options_.file_write_options, + {write_options_.filesystem, filename}); + } + + Result<std::shared_ptr<DatasetWriterFileQueue>> OpenFileQueue( + const std::string& filename) { + Future<std::shared_ptr<FileWriter>> file_writer_fut = + init_future_.Then([this, filename] { + ::arrow::internal::Executor* io_executor = + write_options_.filesystem->io_context().executor(); + return DeferNotOk( + io_executor->Submit([this, filename]() { return OpenWriter(filename); })); + }); + auto file_queue = util::MakeSharedAsync<DatasetWriterFileQueue>( + file_writer_fut, write_options_, visitors_mutex_); + RETURN_NOT_OK(task_group_.AddTask( + file_queue->on_closed().Then([this] { open_files_throttle_->Release(1); }))); + return file_queue; + } + + uint64_t rows_written() const { return rows_written_; } + + void PrepareDirectory() { + init_future_ = + DeferNotOk(write_options_.filesystem->io_context().executor()->Submit([this] { + RETURN_NOT_OK(write_options_.filesystem->CreateDir(directory_)); + if (write_options_.existing_data_behavior == kDeleteMatchingPartitions) { + fs::FileSelector selector; + selector.base_dir = directory_; + selector.recursive = true; + return write_options_.filesystem->DeleteFiles(selector); + } + return Status::OK(); + })); + } + + static Result<std::unique_ptr<DatasetWriterDirectoryQueue, + util::DestroyingDeleter<DatasetWriterDirectoryQueue>>> + Make(util::AsyncTaskGroup* task_group, + const FileSystemDatasetWriteOptions& write_options, Throttle* open_files_throttle, + std::shared_ptr<Schema> schema, std::string dir, std::mutex* visitors_mutex) { + auto dir_queue = util::MakeUniqueAsync<DatasetWriterDirectoryQueue>( + std::move(dir), std::move(schema), write_options, open_files_throttle, + visitors_mutex); + RETURN_NOT_OK(task_group->AddTask(dir_queue->on_closed())); + dir_queue->PrepareDirectory(); + ARROW_ASSIGN_OR_RAISE(dir_queue->current_filename_, dir_queue->GetNextFilename()); + // std::move required to make RTools 3.5 mingw compiler happy + return std::move(dir_queue); + } + + Future<> DoDestroy() override { + latest_open_file_.reset(); + return task_group_.WaitForTasksToFinish(); + } + + private: + util::AsyncTaskGroup task_group_; + std::string directory_; + std::shared_ptr<Schema> schema_; + const FileSystemDatasetWriteOptions& write_options_; + Throttle* open_files_throttle_; + std::mutex* visitors_mutex_; + Future<> init_future_; + std::string current_filename_; + std::shared_ptr<DatasetWriterFileQueue> latest_open_file_; + uint64_t rows_written_ = 0; + uint32_t file_counter_ = 0; +}; + +Status ValidateBasenameTemplate(util::string_view basename_template) { + if (basename_template.find(fs::internal::kSep) != util::string_view::npos) { + return Status::Invalid("basename_template contained '/'"); + } + size_t token_start = basename_template.find(kIntegerToken); + if (token_start == util::string_view::npos) { + return Status::Invalid("basename_template did not contain '", kIntegerToken, "'"); + } + return Status::OK(); +} + +Status EnsureDestinationValid(const FileSystemDatasetWriteOptions& options) { + if (options.existing_data_behavior == kError) { + fs::FileSelector selector; + selector.base_dir = options.base_dir; + selector.recursive = true; + Result<std::vector<fs::FileInfo>> maybe_files = + options.filesystem->GetFileInfo(selector); + if (!maybe_files.ok()) { + // If the path doesn't exist then continue + return Status::OK(); + } + if (maybe_files->size() > 1) { + return Status::Invalid( + "Could not write to ", options.base_dir, + " as the directory is not empty and existing_data_behavior is kError"); + } + } + return Status::OK(); +} + +} // namespace + +class DatasetWriter::DatasetWriterImpl : public util::AsyncDestroyable { + public: + DatasetWriterImpl(FileSystemDatasetWriteOptions write_options, uint64_t max_rows_queued) + : write_options_(std::move(write_options)), + rows_in_flight_throttle_(max_rows_queued), + open_files_throttle_(write_options.max_open_files) {} + + Future<> WriteRecordBatch(std::shared_ptr<RecordBatch> batch, + const std::string& directory) { + RETURN_NOT_OK(CheckError()); + if (batch->num_rows() == 0) { + return Future<>::MakeFinished(); + } + if (!directory.empty()) { + auto full_path = + fs::internal::ConcatAbstractPath(write_options_.base_dir, directory); + return DoWriteRecordBatch(std::move(batch), full_path); + } else { + return DoWriteRecordBatch(std::move(batch), write_options_.base_dir); + } + } + + protected: + Status CloseLargestFile() { + std::shared_ptr<DatasetWriterDirectoryQueue> largest = nullptr; + uint64_t largest_num_rows = 0; + for (auto& dir_queue : directory_queues_) { + if (dir_queue.second->rows_written() > largest_num_rows) { + largest_num_rows = dir_queue.second->rows_written(); + largest = dir_queue.second; + } + } + DCHECK_NE(largest, nullptr); + return largest->FinishCurrentFile(); + } + + Future<> DoWriteRecordBatch(std::shared_ptr<RecordBatch> batch, + const std::string& directory) { + ARROW_ASSIGN_OR_RAISE( + auto dir_queue_itr, + ::arrow::internal::GetOrInsertGenerated( + &directory_queues_, directory, [this, &batch](const std::string& dir) { + return DatasetWriterDirectoryQueue::Make( + &task_group_, write_options_, &open_files_throttle_, batch->schema(), + dir, &visitors_mutex_); + })); + std::shared_ptr<DatasetWriterDirectoryQueue> dir_queue = dir_queue_itr->second; + std::vector<Future<WriteTask>> scheduled_writes; + Future<> backpressure; + while (batch) { + // Keep opening new files until batch is done. + std::shared_ptr<RecordBatch> remainder; + bool will_open_file = false; + ARROW_ASSIGN_OR_RAISE(auto next_chunk, dir_queue->NextWritableChunk( + batch, &remainder, &will_open_file)); + + backpressure = rows_in_flight_throttle_.Acquire(next_chunk->num_rows()); + if (!backpressure.is_finished()) { + break; + } + if (will_open_file) { + backpressure = open_files_throttle_.Acquire(1); + if (!backpressure.is_finished()) { + RETURN_NOT_OK(CloseLargestFile()); + break; + } + } + scheduled_writes.push_back(dir_queue->StartWrite(next_chunk)); + batch = std::move(remainder); + if (batch) { + RETURN_NOT_OK(dir_queue->FinishCurrentFile()); + } + } + + for (auto& scheduled_write : scheduled_writes) { + // One of the below callbacks could run immediately and set err_ so we check + // it each time through the loop + RETURN_NOT_OK(CheckError()); Review comment: Don't we want to check this after the callback may have run? ########## File path: cpp/src/arrow/dataset/dataset_writer.cc ########## @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/dataset/dataset_writer.h" + +#include <list> +#include <mutex> +#include <unordered_map> + +#include "arrow/filesystem/path_util.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" +#include "arrow/util/map.h" +#include "arrow/util/string.h" + +namespace arrow { +namespace dataset { + +namespace { + +constexpr util::string_view kIntegerToken = "{i}"; + +class Throttle { + public: + explicit Throttle(uint64_t max_value) : max_value_(max_value) {} + + bool Unthrottled() const { return max_value_ <= 0; } + + Future<> Acquire(uint64_t values) { + if (Unthrottled()) { + return Future<>::MakeFinished(); + } + std::lock_guard<std::mutex> lg(mutex_); + if (values + current_value_ > max_value_) { + in_waiting_ = values; + backpressure_ = Future<>::Make(); + } else { + current_value_ += values; + } + return backpressure_; + } + + void Release(uint64_t values) { + if (Unthrottled()) { + return; + } + Future<> to_complete; + { + std::lock_guard<std::mutex> lg(mutex_); + current_value_ -= values; + if (in_waiting_ > 0 && in_waiting_ + current_value_ <= max_value_) { + in_waiting_ = 0; + to_complete = backpressure_; + } + } + if (to_complete.is_valid()) { + to_complete.MarkFinished(); + } + } + + private: + Future<> backpressure_ = Future<>::MakeFinished(); + uint64_t max_value_; + uint64_t in_waiting_ = 0; + uint64_t current_value_ = 0; + std::mutex mutex_; +}; + +class DatasetWriterFileQueue : public util::AsyncDestroyable { + public: + explicit DatasetWriterFileQueue(const Future<std::shared_ptr<FileWriter>>& writer_fut, + const FileSystemDatasetWriteOptions& options, + std::mutex* visitors_mutex) + : options_(options), visitors_mutex_(visitors_mutex) { + running_task_ = Future<>::Make(); + writer_fut.AddCallback( + [this](const Result<std::shared_ptr<FileWriter>>& maybe_writer) { + if (maybe_writer.ok()) { + writer_ = *maybe_writer; + Flush(); + } else { + Abort(maybe_writer.status()); + } + }); + } + + Future<uint64_t> Push(std::shared_ptr<RecordBatch> batch) { + std::unique_lock<std::mutex> lk(mutex); + write_queue_.push_back(std::move(batch)); + Future<uint64_t> write_future = Future<uint64_t>::Make(); + write_futures_.push_back(write_future); + if (!running_task_.is_valid()) { + running_task_ = Future<>::Make(); + FlushUnlocked(std::move(lk)); + } + return write_future; + } + + Future<> DoDestroy() override { + std::lock_guard<std::mutex> lg(mutex); + if (!running_task_.is_valid()) { + RETURN_NOT_OK(DoFinish()); + return Future<>::MakeFinished(); + } + return running_task_.Then([this] { return DoFinish(); }); + } + + private: + Future<uint64_t> WriteNext() { + // May want to prototype / measure someday pushing the async write down further + return DeferNotOk( + io::default_io_context().executor()->Submit([this]() -> Result<uint64_t> { + DCHECK(running_task_.is_valid()); + std::unique_lock<std::mutex> lk(mutex); + const std::shared_ptr<RecordBatch>& to_write = write_queue_.front(); + Future<uint64_t> on_complete = write_futures_.front(); + uint64_t rows_to_write = to_write->num_rows(); + lk.unlock(); + Status status = writer_->Write(to_write); + lk.lock(); + write_queue_.pop_front(); + write_futures_.pop_front(); + lk.unlock(); + if (!status.ok()) { + on_complete.MarkFinished(status); + } else { + on_complete.MarkFinished(rows_to_write); + } + return rows_to_write; + })); + } + + Status DoFinish() { + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + RETURN_NOT_OK(options_.writer_pre_finish(writer_.get())); + } + RETURN_NOT_OK(writer_->Finish()); + { + std::lock_guard<std::mutex> lg(*visitors_mutex_); + return options_.writer_post_finish(writer_.get()); + } + } + + void Abort(Status err) { + std::vector<Future<uint64_t>> futures_to_abort; + Future<> old_running_task = running_task_; + { + std::lock_guard<std::mutex> lg(mutex); + write_queue_.clear(); + futures_to_abort = + std::vector<Future<uint64_t>>(write_futures_.begin(), write_futures_.end()); + write_futures_.clear(); + running_task_ = Future<>(); + } + for (auto& fut : futures_to_abort) { + fut.MarkFinished(err); + } + old_running_task.MarkFinished(std::move(err)); + } + + void Flush() { + std::unique_lock<std::mutex> lk(mutex); + FlushUnlocked(std::move(lk)); + } + + void FlushUnlocked(std::unique_lock<std::mutex> lk) { + if (write_queue_.empty()) { + Future<> old_running_task = running_task_; + running_task_ = Future<>(); + lk.unlock(); + old_running_task.MarkFinished(); + return; + } + WriteNext().AddCallback([this](const Result<uint64_t>& res) { + if (res.ok()) { + Flush(); + } else { + Abort(res.status()); + } + }); + } + + const FileSystemDatasetWriteOptions& options_; + std::mutex* visitors_mutex_; + std::shared_ptr<FileWriter> writer_; + std::mutex mutex; + std::list<std::shared_ptr<RecordBatch>> write_queue_; + std::list<Future<uint64_t>> write_futures_; + Future<> running_task_; +}; + +struct WriteTask { + std::string filename; + uint64_t num_rows; +}; + +class DatasetWriterDirectoryQueue : public util::AsyncDestroyable { + public: + DatasetWriterDirectoryQueue(std::string directory, std::shared_ptr<Schema> schema, + const FileSystemDatasetWriteOptions& write_options, + Throttle* open_files_throttle, std::mutex* visitors_mutex) + : directory_(std::move(directory)), + schema_(std::move(schema)), + write_options_(write_options), + open_files_throttle_(open_files_throttle), + visitors_mutex_(visitors_mutex) {} + + Result<std::shared_ptr<RecordBatch>> NextWritableChunk( + std::shared_ptr<RecordBatch> batch, std::shared_ptr<RecordBatch>* remainder, + bool* will_open_file) const { + DCHECK_GT(batch->num_rows(), 0); + uint64_t rows_available = std::numeric_limits<uint64_t>::max(); + *will_open_file = rows_written_ == 0; + if (write_options_.max_rows_per_file > 0) { + rows_available = write_options_.max_rows_per_file - rows_written_; + } + + std::shared_ptr<RecordBatch> to_queue; + if (rows_available < static_cast<uint64_t>(batch->num_rows())) { + to_queue = batch->Slice(0, static_cast<int64_t>(rows_available)); + *remainder = batch->Slice(static_cast<int64_t>(rows_available)); + } else { + to_queue = std::move(batch); + } + return to_queue; + } + + Future<WriteTask> StartWrite(const std::shared_ptr<RecordBatch>& batch) { + rows_written_ += batch->num_rows(); + WriteTask task{current_filename_, static_cast<uint64_t>(batch->num_rows())}; + if (!latest_open_file_) { + ARROW_ASSIGN_OR_RAISE(latest_open_file_, OpenFileQueue(current_filename_)); + } + return latest_open_file_->Push(batch).Then([task] { return task; }); + } + + Result<std::string> GetNextFilename() { + auto basename = ::arrow::internal::Replace( + write_options_.basename_template, kIntegerToken, std::to_string(file_counter_++)); + if (!basename) { + return Status::Invalid("string interpolation of basename template failed"); + } + + return fs::internal::ConcatAbstractPath(directory_, *basename); + } + + Status FinishCurrentFile() { + if (latest_open_file_) { + latest_open_file_ = nullptr; + } + rows_written_ = 0; + return GetNextFilename().Value(¤t_filename_); + } + + Result<std::shared_ptr<FileWriter>> OpenWriter(const std::string& filename) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<io::OutputStream> out_stream, + write_options_.filesystem->OpenOutputStream(filename)); + return write_options_.format()->MakeWriter(std::move(out_stream), schema_, + write_options_.file_write_options, + {write_options_.filesystem, filename}); + } + + Result<std::shared_ptr<DatasetWriterFileQueue>> OpenFileQueue( + const std::string& filename) { + Future<std::shared_ptr<FileWriter>> file_writer_fut = + init_future_.Then([this, filename] { + ::arrow::internal::Executor* io_executor = + write_options_.filesystem->io_context().executor(); + return DeferNotOk( + io_executor->Submit([this, filename]() { return OpenWriter(filename); })); + }); + auto file_queue = util::MakeSharedAsync<DatasetWriterFileQueue>( + file_writer_fut, write_options_, visitors_mutex_); + RETURN_NOT_OK(task_group_.AddTask( + file_queue->on_closed().Then([this] { open_files_throttle_->Release(1); }))); + return file_queue; + } + + uint64_t rows_written() const { return rows_written_; } + + void PrepareDirectory() { + init_future_ = + DeferNotOk(write_options_.filesystem->io_context().executor()->Submit([this] { + RETURN_NOT_OK(write_options_.filesystem->CreateDir(directory_)); + if (write_options_.existing_data_behavior == kDeleteMatchingPartitions) { + fs::FileSelector selector; + selector.base_dir = directory_; + selector.recursive = true; + return write_options_.filesystem->DeleteFiles(selector); + } + return Status::OK(); + })); + } + + static Result<std::unique_ptr<DatasetWriterDirectoryQueue, + util::DestroyingDeleter<DatasetWriterDirectoryQueue>>> + Make(util::AsyncTaskGroup* task_group, + const FileSystemDatasetWriteOptions& write_options, Throttle* open_files_throttle, + std::shared_ptr<Schema> schema, std::string dir, std::mutex* visitors_mutex) { + auto dir_queue = util::MakeUniqueAsync<DatasetWriterDirectoryQueue>( + std::move(dir), std::move(schema), write_options, open_files_throttle, + visitors_mutex); + RETURN_NOT_OK(task_group->AddTask(dir_queue->on_closed())); + dir_queue->PrepareDirectory(); + ARROW_ASSIGN_OR_RAISE(dir_queue->current_filename_, dir_queue->GetNextFilename()); + // std::move required to make RTools 3.5 mingw compiler happy + return std::move(dir_queue); + } + + Future<> DoDestroy() override { + latest_open_file_.reset(); + return task_group_.WaitForTasksToFinish(); + } + + private: + util::AsyncTaskGroup task_group_; + std::string directory_; + std::shared_ptr<Schema> schema_; + const FileSystemDatasetWriteOptions& write_options_; + Throttle* open_files_throttle_; + std::mutex* visitors_mutex_; + Future<> init_future_; + std::string current_filename_; + std::shared_ptr<DatasetWriterFileQueue> latest_open_file_; + uint64_t rows_written_ = 0; + uint32_t file_counter_ = 0; +}; + +Status ValidateBasenameTemplate(util::string_view basename_template) { + if (basename_template.find(fs::internal::kSep) != util::string_view::npos) { + return Status::Invalid("basename_template contained '/'"); + } + size_t token_start = basename_template.find(kIntegerToken); + if (token_start == util::string_view::npos) { + return Status::Invalid("basename_template did not contain '", kIntegerToken, "'"); + } Review comment: (Though I suppose it's not a big deal.) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org