This is an automated email from the ASF dual-hosted git repository.
bkietz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new e94ad8c ARROW-10183: [C++] Apply composable futures to CSV
e94ad8c is described below
commit e94ad8c6432b56dca551ec19fc6dfe60738432f1
Author: Weston Pace <[email protected]>
AuthorDate: Tue Feb 16 09:45:34 2021 -0500
ARROW-10183: [C++] Apply composable futures to CSV
Closes #9095 from westonpace/feature/arrow-10183-2
Lead-authored-by: Weston Pace <[email protected]>
Co-authored-by: Benjamin Kietzman <[email protected]>
Signed-off-by: Benjamin Kietzman <[email protected]>
---
c_glib/arrow-glib/reader.cpp | 1 +
cpp/examples/minimal_build/example.cc | 1 +
cpp/src/arrow/CMakeLists.txt | 1 -
cpp/src/arrow/csv/CMakeLists.txt | 3 +-
cpp/src/arrow/csv/column_decoder.cc | 2 +-
cpp/src/arrow/csv/reader.cc | 309 +++++++++++++--------
cpp/src/arrow/csv/reader.h | 5 +
cpp/src/arrow/csv/reader_test.cc | 156 +++++++++++
cpp/src/arrow/csv/test_common.cc | 54 ++++
cpp/src/arrow/csv/test_common.h | 3 +
cpp/src/arrow/json/reader.cc | 1 +
cpp/src/arrow/result.h | 4 +-
cpp/src/arrow/testing/gtest_util.h | 46 +++-
cpp/src/arrow/util/async_generator.h | 388 +++++++++++++++++++++++++++
cpp/src/arrow/util/future.cc | 14 +
cpp/src/arrow/util/future.h | 171 ++++++++++--
cpp/src/arrow/util/future_test.cc | 488 +++++++++++++++++++++++++---------
cpp/src/arrow/util/iterator.cc | 175 ------------
cpp/src/arrow/util/iterator.h | 235 ++++++++--------
cpp/src/arrow/util/iterator_test.cc | 452 +++++++++++++++++++++++++++++--
cpp/src/arrow/util/task_group.cc | 30 +++
cpp/src/arrow/util/task_group.h | 14 +
cpp/src/arrow/util/task_group_test.cc | 91 +++++++
cpp/src/arrow/util/thread_pool.h | 22 ++
docs/source/cpp/csv.rst | 2 +
python/pyarrow/_csv.pyx | 3 +-
python/pyarrow/includes/libarrow.pxd | 5 +-
r/src/csv.cpp | 5 +-
28 files changed, 2109 insertions(+), 572 deletions(-)
diff --git a/c_glib/arrow-glib/reader.cpp b/c_glib/arrow-glib/reader.cpp
index c308227..17100e7 100644
--- a/c_glib/arrow-glib/reader.cpp
+++ b/c_glib/arrow-glib/reader.cpp
@@ -1592,6 +1592,7 @@ garrow_csv_reader_new(GArrowInputStream *input,
auto arrow_reader =
arrow::csv::TableReader::Make(arrow::default_memory_pool(),
+ arrow::io::AsyncContext(),
arrow_input,
read_options,
parse_options,
diff --git a/cpp/examples/minimal_build/example.cc
b/cpp/examples/minimal_build/example.cc
index 4b6acd2..8f58de5 100644
--- a/cpp/examples/minimal_build/example.cc
+++ b/cpp/examples/minimal_build/example.cc
@@ -39,6 +39,7 @@ Status RunMain(int argc, char** argv) {
ARROW_ASSIGN_OR_RAISE(
auto csv_reader,
arrow::csv::TableReader::Make(arrow::default_memory_pool(),
+ arrow::io::AsyncContext(),
input_file,
arrow::csv::ReadOptions::Defaults(),
arrow::csv::ParseOptions::Defaults(),
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 1e93cf9..4403def 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -189,7 +189,6 @@ set(ARROW_SRCS
util/future.cc
util/int_util.cc
util/io_util.cc
- util/iterator.cc
util/logging.cc
util/key_value_metadata.cc
util/memory.cc
diff --git a/cpp/src/arrow/csv/CMakeLists.txt b/cpp/src/arrow/csv/CMakeLists.txt
index 84b1a10..2766cfd 100644
--- a/cpp/src/arrow/csv/CMakeLists.txt
+++ b/cpp/src/arrow/csv/CMakeLists.txt
@@ -21,7 +21,8 @@ add_arrow_test(csv-test
column_builder_test.cc
column_decoder_test.cc
converter_test.cc
- parser_test.cc)
+ parser_test.cc
+ reader_test.cc)
add_arrow_benchmark(converter_benchmark PREFIX "arrow-csv")
add_arrow_benchmark(parser_benchmark PREFIX "arrow-csv")
diff --git a/cpp/src/arrow/csv/column_decoder.cc
b/cpp/src/arrow/csv/column_decoder.cc
index c57477e..1dd13bc 100644
--- a/cpp/src/arrow/csv/column_decoder.cc
+++ b/cpp/src/arrow/csv/column_decoder.cc
@@ -84,7 +84,7 @@ class ConcreteColumnDecoder : public ColumnDecoder {
auto chunk_index = next_chunk_++;
WaitForChunkUnlocked(chunk_index);
// Move Future to avoid keeping chunk alive
- return std::move(chunks_[chunk_index]).result();
+ return chunks_[chunk_index].MoveResult();
}
protected:
diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc
index cf5047a..f0fa1f2 100644
--- a/cpp/src/arrow/csv/reader.cc
+++ b/cpp/src/arrow/csv/reader.cc
@@ -40,6 +40,8 @@
#include "arrow/status.h"
#include "arrow/table.h"
#include "arrow/type.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/future.h"
#include "arrow/util/iterator.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
@@ -60,8 +62,7 @@ class InputStream;
namespace csv {
-using internal::GetCpuThreadPool;
-using internal::ThreadPool;
+using internal::Executor;
struct ConversionSchema {
struct Column {
@@ -94,20 +95,24 @@ struct ConversionSchema {
// An iterator of Buffers that makes sure there is no straddling CRLF sequence.
class CSVBufferIterator {
public:
- explicit CSVBufferIterator(Iterator<std::shared_ptr<Buffer>> buffer_iterator)
- : buffer_iterator_(std::move(buffer_iterator)) {}
-
static Iterator<std::shared_ptr<Buffer>> Make(
Iterator<std::shared_ptr<Buffer>> buffer_iterator) {
- CSVBufferIterator it(std::move(buffer_iterator));
- return Iterator<std::shared_ptr<Buffer>>(std::move(it));
+ Transformer<std::shared_ptr<Buffer>, std::shared_ptr<Buffer>> fn =
+ CSVBufferIterator();
+ return MakeTransformedIterator(std::move(buffer_iterator), fn);
+ }
+
+ static AsyncGenerator<std::shared_ptr<Buffer>> MakeAsync(
+ AsyncGenerator<std::shared_ptr<Buffer>> buffer_iterator) {
+ Transformer<std::shared_ptr<Buffer>, std::shared_ptr<Buffer>> fn =
+ CSVBufferIterator();
+ return MakeAsyncGenerator(std::move(buffer_iterator), fn);
}
- Result<std::shared_ptr<Buffer>> Next() {
- ARROW_ASSIGN_OR_RAISE(auto buf, buffer_iterator_.Next());
+ Result<TransformFlow<std::shared_ptr<Buffer>>>
operator()(std::shared_ptr<Buffer> buf) {
if (buf == nullptr) {
// EOF
- return nullptr;
+ return TransformFinish();
}
int64_t offset = 0;
@@ -127,14 +132,13 @@ class CSVBufferIterator {
buf = SliceBuffer(buf, offset);
if (buf->size() == 0) {
// EOF
- return nullptr;
+ return TransformFinish();
} else {
- return buf;
+ return TransformYield(buf);
}
}
protected:
- Iterator<std::shared_ptr<Buffer>> buffer_iterator_;
bool first_buffer_ = true;
// Whether there was a trailing CR at the end of last received buffer
bool trailing_cr_ = false;
@@ -150,20 +154,36 @@ struct CSVBlock {
std::function<Status(int64_t)> consume_bytes;
};
+} // namespace csv
+
+template <>
+struct IterationTraits<csv::CSVBlock> {
+ static csv::CSVBlock End() { return csv::CSVBlock{{}, {}, {}, -1, true, {}};
}
+};
+
+namespace csv {
+
+// The == operator must be defined to be used as T in Iterator<T>
+bool operator==(const CSVBlock& left, const CSVBlock& right) {
+ return left.block_index == right.block_index;
+}
+bool operator!=(const CSVBlock& left, const CSVBlock& right) {
+ return left.block_index != right.block_index;
+}
+
+// This is a callable that can be used to transform an iterator. The source
iterator
+// will contain buffers of data and the output iterator will contain delimited
CSV
+// blocks. util::optional is used so that there is an end token (required by
the
+// iterator APIs (e.g. Visit)) even though an empty optional is never used in
this code.
class BlockReader {
public:
- BlockReader(std::unique_ptr<Chunker> chunker,
- Iterator<std::shared_ptr<Buffer>> buffer_iterator,
- std::shared_ptr<Buffer> first_buffer)
+ BlockReader(std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer>
first_buffer)
: chunker_(std::move(chunker)),
- buffer_iterator_(std::move(buffer_iterator)),
partial_(std::make_shared<Buffer>("")),
buffer_(std::move(first_buffer)) {}
protected:
std::unique_ptr<Chunker> chunker_;
- Iterator<std::shared_ptr<Buffer>> buffer_iterator_;
-
std::shared_ptr<Buffer> partial_, buffer_;
int64_t block_index_ = 0;
// Whether there was a trailing CR at the end of last received buffer
@@ -177,14 +197,25 @@ class SerialBlockReader : public BlockReader {
public:
using BlockReader::BlockReader;
- Result<arrow::util::optional<CSVBlock>> Next() {
+ static Iterator<CSVBlock> MakeIterator(
+ Iterator<std::shared_ptr<Buffer>> buffer_iterator,
std::unique_ptr<Chunker> chunker,
+ std::shared_ptr<Buffer> first_buffer) {
+ auto block_reader =
+ std::make_shared<SerialBlockReader>(std::move(chunker), first_buffer);
+ // Wrap shared pointer in callable
+ Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn =
+ [block_reader](std::shared_ptr<Buffer> buf) {
+ return (*block_reader)(std::move(buf));
+ };
+ return MakeTransformedIterator(std::move(buffer_iterator),
block_reader_fn);
+ }
+
+ Result<TransformFlow<CSVBlock>> operator()(std::shared_ptr<Buffer>
next_buffer) {
if (buffer_ == nullptr) {
- // EOF
- return util::optional<CSVBlock>();
+ return TransformFinish();
}
- std::shared_ptr<Buffer> next_buffer, completion;
- ARROW_ASSIGN_OR_RAISE(next_buffer, buffer_iterator_.Next());
+ std::shared_ptr<Buffer> completion;
bool is_final = (next_buffer == nullptr);
if (is_final) {
@@ -210,8 +241,9 @@ class SerialBlockReader : public BlockReader {
return Status::OK();
};
- return CSVBlock{partial_, completion, buffer_,
- block_index_++, is_final, std::move(consume_bytes)};
+ return TransformYield<CSVBlock>(CSVBlock{partial_, completion, buffer_,
+ block_index_++, is_final,
+ std::move(consume_bytes)});
}
};
@@ -220,14 +252,35 @@ class ThreadedBlockReader : public BlockReader {
public:
using BlockReader::BlockReader;
- Result<arrow::util::optional<CSVBlock>> Next() {
+ static Iterator<CSVBlock> MakeIterator(
+ Iterator<std::shared_ptr<Buffer>> buffer_iterator,
std::unique_ptr<Chunker> chunker,
+ std::shared_ptr<Buffer> first_buffer) {
+ auto block_reader =
+ std::make_shared<ThreadedBlockReader>(std::move(chunker),
first_buffer);
+ // Wrap shared pointer in callable
+ Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn =
+ [block_reader](std::shared_ptr<Buffer> next) { return
(*block_reader)(next); };
+ return MakeTransformedIterator(std::move(buffer_iterator),
block_reader_fn);
+ }
+
+ static AsyncGenerator<CSVBlock> MakeAsyncIterator(
+ AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator,
+ std::unique_ptr<Chunker> chunker, std::shared_ptr<Buffer> first_buffer) {
+ auto block_reader =
+ std::make_shared<ThreadedBlockReader>(std::move(chunker),
first_buffer);
+ // Wrap shared pointer in callable
+ Transformer<std::shared_ptr<Buffer>, CSVBlock> block_reader_fn =
+ [block_reader](std::shared_ptr<Buffer> next) { return
(*block_reader)(next); };
+ return MakeAsyncGenerator(std::move(buffer_generator), block_reader_fn);
+ }
+
+ Result<TransformFlow<CSVBlock>> operator()(std::shared_ptr<Buffer>
next_buffer) {
if (buffer_ == nullptr) {
// EOF
- return util::optional<CSVBlock>();
+ return TransformFinish();
}
- std::shared_ptr<Buffer> next_buffer, whole, completion, next_partial;
- ARROW_ASSIGN_OR_RAISE(next_buffer, buffer_iterator_.Next());
+ std::shared_ptr<Buffer> whole, completion, next_partial;
bool is_final = (next_buffer == nullptr);
auto current_partial = std::move(partial_);
@@ -252,7 +305,8 @@ class ThreadedBlockReader : public BlockReader {
partial_ = std::move(next_partial);
buffer_ = std::move(next_buffer);
- return CSVBlock{current_partial, completion, whole, block_index_++,
is_final, {}};
+ return TransformYield<CSVBlock>(
+ CSVBlock{current_partial, completion, whole, block_index_++, is_final,
{}});
}
};
@@ -449,7 +503,6 @@ class ReaderMixin {
ConversionSchema conversion_schema_;
std::shared_ptr<io::InputStream> input_;
- Iterator<std::shared_ptr<Buffer>> buffer_iterator_;
std::shared_ptr<internal::TaskGroup> task_group_;
};
@@ -462,6 +515,10 @@ class BaseTableReader : public ReaderMixin, public
csv::TableReader {
virtual Status Init() = 0;
+ Future<std::shared_ptr<Table>> ReadAsync() override {
+ return Future<std::shared_ptr<Table>>::MakeFinished(Read());
+ }
+
protected:
// Make column builders from conversion schema
Status MakeColumnBuilders() {
@@ -624,6 +681,7 @@ class BaseStreamingReader : public ReaderMixin, public
csv::StreamingReader {
std::vector<std::shared_ptr<ColumnDecoder>> column_decoders_;
std::shared_ptr<Schema> schema_;
std::shared_ptr<RecordBatch> pending_batch_;
+ Iterator<std::shared_ptr<Buffer>> buffer_iterator_;
bool eof_ = false;
};
@@ -656,7 +714,7 @@ class SerialStreamingReader : public BaseStreamingReader {
if (eof_) {
return nullptr;
}
- if (block_reader_ == nullptr) {
+ if (!block_iterator_) {
Status st = SetupReader();
if (!st.ok()) {
// Can't setup reader => bail out
@@ -670,18 +728,18 @@ class SerialStreamingReader : public BaseStreamingReader {
}
if (!source_eof_) {
- ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_reader_->Next());
- if (maybe_block.has_value()) {
- last_block_index_ = maybe_block->block_index;
- auto maybe_parsed = ParseAndInsert(maybe_block->partial,
maybe_block->completion,
- maybe_block->buffer,
maybe_block->block_index,
- maybe_block->is_final);
+ ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator_.Next());
+ if (maybe_block != IterationTraits<CSVBlock>::End()) {
+ last_block_index_ = maybe_block.block_index;
+ auto maybe_parsed = ParseAndInsert(maybe_block.partial,
maybe_block.completion,
+ maybe_block.buffer,
maybe_block.block_index,
+ maybe_block.is_final);
if (!maybe_parsed.ok()) {
// Parse error => bail out
eof_ = true;
return maybe_parsed.status();
}
- RETURN_NOT_OK(maybe_block->consume_bytes(*maybe_parsed));
+ RETURN_NOT_OK(maybe_block.consume_bytes(*maybe_parsed));
} else {
source_eof_ = true;
for (auto& decoder : column_decoders_) {
@@ -705,15 +763,15 @@ class SerialStreamingReader : public BaseStreamingReader {
RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer));
RETURN_NOT_OK(MakeColumnDecoders());
- block_reader_ =
std::make_shared<SerialBlockReader>(MakeChunker(parse_options_),
-
std::move(buffer_iterator_),
-
std::move(first_buffer));
+ block_iterator_ =
SerialBlockReader::MakeIterator(std::move(buffer_iterator_),
+
MakeChunker(parse_options_),
+ std::move(first_buffer));
return Status::OK();
}
bool source_eof_ = false;
int64_t last_block_index_ = 0;
- std::shared_ptr<SerialBlockReader> block_reader_;
+ Iterator<CSVBlock> block_iterator_;
};
/////////////////////////////////////////////////////////////////////////
@@ -746,41 +804,46 @@ class SerialTableReader : public BaseTableReader {
RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer));
RETURN_NOT_OK(MakeColumnBuilders());
- SerialBlockReader block_reader(MakeChunker(parse_options_),
- std::move(buffer_iterator_),
std::move(first_buffer));
-
+ auto block_iterator =
SerialBlockReader::MakeIterator(std::move(buffer_iterator_),
+
MakeChunker(parse_options_),
+
std::move(first_buffer));
while (true) {
- ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_reader.Next());
- if (!maybe_block.has_value()) {
+ ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator.Next());
+ if (maybe_block == IterationTraits<CSVBlock>::End()) {
// EOF
break;
}
- ARROW_ASSIGN_OR_RAISE(int64_t parsed_bytes,
- ParseAndInsert(maybe_block->partial,
maybe_block->completion,
- maybe_block->buffer,
maybe_block->block_index,
- maybe_block->is_final));
- RETURN_NOT_OK(maybe_block->consume_bytes(parsed_bytes));
+ ARROW_ASSIGN_OR_RAISE(
+ int64_t parsed_bytes,
+ ParseAndInsert(maybe_block.partial, maybe_block.completion,
maybe_block.buffer,
+ maybe_block.block_index, maybe_block.is_final));
+ RETURN_NOT_OK(maybe_block.consume_bytes(parsed_bytes));
}
// Finish conversion, create schema and table
RETURN_NOT_OK(task_group_->Finish());
return MakeTable();
}
-};
-/////////////////////////////////////////////////////////////////////////
-// Parallel TableReader implementation
+ protected:
+ Iterator<std::shared_ptr<Buffer>> buffer_iterator_;
+};
-class ThreadedTableReader : public BaseTableReader {
+class AsyncThreadedTableReader
+ : public BaseTableReader,
+ public std::enable_shared_from_this<AsyncThreadedTableReader> {
public:
using BaseTableReader::BaseTableReader;
- ThreadedTableReader(MemoryPool* pool, std::shared_ptr<io::InputStream> input,
- const ReadOptions& read_options, const ParseOptions&
parse_options,
- const ConvertOptions& convert_options, ThreadPool*
thread_pool)
+ AsyncThreadedTableReader(MemoryPool* pool, std::shared_ptr<io::InputStream>
input,
+ const ReadOptions& read_options,
+ const ParseOptions& parse_options,
+ const ConvertOptions& convert_options, Executor*
cpu_executor,
+ Executor* io_executor)
: BaseTableReader(pool, input, read_options, parse_options,
convert_options),
- thread_pool_(thread_pool) {}
+ cpu_executor_(cpu_executor),
+ io_executor_(io_executor) {}
- ~ThreadedTableReader() override {
+ ~AsyncThreadedTableReader() override {
if (task_group_) {
// In case of error, make sure all pending tasks are finished before
// we start destroying BaseTableReader members
@@ -792,65 +855,98 @@ class ThreadedTableReader : public BaseTableReader {
ARROW_ASSIGN_OR_RAISE(auto istream_it,
io::MakeInputStreamIterator(input_,
read_options_.block_size));
- int32_t block_queue_size = thread_pool_->GetCapacity();
- ARROW_ASSIGN_OR_RAISE(auto rh_it,
- MakeReadaheadIterator(std::move(istream_it),
block_queue_size));
- buffer_iterator_ = CSVBufferIterator::Make(std::move(rh_it));
+ // TODO: use io_executor_ here, see ARROW-11590
+ ARROW_ASSIGN_OR_RAISE(auto background_executor,
internal::ThreadPool::Make(1));
+ ARROW_ASSIGN_OR_RAISE(auto bg_it,
MakeBackgroundGenerator(std::move(istream_it),
+
background_executor.get()));
+ AsyncGenerator<std::shared_ptr<Buffer>> wrapped_bg_it =
+ [bg_it, background_executor]() { return bg_it(); };
+
+ auto transferred_it =
+ MakeTransferredGenerator(std::move(wrapped_bg_it), cpu_executor_);
+
+ int32_t block_queue_size = cpu_executor_->GetCapacity();
+ auto rh_it = MakeReadaheadGenerator(std::move(transferred_it),
block_queue_size);
+ buffer_generator_ = CSVBufferIterator::MakeAsync(std::move(rh_it));
return Status::OK();
}
- Result<std::shared_ptr<Table>> Read() override {
- task_group_ = internal::TaskGroup::MakeThreaded(thread_pool_);
+ Result<std::shared_ptr<Table>> Read() override { return
ReadAsync().result(); }
+
+ Future<std::shared_ptr<Table>> ReadAsync() override {
+ task_group_ = internal::TaskGroup::MakeThreaded(cpu_executor_);
+
+ auto self = shared_from_this();
+ return ProcessFirstBuffer().Then([self](std::shared_ptr<Buffer>
first_buffer) {
+ auto block_generator = ThreadedBlockReader::MakeAsyncIterator(
+ self->buffer_generator_, MakeChunker(self->parse_options_),
+ std::move(first_buffer));
+
+ std::function<Status(CSVBlock)> block_visitor =
+ [self](CSVBlock maybe_block) -> Status {
+ // The logic in VisitAsyncGenerator ensures that we will never be
+ // passed an empty block (visit does not call with the end token) so
+ // we can be assured maybe_block has a value.
+ DCHECK_GE(maybe_block.block_index, 0);
+ DCHECK(!maybe_block.consume_bytes);
+
+ // Launch parse task
+ self->task_group_->Append([self, maybe_block] {
+ return self
+ ->ParseAndInsert(maybe_block.partial, maybe_block.completion,
+ maybe_block.buffer, maybe_block.block_index,
+ maybe_block.is_final)
+ .status();
+ });
+ return Status::OK();
+ };
+
+ return VisitAsyncGenerator(std::move(block_generator), block_visitor)
+ .Then([self](...) -> Future<> {
+ // By this point we've added all top level tasks so it is safe to
call
+ // FinishAsync
+ return self->task_group_->FinishAsync();
+ })
+ .Then([self](...) -> Result<std::shared_ptr<Table>> {
+ // Finish conversion, create schema and table
+ return self->MakeTable();
+ });
+ });
+ }
+ protected:
+ Future<std::shared_ptr<Buffer>> ProcessFirstBuffer() {
// First block
- ARROW_ASSIGN_OR_RAISE(auto first_buffer, buffer_iterator_.Next());
- if (first_buffer == nullptr) {
- return Status::Invalid("Empty CSV file");
- }
- RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer));
- RETURN_NOT_OK(MakeColumnBuilders());
-
- ThreadedBlockReader block_reader(MakeChunker(parse_options_),
- std::move(buffer_iterator_),
- std::move(first_buffer));
-
- while (true) {
- ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_reader.Next());
- if (!maybe_block.has_value()) {
- // EOF
- break;
+ auto first_buffer_future = buffer_generator_();
+ return first_buffer_future.Then([this](const std::shared_ptr<Buffer>&
first_buffer)
+ -> Result<std::shared_ptr<Buffer>> {
+ if (first_buffer == nullptr) {
+ return Status::Invalid("Empty CSV file");
}
- DCHECK(!maybe_block->consume_bytes);
-
- // Launch parse task
- task_group_->Append([this, maybe_block] {
- return ParseAndInsert(maybe_block->partial, maybe_block->completion,
- maybe_block->buffer, maybe_block->block_index,
- maybe_block->is_final)
- .status();
- });
- }
-
- // Finish conversion, create schema and table
- RETURN_NOT_OK(task_group_->Finish());
- return MakeTable();
+ std::shared_ptr<Buffer> first_buffer_processed;
+ RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer_processed));
+ RETURN_NOT_OK(MakeColumnBuilders());
+ return first_buffer_processed;
+ });
}
- protected:
- ThreadPool* thread_pool_;
+ Executor* cpu_executor_;
+ Executor* io_executor_;
+ AsyncGenerator<std::shared_ptr<Buffer>> buffer_generator_;
};
/////////////////////////////////////////////////////////////////////////
// Factory functions
Result<std::shared_ptr<TableReader>> TableReader::Make(
- MemoryPool* pool, std::shared_ptr<io::InputStream> input,
- const ReadOptions& read_options, const ParseOptions& parse_options,
- const ConvertOptions& convert_options) {
+ MemoryPool* pool, io::AsyncContext async_context,
+ std::shared_ptr<io::InputStream> input, const ReadOptions& read_options,
+ const ParseOptions& parse_options, const ConvertOptions& convert_options) {
std::shared_ptr<BaseTableReader> reader;
if (read_options.use_threads) {
- reader = std::make_shared<ThreadedTableReader>(
- pool, input, read_options, parse_options, convert_options,
GetCpuThreadPool());
+ reader = std::make_shared<AsyncThreadedTableReader>(
+ pool, input, read_options, parse_options, convert_options,
async_context.executor,
+ internal::GetCpuThreadPool());
} else {
reader = std::make_shared<SerialTableReader>(pool, input, read_options,
parse_options,
convert_options);
@@ -871,4 +967,5 @@ Result<std::shared_ptr<StreamingReader>>
StreamingReader::Make(
}
} // namespace csv
+
} // namespace arrow
diff --git a/cpp/src/arrow/csv/reader.h b/cpp/src/arrow/csv/reader.h
index 652cedc..c361fbd 100644
--- a/cpp/src/arrow/csv/reader.h
+++ b/cpp/src/arrow/csv/reader.h
@@ -20,10 +20,12 @@
#include <memory>
#include "arrow/csv/options.h" // IWYU pragma: keep
+#include "arrow/io/interfaces.h"
#include "arrow/record_batch.h"
#include "arrow/result.h"
#include "arrow/type.h"
#include "arrow/type_fwd.h"
+#include "arrow/util/future.h"
#include "arrow/util/visibility.h"
namespace arrow {
@@ -40,9 +42,12 @@ class ARROW_EXPORT TableReader {
/// Read the entire CSV file and convert it to a Arrow Table
virtual Result<std::shared_ptr<Table>> Read() = 0;
+ /// Read the entire CSV file and convert it to a Arrow Table
+ virtual Future<std::shared_ptr<Table>> ReadAsync() = 0;
/// Create a TableReader instance
static Result<std::shared_ptr<TableReader>> Make(MemoryPool* pool,
+ io::AsyncContext
async_context,
std::shared_ptr<io::InputStream> input,
const ReadOptions&,
const ParseOptions&,
diff --git a/cpp/src/arrow/csv/reader_test.cc b/cpp/src/arrow/csv/reader_test.cc
new file mode 100644
index 0000000..64010ae
--- /dev/null
+++ b/cpp/src/arrow/csv/reader_test.cc
@@ -0,0 +1,156 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <string>
+#include <thread>
+#include <utility>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "arrow/csv/options.h"
+#include "arrow/csv/reader.h"
+#include "arrow/csv/test_common.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/io/memory.h"
+#include "arrow/status.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/future.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+namespace csv {
+
+using TableReaderFactory =
+
std::function<Result<std::shared_ptr<TableReader>>(std::shared_ptr<io::InputStream>)>;
+
+void StressTableReader(TableReaderFactory reader_factory) {
+ const int NTASKS = 100;
+ const int NROWS = 1000;
+ ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS));
+
+ std::vector<Future<std::shared_ptr<Table>>> task_futures(NTASKS);
+ for (int i = 0; i < NTASKS; i++) {
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+ ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input));
+ task_futures[i] = reader->ReadAsync();
+ }
+ auto combined_future = All(task_futures);
+ combined_future.Wait();
+
+ ASSERT_OK_AND_ASSIGN(std::vector<Result<std::shared_ptr<Table>>> results,
+ combined_future.result());
+ for (auto&& result : results) {
+ ASSERT_OK_AND_ASSIGN(auto table, result);
+ ASSERT_EQ(NROWS, table->num_rows());
+ }
+}
+
+void StressInvalidTableReader(TableReaderFactory reader_factory) {
+ const int NTASKS = 100;
+ const int NROWS = 1000;
+ ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS, false));
+
+ std::vector<Future<std::shared_ptr<Table>>> task_futures(NTASKS);
+ for (int i = 0; i < NTASKS; i++) {
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+ ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input));
+ task_futures[i] = reader->ReadAsync();
+ }
+ auto combined_future = All(task_futures);
+ combined_future.Wait();
+
+ ASSERT_OK_AND_ASSIGN(std::vector<Result<std::shared_ptr<Table>>> results,
+ combined_future.result());
+ for (auto&& result : results) {
+ ASSERT_RAISES(Invalid, result);
+ }
+}
+
+void TestNestedParallelism(std::shared_ptr<internal::ThreadPool> thread_pool,
+ TableReaderFactory reader_factory) {
+ const int NROWS = 1000;
+ ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS));
+ auto input = std::make_shared<io::BufferReader>(table_buffer);
+ ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input));
+
+ Future<std::shared_ptr<Table>> table_future;
+
+ auto read_task = [&reader, &table_future]() mutable {
+ table_future = reader->ReadAsync();
+ return Status::OK();
+ };
+ ASSERT_OK_AND_ASSIGN(auto future, thread_pool->Submit(read_task));
+
+ ASSERT_FINISHES_OK(future);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto table, table_future);
+ ASSERT_EQ(table->num_rows(), NROWS);
+} // namespace csv
+
+TableReaderFactory MakeSerialFactory() {
+ return [](std::shared_ptr<io::InputStream> input_stream) {
+ auto read_options = ReadOptions::Defaults();
+ read_options.block_size = 1 << 10;
+ read_options.use_threads = false;
+ return TableReader::Make(default_memory_pool(), io::AsyncContext(),
input_stream,
+ read_options, ParseOptions::Defaults(),
+ ConvertOptions::Defaults());
+ };
+}
+
+TEST(SerialReaderTests, Stress) { StressTableReader(MakeSerialFactory()); }
+TEST(SerialReaderTests, StressInvalid) {
StressInvalidTableReader(MakeSerialFactory()); }
+TEST(SerialReaderTests, NestedParallelism) {
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));
+ TestNestedParallelism(thread_pool, MakeSerialFactory());
+}
+
+Result<TableReaderFactory> MakeAsyncFactory(
+ std::shared_ptr<internal::ThreadPool> thread_pool = nullptr) {
+ if (!thread_pool) {
+ ARROW_ASSIGN_OR_RAISE(thread_pool, internal::ThreadPool::Make(1));
+ }
+ return [thread_pool](std::shared_ptr<io::InputStream> input_stream)
+ -> Result<std::shared_ptr<TableReader>> {
+ ReadOptions read_options = ReadOptions::Defaults();
+ read_options.use_threads = true;
+ read_options.block_size = 1 << 10;
+ auto table_reader = TableReader::Make(
+ default_memory_pool(), io::AsyncContext(thread_pool.get()),
input_stream,
+ read_options, ParseOptions::Defaults(), ConvertOptions::Defaults());
+ return table_reader;
+ };
+}
+
+TEST(AsyncReaderTests, Stress) {
+ ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory());
+ StressTableReader(table_factory);
+}
+TEST(AsyncReaderTests, StressInvalid) {
+ ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory());
+ StressInvalidTableReader(table_factory);
+}
+TEST(AsyncReaderTests, NestedParallelism) {
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));
+ ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory(thread_pool));
+ TestNestedParallelism(thread_pool, table_factory);
+}
+
+} // namespace csv
+} // namespace arrow
diff --git a/cpp/src/arrow/csv/test_common.cc b/cpp/src/arrow/csv/test_common.cc
index 08981a7..c3d0241 100644
--- a/cpp/src/arrow/csv/test_common.cc
+++ b/cpp/src/arrow/csv/test_common.cc
@@ -61,5 +61,59 @@ void MakeColumnParser(std::vector<std::string> items,
std::shared_ptr<BlockParse
ASSERT_EQ((*out)->num_rows(), items.size());
}
+namespace {
+
+const std::vector<std::string> int64_rows = {"123", "4", "-317005557", "",
"N/A", "0"};
+const std::vector<std::string> float_rows = {"0", "123.456", "-3170.55766",
"", "N/A"};
+const std::vector<std::string> decimal128_rows = {"0", "123.456",
"-3170.55766",
+ "", "N/A",
"1233456789.123456789"};
+const std::vector<std::string> iso8601_rows = {"1917-10-17", "2018-09-13",
+ "1941-06-22 04:00", "1945-05-09
09:45:38"};
+const std::vector<std::string> strptime_rows = {"10/17/1917", "9/13/2018",
"9/5/1945"};
+
+static void WriteHeader(std::ostream& writer) {
+ writer << "Int64,Float,Decimal128,ISO8601,Strptime" << std::endl;
+}
+
+static std::string GetCell(const std::vector<std::string>& base_rows, size_t
row_index) {
+ return base_rows[row_index % base_rows.size()];
+}
+
+static void WriteRow(std::ostream& writer, size_t row_index) {
+ writer << GetCell(int64_rows, row_index);
+ writer << ',';
+ writer << GetCell(float_rows, row_index);
+ writer << ',';
+ writer << GetCell(decimal128_rows, row_index);
+ writer << ',';
+ writer << GetCell(iso8601_rows, row_index);
+ writer << ',';
+ writer << GetCell(strptime_rows, row_index);
+ writer << std::endl;
+}
+
+static void WriteInvalidRow(std::ostream& writer, size_t row_index) {
+ writer << "\"" << std::endl << "\"";
+ writer << std::endl;
+}
+} // namespace
+
+Result<std::shared_ptr<Buffer>> MakeSampleCsvBuffer(size_t num_rows, bool
valid) {
+ std::stringstream writer;
+
+ WriteHeader(writer);
+ for (size_t i = 0; i < num_rows; ++i) {
+ if (i == num_rows / 2 && !valid) {
+ WriteInvalidRow(writer, i);
+ } else {
+ WriteRow(writer, i);
+ }
+ }
+
+ auto table_str = writer.str();
+ auto table_buffer = std::make_shared<Buffer>(table_str);
+ return MemoryManager::CopyBuffer(table_buffer, default_cpu_memory_manager());
+}
+
} // namespace csv
} // namespace arrow
diff --git a/cpp/src/arrow/csv/test_common.h b/cpp/src/arrow/csv/test_common.h
index 119da03..823cf64 100644
--- a/cpp/src/arrow/csv/test_common.h
+++ b/cpp/src/arrow/csv/test_common.h
@@ -46,5 +46,8 @@ void MakeCSVParser(std::vector<std::string> lines,
std::shared_ptr<BlockParser>*
ARROW_TESTING_EXPORT
void MakeColumnParser(std::vector<std::string> items,
std::shared_ptr<BlockParser>* out);
+ARROW_TESTING_EXPORT
+Result<std::shared_ptr<Buffer>> MakeSampleCsvBuffer(size_t num_rows, bool
valid = true);
+
} // namespace csv
} // namespace arrow
diff --git a/cpp/src/arrow/json/reader.cc b/cpp/src/arrow/json/reader.cc
index dc0d6e0..44aa260 100644
--- a/cpp/src/arrow/json/reader.cc
+++ b/cpp/src/arrow/json/reader.cc
@@ -29,6 +29,7 @@
#include "arrow/json/parser.h"
#include "arrow/record_batch.h"
#include "arrow/table.h"
+#include "arrow/util/async_generator.h"
#include "arrow/util/iterator.h"
#include "arrow/util/logging.h"
#include "arrow/util/string_view.h"
diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h
index 6504d95..0172a85 100644
--- a/cpp/src/arrow/result.h
+++ b/cpp/src/arrow/result.h
@@ -317,7 +317,7 @@ class ARROW_MUST_USE_TYPE Result : public
util::EqualityComparable<Result<T>> {
return ValueUnsafe();
}
const T& operator*() const& { return ValueOrDie(); }
- const T* operator->() const& { return &ValueOrDie(); }
+ const T* operator->() const { return &ValueOrDie(); }
/// Gets a mutable reference to the stored `T` value.
///
@@ -332,7 +332,7 @@ class ARROW_MUST_USE_TYPE Result : public
util::EqualityComparable<Result<T>> {
return ValueUnsafe();
}
T& operator*() & { return ValueOrDie(); }
- T* operator->() & { return &ValueOrDie(); }
+ T* operator->() { return &ValueOrDie(); }
/// Moves and returns the internally-stored `T` value.
///
diff --git a/cpp/src/arrow/testing/gtest_util.h
b/cpp/src/arrow/testing/gtest_util.h
index cdb23a9..fafccc2 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -135,15 +135,55 @@
ASSERT_EQ(expected, _actual); \
} while (0)
+// This macro should be called by futures that are expected to
+// complete pretty quickly. 2 seconds is the default max wait
+// here. Anything longer than that and it's a questionable
+// unit test anyways.
+#define ASSERT_FINISHES_IMPL(fut) \
+ do { \
+ ASSERT_TRUE(fut.Wait(10)); \
+ if (!fut.is_finished()) { \
+ FAIL() << "Future did not finish in a timely fashion"; \
+ } \
+ } while (false)
+
+#define ASSERT_FINISHES_OK(expr) \
+ do { \
+ auto&& _fut = (expr); \
+ ASSERT_TRUE(_fut.Wait(10)); \
+ if (!_fut.is_finished()) { \
+ FAIL() << "Future did not finish in a timely fashion"; \
+ } \
+ auto _st = _fut.status(); \
+ if (!_st.ok()) { \
+ FAIL() << "'" ARROW_STRINGIFY(expr) "' failed with " << _st.ToString(); \
+ } \
+ } while (false)
+
+#define ASSERT_FINISHES_ERR(ENUM, expr) \
+ do { \
+ auto&& fut = (expr); \
+ ASSERT_FINISHES_IMPL(fut); \
+ ASSERT_RAISES(ENUM, fut.status()); \
+ } while (false)
+
+#define ASSERT_FINISHES_OK_AND_ASSIGN_IMPL(lhs, rexpr, future_name) \
+ auto future_name = (rexpr); \
+ ASSERT_FINISHES_IMPL(future_name); \
+ ASSERT_OK_AND_ASSIGN(lhs, future_name.result());
+
+#define ASSERT_FINISHES_OK_AND_ASSIGN(lhs, rexpr) \
+ ASSERT_FINISHES_OK_AND_ASSIGN_IMPL(lhs, rexpr, \
+ ARROW_ASSIGN_OR_RAISE_NAME(_fut,
__COUNTER__))
+
namespace arrow {
+// ----------------------------------------------------------------------
+// Useful testing::Types declarations
inline void PrintTo(StatusCode code, std::ostream* os) {
*os << Status::CodeAsString(code);
}
-// ----------------------------------------------------------------------
-// Useful testing::Types declarations
-
using NumericArrowTypes =
::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type,
Int16Type,
Int32Type, Int64Type, FloatType, DoubleType>;
diff --git a/cpp/src/arrow/util/async_generator.h
b/cpp/src/arrow/util/async_generator.h
new file mode 100644
index 0000000..8e88813
--- /dev/null
+++ b/cpp/src/arrow/util/async_generator.h
@@ -0,0 +1,388 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+#include <queue>
+
+#include "arrow/util/functional.h"
+#include "arrow/util/future.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/optional.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+template <typename T>
+using AsyncGenerator = std::function<Future<T>()>;
+
+/// Iterates through a generator of futures, visiting the result of each one
and
+/// returning a future that completes when all have been visited
+template <typename T>
+Future<> VisitAsyncGenerator(AsyncGenerator<T> generator,
+ std::function<Status(T)> visitor) {
+ struct LoopBody {
+ struct Callback {
+ Result<ControlFlow<detail::Empty>> operator()(const T& result) {
+ if (result == IterationTraits<T>::End()) {
+ return Break(detail::Empty());
+ } else {
+ auto visited = visitor(result);
+ if (visited.ok()) {
+ return Continue();
+ } else {
+ return visited;
+ }
+ }
+ }
+
+ std::function<Status(T)> visitor;
+ };
+
+ Future<ControlFlow<detail::Empty>> operator()() {
+ Callback callback{visitor};
+ auto next = generator();
+ return next.Then(std::move(callback));
+ }
+
+ AsyncGenerator<T> generator;
+ std::function<Status(T)> visitor;
+ };
+
+ return Loop(LoopBody{std::move(generator), std::move(visitor)});
+}
+
+template <typename T>
+Future<std::vector<T>> CollectAsyncGenerator(AsyncGenerator<T> generator) {
+ auto vec = std::make_shared<std::vector<T>>();
+ struct LoopBody {
+ Future<ControlFlow<std::vector<T>>> operator()() {
+ auto next = generator();
+ auto vec = vec_;
+ return next.Then([vec](const T& result) ->
Result<ControlFlow<std::vector<T>>> {
+ if (result == IterationTraits<T>::End()) {
+ return Break(*vec);
+ } else {
+ vec->push_back(result);
+ return Continue();
+ }
+ });
+ }
+ AsyncGenerator<T> generator;
+ std::shared_ptr<std::vector<T>> vec_;
+ };
+ return Loop(LoopBody{std::move(generator), std::move(vec)});
+}
+
+template <typename T, typename V>
+class TransformingGenerator {
+ // The transforming generator state will be referenced as an async generator
but will
+ // also be referenced via callback to various futures. If the async
generator owner
+ // moves it around we need the state to be consistent for future callbacks.
+ struct TransformingGeneratorState
+ : std::enable_shared_from_this<TransformingGeneratorState> {
+ TransformingGeneratorState(AsyncGenerator<T> generator, Transformer<T, V>
transformer)
+ : generator_(std::move(generator)),
+ transformer_(std::move(transformer)),
+ last_value_(),
+ finished_() {}
+
+ Future<V> operator()() {
+ while (true) {
+ auto maybe_next_result = Pump();
+ if (!maybe_next_result.ok()) {
+ return Future<V>::MakeFinished(maybe_next_result.status());
+ }
+ auto maybe_next = std::move(maybe_next_result).ValueUnsafe();
+ if (maybe_next.has_value()) {
+ return Future<V>::MakeFinished(*std::move(maybe_next));
+ }
+
+ auto next_fut = generator_();
+ // If finished already, process results immediately inside the loop to
avoid stack
+ // overflow
+ if (next_fut.is_finished()) {
+ auto next_result = next_fut.result();
+ if (next_result.ok()) {
+ last_value_ = *next_result;
+ } else {
+ return Future<V>::MakeFinished(next_result.status());
+ }
+ // Otherwise, if not finished immediately, add callback to process
results
+ } else {
+ auto self = this->shared_from_this();
+ return next_fut.Then([self](const Result<T>& next_result) {
+ if (next_result.ok()) {
+ self->last_value_ = *next_result;
+ return (*self)();
+ } else {
+ return Future<V>::MakeFinished(next_result.status());
+ }
+ });
+ }
+ }
+ }
+
+ // See comment on TransformingIterator::Pump
+ Result<util::optional<V>> Pump() {
+ if (!finished_ && last_value_.has_value()) {
+ ARROW_ASSIGN_OR_RAISE(TransformFlow<V> next,
transformer_(*last_value_));
+ if (next.ReadyForNext()) {
+ if (*last_value_ == IterationTraits<T>::End()) {
+ finished_ = true;
+ }
+ last_value_.reset();
+ }
+ if (next.Finished()) {
+ finished_ = true;
+ }
+ if (next.HasValue()) {
+ return next.Value();
+ }
+ }
+ if (finished_) {
+ return IterationTraits<V>::End();
+ }
+ return util::nullopt;
+ }
+
+ AsyncGenerator<T> generator_;
+ Transformer<T, V> transformer_;
+ util::optional<T> last_value_;
+ bool finished_;
+ };
+
+ public:
+ explicit TransformingGenerator(AsyncGenerator<T> generator,
+ Transformer<T, V> transformer)
+ :
state_(std::make_shared<TransformingGeneratorState>(std::move(generator),
+
std::move(transformer))) {}
+
+ Future<V> operator()() { return (*state_)(); }
+
+ protected:
+ std::shared_ptr<TransformingGeneratorState> state_;
+};
+
+template <typename T>
+class ReadaheadGenerator {
+ public:
+ ReadaheadGenerator(AsyncGenerator<T> source_generator, int max_readahead)
+ : source_generator_(std::move(source_generator)),
max_readahead_(max_readahead) {
+ auto finished = std::make_shared<std::atomic<bool>>();
+ mark_finished_if_done_ = [finished](const Result<T>& next_result) {
+ if (!next_result.ok()) {
+ finished->store(true);
+ } else {
+ const auto& next = *next_result;
+ if (next == IterationTraits<T>::End()) {
+ *finished = true;
+ }
+ }
+ };
+ finished_ = std::move(finished);
+ }
+
+ Future<T> operator()() {
+ if (readahead_queue_.empty()) {
+ // This is the first request, let's pump the underlying queue
+ for (int i = 0; i < max_readahead_; i++) {
+ auto next = source_generator_();
+ next.AddCallback(mark_finished_if_done_);
+ readahead_queue_.push(std::move(next));
+ }
+ }
+ // Pop one and add one
+ auto result = readahead_queue_.front();
+ readahead_queue_.pop();
+ if (finished_->load()) {
+
readahead_queue_.push(Future<T>::MakeFinished(IterationTraits<T>::End()));
+ } else {
+ auto back_of_queue = source_generator_();
+ back_of_queue.AddCallback(mark_finished_if_done_);
+ readahead_queue_.push(std::move(back_of_queue));
+ }
+ return result;
+ }
+
+ private:
+ AsyncGenerator<T> source_generator_;
+ int max_readahead_;
+ std::function<void(const Result<T>&)> mark_finished_if_done_;
+ // Can't use a bool here because finished may be referenced by callbacks that
+ // outlive this class
+ std::shared_ptr<std::atomic<bool>> finished_;
+ std::queue<Future<T>> readahead_queue_;
+};
+
+/// \brief Creates a generator that pulls reentrantly from a source
+/// This generator will pull reentrantly from a source, ensuring that
max_readahead
+/// requests are active at any given time.
+///
+/// The source generator must be async-reentrant
+///
+/// This generator itself is async-reentrant.
+template <typename T>
+AsyncGenerator<T> MakeReadaheadGenerator(AsyncGenerator<T> source_generator,
+ int max_readahead) {
+ return ReadaheadGenerator<T>(std::move(source_generator), max_readahead);
+}
+
+/// \brief Transforms an async generator using a transformer function
returning a new
+/// AsyncGenerator
+///
+/// The transform function here behaves exactly the same as the transform
function in
+/// MakeTransformedIterator and you can safely use the same transform function
to
+/// transform both synchronous and asynchronous streams.
+///
+/// This generator is not async-reentrant
+template <typename T, typename V>
+AsyncGenerator<V> MakeAsyncGenerator(AsyncGenerator<T> generator,
+ Transformer<T, V> transformer) {
+ return TransformingGenerator<T, V>(generator, transformer);
+}
+
+/// \brief Transfers execution of the generator onto the given executor
+///
+/// This generator is async-reentrant if the source generator is
async-reentrant
+template <typename T>
+class TransferringGenerator {
+ public:
+ explicit TransferringGenerator(AsyncGenerator<T> source, internal::Executor*
executor)
+ : source_(std::move(source)), executor_(executor) {}
+
+ Future<T> operator()() { return executor_->Transfer(source_()); }
+
+ private:
+ AsyncGenerator<T> source_;
+ internal::Executor* executor_;
+};
+
+/// \brief Transfers a future to an underlying executor.
+///
+/// Continuations run on the returned future will be run on the given executor
+/// if they cannot be run synchronously.
+///
+/// This is often needed to move computation off I/O threads or other external
+/// completion sources and back on to the CPU executor so the I/O thread can
+/// stay busy and focused on I/O
+///
+/// Keep in mind that continuations called on an already completed future will
+/// always be run synchronously and so no transfer will happen in that case.
+template <typename T>
+AsyncGenerator<T> MakeTransferredGenerator(AsyncGenerator<T> source,
+ internal::Executor* executor) {
+ return TransferringGenerator<T>(std::move(source), executor);
+}
+
+/// \brief Async generator that iterates on an underlying iterator in a
+/// separate executor.
+///
+/// This generator is async-reentrant
+template <typename T>
+class BackgroundGenerator {
+ public:
+ explicit BackgroundGenerator(Iterator<T> it, internal::Executor* io_executor)
+ : io_executor_(io_executor) {
+ task_ = Task{std::make_shared<Iterator<T>>(std::move(it)),
+ std::make_shared<std::atomic<bool>>(false)};
+ }
+
+ ~BackgroundGenerator() {
+ // The thread pool will be disposed of automatically. By default it will
not wait
+ // so the background thread may outlive this object. That should be ok.
Any task
+ // objects in the thread pool are copies of task_ and have their own
shared_ptr to
+ // the iterator.
+ }
+
+ ARROW_DEFAULT_MOVE_AND_ASSIGN(BackgroundGenerator);
+ ARROW_DISALLOW_COPY_AND_ASSIGN(BackgroundGenerator);
+
+ Future<T> operator()() {
+ auto submitted_future = io_executor_->Submit(task_);
+ if (!submitted_future.ok()) {
+ return Future<T>::MakeFinished(submitted_future.status());
+ }
+ return std::move(*submitted_future);
+ }
+
+ protected:
+ struct Task {
+ Result<T> operator()() {
+ if (*done_) {
+ return IterationTraits<T>::End();
+ }
+ auto next = it_->Next();
+ if (!next.ok() || *next == IterationTraits<T>::End()) {
+ *done_ = true;
+ }
+ return next;
+ }
+ // This task is going to be copied so we need to convert the iterator ptr
to
+ // a shared ptr. This should be safe however because the background
executor only
+ // has a single thread so it can't access it_ across multiple threads.
+ std::shared_ptr<Iterator<T>> it_;
+ std::shared_ptr<std::atomic<bool>> done_;
+ };
+
+ Task task_;
+ internal::Executor* io_executor_;
+};
+
+/// \brief Creates an AsyncGenerator<T> by iterating over an Iterator<T> on a
background
+/// thread
+template <typename T>
+static Result<AsyncGenerator<T>> MakeBackgroundGenerator(
+ Iterator<T> iterator, internal::Executor* io_executor) {
+ auto background_iterator = std::make_shared<BackgroundGenerator<T>>(
+ std::move(iterator), std::move(io_executor));
+ return [background_iterator]() { return (*background_iterator)(); };
+}
+
+/// \brief Converts an AsyncGenerator<T> to an Iterator<T> by blocking until
each future
+/// is finished
+template <typename T>
+class GeneratorIterator {
+ public:
+ explicit GeneratorIterator(AsyncGenerator<T> source) :
source_(std::move(source)) {}
+
+ Result<T> Next() { return source_().result(); }
+
+ private:
+ AsyncGenerator<T> source_;
+};
+
+template <typename T>
+Result<Iterator<T>> MakeGeneratorIterator(AsyncGenerator<T> source) {
+ return Iterator<T>(GeneratorIterator<T>(std::move(source)));
+}
+
+template <typename T>
+Result<Iterator<T>> MakeReadaheadIterator(Iterator<T> it, int
readahead_queue_size) {
+ ARROW_ASSIGN_OR_RAISE(auto io_executor, internal::ThreadPool::Make(1));
+ ARROW_ASSIGN_OR_RAISE(auto background_generator,
+ MakeBackgroundGenerator(std::move(it),
io_executor.get()));
+ // Capture io_executor to keep it alive as long as owned_bg_generator is
still
+ // referenced
+ AsyncGenerator<T> owned_bg_generator = [io_executor, background_generator]()
{
+ return background_generator();
+ };
+ auto readahead_generator =
+ MakeReadaheadGenerator(std::move(owned_bg_generator),
readahead_queue_size);
+ return MakeGeneratorIterator(std::move(readahead_generator));
+}
+
+} // namespace arrow
diff --git a/cpp/src/arrow/util/future.cc b/cpp/src/arrow/util/future.cc
index f8d12ad..3a77f34 100644
--- a/cpp/src/arrow/util/future.cc
+++ b/cpp/src/arrow/util/future.cc
@@ -239,6 +239,16 @@ class ConcreteFutureImpl : public FutureImpl {
}
}
+ bool TryAddCallback(const std::function<Callback()>& callback_factory) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (IsFutureFinished(state_)) {
+ return false;
+ } else {
+ callbacks_.push_back(callback_factory());
+ return true;
+ }
+ }
+
void DoMarkFinishedOrFailed(FutureState state) {
{
// Lock the hypothetical waiter first, and the future after.
@@ -326,4 +336,8 @@ void FutureImpl::AddCallback(Callback callback) {
GetConcreteFuture(this)->AddCallback(std::move(callback));
}
+bool FutureImpl::TryAddCallback(const std::function<Callback()>&
callback_factory) {
+ return GetConcreteFuture(this)->TryAddCallback(callback_factory);
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h
index 2fc040c..ee053cf 100644
--- a/cpp/src/arrow/util/future.h
+++ b/cpp/src/arrow/util/future.h
@@ -29,6 +29,7 @@
#include "arrow/status.h"
#include "arrow/util/functional.h"
#include "arrow/util/macros.h"
+#include "arrow/util/optional.h"
#include "arrow/util/type_fwd.h"
#include "arrow/util/visibility.h"
@@ -152,6 +153,7 @@ class ARROW_EXPORT FutureImpl {
using Callback = internal::FnOnce<void()>;
void AddCallback(Callback callback);
+ bool TryAddCallback(const std::function<Callback()>& callback_factory);
// Waiter API
inline FutureState SetWaiter(FutureWaiter* w, int future_num);
@@ -273,7 +275,14 @@ class ARROW_MUST_USE_TYPE Future {
Wait();
return *GetResult();
}
- Result<ValueType>&& result() && {
+
+ /// \brief Returns an rvalue to the result. This method is potentially
unsafe
+ ///
+ /// The future is not the unique owner of the result, copies of a future will
+ /// also point to the same result. You must make sure that no other copies
+ /// of the future exist. Attempts to add callbacks after you move the result
+ /// will result in undefined behavior.
+ Result<ValueType>&& MoveResult() {
Wait();
return std::move(*GetResult());
}
@@ -326,7 +335,10 @@ class ARROW_MUST_USE_TYPE Future {
/// \brief Producer API: instantiate a valid Future
///
- /// The Future's state is initialized with PENDING.
+ /// The Future's state is initialized with PENDING. If you are creating a
future with
+ /// this method you must ensure that future is eventually completed (with
success or
+ /// failure). Creating a future, returning it, and never completing the
future can lead
+ /// to memory leaks (for example, see Loop).
static Future Make() {
Future fut;
fut.impl_ = FutureImpl::Make();
@@ -375,22 +387,33 @@ class ARROW_MUST_USE_TYPE Future {
/// In this example `fut` falls out of scope but is not destroyed because it
holds a
/// cyclic reference to itself through the callback.
template <typename OnComplete>
- void AddCallback(OnComplete&& on_complete) const {
- struct Callback {
- void operator()() && {
- auto self = weak_self.get();
- std::move(on_complete)(*self.GetResult());
- }
-
- WeakFuture<T> weak_self;
- OnComplete on_complete;
- };
-
+ void AddCallback(OnComplete on_complete) const {
// We know impl_ will not be dangling when invoking callbacks because at
least one
// thread will be waiting for MarkFinished to return. Thus it's safe to
keep a
// weak reference to impl_ here
impl_->AddCallback(
- Callback{WeakFuture<T>(*this), std::forward<OnComplete>(on_complete)});
+ Callback<OnComplete>{WeakFuture<T>(*this), std::move(on_complete)});
+ }
+
+ /// \brief Overload of AddCallback that will return false instead of running
+ /// synchronously
+ ///
+ /// This overload will guarantee the callback is never run synchronously.
If the future
+ /// is already finished then it will simply return false. This can be
useful to avoid
+ /// stack overflow in a situation where you have recursive Futures. For an
example
+ /// see the Loop function
+ ///
+ /// Takes in a callback factory function to allow moving callbacks (the
factory function
+ /// will only be called if the callback can successfully be added)
+ ///
+ /// Returns true if a callback was actually added and false if the callback
failed
+ /// to add because the future was marked complete.
+ template <typename CallbackFactory>
+ bool TryAddCallback(const CallbackFactory& callback_factory) const {
+ return impl_->TryAddCallback([this, &callback_factory]() {
+ return
Callback<detail::result_of_t<CallbackFactory()>>{WeakFuture<T>(*this),
+
callback_factory()};
+ });
}
/// \brief Consumer API: Register a continuation to run when this future
completes
@@ -428,7 +451,7 @@ class ARROW_MUST_USE_TYPE Future {
template <typename OnSuccess, typename OnFailure,
typename ContinuedFuture =
detail::ContinueFuture::ForSignature<OnSuccess && (const T&)>>
- ContinuedFuture Then(OnSuccess&& on_success, OnFailure&& on_failure) const {
+ ContinuedFuture Then(OnSuccess on_success, OnFailure on_failure) const {
static_assert(
std::is_same<detail::ContinueFuture::ForSignature<OnFailure && (const
Status&)>,
ContinuedFuture>::value,
@@ -471,6 +494,17 @@ class ARROW_MUST_USE_TYPE Future {
}
protected:
+ template <typename OnComplete>
+ struct Callback {
+ void operator()() && {
+ auto self = weak_self.get();
+ std::move(on_complete)(*self.GetResult());
+ }
+
+ WeakFuture<T> weak_self;
+ OnComplete on_complete;
+ };
+
Result<ValueType>* GetResult() const {
return static_cast<Result<ValueType>*>(impl_->result_.get());
}
@@ -557,6 +591,38 @@ inline bool WaitForAll(const std::vector<Future<T>*>&
futures,
return waiter->Wait(seconds);
}
+/// \brief Create a Future which completes when all of `futures` complete.
+///
+/// The future's result is a vector of the results of `futures`.
+/// Note that this future will never be marked "failed"; failed results
+/// will be stored in the result vector alongside successful results.
+template <typename T>
+Future<std::vector<Result<T>>> All(std::vector<Future<T>> futures) {
+ struct State {
+ explicit State(std::vector<Future<T>> f)
+ : futures(std::move(f)), n_remaining(futures.size()) {}
+
+ std::vector<Future<T>> futures;
+ std::atomic<size_t> n_remaining;
+ };
+
+ auto state = std::make_shared<State>(std::move(futures));
+
+ auto out = Future<std::vector<Result<T>>>::Make();
+ for (const Future<T>& future : state->futures) {
+ future.AddCallback([state, out](const Result<T>&) mutable {
+ if (state->n_remaining.fetch_sub(1) != 1) return;
+
+ std::vector<Result<T>> results(state->futures.size());
+ for (size_t i = 0; i < results.size(); ++i) {
+ results[i] = state->futures[i].result();
+ }
+ out.MarkFinished(std::move(results));
+ });
+ }
+ return out;
+}
+
/// \brief Wait for one of the futures to end, or for the given timeout to
expire.
///
/// The indices of all completed futures are returned. Note that some futures
@@ -581,4 +647,79 @@ inline std::vector<int> WaitForAny(const
std::vector<Future<T>*>& futures,
return waiter->MoveFinishedFutures();
}
+struct Continue {
+ template <typename T>
+ operator util::optional<T>() && { // NOLINT explicit
+ return {};
+ }
+};
+
+template <typename T = detail::Empty>
+util::optional<T> Break(T break_value = {}) {
+ return util::optional<T>{std::move(break_value)};
+}
+
+template <typename T = detail::Empty>
+using ControlFlow = util::optional<T>;
+
+/// \brief Loop through an asynchronous sequence
+///
+/// \param[in] iterate A generator of Future<ControlFlow<BreakValue>>. On
completion of
+/// each yielded future the resulting ControlFlow will be examined. A Break
will terminate
+/// the loop, while a Continue will re-invoke `iterate`. \return A future
which will
+/// complete when a Future returned by iterate completes with a Break
+template <typename Iterate,
+ typename Control = typename
detail::result_of_t<Iterate()>::ValueType,
+ typename BreakValueType = typename Control::value_type>
+Future<BreakValueType> Loop(Iterate iterate) {
+ auto break_fut = Future<BreakValueType>::Make();
+
+ struct Callback {
+ bool CheckForTermination(const Result<Control>& control_res) {
+ if (!control_res.ok()) {
+ break_fut.MarkFinished(control_res.status());
+ return true;
+ }
+ if (control_res->has_value()) {
+ break_fut.MarkFinished(*std::move(*control_res));
+ return true;
+ }
+ return false;
+ }
+
+ void operator()(const Result<Control>& maybe_control) && {
+ if (CheckForTermination(maybe_control)) return;
+
+ auto control_fut = iterate();
+ while (true) {
+ if (control_fut.TryAddCallback([this]() { return *this; })) {
+ // Adding a callback succeeded; control_fut was not finished
+ // and we must wait to CheckForTermination.
+ return;
+ }
+ // Adding a callback failed; control_fut was finished and we
+ // can CheckForTermination immediately. This also avoids recursion and
potential
+ // stack overflow.
+ if (CheckForTermination(control_fut.result())) return;
+
+ control_fut = iterate();
+ }
+ }
+
+ Iterate iterate;
+
+ // If the future returned by control_fut is never completed then we will
be hanging on
+ // to break_fut forever even if the listener has given up listening on it.
Instead we
+ // rely on the fact that a producer (the caller of Future<>::Make) is
always
+ // responsible for completing the futures they create.
+ // TODO: Could avoid this kind of situation with "future abandonment"
similar to mesos
+ Future<BreakValueType> break_fut;
+ };
+
+ auto control_fut = iterate();
+ control_fut.AddCallback(Callback{std::move(iterate), break_fut});
+
+ return break_fut;
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/util/future_test.cc
b/cpp/src/arrow/util/future_test.cc
index 203f05b..2a4fc6b 100644
--- a/cpp/src/arrow/util/future_test.cc
+++ b/cpp/src/arrow/util/future_test.cc
@@ -20,7 +20,9 @@
#include <algorithm>
#include <chrono>
+#include <condition_variable>
#include <memory>
+#include <mutex>
#include <ostream>
#include <random>
#include <string>
@@ -287,6 +289,109 @@ TEST(FutureSyncTest, Int) {
}
}
+TEST(FutureSyncTest, Foo) {
+ {
+ auto fut = Future<Foo>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Foo(42));
+ AssertSuccessful(fut);
+ auto res = fut.result();
+ ASSERT_OK(res);
+ Foo value = *res;
+ ASSERT_EQ(value, 42);
+ ASSERT_OK(fut.status());
+ res = std::move(fut).result();
+ ASSERT_OK(res);
+ value = *res;
+ ASSERT_EQ(value, 42);
+ }
+ {
+ // MarkFinished(Result<Foo>)
+ auto fut = Future<Foo>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Result<Foo>(Foo(42)));
+ AssertSuccessful(fut);
+ ASSERT_OK_AND_ASSIGN(Foo value, fut.result());
+ ASSERT_EQ(value, 42);
+ }
+ {
+ // MarkFinished(failed Result<Foo>)
+ auto fut = Future<Foo>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Result<Foo>(Status::IOError("xxx")));
+ AssertFailed(fut);
+ ASSERT_RAISES(IOError, fut.result());
+ }
+}
+
+TEST(FutureSyncTest, Empty) {
+ {
+ // MarkFinished()
+ auto fut = Future<>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished();
+ AssertSuccessful(fut);
+ }
+ {
+ // MakeFinished()
+ auto fut = Future<>::MakeFinished();
+ AssertSuccessful(fut);
+ auto res = fut.result();
+ ASSERT_OK(res);
+ res = std::move(fut.result());
+ ASSERT_OK(res);
+ }
+ {
+ // MarkFinished(Status)
+ auto fut = Future<>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Status::OK());
+ AssertSuccessful(fut);
+ }
+ {
+ // MakeFinished(Status)
+ auto fut = Future<>::MakeFinished(Status::OK());
+ AssertSuccessful(fut);
+ fut = Future<>::MakeFinished(Status::IOError("xxx"));
+ AssertFailed(fut);
+ }
+ {
+ // MarkFinished(Status)
+ auto fut = Future<>::Make();
+ AssertNotFinished(fut);
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut);
+ ASSERT_RAISES(IOError, fut.status());
+ }
+}
+
+TEST(FutureSyncTest, GetStatusFuture) {
+ {
+ auto fut = Future<MoveOnlyDataType>::Make();
+ Future<> status_future(fut);
+
+ AssertNotFinished(fut);
+ AssertNotFinished(status_future);
+
+ fut.MarkFinished(MoveOnlyDataType(42));
+ AssertSuccessful(fut);
+ AssertSuccessful(status_future);
+ ASSERT_EQ(&fut.status(), &status_future.status());
+ }
+ {
+ auto fut = Future<MoveOnlyDataType>::Make();
+ Future<> status_future(fut);
+
+ AssertNotFinished(fut);
+ AssertNotFinished(status_future);
+
+ fut.MarkFinished(Status::IOError("xxx"));
+ AssertFailed(fut);
+ AssertFailed(status_future);
+ ASSERT_EQ(&fut.status(), &status_future.status());
+ }
+}
+
TEST(FutureRefTest, ChainRemoved) {
// Creating a future chain should not prevent the futures from being deleted
if the
// entire chain is deleted
@@ -359,7 +464,7 @@ TEST(FutureRefTest, HeadRemoved) {
ASSERT_TRUE(ref.expired());
}
-TEST(FutureTest, StressCallback) {
+TEST(FutureStessTest, Callback) {
for (unsigned int n = 0; n < 1000; n++) {
auto fut = Future<>::Make();
std::atomic<unsigned int> count_finished_immediately(0);
@@ -404,6 +509,56 @@ TEST(FutureTest, StressCallback) {
}
}
+TEST(FutureStessTest, TryAddCallback) {
+ for (unsigned int n = 0; n < 1; n++) {
+ auto fut = Future<>::Make();
+ std::atomic<unsigned int> callbacks_added(0);
+ std::atomic<bool> finished;
+ std::mutex mutex;
+ std::condition_variable cv;
+ std::thread::id callback_adder_thread_id;
+
+ std::thread callback_adder([&] {
+ callback_adder_thread_id = std::this_thread::get_id();
+ std::function<void(const Result<detail::Empty>&)> callback =
+ [&callback_adder_thread_id](const Result<detail::Empty>&) {
+ if (std::this_thread::get_id() == callback_adder_thread_id) {
+ FAIL() << "TryAddCallback allowed a callback to be run
synchronously";
+ }
+ };
+ std::function<std::function<void(const Result<detail::Empty>&)>()>
+ callback_factory = [&callback]() { return callback; };
+ while (true) {
+ auto callback_added = fut.TryAddCallback(callback_factory);
+ if (callback_added) {
+ callbacks_added++;
+ } else {
+ break;
+ }
+ }
+ {
+ std::lock_guard<std::mutex> lg(mutex);
+ finished.store(true);
+ }
+ cv.notify_one();
+ });
+
+ while (callbacks_added.load() == 0) {
+ // Spin until the callback_adder has started running
+ }
+
+ fut.MarkFinished();
+
+ std::unique_lock<std::mutex> lk(mutex);
+ cv.wait_for(lk, std::chrono::duration<double>(0.5),
+ [&finished] { return finished.load(); });
+ lk.unlock();
+
+ ASSERT_TRUE(finished);
+ callback_adder.join();
+ }
+}
+
TEST(FutureCompletionTest, Void) {
{
// Simple callback
@@ -832,142 +987,213 @@ TEST(FutureCompletionTest, FutureVoid) {
}
}
-TEST(FutureSyncTest, Foo) {
- {
- // MarkFinished(Foo)
- auto fut = Future<Foo>::Make();
- AssertNotFinished(fut);
- fut.MarkFinished(Foo(42));
- AssertSuccessful(fut);
- auto res = fut.result();
- ASSERT_OK(res);
- Foo value = *res;
- ASSERT_EQ(value, 42);
- ASSERT_OK(fut.status());
- res = std::move(fut).result();
- ASSERT_OK(res);
- value = *res;
- ASSERT_EQ(value, 42);
- }
- {
- // MarkFinished(Result<Foo>)
- auto fut = Future<Foo>::Make();
- AssertNotFinished(fut);
- fut.MarkFinished(Result<Foo>(Foo(42)));
- AssertSuccessful(fut);
- ASSERT_OK_AND_ASSIGN(Foo value, fut.result());
- ASSERT_EQ(value, 42);
- }
- {
- // MarkFinished(failed Result<Foo>)
- auto fut = Future<Foo>::Make();
- AssertNotFinished(fut);
- fut.MarkFinished(Result<Foo>(Status::IOError("xxx")));
- AssertFailed(fut);
- ASSERT_RAISES(IOError, fut.result());
- }
+TEST(FutureAllTest, Simple) {
+ auto f1 = Future<int>::Make();
+ auto f2 = Future<int>::Make();
+ std::vector<Future<int>> futures = {f1, f2};
+ auto combined = arrow::All(futures);
+
+ auto after_assert = combined.Then([](std::vector<Result<int>> results) {
+ ASSERT_EQ(2, results.size());
+ ASSERT_EQ(1, *results[0]);
+ ASSERT_EQ(2, *results[1]);
+ });
+
+ // Finish in reverse order, results should still be delivered in proper order
+ AssertNotFinished(after_assert);
+ f2.MarkFinished(2);
+ AssertNotFinished(after_assert);
+ f1.MarkFinished(1);
+ AssertSuccessful(after_assert);
}
-TEST(FutureSyncTest, MoveOnlyDataType) {
- {
- // MarkFinished(MoveOnlyDataType)
- auto fut = Future<MoveOnlyDataType>::Make();
- AssertNotFinished(fut);
- fut.MarkFinished(MoveOnlyDataType(42));
- AssertSuccessful(fut);
- const auto& res = fut.result();
- ASSERT_TRUE(res.ok());
- ASSERT_EQ(*res, 42);
- ASSERT_OK_AND_ASSIGN(MoveOnlyDataType value, std::move(fut).result());
- ASSERT_EQ(value, 42);
- }
+TEST(FutureAllTest, Failure) {
+ auto f1 = Future<int>::Make();
+ auto f2 = Future<int>::Make();
+ auto f3 = Future<int>::Make();
+ std::vector<Future<int>> futures = {f1, f2, f3};
+ auto combined = arrow::All(futures);
+
+ auto after_assert = combined.Then([](std::vector<Result<int>> results) {
+ ASSERT_EQ(3, results.size());
+ ASSERT_EQ(1, *results[0]);
+ ASSERT_EQ(Status::IOError("XYZ"), results[1].status());
+ ASSERT_EQ(3, *results[2]);
+ });
+
+ f1.MarkFinished(1);
+ f2.MarkFinished(Status::IOError("XYZ"));
+ f3.MarkFinished(3);
+
+ AssertFinished(after_assert);
+}
+
+TEST(FutureLoopTest, Sync) {
+ struct {
+ int i = 0;
+ Future<int> Get() { return Future<int>::MakeFinished(i++); }
+ } IntSource;
+
+ bool do_fail = false;
+ std::vector<int> ints;
+ auto loop_body = [&] {
+ return IntSource.Get().Then([&](int i) -> Result<ControlFlow<int>> {
+ if (do_fail && i == 3) {
+ return Status::IOError("xxx");
+ }
+
+ if (i == 5) {
+ int sum = 0;
+ for (int i : ints) sum += i;
+ return Break(sum);
+ }
+
+ ints.push_back(i);
+ return Continue();
+ });
+ };
+
{
- // MarkFinished(Result<MoveOnlyDataType>)
- auto fut = Future<MoveOnlyDataType>::Make();
- AssertNotFinished(fut);
- fut.MarkFinished(Result<MoveOnlyDataType>(MoveOnlyDataType(43)));
- AssertSuccessful(fut);
- ASSERT_OK_AND_ASSIGN(MoveOnlyDataType value, std::move(fut).result());
- ASSERT_EQ(value, 43);
+ auto sum_fut = Loop(loop_body);
+ AssertSuccessful(sum_fut);
+
+ ASSERT_OK_AND_ASSIGN(auto sum, sum_fut.result());
+ ASSERT_EQ(sum, 0 + 1 + 2 + 3 + 4);
}
+
{
- // MarkFinished(failed Result<MoveOnlyDataType>)
- auto fut = Future<MoveOnlyDataType>::Make();
- AssertNotFinished(fut);
- fut.MarkFinished(Result<MoveOnlyDataType>(Status::IOError("xxx")));
- AssertFailed(fut);
- ASSERT_RAISES(IOError, fut.status());
- const auto& res = fut.result();
- ASSERT_TRUE(res.status().IsIOError());
- ASSERT_RAISES(IOError, std::move(fut).result());
+ do_fail = true;
+ IntSource.i = 0;
+ auto sum_fut = Loop(loop_body);
+ AssertFailed(sum_fut);
+ ASSERT_RAISES(IOError, sum_fut.result());
}
}
-TEST(FutureSyncTest, Empty) {
- {
- // MarkFinished()
- auto fut = Future<>::Make();
- AssertNotFinished(fut);
- fut.MarkFinished();
- AssertSuccessful(fut);
+TEST(FutureLoopTest, EmptyBreakValue) {
+ Future<> none_fut =
+ Loop([&] { return Future<>::MakeFinished().Then([&](...) { return
Break(); }); });
+ AssertSuccessful(none_fut);
+}
+
+TEST(FutureLoopTest, EmptyLoop) {
+ auto loop_body = []() -> Future<ControlFlow<int>> {
+ return Future<ControlFlow<int>>::MakeFinished(Break(0));
+ };
+ auto loop_fut = Loop(loop_body);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto loop_res, loop_fut);
+ ASSERT_EQ(loop_res, 0);
+}
+
+// TODO - Test provided by Ben but I don't understand how it can pass
legitimately.
+// Any future result will be passed by reference to the callbacks (as there
can be
+// multiple callbacks). In the Loop construct it takes the break and forwards
it
+// on to the outer future. Since there is no way to move a reference this can
only
+// be done by copying.
+//
+// In theory it should be safe since Loop is guaranteed to be the last
callback added
+// to the control future and so the value can be safely moved at that point.
However,
+// I'm unable to reproduce whatever trick you had in ControlFlow to make this
work.
+// If we want to formalize this "last callback can steal" concept then we
could add
+// a "last callback" to Future which gets called with an rvalue instead of an
lvalue
+// reference but that seems overly complicated.
+//
+// Ben, can you recreate whatever trick you had in place before that allowed
this to
+// pass? Perhaps some kind of cast. Worst case, I can move back to using
+// ControlFlow instead of std::optional
+//
+// TEST(FutureLoopTest, MoveOnlyBreakValue) {
+// Future<MoveOnlyDataType> one_fut = Loop([&] {
+// return Future<int>::MakeFinished(1).Then(
+// [&](int i) { return Break(MoveOnlyDataType(i)); });
+// });
+// AssertSuccessful(one_fut);
+// ASSERT_OK_AND_ASSIGN(auto one, std::move(one_fut).result());
+// ASSERT_EQ(one, 1);
+// }
+
+TEST(FutureLoopTest, StackOverflow) {
+ // Looping over futures is normally a rather recursive task. If the futures
complete
+ // synchronously (because they are already finished) it could lead to a
stack overflow
+ // if care is not taken.
+ int counter = 0;
+ auto loop_body = [&counter]() -> Future<ControlFlow<int>> {
+ while (counter < 1000000) {
+ counter++;
+ return Future<ControlFlow<int>>::MakeFinished(Continue());
+ }
+ return Future<ControlFlow<int>>::MakeFinished(Break(-1));
+ };
+ auto loop_fut = Loop(loop_body);
+ ASSERT_TRUE(loop_fut.Wait(0.1));
+}
+
+TEST(FutureLoopTest, AllowsBreakFutToBeDiscarded) {
+ int counter = 0;
+ auto loop_body = [&counter]() -> Future<ControlFlow<int>> {
+ while (counter < 10) {
+ counter++;
+ return Future<ControlFlow<int>>::MakeFinished(Continue());
+ }
+ return Future<ControlFlow<int>>::MakeFinished(Break(-1));
+ };
+ auto loop_fut = Loop(loop_body).Then([](...) { return Status::OK(); });
+ ASSERT_TRUE(loop_fut.Wait(0.1));
+}
+
+class MoveTrackingCallable {
+ public:
+ MoveTrackingCallable() {
+ // std::cout << "CONSTRUCT" << std::endl;
}
- {
- // MakeFinished()
- auto fut = Future<>::MakeFinished();
- AssertSuccessful(fut);
- auto res = fut.result();
- ASSERT_OK(res);
- res = std::move(fut.result());
- ASSERT_OK(res);
+ ~MoveTrackingCallable() {
+ valid_ = false;
+ // std::cout << "DESTRUCT" << std::endl;
}
- {
- // MarkFinished(Status)
- auto fut = Future<>::Make();
- AssertNotFinished(fut);
- fut.MarkFinished(Status::OK());
- AssertSuccessful(fut);
+ MoveTrackingCallable(const MoveTrackingCallable& other) {
+ // std::cout << "COPY CONSTRUCT" << std::endl;
}
- {
- // MakeFinished(Status)
- auto fut = Future<>::MakeFinished(Status::OK());
- AssertSuccessful(fut);
- fut = Future<>::MakeFinished(Status::IOError("xxx"));
- AssertFailed(fut);
+ MoveTrackingCallable(MoveTrackingCallable&& other) {
+ other.valid_ = false;
+ // std::cout << "MOVE CONSTRUCT" << std::endl;
}
- {
- // MarkFinished(Status)
- auto fut = Future<>::Make();
- AssertNotFinished(fut);
- fut.MarkFinished(Status::IOError("xxx"));
- AssertFailed(fut);
- ASSERT_RAISES(IOError, fut.status());
+ MoveTrackingCallable& operator=(const MoveTrackingCallable& other) {
+ // std::cout << "COPY ASSIGN" << std::endl;
+ return *this;
+ }
+ MoveTrackingCallable& operator=(MoveTrackingCallable&& other) {
+ other.valid_ = false;
+ // std::cout << "MOVE ASSIGN" << std::endl;
+ return *this;
}
-}
-TEST(FutureSyncTest, GetStatusFuture) {
- {
- auto fut = Future<MoveOnlyDataType>::Make();
- Future<> status_future(fut);
+ Status operator()(...) {
+ // std::cout << "TRIGGER" << std::endl;
+ if (valid_) {
+ return Status::OK();
+ } else {
+ return Status::Invalid("Invalid callback triggered");
+ }
+ }
- AssertNotFinished(fut);
- AssertNotFinished(status_future);
+ private:
+ bool valid_ = true;
+};
- fut.MarkFinished(MoveOnlyDataType(42));
- AssertSuccessful(fut);
- AssertSuccessful(status_future);
- ASSERT_EQ(&fut.status(), &status_future.status());
- }
+TEST(FutureCompletionTest, ReuseCallback) {
+ auto fut = Future<>::Make();
+
+ Future<> continuation;
{
- auto fut = Future<MoveOnlyDataType>::Make();
- Future<> status_future(fut);
+ MoveTrackingCallable callback;
+ continuation = fut.Then(callback);
+ }
- AssertNotFinished(fut);
- AssertNotFinished(status_future);
+ fut.MarkFinished(Status::OK());
- fut.MarkFinished(Status::IOError("xxx"));
- AssertFailed(fut);
- AssertFailed(status_future);
- ASSERT_EQ(&fut.status(), &status_future.status());
+ ASSERT_TRUE(continuation.is_finished());
+ if (continuation.is_finished()) {
+ ASSERT_OK(continuation.status());
}
}
@@ -1287,34 +1513,34 @@ class FutureTestBase : public ::testing::Test {
};
template <typename T>
-class FutureTest : public FutureTestBase<T> {};
+class FutureWaitTest : public FutureTestBase<T> {};
-using FutureTestTypes = ::testing::Types<int, Foo, MoveOnlyDataType>;
+using FutureWaitTestTypes = ::testing::Types<int, Foo, MoveOnlyDataType>;
-TYPED_TEST_SUITE(FutureTest, FutureTestTypes);
+TYPED_TEST_SUITE(FutureWaitTest, FutureWaitTestTypes);
-TYPED_TEST(FutureTest, BasicWait) { this->TestBasicWait(); }
+TYPED_TEST(FutureWaitTest, BasicWait) { this->TestBasicWait(); }
-TYPED_TEST(FutureTest, TimedWait) { this->TestTimedWait(); }
+TYPED_TEST(FutureWaitTest, TimedWait) { this->TestTimedWait(); }
-TYPED_TEST(FutureTest, StressWait) { this->TestStressWait(); }
+TYPED_TEST(FutureWaitTest, StressWait) { this->TestStressWait(); }
-TYPED_TEST(FutureTest, BasicWaitForAny) { this->TestBasicWaitForAny(); }
+TYPED_TEST(FutureWaitTest, BasicWaitForAny) { this->TestBasicWaitForAny(); }
-TYPED_TEST(FutureTest, TimedWaitForAny) { this->TestTimedWaitForAny(); }
+TYPED_TEST(FutureWaitTest, TimedWaitForAny) { this->TestTimedWaitForAny(); }
-TYPED_TEST(FutureTest, StressWaitForAny) { this->TestStressWaitForAny(); }
+TYPED_TEST(FutureWaitTest, StressWaitForAny) { this->TestStressWaitForAny(); }
-TYPED_TEST(FutureTest, BasicWaitForAll) { this->TestBasicWaitForAll(); }
+TYPED_TEST(FutureWaitTest, BasicWaitForAll) { this->TestBasicWaitForAll(); }
-TYPED_TEST(FutureTest, TimedWaitForAll) { this->TestTimedWaitForAll(); }
+TYPED_TEST(FutureWaitTest, TimedWaitForAll) { this->TestTimedWaitForAll(); }
-TYPED_TEST(FutureTest, StressWaitForAll) { this->TestStressWaitForAll(); }
+TYPED_TEST(FutureWaitTest, StressWaitForAll) { this->TestStressWaitForAll(); }
template <typename T>
class FutureIteratorTest : public FutureTestBase<T> {};
-using FutureIteratorTestTypes = ::testing::Types<Foo, MoveOnlyDataType>;
+using FutureIteratorTestTypes = ::testing::Types<Foo>;
TYPED_TEST_SUITE(FutureIteratorTest, FutureIteratorTestTypes);
diff --git a/cpp/src/arrow/util/iterator.cc b/cpp/src/arrow/util/iterator.cc
deleted file mode 100644
index 0c71bba..0000000
--- a/cpp/src/arrow/util/iterator.cc
+++ /dev/null
@@ -1,175 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#include "arrow/util/iterator.h"
-
-#include <condition_variable>
-#include <cstdint>
-#include <deque>
-#include <mutex>
-#include <thread>
-
-#include "arrow/util/logging.h"
-
-namespace arrow {
-namespace detail {
-
-ReadaheadPromise::~ReadaheadPromise() {}
-
-class ReadaheadQueue::Impl : public
std::enable_shared_from_this<ReadaheadQueue::Impl> {
- public:
- explicit Impl(int64_t readahead_queue_size) :
max_readahead_(readahead_queue_size) {}
-
- ~Impl() { EnsureShutdownOrDie(false); }
-
- void Start() {
- // Cannot do this in constructor as shared_from_this() would throw
- DCHECK(!thread_.joinable());
- auto self = shared_from_this();
- thread_ = std::thread([self]() { self->DoWork(); });
- DCHECK(thread_.joinable());
- }
-
- void EnsureShutdownOrDie(bool wait = true) {
- std::unique_lock<std::mutex> lock(mutex_);
- if (!please_shutdown_) {
- ARROW_CHECK_OK(ShutdownUnlocked(std::move(lock), wait));
- }
- DCHECK(!thread_.joinable());
- }
-
- Status Append(std::unique_ptr<ReadaheadPromise> promise) {
- std::unique_lock<std::mutex> lock(mutex_);
- if (please_shutdown_) {
- return Status::Invalid("Shutdown requested");
- }
- todo_.push_back(std::move(promise));
- if (static_cast<int64_t>(todo_.size()) == 1) {
- // Signal there's more work to do
- lock.unlock();
- worker_wakeup_.notify_one();
- }
- return Status::OK();
- }
-
- Status PopDone(std::unique_ptr<ReadaheadPromise>* out) {
- std::unique_lock<std::mutex> lock(mutex_);
- if (please_shutdown_) {
- return Status::Invalid("Shutdown requested");
- }
- work_done_.wait(lock, [this]() { return done_.size() > 0; });
- *out = std::move(done_.front());
- done_.pop_front();
- if (static_cast<int64_t>(done_.size()) < max_readahead_) {
- // Signal there's more work to do
- lock.unlock();
- worker_wakeup_.notify_one();
- }
- return Status::OK();
- }
-
- Status Pump(std::function<std::unique_ptr<ReadaheadPromise>()> factory) {
- std::unique_lock<std::mutex> lock(mutex_);
- if (please_shutdown_) {
- return Status::Invalid("Shutdown requested");
- }
- while (static_cast<int64_t>(done_.size() + todo_.size()) < max_readahead_)
{
- todo_.push_back(factory());
- }
- // Signal there's more work to do
- lock.unlock();
- worker_wakeup_.notify_one();
- return Status::OK();
- }
-
- Status Shutdown(bool wait = true) {
- return ShutdownUnlocked(std::unique_lock<std::mutex>(mutex_), wait);
- }
-
- Status ShutdownUnlocked(std::unique_lock<std::mutex> lock, bool wait = true)
{
- if (please_shutdown_) {
- return Status::Invalid("Shutdown already requested");
- }
- DCHECK(thread_.joinable());
- please_shutdown_ = true;
- lock.unlock();
- worker_wakeup_.notify_one();
- if (wait) {
- thread_.join();
- } else {
- thread_.detach();
- }
- return Status::OK();
- }
-
- void DoWork() {
- std::unique_lock<std::mutex> lock(mutex_);
- while (!please_shutdown_) {
- while (static_cast<int64_t>(done_.size()) < max_readahead_ &&
todo_.size() > 0) {
- auto promise = std::move(todo_.front());
- todo_.pop_front();
- lock.unlock();
- promise->Call();
- lock.lock();
- done_.push_back(std::move(promise));
- work_done_.notify_one();
- // Exit eagerly
- if (please_shutdown_) {
- return;
- }
- }
- // Wait for more work to do
- worker_wakeup_.wait(lock);
- }
- }
-
- std::deque<std::unique_ptr<ReadaheadPromise>> todo_;
- std::deque<std::unique_ptr<ReadaheadPromise>> done_;
- int64_t max_readahead_;
- bool please_shutdown_ = false;
-
- std::thread thread_;
- std::mutex mutex_;
- std::condition_variable worker_wakeup_;
- std::condition_variable work_done_;
-};
-
-ReadaheadQueue::ReadaheadQueue(int readahead_queue_size)
- : impl_(new Impl(readahead_queue_size)) {
- impl_->Start();
-}
-
-ReadaheadQueue::~ReadaheadQueue() {}
-
-Status ReadaheadQueue::Append(std::unique_ptr<ReadaheadPromise> promise) {
- return impl_->Append(std::move(promise));
-}
-
-Status ReadaheadQueue::PopDone(std::unique_ptr<ReadaheadPromise>* out) {
- return impl_->PopDone(out);
-}
-
-Status ReadaheadQueue::Pump(std::function<std::unique_ptr<ReadaheadPromise>()>
factory) {
- return impl_->Pump(std::move(factory));
-}
-
-Status ReadaheadQueue::Shutdown() { return impl_->Shutdown(); }
-
-void ReadaheadQueue::EnsureShutdownOrDie() { return
impl_->EnsureShutdownOrDie(); }
-
-} // namespace detail
-} // namespace arrow
diff --git a/cpp/src/arrow/util/iterator.h b/cpp/src/arrow/util/iterator.h
index 58dda5d..75ccf28 100644
--- a/cpp/src/arrow/util/iterator.h
+++ b/cpp/src/arrow/util/iterator.h
@@ -20,6 +20,7 @@
#include <cassert>
#include <functional>
#include <memory>
+#include <queue>
#include <tuple>
#include <type_traits>
#include <utility>
@@ -187,6 +188,127 @@ class Iterator : public
util::EqualityComparable<Iterator<T>> {
};
template <typename T>
+struct TransformFlow {
+ using YieldValueType = T;
+
+ TransformFlow(YieldValueType value, bool ready_for_next)
+ : finished_(false),
+ ready_for_next_(ready_for_next),
+ yield_value_(std::move(value)) {}
+ TransformFlow(bool finished, bool ready_for_next)
+ : finished_(finished), ready_for_next_(ready_for_next), yield_value_() {}
+
+ bool HasValue() const { return yield_value_.has_value(); }
+ bool Finished() const { return finished_; }
+ bool ReadyForNext() const { return ready_for_next_; }
+ T Value() const { return *yield_value_; }
+
+ bool finished_ = false;
+ bool ready_for_next_ = false;
+ util::optional<YieldValueType> yield_value_;
+};
+
+struct TransformFinish {
+ template <typename T>
+ operator TransformFlow<T>() && { // NOLINT explicit
+ return TransformFlow<T>(true, true);
+ }
+};
+
+struct TransformSkip {
+ template <typename T>
+ operator TransformFlow<T>() && { // NOLINT explicit
+ return TransformFlow<T>(false, true);
+ }
+};
+
+template <typename T>
+TransformFlow<T> TransformYield(T value = {}, bool ready_for_next = true) {
+ return TransformFlow<T>(std::move(value), ready_for_next);
+}
+
+template <typename T, typename V>
+using Transformer = std::function<Result<TransformFlow<V>>(T)>;
+
+template <typename T, typename V>
+class TransformIterator {
+ public:
+ explicit TransformIterator(Iterator<T> it, Transformer<T, V> transformer)
+ : it_(std::move(it)),
+ transformer_(std::move(transformer)),
+ last_value_(),
+ finished_() {}
+
+ Result<V> Next() {
+ while (!finished_) {
+ ARROW_ASSIGN_OR_RAISE(util::optional<V> next, Pump());
+ if (next.has_value()) {
+ return std::move(*next);
+ }
+ ARROW_ASSIGN_OR_RAISE(last_value_, it_.Next());
+ }
+ return IterationTraits<V>::End();
+ }
+
+ private:
+ // Calls the transform function on the current value. Can return in several
ways
+ // * If the next value is requested (e.g. skip) it will return an empty
optional
+ // * If an invalid status is encountered that will be returned
+ // * If finished it will return IterationTraits<V>::End()
+ // * If a value is returned by the transformer that will be returned
+ Result<util::optional<V>> Pump() {
+ if (!finished_ && last_value_.has_value()) {
+ auto next_res = transformer_(*last_value_);
+ if (!next_res.ok()) {
+ finished_ = true;
+ return next_res.status();
+ }
+ auto next = *next_res;
+ if (next.ReadyForNext()) {
+ if (*last_value_ == IterationTraits<T>::End()) {
+ finished_ = true;
+ }
+ last_value_.reset();
+ }
+ if (next.Finished()) {
+ finished_ = true;
+ }
+ if (next.HasValue()) {
+ return next.Value();
+ }
+ }
+ if (finished_) {
+ return IterationTraits<V>::End();
+ }
+ return util::nullopt;
+ }
+
+ Iterator<T> it_;
+ Transformer<T, V> transformer_;
+ util::optional<T> last_value_;
+ bool finished_ = false;
+};
+
+/// \brief Transforms an iterator according to a transformer, returning a new
Iterator.
+///
+/// The transformer will be called on each element of the source iterator and
for each
+/// call it can yield a value, skip, or finish the iteration. When yielding a
value the
+/// transformer can choose to consume the source item (the default,
ready_for_next = true)
+/// or to keep it and it will be called again on the same value.
+///
+/// This is essentially a more generic form of the map operation that can
return 0, 1, or
+/// many values for each of the source items.
+///
+/// The transformer will be exposed to the end of the source sequence
+/// (IterationTraits::End) in case it needs to return some penultimate item(s).
+///
+/// Any invalid status returned by the transformer will be returned
immediately.
+template <typename T, typename V>
+Iterator<V> MakeTransformedIterator(Iterator<T> it, Transformer<T, V> op) {
+ return Iterator<V>(TransformIterator<T, V>(std::move(it), std::move(op)));
+}
+
+template <typename T>
struct IterationTraits<Iterator<T>> {
// The end condition for an Iterator of Iterators is a default constructed
(null)
// Iterator.
@@ -414,117 +536,4 @@ Iterator<T> MakeFlattenIterator(Iterator<Iterator<T>> it)
{
return Iterator<T>(FlattenIterator<T>(std::move(it)));
}
-namespace detail {
-
-// A type-erased promise object for ReadaheadQueue.
-struct ARROW_EXPORT ReadaheadPromise {
- virtual ~ReadaheadPromise();
- virtual void Call() = 0;
-};
-
-template <typename T>
-struct ReadaheadIteratorPromise : ReadaheadPromise {
- ~ReadaheadIteratorPromise() override {}
-
- explicit ReadaheadIteratorPromise(Iterator<T>* it) : it_(it) {}
-
- void Call() override {
- assert(!called_);
- out_ = it_->Next();
- called_ = true;
- }
-
- Iterator<T>* it_;
- Result<T> out_ = IterationTraits<T>::End();
- bool called_ = false;
-};
-
-class ARROW_EXPORT ReadaheadQueue {
- public:
- explicit ReadaheadQueue(int readahead_queue_size);
- ~ReadaheadQueue();
-
- Status Append(std::unique_ptr<ReadaheadPromise>);
- Status PopDone(std::unique_ptr<ReadaheadPromise>*);
- Status Pump(std::function<std::unique_ptr<ReadaheadPromise>()> factory);
- Status Shutdown();
- void EnsureShutdownOrDie();
-
- protected:
- class Impl;
- std::shared_ptr<Impl> impl_;
-};
-
-} // namespace detail
-
-/// \brief Readahead iterator that iterates on the underlying iterator in a
-/// separate thread, getting up to N values in advance.
-template <typename T>
-class ReadaheadIterator {
- using PromiseType = typename detail::ReadaheadIteratorPromise<T>;
-
- public:
- // Public default constructor creates an empty iterator
- ReadaheadIterator() : done_(true) {}
-
- ~ReadaheadIterator() {
- if (queue_) {
- // Make sure the queue doesn't call any promises after this object
- // is destroyed.
- queue_->EnsureShutdownOrDie();
- }
- }
-
- ARROW_DEFAULT_MOVE_AND_ASSIGN(ReadaheadIterator);
- ARROW_DISALLOW_COPY_AND_ASSIGN(ReadaheadIterator);
-
- Result<T> Next() {
- if (done_) {
- return IterationTraits<T>::End();
- }
-
- std::unique_ptr<detail::ReadaheadPromise> promise;
- ARROW_RETURN_NOT_OK(queue_->PopDone(&promise));
- auto it_promise = static_cast<PromiseType*>(promise.get());
-
- ARROW_RETURN_NOT_OK(queue_->Append(MakePromise()));
-
- ARROW_ASSIGN_OR_RAISE(auto out, it_promise->out_);
- if (out == IterationTraits<T>::End()) {
- done_ = true;
- }
- return out;
- }
-
- static Result<Iterator<T>> Make(Iterator<T> it, int readahead_queue_size) {
- ReadaheadIterator rh(std::move(it), readahead_queue_size);
- ARROW_RETURN_NOT_OK(rh.Pump());
- return Iterator<T>(std::move(rh));
- }
-
- private:
- explicit ReadaheadIterator(Iterator<T> it, int readahead_queue_size)
- : it_(new Iterator<T>(std::move(it))),
- queue_(new detail::ReadaheadQueue(readahead_queue_size)) {}
-
- Status Pump() {
- return queue_->Pump([this]() { return MakePromise(); });
- }
-
- std::unique_ptr<detail::ReadaheadPromise> MakePromise() {
- return std::unique_ptr<detail::ReadaheadPromise>(new
PromiseType{it_.get()});
- }
-
- // The underlying iterator is referenced by pointer in ReadaheadPromise,
- // so make sure it doesn't move.
- std::unique_ptr<Iterator<T>> it_;
- std::unique_ptr<detail::ReadaheadQueue> queue_;
- bool done_ = false;
-};
-
-template <typename T>
-Result<Iterator<T>> MakeReadaheadIterator(Iterator<T> it, int
readahead_queue_size) {
- return ReadaheadIterator<T>::Make(std::move(it), readahead_queue_size);
-}
-
} // namespace arrow
diff --git a/cpp/src/arrow/util/iterator_test.cc
b/cpp/src/arrow/util/iterator_test.cc
index 7295627..322611b 100644
--- a/cpp/src/arrow/util/iterator_test.cc
+++ b/cpp/src/arrow/util/iterator_test.cc
@@ -15,8 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-#include "arrow/util/iterator.h"
-
#include <algorithm>
#include <chrono>
#include <condition_variable>
@@ -28,6 +26,8 @@
#include <vector>
#include "arrow/testing/gtest_util.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/iterator.h"
namespace arrow {
@@ -49,6 +49,32 @@ struct IterationTraits<TestInt> {
static TestInt End() { return TestInt(); }
};
+struct TestStr {
+ TestStr() : value("") {}
+ TestStr(const std::string& s) : value(s) {} // NOLINT runtime/explicit
+ TestStr(const char* s) : value(s) {} // NOLINT runtime/explicit
+ explicit TestStr(const TestInt& test_int) {
+ if (test_int == IterationTraits<TestInt>::End()) {
+ value = "";
+ } else {
+ value = std::to_string(test_int.value);
+ }
+ }
+ std::string value;
+
+ bool operator==(const TestStr& other) const { return value == other.value; }
+
+ friend std::ostream& operator<<(std::ostream& os, const TestStr& v) {
+ os << "{\"" << v.value << "\"}";
+ return os;
+ }
+};
+
+template <>
+struct IterationTraits<TestStr> {
+ static TestStr End() { return TestStr(); }
+};
+
template <typename T>
class TracingIterator {
public:
@@ -129,11 +155,45 @@ template <typename T>
inline Iterator<T> EmptyIt() {
return MakeEmptyIterator<T>();
}
-
inline Iterator<TestInt> VectorIt(std::vector<TestInt> v) {
return MakeVectorIterator<TestInt>(std::move(v));
}
+AsyncGenerator<TestInt> AsyncVectorIt(std::vector<TestInt> v) {
+ size_t index = 0;
+ return [index, v]() mutable -> Future<TestInt> {
+ if (index >= v.size()) {
+ return Future<TestInt>::MakeFinished(IterationTraits<TestInt>::End());
+ }
+ return Future<TestInt>::MakeFinished(v[index++]);
+ };
+}
+
+constexpr auto kYieldDuration = std::chrono::microseconds(50);
+
+// Yields items with a small pause between each one from a background thread
+std::function<Future<TestInt>()> BackgroundAsyncVectorIt(std::vector<TestInt>
v) {
+ auto pool = internal::GetCpuThreadPool();
+ auto iterator = VectorIt(v);
+ auto slow_iterator = MakeTransformedIterator<TestInt, TestInt>(
+ std::move(iterator), [](TestInt item) -> Result<TransformFlow<TestInt>> {
+ std::this_thread::sleep_for(kYieldDuration);
+ return TransformYield(item);
+ });
+ EXPECT_OK_AND_ASSIGN(auto background,
+
MakeBackgroundGenerator<TestInt>(std::move(slow_iterator),
+
internal::GetCpuThreadPool()));
+ return MakeTransferredGenerator(background, pool);
+}
+
+std::vector<TestInt> RangeVector(unsigned int max) {
+ std::vector<TestInt> range(max);
+ for (unsigned int i = 0; i < max; i++) {
+ range[i] = i;
+ }
+ return range;
+}
+
template <typename T>
inline Iterator<T> VectorIt(std::vector<T> v) {
return MakeVectorIterator<T>(std::move(v));
@@ -155,6 +215,13 @@ void AssertIteratorMatch(std::vector<T> expected,
Iterator<T> actual) {
}
template <typename T>
+void AssertAsyncGeneratorMatch(std::vector<T> expected, AsyncGenerator<T>
actual) {
+ auto vec_future = CollectAsyncGenerator(std::move(actual));
+ EXPECT_OK_AND_ASSIGN(auto vec, vec_future.result());
+ EXPECT_EQ(expected, vec);
+}
+
+template <typename T>
void AssertIteratorNoMatch(std::vector<T> expected, Iterator<T> actual) {
EXPECT_NE(expected, IteratorToVector(std::move(actual)));
}
@@ -170,6 +237,9 @@ void AssertIteratorExhausted(Iterator<T>& it) {
AssertIteratorNext(IterationTraits<T>::End(), it);
}
+// --------------------------------------------------------------------
+// Synchronous iterator tests
+
TEST(TestEmptyIterator, Basic) { AssertIteratorMatch({}, EmptyIt<TestInt>()); }
TEST(TestVectorIterator, Basic) {
@@ -214,6 +284,118 @@ TEST(TestVectorIterator, RangeForLoop) {
ASSERT_EQ(ints_it, ints.end());
}
+Transformer<TestInt, TestStr> MakeFirstN(int n) {
+ int remaining = n;
+ return [remaining](TestInt next) mutable -> Result<TransformFlow<TestStr>> {
+ if (remaining > 0) {
+ remaining--;
+ return TransformYield(TestStr(next));
+ }
+ return TransformFinish();
+ };
+}
+
+template <typename T>
+Transformer<T, T> MakeFirstNGeneric(int n) {
+ int remaining = n;
+ return [remaining](T next) mutable -> Result<TransformFlow<T>> {
+ if (remaining > 0) {
+ remaining--;
+ return TransformYield(next);
+ }
+ return TransformFinish();
+ };
+}
+
+TEST(TestIteratorTransform, Truncating) {
+ auto original = VectorIt({1, 2, 3});
+ auto truncated = MakeTransformedIterator(std::move(original), MakeFirstN(2));
+ AssertIteratorMatch({"1", "2"}, std::move(truncated));
+}
+
+TEST(TestIteratorTransform, TestPointer) {
+ auto original = VectorIt<std::shared_ptr<int>>(
+ {std::make_shared<int>(1), std::make_shared<int>(2),
std::make_shared<int>(3)});
+ auto truncated = MakeTransformedIterator(std::move(original),
+
MakeFirstNGeneric<std::shared_ptr<int>>(2));
+ ASSERT_OK_AND_ASSIGN(auto result, truncated.ToVector());
+ ASSERT_EQ(2, result.size());
+}
+
+TEST(TestIteratorTransform, TruncatingShort) {
+ // Tests the failsafe case where we never call Finish
+ auto original = VectorIt({1});
+ auto truncated =
+ MakeTransformedIterator<TestInt, TestStr>(std::move(original),
MakeFirstN(2));
+ AssertIteratorMatch({"1"}, std::move(truncated));
+}
+
+Transformer<TestInt, TestStr> MakeFilter(std::function<bool(TestInt&)> filter)
{
+ return [filter](TestInt next) -> Result<TransformFlow<TestStr>> {
+ if (filter(next)) {
+ return TransformYield(TestStr(next));
+ } else {
+ return TransformSkip();
+ }
+ };
+}
+
+TEST(TestIteratorTransform, SkipSome) {
+ // Exercises TransformSkip
+ auto original = VectorIt({1, 2, 3});
+ auto filter = MakeFilter([](TestInt& t) { return t.value != 2; });
+ auto filtered = MakeTransformedIterator(std::move(original), filter);
+ AssertIteratorMatch({"1", "3"}, std::move(filtered));
+}
+
+TEST(TestIteratorTransform, SkipAll) {
+ // Exercises TransformSkip
+ auto original = VectorIt({1, 2, 3});
+ auto filter = MakeFilter([](TestInt& t) { return false; });
+ auto filtered = MakeTransformedIterator(std::move(original), filter);
+ AssertIteratorMatch({}, std::move(filtered));
+}
+
+Transformer<TestInt, TestStr> MakeAbortOnSecond() {
+ int counter = 0;
+ return [counter](TestInt next) mutable -> Result<TransformFlow<TestStr>> {
+ if (counter++ == 1) {
+ return Status::Invalid("X");
+ }
+ return TransformYield(TestStr(next));
+ };
+}
+
+TEST(TestIteratorTransform, Abort) {
+ auto original = VectorIt({1, 2, 3});
+ auto transformed = MakeTransformedIterator(std::move(original),
MakeAbortOnSecond());
+ ASSERT_OK(transformed.Next());
+ ASSERT_RAISES(Invalid, transformed.Next());
+ ASSERT_OK_AND_ASSIGN(auto third, transformed.Next());
+ ASSERT_EQ(IterationTraits<TestStr>::End(), third);
+}
+
+template <typename T>
+Transformer<T, T> MakeRepeatN(int repeat_count) {
+ int current_repeat = 0;
+ return [repeat_count, current_repeat](T next) mutable ->
Result<TransformFlow<T>> {
+ current_repeat++;
+ bool ready_for_next = false;
+ if (current_repeat == repeat_count) {
+ current_repeat = 0;
+ ready_for_next = true;
+ }
+ return TransformYield(next, ready_for_next);
+ };
+}
+
+TEST(TestIteratorTransform, Repeating) {
+ auto original = VectorIt({1, 2, 3});
+ auto repeated = MakeTransformedIterator<TestInt,
TestInt>(std::move(original),
+
MakeRepeatN<TestInt>(2));
+ AssertIteratorMatch({1, 1, 2, 2, 3, 3}, std::move(repeated));
+}
+
TEST(TestFunctionIterator, RangeForLoop) {
int i = 0;
auto fails_at_3 = MakeFunctionIterator([&]() -> Result<TestInt> {
@@ -295,13 +477,6 @@ TEST(FlattenVectorIterator, Pyramid) {
AssertIteratorMatch({1, 2, 2, 3, 3, 3}, std::move(it));
}
-TEST(ReadaheadIterator, DefaultConstructor) {
- ReadaheadIterator<TestInt> it;
- TestInt v{42};
- ASSERT_OK_AND_ASSIGN(v, it.Next());
- ASSERT_EQ(v, TestInt());
-}
-
TEST(ReadaheadIterator, Empty) {
ASSERT_OK_AND_ASSIGN(auto it, MakeReadaheadIterator(VectorIt({}), 2));
AssertIteratorMatch({}, std::move(it));
@@ -329,13 +504,16 @@ TEST(ReadaheadIterator, Trace) {
ASSERT_OK_AND_ASSIGN(
auto it, MakeReadaheadIterator(Iterator<TestInt>(std::move(tracing_it)),
2));
- tracing->WaitForValues(2);
- SleepABit(); // check no further value is emitted
- tracing->AssertValuesEqual({1, 2});
+ SleepABit(); // Background iterator won't start pumping until first request
comes in
+ ASSERT_EQ(tracing->values().size(), 0);
+
+ AssertIteratorNext({1}, it); // Once we ask for one value we should get
that one value
+ // as well as 2 read ahead
- AssertIteratorNext({1}, it);
tracing->WaitForValues(3);
- SleepABit();
+ tracing->AssertValuesEqual({1, 2, 3});
+
+ SleepABit(); // No further values should be fetched
tracing->AssertValuesEqual({1, 2, 3});
AssertIteratorNext({2}, it);
@@ -383,13 +561,247 @@ TEST(ReadaheadIterator, NextError) {
ASSERT_RAISES(IOError, it.Next().status());
- AssertIteratorNext({1}, it);
- tracing->WaitForValues(3);
+ AssertIteratorExhausted(it);
SleepABit();
- tracing->AssertValuesEqual({1, 2, 3});
- AssertIteratorNext({2}, it);
- AssertIteratorNext({3}, it);
+ tracing->AssertValuesEqual({});
AssertIteratorExhausted(it);
}
+// --------------------------------------------------------------------
+// Asynchronous iterator tests
+
+TEST(TestAsyncUtil, Visit) {
+ auto generator = AsyncVectorIt({1, 2, 3});
+ unsigned int sum = 0;
+ auto sum_future = VisitAsyncGenerator<TestInt>(generator, [&sum](TestInt
item) {
+ sum += item.value;
+ return Status::OK();
+ });
+ ASSERT_TRUE(sum_future.is_finished());
+ ASSERT_EQ(6, sum);
+}
+
+TEST(TestAsyncUtil, Collect) {
+ std::vector<TestInt> expected = {1, 2, 3};
+ auto generator = AsyncVectorIt(expected);
+ auto collected = CollectAsyncGenerator(generator);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto collected_val, collected);
+ ASSERT_EQ(expected, collected_val);
+}
+
+TEST(TestAsyncUtil, SynchronousFinish) {
+ AsyncGenerator<TestInt> generator = []() {
+ return Future<TestInt>::MakeFinished(IterationTraits<TestInt>::End());
+ };
+ Transformer<TestInt, TestStr> skip_all = [](TestInt value) { return
TransformSkip(); };
+ auto transformed = MakeAsyncGenerator(generator, skip_all);
+ auto future = CollectAsyncGenerator(transformed);
+ ASSERT_TRUE(future.is_finished());
+ ASSERT_OK_AND_ASSIGN(auto actual, future.result());
+ ASSERT_EQ(std::vector<TestStr>(), actual);
+}
+
+TEST(TestAsyncUtil, GeneratorIterator) {
+ auto generator = BackgroundAsyncVectorIt({1, 2, 3});
+ ASSERT_OK_AND_ASSIGN(auto iterator,
MakeGeneratorIterator(std::move(generator)));
+ ASSERT_OK_AND_EQ(TestInt(1), iterator.Next());
+ ASSERT_OK_AND_EQ(TestInt(2), iterator.Next());
+ ASSERT_OK_AND_EQ(TestInt(3), iterator.Next());
+ ASSERT_OK_AND_EQ(IterationTraits<TestInt>::End(), iterator.Next());
+ ASSERT_OK_AND_EQ(IterationTraits<TestInt>::End(), iterator.Next());
+}
+
+TEST(TestAsyncUtil, MakeTransferredGenerator) {
+ std::mutex mutex;
+ std::condition_variable cv;
+ std::atomic<bool> finished(false);
+
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1));
+
+ // Needs to be a slow source to ensure we don't call Then on a completed
+ AsyncGenerator<TestInt> slow_generator = [&]() {
+ return thread_pool
+ ->Submit([&] {
+ std::unique_lock<std::mutex> lock(mutex);
+ cv.wait_for(lock, std::chrono::duration<double>(30),
+ [&] { return finished.load(); });
+ return IterationTraits<TestInt>::End();
+ })
+ .ValueOrDie();
+ };
+
+ auto transferred =
+ MakeTransferredGenerator<TestInt>(std::move(slow_generator),
thread_pool.get());
+
+ auto current_thread_id = std::this_thread::get_id();
+ auto fut = transferred().Then([¤t_thread_id](const Result<TestInt>&
result) {
+ ASSERT_NE(current_thread_id, std::this_thread::get_id());
+ });
+
+ {
+ std::lock_guard<std::mutex> lg(mutex);
+ finished.store(true);
+ }
+ cv.notify_one();
+ ASSERT_FINISHES_OK(fut);
+}
+
+// This test is too slow for valgrind
+#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER))
+
+TEST(TestAsyncUtil, StackOverflow) {
+ int counter = 0;
+ AsyncGenerator<TestInt> generator = [&counter]() {
+ if (counter < 1000000) {
+ return Future<TestInt>::MakeFinished(counter++);
+ } else {
+ return Future<TestInt>::MakeFinished(IterationTraits<TestInt>::End());
+ }
+ };
+ Transformer<TestInt, TestStr> discard =
+ [](TestInt next) -> Result<TransformFlow<TestStr>> { return
TransformSkip(); };
+ auto transformed = MakeAsyncGenerator(generator, discard);
+ auto collected_future = CollectAsyncGenerator(transformed);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, collected_future);
+ ASSERT_EQ(0, collected.size());
+}
+
+#endif
+
+TEST(TestAsyncUtil, Background) {
+ std::vector<TestInt> expected = {1, 2, 3};
+ auto background = BackgroundAsyncVectorIt(expected);
+ auto future = CollectAsyncGenerator(background);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, future);
+ ASSERT_EQ(expected, collected);
+}
+
+struct SlowEmptyIterator {
+ Result<TestInt> Next() {
+ if (called_) {
+ return Status::Invalid("Should not have been called twice");
+ }
+ SleepFor(0.1);
+ return IterationTraits<TestInt>::End();
+ }
+
+ private:
+ bool called_ = false;
+};
+
+TEST(TestAsyncUtil, BackgroundRepeatEnd) {
+ // Ensure that the background generator properly fulfills the asyncgenerator
contract
+ // and can be called after it ends.
+ ASSERT_OK_AND_ASSIGN(auto io_pool, internal::ThreadPool::Make(1));
+
+ auto iterator = Iterator<TestInt>(SlowEmptyIterator());
+ ASSERT_OK_AND_ASSIGN(auto background_gen,
+ MakeBackgroundGenerator(std::move(iterator),
io_pool.get()));
+
+ background_gen =
+ MakeTransferredGenerator(std::move(background_gen),
internal::GetCpuThreadPool());
+
+ auto one = background_gen();
+ auto two = background_gen();
+
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto one_fin, one);
+ ASSERT_EQ(IterationTraits<TestInt>::End(), one_fin);
+
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto two_fin, two);
+ ASSERT_EQ(IterationTraits<TestInt>::End(), two_fin);
+}
+
+TEST(TestAsyncUtil, CompleteBackgroundStressTest) {
+ auto expected = RangeVector(20);
+ std::vector<Future<std::vector<TestInt>>> futures;
+ for (unsigned int i = 0; i < 20; i++) {
+ auto background = BackgroundAsyncVectorIt(expected);
+ futures.push_back(CollectAsyncGenerator(background));
+ }
+ auto combined = All(futures);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto completed_vectors, combined);
+ for (std::size_t i = 0; i < completed_vectors.size(); i++) {
+ ASSERT_OK_AND_ASSIGN(auto vector, completed_vectors[i]);
+ ASSERT_EQ(vector, expected);
+ }
+}
+
+TEST(TestAsyncUtil, Readahead) {
+ int num_delivered = 0;
+ auto source = [&num_delivered]() {
+ if (num_delivered < 5) {
+ return Future<TestInt>::MakeFinished(num_delivered++);
+ } else {
+ return Future<TestInt>::MakeFinished(IterationTraits<TestInt>::End());
+ }
+ };
+ auto readahead = MakeReadaheadGenerator<TestInt>(source, 10);
+ // Should not pump until first item requested
+ ASSERT_EQ(0, num_delivered);
+
+ auto first = readahead();
+ // At this point the pumping should have happened
+ ASSERT_EQ(5, num_delivered);
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto first_val, first);
+ ASSERT_EQ(TestInt(0), first_val);
+
+ // Read the rest
+ for (int i = 0; i < 4; i++) {
+ auto next = readahead();
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto next_val, next);
+ ASSERT_EQ(TestInt(i + 1), next_val);
+ }
+
+ // Next should be end
+ auto last = readahead();
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto last_val, last);
+ ASSERT_EQ(IterationTraits<TestInt>::End(), last_val);
+}
+
+TEST(TestAsyncUtil, ReadaheadFailed) {
+ ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(4));
+ std::atomic<int32_t> counter(0);
+ // All tasks are a little slow. The first task fails.
+ // The readahead will have spawned 9 more tasks and they
+ // should all pass
+ auto source = [thread_pool, &counter]() -> Future<TestInt> {
+ auto count = counter++;
+ return *thread_pool->Submit([count]() -> Result<TestInt> {
+ if (count == 0) {
+ return Status::Invalid("X");
+ }
+ return TestInt(count);
+ });
+ };
+ auto readahead = MakeReadaheadGenerator<TestInt>(source, 10);
+ ASSERT_FINISHES_ERR(Invalid, readahead());
+ SleepABit();
+
+ for (int i = 0; i < 9; i++) {
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto next_val, readahead());
+ ASSERT_EQ(TestInt(i + 1), next_val);
+ }
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto after, readahead());
+
+ // It's possible that finished was set quickly and there
+ // are only 10 elements
+ if (after == IterationTraits<TestInt>::End()) {
+ return;
+ }
+
+ // It's also possible that finished was too slow and there
+ // ended up being 11 elements
+ ASSERT_EQ(TestInt(10), after);
+ // There can't be 12 elements because SleepABit will prevent it
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto definitely_last, readahead());
+ ASSERT_EQ(IterationTraits<TestInt>::End(), definitely_last);
+}
+
+TEST(TestAsyncIteratorTransform, SkipSome) {
+ auto original = AsyncVectorIt({1, 2, 3});
+ auto filter = MakeFilter([](TestInt& t) { return t.value != 2; });
+ auto filtered = MakeAsyncGenerator(std::move(original), filter);
+ AssertAsyncGeneratorMatch({"1", "3"}, std::move(filtered));
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/util/task_group.cc b/cpp/src/arrow/util/task_group.cc
index 8765602..a7b5592 100644
--- a/cpp/src/arrow/util/task_group.cc
+++ b/cpp/src/arrow/util/task_group.cc
@@ -54,6 +54,8 @@ class SerialTaskGroup : public TaskGroup {
return status_;
}
+ Future<> FinishAsync() override { return Future<>::MakeFinished(Finish()); }
+
int parallelism() override { return 1; }
Status status_;
@@ -114,6 +116,18 @@ class ThreadedTaskGroup : public TaskGroup {
return status_;
}
+ Future<> FinishAsync() override {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (!completion_future_.has_value()) {
+ if (nremaining_.load() == 0) {
+ completion_future_ = Future<>::MakeFinished(status_);
+ } else {
+ completion_future_ = Future<>::Make();
+ }
+ }
+ return *completion_future_;
+ }
+
int parallelism() override { return executor_->GetCapacity(); }
protected:
@@ -135,6 +149,21 @@ class ThreadedTaskGroup : public TaskGroup {
// before cv.notify_one() has returned
std::unique_lock<std::mutex> lock(mutex_);
cv_.notify_one();
+ if (completion_future_.has_value()) {
+ // MarkFinished could be slow. We don't want to call it while we are
holding
+ // the lock.
+ auto& future = *completion_future_;
+ const auto finished = completion_future_->is_finished();
+ const auto& status = status_;
+ // This will be redundant if the user calls Finish and not FinishAsync
+ if (!finished && !finished_) {
+ finished_ = true;
+ lock.unlock();
+ future.MarkFinished(status);
+ } else {
+ lock.unlock();
+ }
+ }
}
}
@@ -148,6 +177,7 @@ class ThreadedTaskGroup : public TaskGroup {
std::condition_variable cv_;
Status status_;
bool finished_ = false;
+ util::optional<Future<>> completion_future_;
};
std::shared_ptr<TaskGroup> TaskGroup::MakeSerial() {
diff --git a/cpp/src/arrow/util/task_group.h b/cpp/src/arrow/util/task_group.h
index db3265d..a6df43f 100644
--- a/cpp/src/arrow/util/task_group.h
+++ b/cpp/src/arrow/util/task_group.h
@@ -63,6 +63,20 @@ class ARROW_EXPORT TaskGroup : public
std::enable_shared_from_this<TaskGroup> {
/// task (or subgroup).
virtual Status Finish() = 0;
+ /// Returns a future that will complete the first time all tasks are
finished.
+ /// This should be called only after all top level tasks
+ /// have been added to the task group.
+ ///
+ /// If you are using a TaskGroup asynchronously there are a few
considerations to keep
+ /// in mind. The tasks should not block on I/O, etc (defeats the purpose of
using
+ /// futures) and should not be doing any nested locking or you run the risk
of the tasks
+ /// getting stuck in the thread pool waiting for tasks which cannot get
scheduled.
+ ///
+ /// Primarily this call is intended to help migrate existing work written
with TaskGroup
+ /// in mind to using futures without having to do a complete conversion on
the first
+ /// pass.
+ virtual Future<> FinishAsync() = 0;
+
/// The current aggregate error Status. Non-blocking, useful for stopping
early.
virtual Status current_status() = 0;
diff --git a/cpp/src/arrow/util/task_group_test.cc
b/cpp/src/arrow/util/task_group_test.cc
index 1e47a34..38f4b21 100644
--- a/cpp/src/arrow/util/task_group_test.cc
+++ b/cpp/src/arrow/util/task_group_test.cc
@@ -17,6 +17,7 @@
#include <atomic>
#include <chrono>
+#include <condition_variable>
#include <cstdint>
#include <functional>
#include <memory>
@@ -243,6 +244,68 @@ void TestNoCopyTask(std::shared_ptr<TaskGroup> task_group)
{
ASSERT_EQ(0, *counter);
}
+void TestFinishNotSticky(std::function<std::shared_ptr<TaskGroup>()> factory) {
+ // If a task is added that runs very quickly it might decrement the task
counter back
+ // down to 0 and mark the completion future as complete before all tasks are
added.
+ // The "finished future" of the task group could get stuck to complete.
+ //
+ // Instead the task group should not allow the finished future to be marked
complete
+ // until after FinishAsync has been called.
+ const int NTASKS = 100;
+ for (int i = 0; i < NTASKS; ++i) {
+ auto task_group = factory();
+ // Add a task and let it complete
+ task_group->Append([] { return Status::OK(); });
+ // Wait a little bit, if the task group was going to lock the finish
hopefully it
+ // would do so here while we wait
+ SleepFor(1e-2);
+
+ // Add a new task that will still be running
+ std::atomic<bool> ready(false);
+ std::mutex m;
+ std::condition_variable cv;
+ task_group->Append([&m, &cv, &ready] {
+ std::unique_lock<std::mutex> lk(m);
+ cv.wait(lk, [&ready] { return ready.load(); });
+ return Status::OK();
+ });
+
+ // Ensure task group not finished already
+ auto finished = task_group->FinishAsync();
+ ASSERT_FALSE(finished.is_finished());
+
+ std::unique_lock<std::mutex> lk(m);
+ ready = true;
+ lk.unlock();
+ cv.notify_one();
+
+ ASSERT_FINISHES_OK(finished);
+ }
+}
+
+void TestFinishNeverStarted(std::shared_ptr<TaskGroup> task_group) {
+ // If we call FinishAsync we are done adding tasks so if we never added any
it should be
+ // completed
+ auto finished = task_group->FinishAsync();
+ ASSERT_TRUE(finished.Wait(1));
+}
+
+void TestFinishAlreadyCompleted(std::function<std::shared_ptr<TaskGroup>()>
factory) {
+ // If we call FinishAsync we are done adding tasks so even if no tasks are
running we
+ // should still be completed
+ const int NTASKS = 100;
+ for (int i = 0; i < NTASKS; ++i) {
+ auto task_group = factory();
+ // Add a task and let it complete
+ task_group->Append([] { return Status::OK(); });
+ // Wait a little bit, hopefully enough time for the task to finish on one
of these
+ // iterations
+ SleepFor(1e-2);
+ auto finished = task_group->FinishAsync();
+ ASSERT_FINISHES_OK(finished);
+ }
+}
+
TEST(SerialTaskGroup, Success) {
TestTaskGroupSuccess(TaskGroup::MakeSerial()); }
TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); }
@@ -251,6 +314,14 @@ TEST(SerialTaskGroup, TasksSpawnTasks) {
TestTasksSpawnTasks(TaskGroup::MakeSeri
TEST(SerialTaskGroup, NoCopyTask) { TestNoCopyTask(TaskGroup::MakeSerial()); }
+TEST(SerialTaskGroup, FinishNeverStarted) {
+ TestFinishNeverStarted(TaskGroup::MakeSerial());
+}
+
+TEST(SerialTaskGroup, FinishAlreadyCompleted) {
+ TestFinishAlreadyCompleted([] { return TaskGroup::MakeSerial(); });
+}
+
TEST(ThreadedTaskGroup, Success) {
auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool());
TestTaskGroupSuccess(task_group);
@@ -291,5 +362,25 @@ TEST(ThreadedTaskGroup, StressFailingTaskGroupLifetime) {
[&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
}
+TEST(ThreadedTaskGroup, FinishNotSticky) {
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
+
+ TestFinishNotSticky([&] { return TaskGroup::MakeThreaded(thread_pool.get());
});
+}
+
+TEST(ThreadedTaskGroup, FinishNeverStarted) {
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
+ TestFinishNeverStarted(TaskGroup::MakeThreaded(thread_pool.get()));
+}
+
+TEST(ThreadedTaskGroup, FinishAlreadyCompleted) {
+ std::shared_ptr<ThreadPool> thread_pool;
+ ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
+
+ TestFinishAlreadyCompleted([&] { return
TaskGroup::MakeThreaded(thread_pool.get()); });
+}
+
} // namespace internal
} // namespace arrow
diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h
index 03b925d..5db3a9a 100644
--- a/cpp/src/arrow/util/thread_pool.h
+++ b/cpp/src/arrow/util/thread_pool.h
@@ -86,6 +86,28 @@ class ARROW_EXPORT Executor {
return SpawnReal(hints, std::forward<Function>(func));
}
+ // Transfers a future to this executor. Any continuations added to the
+ // returned future will run in this executor. Otherwise they would run
+ // on the same thread that called MarkFinished.
+ //
+ // This is necessary when (for example) an I/O task is completing a future.
+ // The continuations of that future should run on the CPU thread pool keeping
+ // CPU heavy work off the I/O thread pool. So the I/O task should transfer
+ // the future to the CPU executor before returning.
+ template <typename T>
+ Future<T> Transfer(Future<T> future) {
+ auto transferred = Future<T>::Make();
+ future.AddCallback([this, transferred](const Result<T>& result) mutable {
+ auto spawn_status = Spawn([transferred, result]() mutable {
+ transferred.MarkFinished(std::move(result));
+ });
+ if (!spawn_status.ok()) {
+ transferred.MarkFinished(spawn_status);
+ }
+ });
+ return transferred;
+ }
+
// Submit a callable and arguments for execution. Return a future that
// will return the callable's result value once.
// The callable's arguments are copied before execution.
diff --git a/docs/source/cpp/csv.rst b/docs/source/cpp/csv.rst
index 9f17d56..44dc149 100644
--- a/docs/source/cpp/csv.rst
+++ b/docs/source/cpp/csv.rst
@@ -42,6 +42,7 @@ A CSV file is read from a :class:`~arrow::io::InputStream`.
{
// ...
arrow::MemoryPool* pool = default_memory_pool();
+ arrow::io::AsyncContext async_context;
std::shared_ptr<arrow::io::InputStream> input = ...;
auto read_options = arrow::csv::ReadOptions::Defaults();
@@ -51,6 +52,7 @@ A CSV file is read from a :class:`~arrow::io::InputStream`.
// Instantiate TableReader from input stream and options
auto maybe_reader =
arrow::csv::TableReader::Make(pool,
+ async_context,
input,
read_options,
parse_options,
diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx
index 34c6693..4068a0b 100644
--- a/python/pyarrow/_csv.pyx
+++ b/python/pyarrow/_csv.pyx
@@ -700,6 +700,7 @@ def read_csv(input_file, read_options=None,
parse_options=None,
CCSVConvertOptions c_convert_options
shared_ptr[CCSVReader] reader
shared_ptr[CTable] table
+ CAsyncContext c_async_ctx = CAsyncContext()
_get_reader(input_file, read_options, &stream)
_get_read_options(read_options, &c_read_options)
@@ -707,7 +708,7 @@ def read_csv(input_file, read_options=None,
parse_options=None,
_get_convert_options(convert_options, &c_convert_options)
reader = GetResultValue(CCSVReader.Make(
- maybe_unbox_memory_pool(memory_pool), stream,
+ maybe_unbox_memory_pool(memory_pool), c_async_ctx, stream,
c_read_options, c_parse_options, c_convert_options))
with nogil:
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index 41159bd..6c1c7f6 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1140,6 +1140,9 @@ cdef extern from "arrow/io/api.h" namespace "arrow::io"
nogil:
ObjectType_FILE" arrow::io::ObjectType::FILE"
ObjectType_DIRECTORY" arrow::io::ObjectType::DIRECTORY"
+ cdef cppclass CAsyncContext" arrow::io::AsyncContext":
+ CAsyncContext()
+
cdef cppclass FileStatistics:
int64_t size
ObjectType kind
@@ -1618,7 +1621,7 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv"
nogil:
cdef cppclass CCSVReader" arrow::csv::TableReader":
@staticmethod
CResult[shared_ptr[CCSVReader]] Make(
- CMemoryPool*, shared_ptr[CInputStream],
+ CMemoryPool*, CAsyncContext, shared_ptr[CInputStream],
CCSVReadOptions, CCSVParseOptions, CCSVConvertOptions)
CResult[shared_ptr[CTable]] Read()
diff --git a/r/src/csv.cpp b/r/src/csv.cpp
index 54d3abc..69b834a 100644
--- a/r/src/csv.cpp
+++ b/r/src/csv.cpp
@@ -141,8 +141,9 @@ std::shared_ptr<arrow::csv::TableReader>
csv___TableReader__Make(
const std::shared_ptr<arrow::csv::ReadOptions>& read_options,
const std::shared_ptr<arrow::csv::ParseOptions>& parse_options,
const std::shared_ptr<arrow::csv::ConvertOptions>& convert_options) {
- return ValueOrStop(arrow::csv::TableReader::Make(gc_memory_pool(), input,
*read_options,
- *parse_options,
*convert_options));
+ return ValueOrStop(
+ arrow::csv::TableReader::Make(gc_memory_pool(),
arrow::io::AsyncContext(), input,
+ *read_options, *parse_options,
*convert_options));
}
// [[arrow::export]]