This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 7a5f57f212 GH-36367: [C++] Add a zipped range utility (#36393)
7a5f57f212 is described below
commit 7a5f57f21221b187a7a306a5431f2aeb9a1b1f6f
Author: Benjamin Kietzman <[email protected]>
AuthorDate: Thu Jul 6 07:08:17 2023 -0700
GH-36367: [C++] Add a zipped range utility (#36393)
### Rationale for this change
We write a lot of loops over parallel ranges. This function is a proven
pattern improving code clarity in other languages
### What changes are included in this PR?
A zip utility is added which simplifies writing loops over parallel ranges.
```diff
@@ -1118,9 +1118,8 @@ Status GetFieldsFromArray(const RjArray& json_fields,
FieldPosition parent_pos,
DictionaryMemo* dictionary_memo,
std::vector<std::shared_ptr<Field>>* fields) {
fields->resize(json_fields.Size());
- for (rj::SizeType i = 0; i < json_fields.Size(); ++i) {
- RETURN_NOT_OK(GetField(json_fields[i],
parent_pos.child(static_cast<int>(i)),
- dictionary_memo, &(*fields)[i]));
+ for (auto [json_field, field, i] : Zip(json_fields, *fields,
Enumerate<int>)) {
+ RETURN_NOT_OK(GetField(json_field, parent_pos.child(i),
dictionary_memo, &field));
}
return Status::OK();
}
```
Some(most) of the loops in json_internal.cc are rewritten to showcase usage.
### Are these changes tested?
Yes, in `range_test.cc`.
### Are there any user-facing changes?
No, this is an internal utility.
* Closes: #36367
Authored-by: Benjamin Kietzman <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/testing/json_internal.cc | 224 ++++++++++++++++-----------------
cpp/src/arrow/util/range.h | 111 +++++++++++++++-
cpp/src/arrow/util/range_test.cc | 130 ++++++++++++++++++-
3 files changed, 342 insertions(+), 123 deletions(-)
diff --git a/cpp/src/arrow/testing/json_internal.cc
b/cpp/src/arrow/testing/json_internal.cc
index 796142db54..81c4befbf2 100644
--- a/cpp/src/arrow/testing/json_internal.cc
+++ b/cpp/src/arrow/testing/json_internal.cc
@@ -19,6 +19,7 @@
#include <cstdint>
#include <cstdlib>
+#include <iomanip>
#include <iostream>
#include <memory>
#include <string>
@@ -46,6 +47,7 @@
#include "arrow/util/formatting.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
+#include "arrow/util/range.h"
#include "arrow/util/string.h"
#include "arrow/util/value_parsing.h"
#include "arrow/visit_array_inline.h"
@@ -54,7 +56,9 @@
namespace arrow {
using internal::checked_cast;
+using internal::Enumerate;
using internal::ParseValue;
+using internal::Zip;
using ipc::DictionaryFieldMapper;
using ipc::DictionaryMemo;
@@ -118,10 +122,8 @@ class SchemaWriter {
writer_->StartArray();
FieldPosition field_pos;
- int i = 0;
- for (const std::shared_ptr<Field>& field : schema_.fields()) {
+ for (auto [field, i] : Zip(schema_.fields(), Enumerate<int>)) {
RETURN_NOT_OK(VisitField(field, field_pos.child(i)));
- ++i;
}
writer_->EndArray();
WriteKeyValueMetadata(schema_.metadata());
@@ -139,12 +141,12 @@ class SchemaWriter {
writer_->StartArray();
if (metadata != nullptr) {
- for (int64_t i = 0; i < metadata->size(); ++i) {
- WriteKeyValue(metadata->key(i), metadata->value(i));
+ for (auto [key, value] : Zip(metadata->keys(), metadata->values())) {
+ WriteKeyValue(key, value);
}
}
- for (const auto& kv : additional_metadata) {
- WriteKeyValue(kv.first, kv.second);
+ for (const auto& [key, value] : additional_metadata) {
+ WriteKeyValue(key, value);
}
writer_->EndArray();
}
@@ -334,8 +336,8 @@ class SchemaWriter {
// Write type ids
writer_->Key("typeIds");
writer_->StartArray();
- for (size_t i = 0; i < type.type_codes().size(); ++i) {
- writer_->Int(type.type_codes()[i]);
+ for (int8_t i : type.type_codes()) {
+ writer_->Int(i);
}
writer_->EndArray();
}
@@ -365,10 +367,8 @@ class SchemaWriter {
FieldPosition field_pos) {
writer_->Key("children");
writer_->StartArray();
- int i = 0;
- for (const std::shared_ptr<Field>& field : children) {
+ for (auto [i, field] : Zip(Enumerate<int>, children)) {
RETURN_NOT_OK(VisitField(field, field_pos.child(i)));
- ++i;
}
writer_->EndArray();
return Status::OK();
@@ -669,14 +669,14 @@ class ArrayWriter {
Status WriteChildren(const std::vector<std::shared_ptr<Field>>& fields,
const std::vector<std::shared_ptr<Array>>& arrays) {
// NOTE: the Java parser fails on an empty "children" member (ARROW-11483).
- if (fields.size() > 0) {
- writer_->Key("children");
- writer_->StartArray();
- for (size_t i = 0; i < fields.size(); ++i) {
- RETURN_NOT_OK(VisitArray(fields[i]->name(), *arrays[i]));
- }
- writer_->EndArray();
+ if (fields.size() == 0) return Status::OK();
+
+ writer_->Key("children");
+ writer_->StartArray();
+ for (auto [field, array] : Zip(fields, arrays)) {
+ RETURN_NOT_OK(VisitArray(field->name(), *array));
}
+ writer_->EndArray();
return Status::OK();
}
@@ -1118,9 +1118,8 @@ Status GetFieldsFromArray(const RjArray& json_fields,
FieldPosition parent_pos,
DictionaryMemo* dictionary_memo,
std::vector<std::shared_ptr<Field>>* fields) {
fields->resize(json_fields.Size());
- for (rj::SizeType i = 0; i < json_fields.Size(); ++i) {
- RETURN_NOT_OK(GetField(json_fields[i],
parent_pos.child(static_cast<int>(i)),
- dictionary_memo, &(*fields)[i]));
+ for (auto [json_field, field, i] : Zip(json_fields, *fields,
Enumerate<int>)) {
+ RETURN_NOT_OK(GetField(json_field, parent_pos.child(i), dictionary_memo,
&field));
}
return Status::OK();
}
@@ -1295,14 +1294,8 @@ class ArrayReader {
typename TypeTraits<T>::BuilderType builder(type_, pool_);
ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetDataArray(obj_));
-
- for (int i = 0; i < length_; ++i) {
- if (!is_valid_[i]) {
- RETURN_NOT_OK(builder.AppendNull());
- continue;
- }
- const rj::Value& val = json_data_arr[i];
- RETURN_NOT_OK(builder.Append(UnboxValue<T>(val)));
+ for (auto [is_valid, val] : Zip(is_valid_, json_data_arr)) {
+ RETURN_NOT_OK(is_valid ? builder.Append(UnboxValue<T>(val)) :
builder.AppendNull());
}
return FinishBuilder(&builder);
}
@@ -1329,41 +1322,48 @@ class ArrayReader {
"JSON OFFSET array size differs from advertised array length + 1");
}
- for (int i = 0; i < length_; ++i) {
- if (!is_valid_[i]) {
+ for (auto [i, is_valid, json_val] :
+ Zip(Enumerate<rj::SizeType>, is_valid_, json_data_arr)) {
+ if (!is_valid) {
RETURN_NOT_OK(builder.AppendNull());
continue;
}
- const rj::Value& val = json_data_arr[i];
- DCHECK(val.IsString());
+
+ DCHECK(json_val.IsString());
+ std::string_view val{
+ json_val.GetString()}; // XXX can we use json_val.GetStringLength()?
int64_t offset_start = ParseOffset(json_offsets[i]);
int64_t offset_end = ParseOffset(json_offsets[i + 1]);
- DCHECK(offset_end >= offset_start);
+ DCHECK_GE(offset_end, offset_start);
+ auto val_len = static_cast<size_t>(offset_end - offset_start);
- if (T::is_utf8) {
- auto str = val.GetString();
- DCHECK(std::string(str).size() == static_cast<size_t>(offset_end -
offset_start));
- RETURN_NOT_OK(builder.Append(str));
+ if constexpr (T::is_utf8) {
+ if (val.size() != val_len) {
+ return Status::Invalid("Value ", std::quoted(val),
+ " differs from advertised length ", val_len);
+ }
+ RETURN_NOT_OK(builder.Append(json_val.GetString()));
} else {
- std::string hex_string = val.GetString();
-
- if (hex_string.size() % 2 != 0) {
+ if (val.size() % 2 != 0) {
return Status::Invalid("Expected base16 hex string");
}
- const auto value_len = static_cast<int64_t>(hex_string.size()) / 2;
+ if (val.size() / 2 != val_len) {
+ return Status::Invalid("Value 0x", val, " differs from advertised
byte length ",
+ val_len);
+ }
- ARROW_ASSIGN_OR_RAISE(auto byte_buffer, AllocateBuffer(value_len,
pool_));
+ ARROW_ASSIGN_OR_RAISE(auto byte_buffer, AllocateBuffer(val_len,
pool_));
- const char* hex_data = hex_string.c_str();
uint8_t* byte_buffer_data = byte_buffer->mutable_data();
- for (int64_t j = 0; j < value_len; ++j) {
- RETURN_NOT_OK(ParseHexValue(hex_data + j * 2, &byte_buffer_data[j]));
+ for (size_t j = 0; j < val_len; ++j) {
+ RETURN_NOT_OK(ParseHexValue(&val[j * 2], &byte_buffer_data[j]));
}
RETURN_NOT_OK(
- builder.Append(byte_buffer_data,
static_cast<offset_type>(value_len)));
+ builder.Append(byte_buffer_data,
static_cast<offset_type>(val_len)));
}
}
+
return FinishBuilder(&builder);
}
@@ -1372,15 +1372,13 @@ class ArrayReader {
ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetDataArray(obj_));
- for (int i = 0; i < length_; ++i) {
- if (!is_valid_[i]) {
+ for (auto [is_valid, val] : Zip(is_valid_, json_data_arr)) {
+ if (!is_valid) {
RETURN_NOT_OK(builder.AppendNull());
continue;
}
-
- const rj::Value& val = json_data_arr[i];
DCHECK(val.IsObject());
- DayTimeIntervalType::DayMilliseconds dm = {0, 0};
+ DayTimeIntervalType::DayMilliseconds dm;
dm.days = val[kDays].GetInt();
dm.milliseconds = val[kMilliseconds].GetInt();
RETURN_NOT_OK(builder.Append(dm));
@@ -1393,19 +1391,17 @@ class ArrayReader {
ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetDataArray(obj_));
- for (int i = 0; i < length_; ++i) {
- if (!is_valid_[i]) {
+ for (auto [is_valid, val] : Zip(is_valid_, json_data_arr)) {
+ if (!is_valid) {
RETURN_NOT_OK(builder.AppendNull());
continue;
}
-
- const rj::Value& val = json_data_arr[i];
DCHECK(val.IsObject());
- MonthDayNanoIntervalType::MonthDayNanos dm = {0, 0, 0};
- dm.months = val[kMonths].GetInt();
- dm.days = val[kDays].GetInt();
- dm.nanoseconds = val[kNanoseconds].GetInt64();
- RETURN_NOT_OK(builder.Append(dm));
+ MonthDayNanoIntervalType::MonthDayNanos mdn;
+ mdn.months = val[kMonths].GetInt();
+ mdn.days = val[kDays].GetInt();
+ mdn.nanoseconds = val[kNanoseconds].GetInt64();
+ RETURN_NOT_OK(builder.Append(mdn));
}
return FinishBuilder(&builder);
}
@@ -1423,25 +1419,24 @@ class ArrayReader {
ARROW_ASSIGN_OR_RAISE(auto byte_buffer, AllocateBuffer(byte_width, pool_));
uint8_t* byte_buffer_data = byte_buffer->mutable_data();
- for (int i = 0; i < length_; ++i) {
- if (!is_valid_[i]) {
+ for (auto [is_valid, json_val] : Zip(is_valid_, json_data_arr)) {
+ if (!is_valid) {
RETURN_NOT_OK(builder.AppendNull());
- } else {
- const rj::Value& val = json_data_arr[i];
- DCHECK(val.IsString())
- << "Found non-string JSON value when parsing FixedSizeBinary
value";
- std::string hex_string = val.GetString();
- if (static_cast<int32_t>(hex_string.size()) != byte_width * 2) {
- DCHECK(false) << "Expected size: " << byte_width * 2
- << " got: " << hex_string.size();
- }
- const char* hex_data = hex_string.c_str();
+ continue;
+ }
- for (int32_t j = 0; j < byte_width; ++j) {
- RETURN_NOT_OK(ParseHexValue(hex_data + j * 2, &byte_buffer_data[j]));
- }
- RETURN_NOT_OK(builder.Append(byte_buffer_data));
+ DCHECK(json_val.IsString())
+ << "Found non-string JSON value when parsing FixedSizeBinary value";
+
+ std::string_view val = json_val.GetString();
+ if (static_cast<int32_t>(val.size()) != byte_width * 2) {
+ DCHECK(false) << "Expected size: " << byte_width * 2 << " got: " <<
val.size();
+ }
+
+ for (int32_t j = 0; j < byte_width; ++j) {
+ RETURN_NOT_OK(ParseHexValue(&val[j * 2], &byte_buffer_data[j]));
}
+ RETURN_NOT_OK(builder.Append(byte_buffer_data));
}
return FinishBuilder(&builder);
}
@@ -1451,22 +1446,25 @@ class ArrayReader {
typename TypeTraits<T>::BuilderType builder(type_, pool_);
ARROW_ASSIGN_OR_RAISE(const auto json_data_arr, GetDataArray(obj_));
+ if (static_cast<rj::SizeType>(length_) != json_data_arr.Size()) {
+ return Status::Invalid("Integer array had unexpected length ",
json_data_arr.Size(),
+ " (expected ", length_, ")");
+ }
- for (int i = 0; i < length_; ++i) {
- if (!is_valid_[i]) {
+ for (auto [is_valid, val] : Zip(is_valid_, json_data_arr)) {
+ if (!is_valid) {
RETURN_NOT_OK(builder.AppendNull());
- } else {
- const rj::Value& val = json_data_arr[i];
- DCHECK(val.IsString())
- << "Found non-string JSON value when parsing Decimal128 value";
- DCHECK_GT(val.GetStringLength(), 0)
- << "Empty string found when parsing Decimal128 value";
-
- using Value = typename TypeTraits<T>::ScalarType::ValueType;
- Value value;
- ARROW_ASSIGN_OR_RAISE(value, Value::FromString(val.GetString()));
- RETURN_NOT_OK(builder.Append(value));
+ continue;
}
+
+ DCHECK(val.IsString())
+ << "Found non-string JSON value when parsing Decimal128 value";
+ DCHECK_GT(val.GetStringLength(), 0)
+ << "Empty string found when parsing Decimal128 value";
+
+ using Value = typename TypeTraits<T>::ScalarType::ValueType;
+ ARROW_ASSIGN_OR_RAISE(Value decimal_val,
Value::FromString(val.GetString()));
+ RETURN_NOT_OK(builder.Append(decimal_val));
}
return FinishBuilder(&builder);
@@ -1475,26 +1473,29 @@ class ArrayReader {
template <typename T>
Status GetIntArray(const RjArray& json_array, const int32_t length,
std::shared_ptr<Buffer>* out) {
- using ArrowType = typename CTypeTraits<T>::ArrowType;
+ if (static_cast<rj::SizeType>(length) != json_array.Size()) {
+ return Status::Invalid("Integer array had unexpected length ",
json_array.Size(),
+ " (expected ", length, ")");
+ }
+
ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(length * sizeof(T),
pool_));
T* values = reinterpret_cast<T*>(buffer->mutable_data());
- if (sizeof(T) < sizeof(int64_t)) {
- for (int i = 0; i < length; ++i) {
- const rj::Value& val = json_array[i];
+
+ for (auto [i, val] : Zip(Enumerate<rj::SizeType>, json_array)) {
+ if constexpr (sizeof(T) < sizeof(int64_t)) {
DCHECK(val.IsInt() || val.IsInt64());
if (val.IsInt()) {
values[i] = static_cast<T>(val.GetInt());
} else {
values[i] = static_cast<T>(val.GetInt64());
}
- }
- } else {
- // Read 64-bit integers as strings, as JSON numbers cannot represent
- // them exactly.
- for (int i = 0; i < length; ++i) {
- const rj::Value& val = json_array[i];
+ } else {
+ // Read 64-bit integers as strings, as JSON numbers cannot represent
+ // them exactly.
DCHECK(val.IsString());
+
+ using ArrowType = typename CTypeTraits<T>::ArrowType;
if (!ParseValue<ArrowType>(val.GetString(), val.GetStringLength(),
&values[i])) {
return Status::Invalid("Failed to parse integer: '",
std::string(val.GetString(),
val.GetStringLength()),
@@ -1614,7 +1615,7 @@ class ArrayReader {
}
Status GetNullBitmap() {
- const int64_t length = static_cast<int64_t>(is_valid_.size());
+ const auto length = static_cast<int64_t>(is_valid_.size());
ARROW_ASSIGN_OR_RAISE(data_->buffers[0], AllocateEmptyBitmap(length,
pool_));
uint8_t* bitmap = data_->buffers[0]->mutable_data();
@@ -1643,19 +1644,17 @@ class ArrayReader {
}
data_->child_data.resize(type.num_fields());
- for (int i = 0; i < type.num_fields(); ++i) {
- const rj::Value& json_child = json_children[i];
+ for (auto [json_child, child_field, child_data] :
+ Zip(json_children, type.fields(), data_->child_data)) {
DCHECK(json_child.IsObject());
const auto& child_obj = json_child.GetObject();
- std::shared_ptr<Field> child_field = type.field(i);
-
auto it = json_child.FindMember("name");
RETURN_NOT_STRING("name", it, json_child);
-
DCHECK_EQ(it->value.GetString(), child_field->name());
+
ArrayReader child_reader(child_obj, pool_, child_field);
- ARROW_ASSIGN_OR_RAISE(data_->child_data[i], child_reader.Parse());
+ ARROW_ASSIGN_OR_RAISE(child_data, child_reader.Parse());
}
return Status::OK();
@@ -1795,9 +1794,8 @@ Status ReadRecordBatch(const rj::Value& json_obj, const
std::shared_ptr<Schema>&
ARROW_ASSIGN_OR_RAISE(const auto json_columns, GetMemberArray(batch_obj,
"columns"));
ArrayDataVector columns(json_columns.Size());
- for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
- ARROW_ASSIGN_OR_RAISE(columns[i],
- ReadArrayData(pool, json_columns[i],
schema->field(i)));
+ for (auto [column, json_column, field] : Zip(columns, json_columns,
schema->fields())) {
+ ARROW_ASSIGN_OR_RAISE(column, ReadArrayData(pool, json_column, field));
}
RETURN_NOT_OK(ResolveDictionaries(columns, *dictionary_memo, pool));
@@ -1835,9 +1833,7 @@ Status WriteRecordBatch(const RecordBatch& batch,
RjWriter* writer) {
writer->Key("columns");
writer->StartArray();
- for (int i = 0; i < batch.num_columns(); ++i) {
- const std::shared_ptr<Array>& column = batch.column(i);
-
+ for (auto [column, i] : Zip(batch.columns(), Enumerate<int>)) {
DCHECK_EQ(batch.num_rows(), column->length())
<< "Array length did not match record batch length: " <<
batch.num_rows()
<< " != " << column->length() << " " << batch.column_name(i);
diff --git a/cpp/src/arrow/util/range.h b/cpp/src/arrow/util/range.h
index ea0fb0eeaa..2055328798 100644
--- a/cpp/src/arrow/util/range.h
+++ b/cpp/src/arrow/util/range.h
@@ -21,11 +21,11 @@
#include <cstdint>
#include <iterator>
#include <numeric>
+#include <tuple>
#include <utility>
#include <vector>
-namespace arrow {
-namespace internal {
+namespace arrow::internal {
/// Create a vector containing the values from start up to stop
template <typename T>
@@ -151,5 +151,108 @@ LazyRange<Generator> MakeLazyRange(Generator&& gen,
int64_t length) {
return LazyRange<Generator>(std::forward<Generator>(gen), length);
}
-} // namespace internal
-} // namespace arrow
+/// \brief A helper for iterating multiple ranges simultaneously, similar to
C++23's
+/// zip() view adapter modelled after python's built-in zip() function.
+///
+/// \code {.cpp}
+/// const std::vector<SomeTable>& tables = ...
+/// std::function<std::vector<std::string>()> GetNames = ...
+/// for (auto [table, name] : Zip(tables, GetNames())) {
+/// static_assert(std::is_same_v<decltype(table), const SomeTable&>);
+/// static_assert(std::is_same_v<decltype(name), std::string&>);
+/// // temporaries (like this vector of strings) are kept alive for the
+/// // duration of a loop and are safely movable).
+/// RegisterTableWithName(std::move(name), &table);
+/// }
+/// \endcode
+///
+/// The zipped sequence ends as soon as any of its member ranges ends.
+///
+/// Always use `auto` for the loop's declaration; it will always be a tuple
+/// of references so for example using `const auto&` will compile but will
+/// *look* like forcing const-ness even though the members of the tuple are
+/// still mutable references.
+///
+/// NOTE: we *could* make Zip a more full fledged range and enable things like
+/// - gtest recognizing it as a container; it currently doesn't since Zip is
+/// always mutable so this breaks:
+/// EXPECT_THAT(Zip(std::vector{0}, std::vector{1}),
+/// ElementsAre(std::tuple{0, 1}));
+/// - letting it be random access when possible so we can do things like *sort*
+/// parallel ranges
+/// - ...
+///
+/// However doing this will increase the compile time overhead of using Zip as
+/// long as we're still using headers. Therefore until we can use c++20
modules:
+/// *don't* extend Zip.
+template <typename Ranges, typename Indices>
+struct Zip;
+
+template <typename... Ranges>
+Zip(Ranges&&...) -> Zip<std::tuple<Ranges...>,
std::index_sequence_for<Ranges...>>;
+
+template <typename... Ranges, size_t... I>
+struct Zip<std::tuple<Ranges...>, std::index_sequence<I...>> {
+ explicit Zip(Ranges... ranges) : ranges_(std::forward<Ranges>(ranges)...) {}
+
+ std::tuple<Ranges...> ranges_;
+
+ using sentinel = std::tuple<decltype(std::end(std::get<I>(ranges_)))...>;
+ constexpr sentinel end() { return {std::end(std::get<I>(ranges_))...}; }
+
+ struct iterator : std::tuple<decltype(std::begin(std::get<I>(ranges_)))...> {
+ using std::tuple<decltype(std::begin(std::get<I>(ranges_)))...>::tuple;
+
+ constexpr auto operator*() {
+ return
std::tuple<decltype(*std::get<I>(*this))...>{*std::get<I>(*this)...};
+ }
+
+ constexpr iterator& operator++() {
+ (++std::get<I>(*this), ...);
+ return *this;
+ }
+
+ constexpr bool operator!=(const sentinel& s) const {
+ bool all_iterators_valid = (... && (std::get<I>(*this) !=
std::get<I>(s)));
+ return all_iterators_valid;
+ }
+ };
+ constexpr iterator begin() { return {std::begin(std::get<I>(ranges_))...}; }
+};
+
+/// \brief A lazy sequence of integers which starts from 0 and never stops.
+///
+/// This can be used in conjunction with Zip() to emulate python's built-in
+/// enumerate() function:
+///
+/// \code {.cpp}
+/// const std::vector<SomeTable>& tables = ...
+/// for (auto [i, table] : Zip(Enumerate<>, tables)) {
+/// std::cout << "#" << i << ": " << table.name() << std::endl;
+/// }
+/// \endcode
+template <typename I = size_t>
+constexpr auto Enumerate = [] {
+ struct {
+ struct sentinel {};
+ constexpr sentinel end() const { return {}; }
+
+ struct iterator {
+ I value{0};
+
+ constexpr I operator*() { return value; }
+
+ constexpr iterator& operator++() {
+ ++value;
+ return *this;
+ }
+
+ constexpr std::true_type operator!=(sentinel) const { return {}; }
+ };
+ constexpr iterator begin() const { return {}; }
+ } out;
+
+ return out;
+}();
+
+} // namespace arrow::internal
diff --git a/cpp/src/arrow/util/range_test.cc b/cpp/src/arrow/util/range_test.cc
index 7fedcde998..282c559438 100644
--- a/cpp/src/arrow/util/range_test.cc
+++ b/cpp/src/arrow/util/range_test.cc
@@ -20,18 +20,21 @@
#include <cstdint>
#include <vector>
+#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
#include "arrow/testing/random.h"
#include "arrow/testing/util.h"
#include "arrow/util/range.h"
-namespace arrow {
+using testing::ElementsAre;
+
+namespace arrow::internal {
class TestLazyIter : public ::testing::Test {
public:
int64_t kSize = 1000;
- void SetUp() {
+ void SetUp() override {
randint(kSize, 0, 1000000, &source_);
target_.resize(kSize);
}
@@ -43,7 +46,7 @@ class TestLazyIter : public ::testing::Test {
TEST_F(TestLazyIter, TestIncrementCopy) {
auto add_one = [this](int64_t index) { return source_[index] + 1; };
- auto lazy_range = internal::MakeLazyRange(add_one, kSize);
+ auto lazy_range = MakeLazyRange(add_one, kSize);
std::copy(lazy_range.begin(), lazy_range.end(), target_.begin());
for (int64_t index = 0; index < kSize; ++index) {
@@ -53,7 +56,7 @@ TEST_F(TestLazyIter, TestIncrementCopy) {
TEST_F(TestLazyIter, TestPostIncrementCopy) {
auto add_one = [this](int64_t index) { return source_[index] + 1; };
- auto lazy_range = internal::MakeLazyRange(add_one, kSize);
+ auto lazy_range = MakeLazyRange(add_one, kSize);
auto iter = lazy_range.begin();
auto end = lazy_range.end();
auto target_iter = target_.begin();
@@ -66,4 +69,121 @@ TEST_F(TestLazyIter, TestPostIncrementCopy) {
ASSERT_EQ(source_[index] + 1, target_[index]);
}
}
-} // namespace arrow
+
+TEST(Zip, TupleTypes) {
+ char arr[3];
+ const std::string const_arr[3];
+ for (auto tuple :
+ Zip(arr, // 1. mutable lvalue range
+ const_arr, // 2. const lvalue range
+ std::vector<float>{}, // 3. rvalue range
+ std::vector<bool>{}, // 4. rvalue range dereferencing to non ref
+ Enumerate<int>)) { // 6. Enumerate
+ // (const lvalue range dereferencing to
non ref)
+ static_assert(
+ std::is_same_v<decltype(tuple),
+ std::tuple<char&, // 1. mutable lvalue
ref binding
+ const std::string&, // 2. const lvalue ref
binding
+ float&, // 3. mutable lvalue
ref binding
+ std::vector<bool>::reference, // 4.
by-value non ref
+ // binding
(thanks STL)
+ int // 5. by-value non ref binding
+ // (that's fine they're just ints)
+ >>);
+ }
+
+ static size_t max_count;
+ static size_t count = 0;
+
+ struct Counted {
+ static void increment_count() {
+ ++count;
+ EXPECT_LE(count, max_count);
+ }
+
+ Counted() { increment_count(); }
+ Counted(Counted&&) { increment_count(); }
+
+ ~Counted() { --count; }
+
+ Counted(const Counted&) = delete;
+ Counted& operator=(const Counted&) = delete;
+ Counted& operator=(Counted&&) = delete;
+ };
+
+ {
+ max_count = 3;
+ const Counted const_arr[3];
+ EXPECT_EQ(count, 3);
+
+ for (auto [e] : Zip(const_arr)) {
+ // Putting a const reference to range into Zip results in no copies and
the
+ // corresponding tuple element will also be a const reference
+ EXPECT_EQ(count, 3);
+ static_assert(std::is_same_v<decltype(e), const Counted&>);
+ }
+ EXPECT_EQ(count, 3);
+ }
+
+ {
+ max_count = 3;
+ Counted arr[3];
+ EXPECT_EQ(count, 3);
+
+ for (auto [e] : Zip(arr)) {
+ // Putting a mutable reference to range into Zip results in no copies
and the
+ // corresponding tuple element will also be a mutable reference
+ EXPECT_EQ(count, 3);
+ static_assert(std::is_same_v<decltype(e), Counted&>);
+ }
+ EXPECT_EQ(count, 3);
+ }
+
+ {
+ max_count = 3;
+ EXPECT_EQ(count, 0);
+ for (auto [e] : Zip(std::vector<Counted>(3))) {
+ // Putting a prvalue vector into Zip results in no copies and keeps the
temporary
+ // alive as a mutable vector so that we can move out of it if we might
reuse the
+ // elements:
+ EXPECT_EQ(count, 3);
+ static_assert(std::is_same_v<decltype(e), Counted&>);
+ }
+ EXPECT_EQ(count, 0);
+ }
+
+ {
+ std::vector<bool> v{false, false, false, false};
+ for (auto [i, e] : Zip(Enumerate<int>, v)) {
+ // Testing with a range whose references aren't actually references
+ static_assert(std::is_same_v<decltype(e), decltype(v)::reference>);
+ static_assert(std::is_same_v<decltype(e), decltype(v[0])>);
+ static_assert(!std::is_reference_v<decltype(e)>);
+ e = (i % 2 == 0);
+ }
+
+ EXPECT_THAT(v, ElementsAre(true, false, true, false));
+ }
+}
+
+TEST(Zip, EndAfterShortestEnds) {
+ std::vector<int> shorter{0, 0, 0}, longer{9, 9, 9, 9, 9, 9};
+
+ for (auto [s, l] : Zip(shorter, longer)) {
+ std::swap(s, l);
+ }
+
+ EXPECT_THAT(longer, ElementsAre(0, 0, 0, 9, 9, 9));
+}
+
+TEST(Zip, Enumerate) {
+ std::vector<std::string> vec(3);
+
+ for (auto [i, s] : Zip(Enumerate<>, vec)) {
+ static_assert(std::is_same_v<decltype(s), std::string&>);
+ s = std::to_string(i + 7);
+ }
+
+ EXPECT_THAT(vec, ElementsAre("7", "8", "9"));
+}
+} // namespace arrow::internal