pitrou commented on a change in pull request #11946:
URL: https://github.com/apache/arrow/pull/11946#discussion_r781218663
##########
File path: cpp/src/arrow/record_batch.h
##########
@@ -234,6 +234,67 @@ class ARROW_EXPORT RecordBatchReader {
return batch;
}
+ class RecordBatchReaderIterator {
Review comment:
Nit, but since this is a nested class, perhaps calling it `Iterator` is
sufficient?
##########
File path: cpp/src/arrow/record_batch.h
##########
@@ -234,6 +234,67 @@ class ARROW_EXPORT RecordBatchReader {
return batch;
}
+ class RecordBatchReaderIterator {
+ public:
+ using iterator_category = std::input_iterator_tag;
+ using difference_type = std::ptrdiff_t;
+ using value_type = std::shared_ptr<RecordBatch>;
+ using pointer = value_type const*;
+ using reference = value_type const&;
+
+ RecordBatchReaderIterator() : batch_(RecordBatchEnd()), reader_(NULLPTR) {}
+
+ explicit RecordBatchReaderIterator(RecordBatchReader* reader)
+ : batch_(RecordBatchEnd()), reader_(reader) {
+ Next();
+ }
+
+ bool operator==(const RecordBatchReaderIterator& other) const {
+ return batch_ == other.batch_;
+ }
+
+ bool operator!=(const RecordBatchReaderIterator& other) const {
+ return !(*this == other);
+ }
+
+ Result<std::shared_ptr<RecordBatch>> operator*() {
+ ARROW_RETURN_NOT_OK(batch_.status());
+
+ return batch_;
+ }
+
+ RecordBatchReaderIterator& operator++() {
+ Next();
+ return *this;
+ }
+
+ RecordBatchReaderIterator operator++(int) {
+ RecordBatchReaderIterator tmp(*this);
+ Next();
+ return tmp;
+ }
+
+ private:
+ std::shared_ptr<RecordBatch> RecordBatchEnd() {
+ return std::shared_ptr<RecordBatch>(NULLPTR);
+ }
+
+ void Next() {
Review comment:
I don't think anything more is required, no. The caller should notice
that the iterator is equal to `end()`, and then stop reading.
##########
File path: cpp/src/arrow/record_batch_test.cc
##########
@@ -350,4 +350,64 @@ TEST_F(TestRecordBatch, MakeEmpty) {
ASSERT_EQ(empty->num_rows(), 0);
}
+class TestRecordBatchReader : public ::testing::Test {
+ public:
+ void SetUp() override { MakeBatchesAndReader(100); }
+
+ protected:
+ void MakeBatchesAndReader(int length) {
+ auto field1 = field("f1", int32());
+ auto field2 = field("f2", uint8());
+ auto field3 = field("f3", int16());
+
+ auto schema = ::arrow::schema({field1, field2, field3});
+
+ random::RandomArrayGenerator gen(42);
+
+ auto array1_1 = gen.ArrayOf(int32(), length);
+ auto array1_2 = gen.ArrayOf(int32(), length);
+ auto array1_3 = gen.ArrayOf(int32(), length);
+
+ auto array2_1 = gen.ArrayOf(uint8(), length);
+ auto array2_2 = gen.ArrayOf(uint8(), length);
+ auto array2_3 = gen.ArrayOf(uint8(), length);
+
+ auto array3_1 = gen.ArrayOf(int16(), length);
+ auto array3_2 = gen.ArrayOf(int16(), length);
+ auto array3_3 = gen.ArrayOf(int16(), length);
+
+ auto batch1 = RecordBatch::Make(schema, length, {array1_1, array2_1,
array3_1});
+ auto batch2 = RecordBatch::Make(schema, length, {array1_2, array2_2,
array3_2});
+ auto batch3 = RecordBatch::Make(schema, length, {array1_3, array2_3,
array3_3});
+
+ batches_ = {batch1, batch2, batch3};
+
+ ASSERT_OK_AND_ASSIGN(reader_, RecordBatchReader::Make(batches_));
+ }
+ std::vector<std::shared_ptr<RecordBatch>> batches_;
+ std::shared_ptr<RecordBatchReader> reader_;
+};
+
+TEST_F(TestRecordBatchReader, RangeForLoop) {
+ int64_t i = 0;
+ ASSERT_LT(i, static_cast<int64_t>(batches_.size()));
Review comment:
This check was meant to go inside the loop, before `i` is incremented.
##########
File path: cpp/src/arrow/record_batch.h
##########
@@ -234,6 +234,67 @@ class ARROW_EXPORT RecordBatchReader {
return batch;
}
+ class RecordBatchReaderIterator {
+ public:
+ using iterator_category = std::input_iterator_tag;
+ using difference_type = std::ptrdiff_t;
+ using value_type = std::shared_ptr<RecordBatch>;
+ using pointer = value_type const*;
+ using reference = value_type const&;
+
+ RecordBatchReaderIterator() : batch_(RecordBatchEnd()), reader_(NULLPTR) {}
+
+ explicit RecordBatchReaderIterator(RecordBatchReader* reader)
+ : batch_(RecordBatchEnd()), reader_(reader) {
+ Next();
+ }
+
+ bool operator==(const RecordBatchReaderIterator& other) const {
+ return batch_ == other.batch_;
+ }
+
+ bool operator!=(const RecordBatchReaderIterator& other) const {
+ return !(*this == other);
+ }
+
+ Result<std::shared_ptr<RecordBatch>> operator*() {
+ ARROW_RETURN_NOT_OK(batch_.status());
+
+ return batch_;
+ }
+
+ RecordBatchReaderIterator& operator++() {
+ Next();
+ return *this;
+ }
+
+ RecordBatchReaderIterator operator++(int) {
+ RecordBatchReaderIterator tmp(*this);
+ Next();
+ return tmp;
+ }
+
+ private:
+ std::shared_ptr<RecordBatch> RecordBatchEnd() {
+ return std::shared_ptr<RecordBatch>(NULLPTR);
+ }
+
+ void Next() {
+ if (reader_ == NULLPTR) {
+ batch_ = RecordBatchEnd();
+ return;
+ }
+ batch_ = reader_->Next();
+ }
+
+ Result<std::shared_ptr<RecordBatch>> batch_;
+ RecordBatchReader* reader_;
+ };
+
+ RecordBatchReaderIterator begin() { return RecordBatchReaderIterator(this); }
+
+ RecordBatchReaderIterator end() { return RecordBatchReaderIterator(); }
Review comment:
Can you add a terse docstring to these two public methods?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]