This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 99b57e8427 ARROW-17412: [C++] AsofJoin multiple keys and types (#13880)
99b57e8427 is described below
commit 99b57e84277f24e8ec1ddadbb11ef8b4f43c8c89
Author: rtpsw <[email protected]>
AuthorDate: Thu Sep 8 23:05:40 2022 +0300
ARROW-17412: [C++] AsofJoin multiple keys and types (#13880)
See https://issues.apache.org/jira/browse/ARROW-17412
Lead-authored-by: Yaron Gvili <[email protected]>
Co-authored-by: rtpsw <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
cpp/src/arrow/array/data.h | 15 +-
cpp/src/arrow/compute/exec/asof_join_benchmark.cc | 2 +-
cpp/src/arrow/compute/exec/asof_join_node.cc | 675 ++++++++++---
cpp/src/arrow/compute/exec/asof_join_node_test.cc | 1079 +++++++++++++++++----
cpp/src/arrow/compute/exec/hash_join.cc | 1 -
cpp/src/arrow/compute/exec/options.h | 20 +-
cpp/src/arrow/compute/light_array.cc | 6 +
cpp/src/arrow/compute/light_array.h | 19 +-
cpp/src/arrow/type_traits.h | 7 +
9 files changed, 1482 insertions(+), 342 deletions(-)
diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h
index dde66ac79c..e024483f66 100644
--- a/cpp/src/arrow/array/data.h
+++ b/cpp/src/arrow/array/data.h
@@ -167,6 +167,11 @@ struct ARROW_EXPORT ArrayData {
std::shared_ptr<ArrayData> Copy() const { return
std::make_shared<ArrayData>(*this); }
+ bool IsNull(int64_t i) const {
+ return ((buffers[0] != NULLPTR) ? !bit_util::GetBit(buffers[0]->data(), i
+ offset)
+ : null_count.load() == length);
+ }
+
// Access a buffer's data as a typed C pointer
template <typename T>
inline const T* GetValues(int i, int64_t absolute_offset) const {
@@ -324,18 +329,14 @@ struct ARROW_EXPORT ArraySpan {
return GetValues<T>(i, this->offset);
}
- bool IsNull(int64_t i) const {
- return ((this->buffers[0].data != NULLPTR)
- ? !bit_util::GetBit(this->buffers[0].data, i + this->offset)
- : this->null_count == this->length);
- }
-
- bool IsValid(int64_t i) const {
+ inline bool IsValid(int64_t i) const {
return ((this->buffers[0].data != NULLPTR)
? bit_util::GetBit(this->buffers[0].data, i + this->offset)
: this->null_count != this->length);
}
+ inline bool IsNull(int64_t i) const { return !IsValid(i); }
+
std::shared_ptr<ArrayData> ToArrayData() const;
std::shared_ptr<Array> ToArray() const;
diff --git a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc
b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc
index 543a4ece57..7d8abc0ba4 100644
--- a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc
+++ b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc
@@ -109,7 +109,7 @@ static void TableJoinOverhead(benchmark::State& state,
static void AsOfJoinOverhead(benchmark::State& state) {
int64_t tolerance = 0;
- AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, kKeyCol,
tolerance);
+ AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, {kKeyCol},
tolerance);
TableJoinOverhead(
state,
TableGenerationProperties{int(state.range(0)), int(state.range(1)),
diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc
b/cpp/src/arrow/compute/exec/asof_join_node.cc
index 3da612aa03..869456a577 100644
--- a/cpp/src/arrow/compute/exec/asof_join_node.cc
+++ b/cpp/src/arrow/compute/exec/asof_join_node.cc
@@ -17,34 +17,63 @@
#include <condition_variable>
#include <mutex>
-#include <set>
#include <thread>
#include <unordered_map>
+#include <unordered_set>
+#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/key_hash.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/schema_util.h"
#include "arrow/compute/exec/util.h"
+#include "arrow/compute/light_array.h"
#include "arrow/record_batch.h"
#include "arrow/result.h"
#include "arrow/status.h"
+#include "arrow/type_traits.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/future.h"
#include "arrow/util/make_unique.h"
#include "arrow/util/optional.h"
+#include "arrow/util/string_view.h"
namespace arrow {
namespace compute {
-// Remove this when multiple keys and/or types is supported
-typedef int32_t KeyType;
+template <typename T, typename V = typename T::value_type>
+inline typename T::const_iterator std_find(const T& container, const V& val) {
+ return std::find(container.begin(), container.end(), val);
+}
+
+template <typename T, typename V = typename T::value_type>
+inline bool std_has(const T& container, const V& val) {
+ return container.end() != std_find(container, val);
+}
+
+typedef uint64_t ByType;
+typedef uint64_t OnType;
+typedef uint64_t HashType;
// Maximum number of tables that can be joined
#define MAX_JOIN_TABLES 64
typedef uint64_t row_index_t;
typedef int col_index_t;
+// normalize the value to 64-bits while preserving ordering of values
+template <typename T, enable_if_t<std::is_integral<T>::value, bool> = true>
+static inline uint64_t time_value(T t) {
+ uint64_t bias = std::is_signed<T>::value ? (uint64_t)1 << (8 * sizeof(T) -
1) : 0;
+ return t < 0 ? static_cast<uint64_t>(t + bias) : static_cast<uint64_t>(t);
+}
+
+// indicates normalization of a key value
+template <typename T, enable_if_t<std::is_integral<T>::value, bool> = true>
+static inline uint64_t key_value(T t) {
+ return static_cast<uint64_t>(t);
+}
+
/**
* Simple implementation for an unbound concurrent queue
*/
@@ -65,6 +94,11 @@ class ConcurrentQueue {
cond_.notify_one();
}
+ void Clear() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ queue_ = std::queue<T>();
+ }
+
util::optional<T> TryPop() {
// Try to pop the oldest value from the queue (or return nullopt if none)
std::unique_lock<std::mutex> lock(mutex_);
@@ -99,7 +133,7 @@ struct MemoStore {
struct Entry {
// Timestamp associated with the entry
- int64_t time;
+ OnType time;
// Batch associated with the entry (perf is probably OK for this; batches
change
// rarely)
@@ -109,10 +143,10 @@ struct MemoStore {
row_index_t row;
};
- std::unordered_map<KeyType, Entry> entries_;
+ std::unordered_map<ByType, Entry> entries_;
- void Store(const std::shared_ptr<RecordBatch>& batch, row_index_t row,
int64_t time,
- KeyType key) {
+ void Store(const std::shared_ptr<RecordBatch>& batch, row_index_t row,
OnType time,
+ ByType key) {
auto& e = entries_[key];
// that we can do this assignment optionally, is why we
// can get array with using shared_ptr above (the batch
@@ -122,13 +156,13 @@ struct MemoStore {
e.time = time;
}
- util::optional<const Entry*> GetEntryForKey(KeyType key) const {
+ util::optional<const Entry*> GetEntryForKey(ByType key) const {
auto e = entries_.find(key);
if (entries_.end() == e) return util::nullopt;
return util::optional<const Entry*>(&e->second);
}
- void RemoveEntriesWithLesserTime(int64_t ts) {
+ void RemoveEntriesWithLesserTime(OnType ts) {
for (auto e = entries_.begin(); e != entries_.end();)
if (e->second.time < ts)
e = entries_.erase(e);
@@ -137,18 +171,89 @@ struct MemoStore {
}
};
+// a specialized higher-performance variation of Hashing64 logic from
hash_join_node
+// the code here avoids recreating objects that are independent of each batch
processed
+class KeyHasher {
+ static constexpr int kMiniBatchLength = util::MiniBatch::kMiniBatchLength;
+
+ public:
+ explicit KeyHasher(const std::vector<col_index_t>& indices)
+ : indices_(indices),
+ metadata_(indices.size()),
+ batch_(NULLPTR),
+ hashes_(),
+ ctx_(),
+ column_arrays_(),
+ stack_() {
+ ctx_.stack = &stack_;
+ column_arrays_.resize(indices.size());
+ }
+
+ Status Init(ExecContext* exec_context, const std::shared_ptr<arrow::Schema>&
schema) {
+ ctx_.hardware_flags = exec_context->cpu_info()->hardware_flags();
+ const auto& fields = schema->fields();
+ for (size_t k = 0; k < metadata_.size(); k++) {
+ ARROW_ASSIGN_OR_RAISE(metadata_[k],
+
ColumnMetadataFromDataType(fields[indices_[k]]->type()));
+ }
+ return stack_.Init(exec_context->memory_pool(),
+ 4 * kMiniBatchLength * sizeof(uint32_t));
+ }
+
+ const std::vector<HashType>& HashesFor(const RecordBatch* batch) {
+ if (batch_ == batch) {
+ return hashes_;
+ }
+ batch_ = NULLPTR; // invalidate cached hashes for batch
+ size_t batch_length = batch->num_rows();
+ hashes_.resize(batch_length);
+ for (int64_t i = 0; i < static_cast<int64_t>(batch_length); i +=
kMiniBatchLength) {
+ int64_t length = std::min(static_cast<int64_t>(batch_length - i),
+ static_cast<int64_t>(kMiniBatchLength));
+ for (size_t k = 0; k < indices_.size(); k++) {
+ auto array_data = batch->column_data(indices_[k]);
+ column_arrays_[k] =
+ ColumnArrayFromArrayDataAndMetadata(array_data, metadata_[k], i,
length);
+ }
+ Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i);
+ }
+ batch_ = batch;
+ return hashes_;
+ }
+
+ private:
+ std::vector<col_index_t> indices_;
+ std::vector<KeyColumnMetadata> metadata_;
+ const RecordBatch* batch_;
+ std::vector<HashType> hashes_;
+ LightContext ctx_;
+ std::vector<KeyColumnArray> column_arrays_;
+ util::TempVectorStack stack_;
+};
+
class InputState {
// InputState correponds to an input
// Input record batches are queued up in InputState until processed and
// turned into output record batches.
public:
- InputState(const std::shared_ptr<arrow::Schema>& schema,
- const std::string& time_col_name, const std::string& key_col_name)
+ InputState(bool must_hash, bool may_rehash, KeyHasher* key_hasher,
+ const std::shared_ptr<arrow::Schema>& schema,
+ const col_index_t time_col_index,
+ const std::vector<col_index_t>& key_col_index)
: queue_(),
schema_(schema),
- time_col_index_(schema->GetFieldIndex(time_col_name)),
- key_col_index_(schema->GetFieldIndex(key_col_name)) {}
+ time_col_index_(time_col_index),
+ key_col_index_(key_col_index),
+ time_type_id_(schema_->fields()[time_col_index_]->type()->id()),
+ key_type_id_(key_col_index.size()),
+ key_hasher_(key_hasher),
+ must_hash_(must_hash),
+ may_rehash_(may_rehash) {
+ for (size_t k = 0; k < key_col_index_.size(); k++) {
+ key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id();
+ }
+ }
col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool
skip_time_and_key_fields) {
src_to_dst_.resize(schema_->num_fields());
@@ -164,7 +269,7 @@ class InputState {
bool IsTimeOrKeyColumn(col_index_t i) const {
DCHECK_LT(i, schema_->num_fields());
- return (i == time_col_index_) || (i == key_col_index_);
+ return (i == time_col_index_) || std_has(key_col_index_, i);
}
// Gets the latest row index, assuming the queue isn't empty
@@ -184,27 +289,87 @@ class InputState {
return queue_.UnsyncFront();
}
- KeyType GetLatestKey() const {
- return queue_.UnsyncFront()
- ->column_data(key_col_index_)
- ->GetValues<KeyType>(1)[latest_ref_row_];
+#define LATEST_VAL_CASE(id, val) \
+ case Type::id: { \
+ using T = typename TypeIdTraits<Type::id>::Type; \
+ using CType = typename TypeTraits<T>::CType; \
+ return val(data->GetValues<CType>(1)[row]); \
+ }
+
+ inline ByType GetLatestKey() const {
+ return GetLatestKey(queue_.UnsyncFront().get(), latest_ref_row_);
}
- int64_t GetLatestTime() const {
- return queue_.UnsyncFront()
- ->column_data(time_col_index_)
- ->GetValues<int64_t>(1)[latest_ref_row_];
+ inline ByType GetLatestKey(const RecordBatch* batch, row_index_t row) const {
+ if (must_hash_) {
+ return key_hasher_->HashesFor(batch)[row];
+ }
+ if (key_col_index_.size() == 0) {
+ return 0;
+ }
+ auto data = batch->column_data(key_col_index_[0]);
+ switch (key_type_id_[0]) {
+ LATEST_VAL_CASE(INT8, key_value)
+ LATEST_VAL_CASE(INT16, key_value)
+ LATEST_VAL_CASE(INT32, key_value)
+ LATEST_VAL_CASE(INT64, key_value)
+ LATEST_VAL_CASE(UINT8, key_value)
+ LATEST_VAL_CASE(UINT16, key_value)
+ LATEST_VAL_CASE(UINT32, key_value)
+ LATEST_VAL_CASE(UINT64, key_value)
+ LATEST_VAL_CASE(DATE32, key_value)
+ LATEST_VAL_CASE(DATE64, key_value)
+ LATEST_VAL_CASE(TIME32, key_value)
+ LATEST_VAL_CASE(TIME64, key_value)
+ LATEST_VAL_CASE(TIMESTAMP, key_value)
+ default:
+ DCHECK(false);
+ return 0; // cannot happen
+ }
}
+ inline OnType GetLatestTime() const {
+ return GetLatestTime(queue_.UnsyncFront().get(), latest_ref_row_);
+ }
+
+ inline ByType GetLatestTime(const RecordBatch* batch, row_index_t row) const
{
+ auto data = batch->column_data(time_col_index_);
+ switch (time_type_id_) {
+ LATEST_VAL_CASE(INT8, time_value)
+ LATEST_VAL_CASE(INT16, time_value)
+ LATEST_VAL_CASE(INT32, time_value)
+ LATEST_VAL_CASE(INT64, time_value)
+ LATEST_VAL_CASE(UINT8, time_value)
+ LATEST_VAL_CASE(UINT16, time_value)
+ LATEST_VAL_CASE(UINT32, time_value)
+ LATEST_VAL_CASE(UINT64, time_value)
+ LATEST_VAL_CASE(DATE32, time_value)
+ LATEST_VAL_CASE(DATE64, time_value)
+ LATEST_VAL_CASE(TIME32, time_value)
+ LATEST_VAL_CASE(TIME64, time_value)
+ LATEST_VAL_CASE(TIMESTAMP, time_value)
+ default:
+ DCHECK(false);
+ return 0; // cannot happen
+ }
+ }
+
+#undef LATEST_VAL_CASE
+
bool Finished() const { return batches_processed_ == total_batches_; }
- bool Advance() {
+ Result<bool> Advance() {
// Try advancing to the next row and update latest_ref_row_
// Returns true if able to advance, false if not.
bool have_active_batch =
(latest_ref_row_ > 0 /*short circuit the lock on the queue*/) ||
!queue_.Empty();
if (have_active_batch) {
+ OnType next_time = GetLatestTime();
+ if (latest_time_ > next_time) {
+ return Status::Invalid("AsofJoin does not allow out-of-order on-key
values");
+ }
+ latest_time_ = next_time;
// If we have an active batch
if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront()->num_rows()) {
// hit the end of the batch, need to get the next batch if possible.
@@ -222,46 +387,60 @@ class InputState {
// latest_time and latest_ref_row to the value that immediately pass the
// specified timestamp.
// Returns true if updates were made, false if not.
- bool AdvanceAndMemoize(int64_t ts) {
+ Result<bool> AdvanceAndMemoize(OnType ts) {
// Advance the right side row index until we reach the latest right row
(for each key)
// for the given left timestamp.
// Check if already updated for TS (or if there is no latest)
if (Empty()) return false; // can't advance if empty
- auto latest_time = GetLatestTime();
- if (latest_time > ts) return false; // already advanced
// Not updated. Try to update and possibly advance.
- bool updated = false;
+ bool advanced, updated = false;
do {
- latest_time = GetLatestTime();
+ auto latest_time = GetLatestTime();
// if Advance() returns true, then the latest_ts must also be valid
// Keep advancing right table until we hit the latest row that has
// timestamp <= ts. This is because we only need the latest row for the
// match given a left ts.
- if (latest_time <= ts) {
- memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time,
GetLatestKey());
- } else {
+ if (latest_time > ts) {
break; // hit a future timestamp -- done updating for now
}
+ auto rb = GetLatestBatch();
+ if (may_rehash_ && rb->column_data(key_col_index_[0])->GetNullCount() >
0) {
+ must_hash_ = true;
+ may_rehash_ = false;
+ Rehash();
+ }
+ memo_.Store(rb, latest_ref_row_, latest_time, GetLatestKey());
updated = true;
- } while (Advance());
+ ARROW_ASSIGN_OR_RAISE(advanced, Advance());
+ } while (advanced);
return updated;
}
- void Push(const std::shared_ptr<arrow::RecordBatch>& rb) {
+ void Rehash() {
+ MemoStore new_memo;
+ for (const auto& entry : memo_.entries_) {
+ const auto& e = entry.second;
+ new_memo.Store(e.batch, e.row, e.time, GetLatestKey(e.batch.get(),
e.row));
+ }
+ memo_ = new_memo;
+ }
+
+ Status Push(const std::shared_ptr<arrow::RecordBatch>& rb) {
if (rb->num_rows() > 0) {
queue_.Push(rb);
} else {
++batches_processed_; // don't enqueue empty batches, just record as
processed
}
+ return Status::OK();
}
- util::optional<const MemoStore::Entry*> GetMemoEntryForKey(KeyType key) {
+ util::optional<const MemoStore::Entry*> GetMemoEntryForKey(ByType key) {
return memo_.GetEntryForKey(key);
}
- util::optional<int64_t> GetMemoTimeForKey(KeyType key) {
+ util::optional<OnType> GetMemoTimeForKey(ByType key) {
auto r = GetMemoEntryForKey(key);
if (r.has_value()) {
return (*r)->time;
@@ -270,7 +449,7 @@ class InputState {
}
}
- void RemoveMemoEntriesWithLesserTime(int64_t ts) {
+ void RemoveMemoEntriesWithLesserTime(OnType ts) {
memo_.RemoveEntriesWithLesserTime(ts);
}
@@ -294,10 +473,22 @@ class InputState {
// Index of the time col
col_index_t time_col_index_;
// Index of the key col
- col_index_t key_col_index_;
+ std::vector<col_index_t> key_col_index_;
+ // Type id of the time column
+ Type::type time_type_id_;
+ // Type id of the key column
+ std::vector<Type::type> key_type_id_;
+ // Hasher for key elements
+ mutable KeyHasher* key_hasher_;
+ // True if hashing is mandatory
+ bool must_hash_;
+ // True if by-key values may be rehashed
+ bool may_rehash_;
// Index of the latest row reference within; if >0 then queue_ cannot be
empty
// Must be < queue_.front()->num_rows() if queue_ is non-empty
row_index_t latest_ref_row_ = 0;
+ // Time of latest row
+ OnType latest_time_ = std::numeric_limits<OnType>::lowest();
// Stores latest known values for the various keys
MemoStore memo_;
// Mapping of source columns to destination columns
@@ -336,18 +527,18 @@ class CompositeReferenceTable {
// Adds the latest row from the input state as a new composite reference row
// - LHS must have a valid key,timestep,and latest rows
// - RHS must have valid data memo'ed for the key
- void Emplace(std::vector<std::unique_ptr<InputState>>& in, int64_t
tolerance) {
+ void Emplace(std::vector<std::unique_ptr<InputState>>& in, OnType tolerance)
{
DCHECK_EQ(in.size(), n_tables_);
// Get the LHS key
- KeyType key = in[0]->GetLatestKey();
+ ByType key = in[0]->GetLatestKey();
// Add row and setup LHS
// (the LHS state comes just from the latest row of the LHS table)
DCHECK(!in[0]->Empty());
const std::shared_ptr<arrow::RecordBatch>& lhs_latest_batch =
in[0]->GetLatestBatch();
row_index_t lhs_latest_row = in[0]->GetLatestRow();
- int64_t lhs_latest_time = in[0]->GetLatestTime();
+ OnType lhs_latest_time = in[0]->GetLatestTime();
if (0 == lhs_latest_row) {
// On the first row of the batch, we resize the destination.
// The destination size is dictated by the size of the LHS batch.
@@ -407,29 +598,42 @@ class CompositeReferenceTable {
DCHECK_EQ(src_field->name(), dst_field->name());
const auto& field_type = src_field->type();
- if (field_type->Equals(arrow::int32())) {
- ARROW_ASSIGN_OR_RAISE(
- arrays.at(i_dst_col),
- (MaterializePrimitiveColumn<arrow::Int32Builder, int32_t>(
- memory_pool, i_table, i_src_col)));
- } else if (field_type->Equals(arrow::int64())) {
- ARROW_ASSIGN_OR_RAISE(
- arrays.at(i_dst_col),
- (MaterializePrimitiveColumn<arrow::Int64Builder, int64_t>(
- memory_pool, i_table, i_src_col)));
- } else if (field_type->Equals(arrow::float32())) {
- ARROW_ASSIGN_OR_RAISE(arrays.at(i_dst_col),
-
(MaterializePrimitiveColumn<arrow::FloatBuilder, float>(
- memory_pool, i_table, i_src_col)));
- } else if (field_type->Equals(arrow::float64())) {
- ARROW_ASSIGN_OR_RAISE(
- arrays.at(i_dst_col),
- (MaterializePrimitiveColumn<arrow::DoubleBuilder, double>(
- memory_pool, i_table, i_src_col)));
- } else {
- ARROW_RETURN_NOT_OK(
- Status::Invalid("Unsupported data type: ", src_field->name()));
+#define ASOFJOIN_MATERIALIZE_CASE(id) \
+ case Type::id: { \
+ using T = typename TypeIdTraits<Type::id>::Type; \
+ ARROW_ASSIGN_OR_RAISE( \
+ arrays.at(i_dst_col), \
+ MaterializeColumn<T>(memory_pool, field_type, i_table, i_src_col)); \
+ break; \
+ }
+
+ switch (field_type->id()) {
+ ASOFJOIN_MATERIALIZE_CASE(INT8)
+ ASOFJOIN_MATERIALIZE_CASE(INT16)
+ ASOFJOIN_MATERIALIZE_CASE(INT32)
+ ASOFJOIN_MATERIALIZE_CASE(INT64)
+ ASOFJOIN_MATERIALIZE_CASE(UINT8)
+ ASOFJOIN_MATERIALIZE_CASE(UINT16)
+ ASOFJOIN_MATERIALIZE_CASE(UINT32)
+ ASOFJOIN_MATERIALIZE_CASE(UINT64)
+ ASOFJOIN_MATERIALIZE_CASE(FLOAT)
+ ASOFJOIN_MATERIALIZE_CASE(DOUBLE)
+ ASOFJOIN_MATERIALIZE_CASE(DATE32)
+ ASOFJOIN_MATERIALIZE_CASE(DATE64)
+ ASOFJOIN_MATERIALIZE_CASE(TIME32)
+ ASOFJOIN_MATERIALIZE_CASE(TIME64)
+ ASOFJOIN_MATERIALIZE_CASE(TIMESTAMP)
+ ASOFJOIN_MATERIALIZE_CASE(STRING)
+ ASOFJOIN_MATERIALIZE_CASE(LARGE_STRING)
+ ASOFJOIN_MATERIALIZE_CASE(BINARY)
+ ASOFJOIN_MATERIALIZE_CASE(LARGE_BINARY)
+ default:
+ return Status::Invalid("Unsupported data type ",
+ src_field->type()->ToString(), " for
field ",
+ src_field->name());
}
+
+#undef ASOFJOIN_MATERIALIZE_CASE
}
}
}
@@ -459,17 +663,45 @@ class CompositeReferenceTable {
if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()]
= ref;
}
- template <class Builder, class PrimitiveType>
- Result<std::shared_ptr<Array>> MaterializePrimitiveColumn(MemoryPool*
memory_pool,
- size_t i_table,
- col_index_t i_col)
{
- Builder builder(memory_pool);
+ template <class Type, class Builder = typename TypeTraits<Type>::BuilderType>
+ enable_if_fixed_width_type<Type, Status> static BuilderAppend(
+ Builder& builder, const std::shared_ptr<ArrayData>& source, row_index_t
row) {
+ if (source->IsNull(row)) {
+ builder.UnsafeAppendNull();
+ return Status::OK();
+ }
+ using CType = typename TypeTraits<Type>::CType;
+ builder.UnsafeAppend(source->template GetValues<CType>(1)[row]);
+ return Status::OK();
+ }
+
+ template <class Type, class Builder = typename TypeTraits<Type>::BuilderType>
+ enable_if_base_binary<Type, Status> static BuilderAppend(
+ Builder& builder, const std::shared_ptr<ArrayData>& source, row_index_t
row) {
+ if (source->IsNull(row)) {
+ return builder.AppendNull();
+ }
+ using offset_type = typename Type::offset_type;
+ const uint8_t* data = source->buffers[2]->data();
+ const offset_type* offsets = source->GetValues<offset_type>(1);
+ const offset_type offset0 = offsets[row];
+ const offset_type offset1 = offsets[row + 1];
+ return builder.Append(data + offset0, offset1 - offset0);
+ }
+
+ template <class Type, class Builder = typename TypeTraits<Type>::BuilderType>
+ Result<std::shared_ptr<Array>> MaterializeColumn(MemoryPool* memory_pool,
+ const
std::shared_ptr<DataType>& type,
+ size_t i_table, col_index_t
i_col) {
+ ARROW_ASSIGN_OR_RAISE(auto a_builder, MakeBuilder(type, memory_pool));
+ Builder& builder = *checked_cast<Builder*>(a_builder.get());
ARROW_RETURN_NOT_OK(builder.Reserve(rows_.size()));
for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) {
const auto& ref = rows_[i_row].refs[i_table];
if (ref.batch) {
- builder.UnsafeAppend(
- ref.batch->column_data(i_col)->template
GetValues<PrimitiveType>(1)[ref.row]);
+ Status st =
+ BuilderAppend<Type, Builder>(builder,
ref.batch->column_data(i_col), ref.row);
+ ARROW_RETURN_NOT_OK(st);
} else {
builder.UnsafeAppendNull();
}
@@ -480,14 +712,21 @@ class CompositeReferenceTable {
}
};
+// TODO: Currently, AsofJoinNode uses 64-bit hashing which leads to a
non-negligible
+// probability of collision, which can cause incorrect results when many
different by-key
+// values are processed. Thus, AsofJoinNode is currently limited to about 100k
by-keys for
+// guaranteeing this probability is below 1 in a billion. The fix is 128-bit
hashing.
+// See ARROW-17653
class AsofJoinNode : public ExecNode {
// Advances the RHS as far as possible to be up to date for the current LHS
timestamp
- bool UpdateRhs() {
+ Result<bool> UpdateRhs() {
auto& lhs = *state_.at(0);
auto lhs_latest_time = lhs.GetLatestTime();
bool any_updated = false;
- for (size_t i = 1; i < state_.size(); ++i)
- any_updated |= state_[i]->AdvanceAndMemoize(lhs_latest_time);
+ for (size_t i = 1; i < state_.size(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(bool advanced,
state_[i]->AdvanceAndMemoize(lhs_latest_time));
+ any_updated |= advanced;
+ }
return any_updated;
}
@@ -495,7 +734,7 @@ class AsofJoinNode : public ExecNode {
bool IsUpToDateWithLhsRow() const {
auto& lhs = *state_[0];
if (lhs.Empty()) return false; // can't proceed if nothing on the LHS
- int64_t lhs_ts = lhs.GetLatestTime();
+ OnType lhs_ts = lhs.GetLatestTime();
for (size_t i = 1; i < state_.size(); ++i) {
auto& rhs = *state_[i];
if (!rhs.Finished()) {
@@ -523,7 +762,7 @@ class AsofJoinNode : public ExecNode {
if (lhs.Finished() || lhs.Empty()) break;
// Advance each of the RHS as far as possible to be up to date for the
LHS timestamp
- bool any_rhs_advanced = UpdateRhs();
+ ARROW_ASSIGN_OR_RAISE(bool any_rhs_advanced, UpdateRhs());
// If we have received enough inputs to produce the next output batch
// (decided by IsUpToDateWithLhsRow), we will perform the join and
@@ -531,8 +770,9 @@ class AsofJoinNode : public ExecNode {
// the LHS and adding joined row to rows_ (done by Emplace). Finally,
// input batches that are no longer needed are removed to free up memory.
if (IsUpToDateWithLhsRow()) {
- dst.Emplace(state_, options_.tolerance);
- if (!lhs.Advance()) break; // if we can't advance LHS, we're done for
this batch
+ dst.Emplace(state_, tolerance_);
+ ARROW_ASSIGN_OR_RAISE(bool advanced, lhs.Advance());
+ if (!advanced) break; // if we can't advance LHS, we're done for this
batch
} else {
if (!any_rhs_advanced) break; // need to wait for new data
}
@@ -541,8 +781,7 @@ class AsofJoinNode : public ExecNode {
// Prune memo entries that have expired (to bound memory consumption)
if (!lhs.Empty()) {
for (size_t i = 1; i < state_.size(); ++i) {
- state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() -
- options_.tolerance);
+ state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() -
tolerance_);
}
}
@@ -572,7 +811,6 @@ class AsofJoinNode : public ExecNode {
ExecBatch out_b(*out_rb);
outputs_[0]->InputReceived(this, std::move(out_b));
} else {
- StopProducing();
ErrorIfNotOk(result.status());
return;
}
@@ -584,8 +822,8 @@ class AsofJoinNode : public ExecNode {
// It may happen here in cases where InputFinished was called before we
were finished
// producing results (so we didn't know the output size at that time)
if (state_.at(0)->Finished()) {
- StopProducing();
outputs_[0]->InputFinished(this, batches_produced_);
+ finished_.MarkFinished();
}
}
@@ -602,54 +840,172 @@ class AsofJoinNode : public ExecNode {
public:
AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector<std::string>
input_labels,
- const AsofJoinNodeOptions& join_options,
- std::shared_ptr<Schema> output_schema);
+ const std::vector<col_index_t>& indices_of_on_key,
+ const std::vector<std::vector<col_index_t>>& indices_of_by_key,
+ OnType tolerance, std::shared_ptr<Schema> output_schema,
+ std::vector<std::unique_ptr<KeyHasher>> key_hashers, bool
must_hash,
+ bool may_rehash);
+
+ Status Init() override {
+ auto inputs = this->inputs();
+ for (size_t i = 0; i < inputs.size(); i++) {
+ RETURN_NOT_OK(key_hashers_[i]->Init(plan()->exec_context(),
output_schema()));
+ state_.push_back(::arrow::internal::make_unique<InputState>(
+ must_hash_, may_rehash_, key_hashers_[i].get(),
inputs[i]->output_schema(),
+ indices_of_on_key_[i], indices_of_by_key_[i]));
+ }
+
+ col_index_t dst_offset = 0;
+ for (auto& state : state_)
+ dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset);
+
+ return Status::OK();
+ }
virtual ~AsofJoinNode() {
process_.Push(false); // poison pill
process_thread_.join();
}
+ const std::vector<col_index_t>& indices_of_on_key() { return
indices_of_on_key_; }
+ const std::vector<std::vector<col_index_t>>& indices_of_by_key() {
+ return indices_of_by_key_;
+ }
+
+ static Status is_valid_on_field(const std::shared_ptr<Field>& field) {
+ switch (field->type()->id()) {
+ case Type::INT8:
+ case Type::INT16:
+ case Type::INT32:
+ case Type::INT64:
+ case Type::UINT8:
+ case Type::UINT16:
+ case Type::UINT32:
+ case Type::UINT64:
+ case Type::DATE32:
+ case Type::DATE64:
+ case Type::TIME32:
+ case Type::TIME64:
+ case Type::TIMESTAMP:
+ return Status::OK();
+ default:
+ return Status::Invalid("Unsupported type for on-key ", field->name(),
" : ",
+ field->type()->ToString());
+ }
+ }
+
+ static Status is_valid_by_field(const std::shared_ptr<Field>& field) {
+ switch (field->type()->id()) {
+ case Type::INT8:
+ case Type::INT16:
+ case Type::INT32:
+ case Type::INT64:
+ case Type::UINT8:
+ case Type::UINT16:
+ case Type::UINT32:
+ case Type::UINT64:
+ case Type::DATE32:
+ case Type::DATE64:
+ case Type::TIME32:
+ case Type::TIME64:
+ case Type::TIMESTAMP:
+ case Type::STRING:
+ case Type::LARGE_STRING:
+ case Type::BINARY:
+ case Type::LARGE_BINARY:
+ return Status::OK();
+ default:
+ return Status::Invalid("Unsupported type for by-key ", field->name(),
" : ",
+ field->type()->ToString());
+ }
+ }
+
+ static Status is_valid_data_field(const std::shared_ptr<Field>& field) {
+ switch (field->type()->id()) {
+ case Type::INT8:
+ case Type::INT16:
+ case Type::INT32:
+ case Type::INT64:
+ case Type::UINT8:
+ case Type::UINT16:
+ case Type::UINT32:
+ case Type::UINT64:
+ case Type::FLOAT:
+ case Type::DOUBLE:
+ case Type::DATE32:
+ case Type::DATE64:
+ case Type::TIME32:
+ case Type::TIME64:
+ case Type::TIMESTAMP:
+ case Type::STRING:
+ case Type::LARGE_STRING:
+ case Type::BINARY:
+ case Type::LARGE_BINARY:
+ return Status::OK();
+ default:
+ return Status::Invalid("Unsupported type for data field ",
field->name(), " : ",
+ field->type()->ToString());
+ }
+ }
+
static arrow::Result<std::shared_ptr<Schema>> MakeOutputSchema(
- const std::vector<ExecNode*>& inputs, const AsofJoinNodeOptions&
options) {
+ const std::vector<ExecNode*>& inputs,
+ const std::vector<col_index_t>& indices_of_on_key,
+ const std::vector<std::vector<col_index_t>>& indices_of_by_key) {
std::vector<std::shared_ptr<arrow::Field>> fields;
- const auto& on_field_name = *options.on_key.name();
- const auto& by_field_name = *options.by_key.name();
-
+ size_t n_by = indices_of_by_key[0].size();
+ const DataType* on_key_type = NULLPTR;
+ std::vector<const DataType*> by_key_type(n_by, NULLPTR);
// Take all non-key, non-time RHS fields
for (size_t j = 0; j < inputs.size(); ++j) {
const auto& input_schema = inputs[j]->output_schema();
- const auto& on_field_ix = input_schema->GetFieldIndex(on_field_name);
- const auto& by_field_ix = input_schema->GetFieldIndex(by_field_name);
+ const auto& on_field_ix = indices_of_on_key[j];
+ const auto& by_field_ix = indices_of_by_key[j];
- if ((on_field_ix == -1) | (by_field_ix == -1)) {
+ if ((on_field_ix == -1) || std_has(by_field_ix, -1)) {
return Status::Invalid("Missing join key on table ", j);
}
+ const auto& on_field = input_schema->fields()[on_field_ix];
+ std::vector<const Field*> by_field(n_by);
+ for (size_t k = 0; k < n_by; k++) {
+ by_field[k] = input_schema->fields()[by_field_ix[k]].get();
+ }
+
+ if (on_key_type == NULLPTR) {
+ on_key_type = on_field->type().get();
+ } else if (*on_key_type != *on_field->type()) {
+ return Status::Invalid("Expected on-key type ", *on_key_type, " but
got ",
+ *on_field->type(), " for field ",
on_field->name(),
+ " in input ", j);
+ }
+ for (size_t k = 0; k < n_by; k++) {
+ if (by_key_type[k] == NULLPTR) {
+ by_key_type[k] = by_field[k]->type().get();
+ } else if (*by_key_type[k] != *by_field[k]->type()) {
+ return Status::Invalid("Expected on-key type ", *by_key_type[k], "
but got ",
+ *by_field[k]->type(), " for field ",
by_field[k]->name(),
+ " in input ", j);
+ }
+ }
+
for (int i = 0; i < input_schema->num_fields(); ++i) {
const auto field = input_schema->field(i);
- if (field->name() == on_field_name) {
- if (kSupportedOnTypes_.find(field->type()) ==
kSupportedOnTypes_.end()) {
- return Status::Invalid("Unsupported type for on key: ",
field->name());
- }
+ if (i == on_field_ix) {
+ ARROW_RETURN_NOT_OK(is_valid_on_field(field));
// Only add on field from the left table
if (j == 0) {
fields.push_back(field);
}
- } else if (field->name() == by_field_name) {
- if (kSupportedByTypes_.find(field->type()) ==
kSupportedByTypes_.end()) {
- return Status::Invalid("Unsupported type for by key: ",
field->name());
- }
+ } else if (std_has(by_field_ix, i)) {
+ ARROW_RETURN_NOT_OK(is_valid_by_field(field));
// Only add by field from the left table
if (j == 0) {
fields.push_back(field);
}
} else {
- if (kSupportedDataTypes_.find(field->type()) ==
kSupportedDataTypes_.end()) {
- return Status::Invalid("Unsupported data type: ", field->name());
- }
-
+ ARROW_RETURN_NOT_OK(is_valid_data_field(field));
fields.push_back(field);
}
}
@@ -657,45 +1013,91 @@ class AsofJoinNode : public ExecNode {
return std::make_shared<arrow::Schema>(fields);
}
+ static inline Result<col_index_t> FindColIndex(const Schema& schema,
+ const FieldRef& field_ref,
+ util::string_view key_kind) {
+ auto match_res = field_ref.FindOne(schema);
+ if (!match_res.ok()) {
+ return Status::Invalid("Bad join key on table : ",
match_res.status().message());
+ }
+ ARROW_ASSIGN_OR_RAISE(auto match, match_res);
+ if (match.indices().size() != 1) {
+ return Status::Invalid("AsOfJoinNode does not support a nested ",
+ to_string(key_kind), "-key ",
field_ref.ToString());
+ }
+ return match.indices()[0];
+ }
+
static arrow::Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*>
inputs,
const ExecNodeOptions& options) {
DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs";
const auto& join_options = checked_cast<const
AsofJoinNodeOptions&>(options);
- ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Schema> output_schema,
- MakeOutputSchema(inputs, join_options));
+ if (join_options.tolerance < 0) {
+ return Status::Invalid("AsOfJoin tolerance must be non-negative but is ",
+ join_options.tolerance);
+ }
- std::vector<std::string> input_labels(inputs.size());
- input_labels[0] = "left";
- for (size_t i = 1; i < inputs.size(); ++i) {
- input_labels[i] = "right_" + std::to_string(i);
+ size_t n_input = inputs.size(), n_by = join_options.by_key.size();
+ std::vector<std::string> input_labels(n_input);
+ std::vector<col_index_t> indices_of_on_key(n_input);
+ std::vector<std::vector<col_index_t>> indices_of_by_key(
+ n_input, std::vector<col_index_t>(n_by));
+ for (size_t i = 0; i < n_input; ++i) {
+ input_labels[i] = i == 0 ? "left" : "right_" + std::to_string(i);
+ const Schema& input_schema = *inputs[i]->output_schema();
+ ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i],
+ FindColIndex(input_schema, join_options.on_key,
"on"));
+ for (size_t k = 0; k < n_by; k++) {
+ ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k],
+ FindColIndex(input_schema,
join_options.by_key[k], "by"));
+ }
}
- return plan->EmplaceNode<AsofJoinNode>(plan, inputs,
std::move(input_labels),
- join_options,
std::move(output_schema));
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Schema> output_schema,
+ MakeOutputSchema(inputs, indices_of_on_key,
indices_of_by_key));
+
+ std::vector<std::unique_ptr<KeyHasher>> key_hashers;
+ for (size_t i = 0; i < n_input; i++) {
+ key_hashers.push_back(
+ ::arrow::internal::make_unique<KeyHasher>(indices_of_by_key[i]));
+ }
+ bool must_hash =
+ n_by > 1 ||
+ (n_by == 1 &&
+ !is_primitive(
+
inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id()));
+ bool may_rehash = n_by == 1 && !must_hash;
+ return plan->EmplaceNode<AsofJoinNode>(
+ plan, inputs, std::move(input_labels), std::move(indices_of_on_key),
+ std::move(indices_of_by_key), time_value(join_options.tolerance),
+ std::move(output_schema), std::move(key_hashers), must_hash,
may_rehash);
}
const char* kind_name() const override { return "AsofJoinNode"; }
void InputReceived(ExecNode* input, ExecBatch batch) override {
// Get the input
- ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) !=
inputs_.end());
- size_t k = std::find(inputs_.begin(), inputs_.end(), input) -
inputs_.begin();
+ ARROW_DCHECK(std_has(inputs_, input));
+ size_t k = std_find(inputs_, input) - inputs_.begin();
// Put into the queue
auto rb = *batch.ToRecordBatch(input->output_schema());
- state_.at(k)->Push(rb);
+ Status st = state_.at(k)->Push(rb);
+ if (!st.ok()) {
+ ErrorReceived(input, st);
+ return;
+ }
process_.Push(true);
}
void ErrorReceived(ExecNode* input, Status error) override {
outputs_[0]->ErrorReceived(this, std::move(error));
- StopProducing();
}
void InputFinished(ExecNode* input, int total_batches) override {
{
std::lock_guard<std::mutex> guard(gate_);
- ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) !=
inputs_.end());
- size_t k = std::find(inputs_.begin(), inputs_.end(), input) -
inputs_.begin();
+ ARROW_DCHECK(std_has(inputs_, input));
+ size_t k = std_find(inputs_, input) - inputs_.begin();
state_.at(k)->set_total_batches(total_batches);
}
// Trigger a process call
@@ -714,20 +1116,24 @@ class AsofJoinNode : public ExecNode {
DCHECK_EQ(output, outputs_[0]);
StopProducing();
}
- void StopProducing() override { finished_.MarkFinished(); }
+ void StopProducing() override {
+ process_.Clear();
+ process_.Push(false);
+ }
arrow::Future<> finished() override { return finished_; }
private:
- static const std::set<std::shared_ptr<DataType>> kSupportedOnTypes_;
- static const std::set<std::shared_ptr<DataType>> kSupportedByTypes_;
- static const std::set<std::shared_ptr<DataType>> kSupportedDataTypes_;
-
arrow::Future<> finished_;
+ std::vector<col_index_t> indices_of_on_key_;
+ std::vector<std::vector<col_index_t>> indices_of_by_key_;
+ std::vector<std::unique_ptr<KeyHasher>> key_hashers_;
+ bool must_hash_;
+ bool may_rehash_;
// InputStates
// Each input state correponds to an input table
std::vector<std::unique_ptr<InputState>> state_;
std::mutex gate_;
- AsofJoinNodeOptions options_;
+ OnType tolerance_;
// Queue for triggering processing of a given input
// (a false value is a poison pill)
@@ -741,30 +1147,25 @@ class AsofJoinNode : public ExecNode {
AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs,
std::vector<std::string> input_labels,
- const AsofJoinNodeOptions& join_options,
- std::shared_ptr<Schema> output_schema)
+ const std::vector<col_index_t>& indices_of_on_key,
+ const std::vector<std::vector<col_index_t>>&
indices_of_by_key,
+ OnType tolerance, std::shared_ptr<Schema>
output_schema,
+ std::vector<std::unique_ptr<KeyHasher>> key_hashers,
+ bool must_hash, bool may_rehash)
: ExecNode(plan, inputs, input_labels,
/*output_schema=*/std::move(output_schema),
/*num_outputs=*/1),
- options_(join_options),
+ indices_of_on_key_(std::move(indices_of_on_key)),
+ indices_of_by_key_(std::move(indices_of_by_key)),
+ key_hashers_(std::move(key_hashers)),
+ must_hash_(must_hash),
+ may_rehash_(may_rehash),
+ tolerance_(tolerance),
process_(),
process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) {
- for (size_t i = 0; i < inputs.size(); ++i)
- state_.push_back(::arrow::internal::make_unique<InputState>(
- inputs[i]->output_schema(), *options_.on_key.name(),
*options_.by_key.name()));
- col_index_t dst_offset = 0;
- for (auto& state : state_)
- dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset);
-
finished_ = arrow::Future<>::MakeFinished();
}
-// Currently supported types
-const std::set<std::shared_ptr<DataType>> AsofJoinNode::kSupportedOnTypes_ =
{int64()};
-const std::set<std::shared_ptr<DataType>> AsofJoinNode::kSupportedByTypes_ =
{int32()};
-const std::set<std::shared_ptr<DataType>> AsofJoinNode::kSupportedDataTypes_ =
{
- int32(), int64(), float32(), float64()};
-
namespace internal {
void RegisterAsofJoinNode(ExecFactoryRegistry* registry) {
DCHECK_OK(registry->AddFactory("asofjoin", AsofJoinNode::Make));
diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc
b/cpp/src/arrow/compute/exec/asof_join_node_test.cc
index 8b993764ab..48d1ae6410 100644
--- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc
+++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc
@@ -17,11 +17,13 @@
#include <gmock/gmock-matchers.h>
+#include <chrono>
#include <numeric>
#include <random>
#include <unordered_set>
#include "arrow/api.h"
+#include "arrow/compute/api_scalar.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/test_util.h"
#include "arrow/compute/exec/util.h"
@@ -32,23 +34,185 @@
#include "arrow/testing/random.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/make_unique.h"
+#include "arrow/util/string_view.h"
#include "arrow/util/thread_pool.h"
+#define TRACED_TEST(t_class, t_name, t_body) \
+ TEST(t_class, t_name) { \
+ ARROW_SCOPED_TRACE(#t_class "_" #t_name); \
+ t_body; \
+ }
+
+#define TRACED_TEST_P(t_class, t_name, t_body) \
+ TEST_P(t_class, t_name) { \
+ ARROW_SCOPED_TRACE(#t_class "_" #t_name "_" + std::get<1>(GetParam())); \
+ t_body; \
+ }
+
using testing::UnorderedElementsAreArray;
namespace arrow {
namespace compute {
+bool is_temporal_primitive(Type::type type_id) {
+ switch (type_id) {
+ case Type::TIME32:
+ case Type::TIME64:
+ case Type::DATE32:
+ case Type::DATE64:
+ case Type::TIMESTAMP:
+ return true;
+ default:
+ return false;
+ }
+}
+
+Result<BatchesWithSchema> MakeBatchesFromNumString(
+ const std::shared_ptr<Schema>& schema,
+ const std::vector<util::string_view>& json_strings, int multiplicity = 1) {
+ FieldVector num_fields;
+ for (auto field : schema->fields()) {
+ num_fields.push_back(
+ is_base_binary_like(field->type()->id()) ? field->WithType(int64()) :
field);
+ }
+ auto num_schema =
+ std::make_shared<Schema>(num_fields, schema->endianness(),
schema->metadata());
+ BatchesWithSchema num_batches =
+ MakeBatchesFromString(num_schema, json_strings, multiplicity);
+ BatchesWithSchema batches;
+ batches.schema = schema;
+ int n_fields = schema->num_fields();
+ for (auto num_batch : num_batches.batches) {
+ std::vector<Datum> values;
+ for (int i = 0; i < n_fields; i++) {
+ auto type = schema->field(i)->type();
+ if (is_base_binary_like(type->id())) {
+ // casting to string first enables casting to binary
+ ARROW_ASSIGN_OR_RAISE(Datum as_string, Cast(num_batch.values[i],
utf8()));
+ ARROW_ASSIGN_OR_RAISE(Datum as_type, Cast(as_string, type));
+ values.push_back(as_type);
+ } else {
+ values.push_back(num_batch.values[i]);
+ }
+ }
+ ExecBatch batch(values, num_batch.length);
+ batches.batches.push_back(batch);
+ }
+ return batches;
+}
+
+void BuildNullArray(std::shared_ptr<Array>& empty, const
std::shared_ptr<DataType>& type,
+ int64_t length) {
+ ASSERT_OK_AND_ASSIGN(auto builder, MakeBuilder(type, default_memory_pool()));
+ ASSERT_OK(builder->Reserve(length));
+ ASSERT_OK(builder->AppendNulls(length));
+ ASSERT_OK(builder->Finish(&empty));
+}
+
+void BuildZeroPrimitiveArray(std::shared_ptr<Array>& empty,
+ const std::shared_ptr<DataType>& type, int64_t
length) {
+ ASSERT_OK_AND_ASSIGN(auto builder, MakeBuilder(type, default_memory_pool()));
+ ASSERT_OK(builder->Reserve(length));
+ ASSERT_OK_AND_ASSIGN(auto scalar, MakeScalar(type, 0));
+ ASSERT_OK(builder->AppendScalar(*scalar, length));
+ ASSERT_OK(builder->Finish(&empty));
+}
+
+template <typename Builder>
+void BuildZeroBaseBinaryArray(std::shared_ptr<Array>& empty, int64_t length) {
+ Builder builder(default_memory_pool());
+ ASSERT_OK(builder.Reserve(length));
+ for (int64_t i = 0; i < length; i++) {
+ ASSERT_OK(builder.Append("0", /*length=*/1));
+ }
+ ASSERT_OK(builder.Finish(&empty));
+}
+
+// mutates by copying from_key into to_key and changing from_key to zero
+Result<BatchesWithSchema> MutateByKey(BatchesWithSchema& batches, std::string
from_key,
+ std::string to_key, bool replace_key =
false,
+ bool null_key = false, bool remove_key =
false) {
+ int from_index = batches.schema->GetFieldIndex(from_key);
+ int n_fields = batches.schema->num_fields();
+ auto fields = batches.schema->fields();
+ BatchesWithSchema new_batches;
+ if (remove_key) {
+ ARROW_ASSIGN_OR_RAISE(new_batches.schema,
batches.schema->RemoveField(from_index));
+ } else {
+ auto new_field = batches.schema->field(from_index)->WithName(to_key);
+ ARROW_ASSIGN_OR_RAISE(new_batches.schema,
+ replace_key ? batches.schema->SetField(from_index,
new_field)
+ : batches.schema->AddField(from_index,
new_field));
+ }
+ for (const ExecBatch& batch : batches.batches) {
+ std::vector<Datum> new_values;
+ for (int i = 0; i < n_fields; i++) {
+ const Datum& value = batch.values[i];
+ if (i == from_index) {
+ if (remove_key) {
+ continue;
+ }
+ auto type = fields[i]->type();
+ if (null_key) {
+ std::shared_ptr<Array> empty;
+ BuildNullArray(empty, type, batch.length);
+ new_values.push_back(empty);
+ } else if (is_primitive(type->id())) {
+ std::shared_ptr<Array> empty;
+ BuildZeroPrimitiveArray(empty, type, batch.length);
+ new_values.push_back(empty);
+ } else if (is_base_binary_like(type->id())) {
+ std::shared_ptr<Array> empty;
+ switch (type->id()) {
+ case Type::STRING:
+ BuildZeroBaseBinaryArray<StringBuilder>(empty, batch.length);
+ break;
+ case Type::LARGE_STRING:
+ BuildZeroBaseBinaryArray<LargeStringBuilder>(empty,
batch.length);
+ break;
+ case Type::BINARY:
+ BuildZeroBaseBinaryArray<BinaryBuilder>(empty, batch.length);
+ break;
+ case Type::LARGE_BINARY:
+ BuildZeroBaseBinaryArray<LargeBinaryBuilder>(empty,
batch.length);
+ break;
+ default:
+ DCHECK(false);
+ break;
+ }
+ new_values.push_back(empty);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto sub, Subtract(value, value));
+ new_values.push_back(sub);
+ }
+ if (replace_key) {
+ continue;
+ }
+ }
+ new_values.push_back(value);
+ }
+ new_batches.batches.emplace_back(new_values, batch.length);
+ }
+ return new_batches;
+}
+
+// code generation for the by_key types supported by AsofJoinNodeOptions
constructors
+// which cannot be directly done using templates because of failure to deduce
the template
+// argument for an invocation with a string- or initializer_list-typed
keys-argument
+#define EXPAND_BY_KEY_TYPE(macro) \
+ macro(const FieldRef); \
+ macro(std::vector<FieldRef>); \
+ macro(std::initializer_list<FieldRef>);
+
void CheckRunOutput(const BatchesWithSchema& l_batches,
const BatchesWithSchema& r0_batches,
const BatchesWithSchema& r1_batches,
- const BatchesWithSchema& exp_batches, const FieldRef time,
- const FieldRef keys, const int64_t tolerance) {
+ const BatchesWithSchema& exp_batches,
+ const AsofJoinNodeOptions join_options) {
auto exec_ctx =
arrow::internal::make_unique<ExecContext>(default_memory_pool(),
nullptr);
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
- AsofJoinNodeOptions join_options(time, keys, tolerance);
Declaration join{"asofjoin", join_options};
join.inputs.emplace_back(Declaration{
@@ -64,6 +228,9 @@ void CheckRunOutput(const BatchesWithSchema& l_batches,
.AddToPlan(plan.get()));
ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(),
sink_gen));
+ for (auto batch : res) {
+ ASSERT_EQ(exp_batches.schema->num_fields(), batch.values.size());
+ }
ASSERT_OK_AND_ASSIGN(auto exp_table,
TableFromExecBatches(exp_batches.schema,
exp_batches.batches));
@@ -74,237 +241,783 @@ void CheckRunOutput(const BatchesWithSchema& l_batches,
/*same_chunk_layout=*/true, /*flatten=*/true);
}
-void DoRunBasicTest(const std::vector<util::string_view>& l_data,
- const std::vector<util::string_view>& r0_data,
- const std::vector<util::string_view>& r1_data,
- const std::vector<util::string_view>& exp_data, int64_t
tolerance) {
- auto l_schema =
- schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())});
- auto r0_schema =
- schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())});
- auto r1_schema =
- schema({field("time", int64()), field("key", int32()), field("r1_v0",
float32())});
-
- auto exp_schema = schema({
- field("time", int64()),
- field("key", int32()),
- field("l_v0", float64()),
- field("r0_v0", float64()),
- field("r1_v0", float32()),
- });
-
- // Test three table join
- BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches;
- l_batches = MakeBatchesFromString(l_schema, l_data);
- r0_batches = MakeBatchesFromString(r0_schema, r0_data);
- r1_batches = MakeBatchesFromString(r1_schema, r1_data);
- exp_batches = MakeBatchesFromString(exp_schema, exp_data);
- CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key",
- tolerance);
-}
-
-void DoRunInvalidTypeTest(const std::shared_ptr<Schema>& l_schema,
- const std::shared_ptr<Schema>& r_schema) {
- BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"});
- BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"});
-
+#define CHECK_RUN_OUTPUT(by_key_type)
\
+ void CheckRunOutput(
\
+ const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches,
\
+ const BatchesWithSchema& r1_batches, const BatchesWithSchema&
exp_batches, \
+ const FieldRef time, by_key_type key, const int64_t tolerance) {
\
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches,
\
+ AsofJoinNodeOptions(time, {key}, tolerance));
\
+ }
+
+EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT)
+
+void DoInvalidPlanTest(const BatchesWithSchema& l_batches,
+ const BatchesWithSchema& r_batches,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str,
+ bool fail_on_plan_creation = false) {
ExecContext exec_ctx;
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx));
- AsofJoinNodeOptions join_options("time", "key", 0);
Declaration join{"asofjoin", join_options};
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{l_batches.schema, l_batches.gen(false,
false)}});
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{r_batches.schema, r_batches.gen(false,
false)}});
- ASSERT_RAISES(Invalid, join.AddToPlan(plan.get()));
+ if (fail_on_plan_creation) {
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ ASSERT_OK(Declaration::Sequence({join, {"sink",
SinkNodeOptions{&sink_gen}}})
+ .AddToPlan(plan.get()));
+ EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(Invalid,
+
::testing::HasSubstr(expected_error_str),
+ StartAndCollect(plan.get(),
sink_gen));
+ } else {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr(expected_error_str),
+ join.AddToPlan(plan.get()));
+ }
+}
+
+void DoRunInvalidPlanTest(const BatchesWithSchema& l_batches,
+ const BatchesWithSchema& r_batches,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str) {
+ DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str);
+}
+
+void DoRunInvalidPlanTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str) {
+ ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema,
{R"([])"}));
+ ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema,
{R"([])"}));
+
+ return DoRunInvalidPlanTest(l_batches, r_batches, join_options,
expected_error_str);
+}
+
+void DoRunInvalidPlanTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema, int64_t
tolerance,
+ const std::string& expected_error_str) {
+ DoRunInvalidPlanTest(l_schema, r_schema,
+ AsofJoinNodeOptions("time", {"key"}, tolerance),
+ expected_error_str);
+}
+
+void DoRunInvalidTypeTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Unsupported type for ");
+}
+
+void DoRunInvalidToleranceTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, -1,
+ "AsOfJoin tolerance must be non-negative but is ");
+}
+
+void DoRunMissingKeysTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : No
match");
+}
+
+void DoRunMissingOnKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema,
+ AsofJoinNodeOptions("invalid_time", {"key"}, 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunMissingByKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema,
+ AsofJoinNodeOptions("time", {"invalid_key"}, 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunNestedOnKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"},
{"key"}, 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunNestedByKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema,
+ AsofJoinNodeOptions("time", {FieldRef{0, 1}}, 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunAmbiguousOnKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table :
Multiple matches");
+}
+
+void DoRunAmbiguousByKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table :
Multiple matches");
+}
+
+// Gets a batch for testing as a Json string
+// The batch will have n_rows rows n_cols columns, the first column being the
on-field
+// If unordered is true then the first column will be out-of-order
+std::string GetTestBatchAsJsonString(int n_rows, int n_cols, bool unordered =
false) {
+ int order_mask = unordered ? 1 : 0;
+ std::stringstream s;
+ s << '[';
+ for (int i = 0; i < n_rows; i++) {
+ if (i > 0) {
+ s << ", ";
+ }
+ s << '[';
+ for (int j = 0; j < n_cols; j++) {
+ if (j > 0) {
+ s << ", " << j;
+ } else if (j < 2) {
+ s << (i ^ order_mask);
+ } else {
+ s << i;
+ }
+ }
+ s << ']';
+ }
+ s << ']';
+ return s.str();
+}
+
+void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered,
+ const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str) {
+ ASSERT_TRUE(l_unordered || r_unordered);
+ int n_rows = 5;
+ auto l_str = GetTestBatchAsJsonString(n_rows, l_schema->num_fields(),
l_unordered);
+ auto r_str = GetTestBatchAsJsonString(n_rows, r_schema->num_fields(),
r_unordered);
+ ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema,
{l_str}));
+ ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema,
{r_str}));
+
+ return DoInvalidPlanTest(l_batches, r_batches, join_options,
expected_error_str,
+ /*then_run_plan=*/true);
}
+void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered,
+ const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema,
+ AsofJoinNodeOptions("time", {"key"}, 1000),
+ "out-of-order on-key values");
+}
+
+struct BasicTestTypes {
+ std::shared_ptr<DataType> time, key, l_val, r0_val, r1_val;
+};
+
+struct BasicTest {
+ BasicTest(const std::vector<util::string_view>& l_data,
+ const std::vector<util::string_view>& r0_data,
+ const std::vector<util::string_view>& r1_data,
+ const std::vector<util::string_view>& exp_nokey_data,
+ const std::vector<util::string_view>& exp_emptykey_data,
+ const std::vector<util::string_view>& exp_data, int64_t tolerance)
+ : l_data(std::move(l_data)),
+ r0_data(std::move(r0_data)),
+ r1_data(std::move(r1_data)),
+ exp_nokey_data(std::move(exp_nokey_data)),
+ exp_emptykey_data(std::move(exp_emptykey_data)),
+ exp_data(std::move(exp_data)),
+ tolerance(tolerance) {}
+
+ static inline void check_init(const std::vector<std::shared_ptr<DataType>>&
types) {
+ ASSERT_NE(0, types.size());
+ }
+
+ template <typename TypeCond>
+ static inline std::vector<std::shared_ptr<DataType>> init_types(
+ const std::vector<std::shared_ptr<DataType>>& all_types, TypeCond
type_cond) {
+ std::vector<std::shared_ptr<DataType>> types;
+ for (auto type : all_types) {
+ if (type_cond(type)) {
+ types.push_back(type);
+ }
+ }
+ check_init(types);
+ return types;
+ }
+
+ void RunSingleByKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time",
"key",
+ tolerance);
+ });
+ }
+ static void DoSingleByKey(BasicTest& basic_tests) {
basic_tests.RunSingleByKey(); }
+ void RunDoubleByKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time",
+ {"key", "key"}, tolerance);
+ });
+ }
+ static void DoDoubleByKey(BasicTest& basic_tests) {
basic_tests.RunDoubleByKey(); }
+ void RunMutateByKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2"));
+ ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2"));
+ ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2"));
+ ASSERT_OK_AND_ASSIGN(exp_batches, MutateByKey(exp_batches, "key",
"key2"));
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time",
+ {"key", "key2"}, tolerance);
+ });
+ }
+ static void DoMutateByKey(BasicTest& basic_tests) {
basic_tests.RunMutateByKey(); }
+ void RunMutateNoKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2",
true));
+ ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2",
true));
+ ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2",
true));
+ ASSERT_OK_AND_ASSIGN(exp_nokey_batches,
+ MutateByKey(exp_nokey_batches, "key", "key2",
true));
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches,
"time", "key2",
+ tolerance);
+ });
+ }
+ static void DoMutateNoKey(BasicTest& basic_tests) {
basic_tests.RunMutateNoKey(); }
+ void RunMutateNullKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2",
true, true));
+ ASSERT_OK_AND_ASSIGN(r0_batches,
+ MutateByKey(r0_batches, "key", "key2", true, true));
+ ASSERT_OK_AND_ASSIGN(r1_batches,
+ MutateByKey(r1_batches, "key", "key2", true, true));
+ ASSERT_OK_AND_ASSIGN(exp_nokey_batches,
+ MutateByKey(exp_nokey_batches, "key", "key2", true,
true));
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches,
+ AsofJoinNodeOptions("time", {"key2"}, tolerance));
+ });
+ }
+ static void DoMutateNullKey(BasicTest& basic_tests) {
basic_tests.RunMutateNullKey(); }
+ void RunMutateEmptyKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ ASSERT_OK_AND_ASSIGN(r0_batches,
+ MutateByKey(r0_batches, "key", "key", false, false,
true));
+ ASSERT_OK_AND_ASSIGN(r1_batches,
+ MutateByKey(r1_batches, "key", "key", false, false,
true));
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_emptykey_batches,
+ AsofJoinNodeOptions("time", {}, tolerance));
+ });
+ }
+ static void DoMutateEmptyKey(BasicTest& basic_tests) {
+ basic_tests.RunMutateEmptyKey();
+ }
+ template <typename BatchesRunner>
+ void RunBatches(BatchesRunner batches_runner) {
+ std::vector<std::shared_ptr<DataType>> all_types = {
+ utf8(),
+ large_utf8(),
+ binary(),
+ large_binary(),
+ int8(),
+ int16(),
+ int32(),
+ int64(),
+ uint8(),
+ uint16(),
+ uint32(),
+ uint64(),
+ date32(),
+ date64(),
+ time32(TimeUnit::MILLI),
+ time32(TimeUnit::SECOND),
+ time64(TimeUnit::NANO),
+ time64(TimeUnit::MICRO),
+ timestamp(TimeUnit::NANO, "UTC"),
+ timestamp(TimeUnit::MICRO, "UTC"),
+ timestamp(TimeUnit::MILLI, "UTC"),
+ timestamp(TimeUnit::SECOND, "UTC"),
+ float32(),
+ float64()};
+ using T = const std::shared_ptr<DataType>;
+ // byte_width > 1 below allows fitting the tested data
+ auto time_types = init_types(
+ all_types, [](T& t) { return t->byte_width() > 1 &&
!is_floating(t->id()); });
+ auto key_types = init_types(all_types, [](T& t) { return
!is_floating(t->id()); });
+ auto l_types = init_types(all_types, [](T& t) { return true; });
+ auto r0_types = init_types(all_types, [](T& t) { return t->byte_width() >
1; });
+ auto r1_types = init_types(all_types, [](T& t) { return t->byte_width() >
1; });
+
+ // sample a limited number of type-combinations to keep the runnning time
reasonable
+ // the scoped-traces below help reproduce a test failure, should it happen
+ auto start_time = std::chrono::system_clock::now();
+ auto seed = start_time.time_since_epoch().count();
+ ARROW_SCOPED_TRACE("Types seed: ", seed);
+ std::default_random_engine engine(static_cast<unsigned int>(seed));
+ std::uniform_int_distribution<size_t> time_distribution(0,
time_types.size() - 1);
+ std::uniform_int_distribution<size_t> key_distribution(0, key_types.size()
- 1);
+ std::uniform_int_distribution<size_t> l_distribution(0, l_types.size() -
1);
+ std::uniform_int_distribution<size_t> r0_distribution(0, r0_types.size() -
1);
+ std::uniform_int_distribution<size_t> r1_distribution(0, r1_types.size() -
1);
+
+ for (int i = 0; i < 1000; i++) {
+ auto time_type = time_types[time_distribution(engine)];
+ ARROW_SCOPED_TRACE("Time type: ", *time_type);
+ auto key_type = key_types[key_distribution(engine)];
+ ARROW_SCOPED_TRACE("Key type: ", *key_type);
+ auto l_type = l_types[l_distribution(engine)];
+ ARROW_SCOPED_TRACE("Left type: ", *l_type);
+ auto r0_type = r0_types[r0_distribution(engine)];
+ ARROW_SCOPED_TRACE("Right-0 type: ", *r0_type);
+ auto r1_type = r1_types[r1_distribution(engine)];
+ ARROW_SCOPED_TRACE("Right-1 type: ", *r1_type);
+
+ RunTypes({time_type, key_type, l_type, r0_type, r1_type},
batches_runner);
+
+ auto end_time = std::chrono::system_clock::now();
+ std::chrono::duration<double> diff = end_time - start_time;
+ if (diff.count() > 2) {
+ // this normally happens on slow CI systems, but is fine
+ break;
+ }
+ }
+ }
+ template <typename BatchesRunner>
+ void RunTypes(BasicTestTypes basic_test_types, BatchesRunner batches_runner)
{
+ const BasicTestTypes& b = basic_test_types;
+ auto l_schema =
+ schema({field("time", b.time), field("key", b.key), field("l_v0",
b.l_val)});
+ auto r0_schema =
+ schema({field("time", b.time), field("key", b.key), field("r0_v0",
b.r0_val)});
+ auto r1_schema =
+ schema({field("time", b.time), field("key", b.key), field("r1_v0",
b.r1_val)});
+
+ auto exp_schema = schema({
+ field("time", b.time),
+ field("key", b.key),
+ field("l_v0", b.l_val),
+ field("r0_v0", b.r0_val),
+ field("r1_v0", b.r1_val),
+ });
+
+ // Test three table join
+ ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema,
l_data));
+ ASSERT_OK_AND_ASSIGN(auto r0_batches, MakeBatchesFromNumString(r0_schema,
r0_data));
+ ASSERT_OK_AND_ASSIGN(auto r1_batches, MakeBatchesFromNumString(r1_schema,
r1_data));
+ ASSERT_OK_AND_ASSIGN(auto exp_nokey_batches,
+ MakeBatchesFromNumString(exp_schema, exp_nokey_data));
+ ASSERT_OK_AND_ASSIGN(auto exp_emptykey_batches,
+ MakeBatchesFromNumString(exp_schema,
exp_emptykey_data));
+ ASSERT_OK_AND_ASSIGN(auto exp_batches,
+ MakeBatchesFromNumString(exp_schema, exp_data));
+ batches_runner(l_batches, r0_batches, r1_batches, exp_nokey_batches,
+ exp_emptykey_batches, exp_batches);
+ }
+
+ std::vector<util::string_view> l_data;
+ std::vector<util::string_view> r0_data;
+ std::vector<util::string_view> r1_data;
+ std::vector<util::string_view> exp_nokey_data;
+ std::vector<util::string_view> exp_emptykey_data;
+ std::vector<util::string_view> exp_data;
+ int64_t tolerance;
+};
+
+using AsofJoinBasicParams = std::tuple<std::function<void(BasicTest&)>,
std::string>;
+
+struct AsofJoinBasicTest : public testing::TestWithParam<AsofJoinBasicParams>
{};
+
class AsofJoinTest : public testing::Test {};
-TEST(AsofJoinTest, TestBasic1) {
+BasicTest GetBasicTest1() {
// Single key, single batch
- DoRunBasicTest(
- /*l*/ {R"([[0, 1, 1.0], [1000, 1, 2.0]])"},
- /*r0*/ {R"([[0, 1, 11.0]])"},
- /*r1*/ {R"([[1000, 1, 101.0]])"},
- /*exp*/ {R"([[0, 1, 1.0, 11.0, null], [1000, 1, 2.0, 11.0, 101.0]])"},
1000);
+ return BasicTest(
+ /*l*/ {R"([[0, 1, 1], [1000, 1, 2]])"},
+ /*r0*/ {R"([[0, 1, 11]])"},
+ /*r1*/ {R"([[1000, 1, 101]])"},
+ /*exp_nokey*/ {R"([[0, 0, 1, 11, null], [1000, 0, 2, 11, 101]])"},
+ /*exp_emptykey*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"},
+ /*exp*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, 1000);
}
-TEST(AsofJoinTest, TestBasic2) {
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic1, {
+ BasicTest basic_test = GetBasicTest1();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+BasicTest GetBasicTest2() {
// Single key, multiple batches
- DoRunBasicTest(
- /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"},
- /*r0*/ {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"},
- /*r1*/ {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"},
- /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"},
1000);
+ return BasicTest(
+ /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"},
+ /*r0*/ {R"([[0, 1, 11]])", R"([[1000, 1, 12]])"},
+ /*r1*/ {R"([[0, 1, 101]])", R"([[1000, 1, 102]])"},
+ /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"},
+ /*exp_emptykey*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"},
+ /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000);
}
-TEST(AsofJoinTest, TestBasic3) {
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic2, {
+ BasicTest basic_test = GetBasicTest2();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+BasicTest GetBasicTest3() {
// Single key, multiple left batches, single right batches
- DoRunBasicTest(
- /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"},
- /*r0*/ {R"([[0, 1, 11.0], [1000, 1, 12.0]])"},
- /*r1*/ {R"([[0, 1, 101.0], [1000, 1, 102.0]])"},
- /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"},
1000);
+ return BasicTest(
+ /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"},
+ /*r0*/ {R"([[0, 1, 11], [1000, 1, 12]])"},
+ /*r1*/ {R"([[0, 1, 101], [1000, 1, 102]])"},
+ /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"},
+ /*exp_emptykey*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"},
+ /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000);
}
-TEST(AsofJoinTest, TestBasic4) {
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic3, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic3_" +
std::get<1>(GetParam()));
+ BasicTest basic_test = GetBasicTest3();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+BasicTest GetBasicTest4() {
// Multi key, multiple batches, misaligned batches
- DoRunBasicTest(
+ return BasicTest(
/*l*/
- {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500,
1, 3.0], [1500, 2, 23.0]])",
- R"([[2000, 1, 4.0], [2000, 2, 24.0]])"},
+ {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3],
[1500, 2, 23]])",
+ R"([[2000, 1, 4], [2000, 2, 24]])"},
/*r0*/
- {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])",
- R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"},
+ {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])",
+ R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"},
/*r1*/
- {R"([[0, 2, 1001.0], [500, 1, 101.0]])",
- R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"},
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101],
[1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])",
+ R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"},
+ /*exp_emptykey*/
+ {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101],
[1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"},
/*exp*/
- {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0,
11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0],
[1500, 2, 23.0, 32.0, 1002.0]])",
- R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"},
+ {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101],
[1000, 2, 22, 31, 1001], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"},
1000);
}
-TEST(AsofJoinTest, TestBasic5) {
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic4, {
+ BasicTest basic_test = GetBasicTest4();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+BasicTest GetBasicTest5() {
// Multi key, multiple batches, misaligned batches, smaller tolerance
- DoRunBasicTest(/*l*/
- {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2,
22.0], [1500, 1, 3.0], [1500, 2, 23.0]])",
- R"([[2000, 1, 4.0], [2000, 2, 24.0]])"},
- /*r0*/
- {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])",
- R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"},
- /*r1*/
- {R"([[0, 2, 1001.0], [500, 1, 101.0]])",
- R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1,
103.0]])"},
- /*exp*/
- {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0],
[500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, null], [1500, 1, 3.0, 12.0,
102.0], [1500, 2, 23.0, 32.0, 1002.0]])",
- R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0,
1002.0]])"},
- 500);
-}
-
-TEST(AsofJoinTest, TestBasic6) {
+ return BasicTest(/*l*/
+ {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22],
[1500, 1, 3], [1500, 2, 23]])",
+ R"([[2000, 1, 4], [2000, 2, 24]])"},
+ /*r0*/
+ {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])",
+ R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2,
31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32,
1002]])",
+ R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"},
+ /*exp_emptykey*/
+ {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2,
31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32,
1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"},
+ /*exp*/
+ {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1,
2, 11, 101], [1000, 2, 22, 31, null], [1500, 1, 3, 12, 102], [1500, 2, 23, 32,
1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"},
+ 500);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic5, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic5_" +
std::get<1>(GetParam()));
+ BasicTest basic_test = GetBasicTest5();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+BasicTest GetBasicTest6() {
// Multi key, multiple batches, misaligned batches, zero tolerance
- DoRunBasicTest(/*l*/
- {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2,
22.0], [1500, 1, 3.0], [1500, 2, 23.0]])",
- R"([[2000, 1, 4.0], [2000, 2, 24.0]])"},
- /*r0*/
- {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])",
- R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"},
- /*r1*/
- {R"([[0, 2, 1001.0], [500, 1, 101.0]])",
- R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1,
103.0]])"},
- /*exp*/
- {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0],
[500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, null], [1500, 1, 3.0, null,
null], [1500, 2, 23.0, 32.0, 1002.0]])",
- R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, null,
null]])"},
- 0);
-}
-
-TEST(AsofJoinTest, TestEmpty1) {
+ return BasicTest(/*l*/
+ {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22],
[1500, 1, 3], [1500, 2, 23]])",
+ R"([[2000, 1, 4], [2000, 2, 24]])"},
+ /*r0*/
+ {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])",
+ R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2,
31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32,
1002]])",
+ R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"},
+ /*exp_emptykey*/
+ {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2,
31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32,
1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"},
+ /*exp*/
+ {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1,
2, null, 101], [1000, 2, 22, null, null], [1500, 1, 3, null, null], [1500, 2,
23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, null, null]])"},
+ 0);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic6, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic6_" +
std::get<1>(GetParam()));
+ BasicTest basic_test = GetBasicTest6();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+BasicTest GetEmptyTest1() {
// Empty left batch
- DoRunBasicTest(/*l*/
- {R"([])", R"([[2000, 1, 4.0], [2000, 2, 24.0]])"},
- /*r0*/
- {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])",
- R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"},
- /*r1*/
- {R"([[0, 2, 1001.0], [500, 1, 101.0]])",
- R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1,
103.0]])"},
- /*exp*/
- {R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0,
1002.0]])"},
- 1000);
-}
-
-TEST(AsofJoinTest, TestEmpty2) {
+ return BasicTest(/*l*/
+ {R"([])", R"([[2000, 1, 4], [2000, 2, 24]])"},
+ /*r0*/
+ {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])",
+ R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"},
+ /*exp_emptykey*/
+ {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"},
+ /*exp*/
+ {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"},
1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestEmpty1, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty1_" +
std::get<1>(GetParam()));
+ BasicTest basic_test = GetEmptyTest1();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+BasicTest GetEmptyTest2() {
// Empty left input
- DoRunBasicTest(/*l*/
- {R"([])"},
- /*r0*/
- {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])",
- R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"},
- /*r1*/
- {R"([[0, 2, 1001.0], [500, 1, 101.0]])",
- R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1,
103.0]])"},
- /*exp*/
- {R"([])"}, 1000);
-}
-
-TEST(AsofJoinTest, TestEmpty3) {
+ return BasicTest(/*l*/
+ {R"([])"},
+ /*r0*/
+ {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])",
+ R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([])"},
+ /*exp_emptykey*/
+ {R"([])"},
+ /*exp*/
+ {R"([])"}, 1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestEmpty2, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty2_" +
std::get<1>(GetParam()));
+ BasicTest basic_test = GetEmptyTest2();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+BasicTest GetEmptyTest3() {
// Empty right batch
- DoRunBasicTest(/*l*/
- {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2,
22.0], [1500, 1, 3.0], [1500, 2, 23.0]])",
- R"([[2000, 1, 4.0], [2000, 2, 24.0]])"},
- /*r0*/
- {R"([])", R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2,
33.0]])"},
- /*r1*/
- {R"([[0, 2, 1001.0], [500, 1, 101.0]])",
- R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1,
103.0]])"},
- /*exp*/
- {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0],
[500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null,
102.0], [1500, 2, 23.0, 32.0, 1002.0]])",
- R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0,
1002.0]])"},
- 1000);
-}
-
-TEST(AsofJoinTest, TestEmpty4) {
- // Empty right input
- DoRunBasicTest(/*l*/
- {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2,
22.0], [1500, 1, 3.0], [1500, 2, 23.0]])",
- R"([[2000, 1, 4.0], [2000, 2, 24.0]])"},
- /*r0*/
- {R"([])"},
- /*r1*/
- {R"([[0, 2, 1001.0], [500, 1, 101.0]])",
- R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1,
103.0]])"},
- /*exp*/
- {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0],
[500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null,
102.0], [1500, 2, 23.0, null, 1002.0]])",
- R"([[2000, 1, 4.0, null, 103.0], [2000, 2, 24.0, null,
1002.0]])"},
- 1000);
-}
-
-TEST(AsofJoinTest, TestEmpty5) {
- // All empty
- DoRunBasicTest(/*l*/
- {R"([])"},
- /*r0*/
- {R"([])"},
- /*r1*/
- {R"([])"},
- /*exp*/
- {R"([])"}, 1000);
+ return BasicTest(/*l*/
+ {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22],
[1500, 1, 3], [1500, 2, 23]])",
+ R"([[2000, 1, 4], [2000, 2, 24]])"},
+ /*r0*/
+ {R"([])", R"([[1500, 2, 32], [2000, 1, 13], [2500, 2,
33]])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500,
0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, 32, 1002], [1500, 0,
23, 32, 1002]])",
+ R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"},
+ /*exp_emptykey*/
+ {R"([[0, 1, 1, null, 1001], [0, 2, 21, null, 1001], [500,
1, 2, null, 101], [1000, 2, 22, null, 102], [1500, 1, 3, 32, 1002], [1500, 2,
23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"},
+ /*exp*/
+ {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500,
1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2,
23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"},
+ 1000);
}
-TEST(AsofJoinTest, TestUnsupportedOntype) {
- DoRunInvalidTypeTest(
- schema({field("time", utf8()), field("key", int32()), field("l_v0",
float64())}),
- schema({field("time", utf8()), field("key", int32()), field("r0_v0",
float32())}));
+TRACED_TEST_P(AsofJoinBasicTest, TestEmpty3, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty3_" +
std::get<1>(GetParam()));
+ BasicTest basic_test = GetEmptyTest3();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+BasicTest GetEmptyTest4() {
+ // Empty right input
+ return BasicTest(/*l*/
+ {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22],
[1500, 1, 3], [1500, 2, 23]])",
+ R"([[2000, 1, 4], [2000, 2, 24]])"},
+ /*r0*/
+ {R"([])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500,
0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, null, 1002], [1500, 0,
23, null, 1002]])",
+ R"([[2000, 0, 4, null, 103], [2000, 0, 24, null, 103]])"},
+ /*exp_emptykey*/
+ {R"([[0, 1, 1, null, 1001], [0, 2, 21, null, 1001], [500,
1, 2, null, 101], [1000, 2, 22, null, 102], [1500, 1, 3, null, 1002], [1500, 2,
23, null, 1002]])",
+ R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 103]])"},
+ /*exp*/
+ {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500,
1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2,
23, null, 1002]])",
+ R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 1002]])"},
+ 1000);
}
-TEST(AsofJoinTest, TestUnsupportedBytype) {
- DoRunInvalidTypeTest(
- schema({field("time", int64()), field("key", utf8()), field("l_v0",
float64())}),
- schema({field("time", int64()), field("key", utf8()), field("r0_v0",
float32())}));
+TRACED_TEST_P(AsofJoinBasicTest, TestEmpty4, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty4_" +
std::get<1>(GetParam()));
+ BasicTest basic_test = GetEmptyTest4();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+BasicTest GetEmptyTest5() {
+ // All empty
+ return BasicTest(/*l*/
+ {R"([])"},
+ /*r0*/
+ {R"([])"},
+ /*r1*/
+ {R"([])"},
+ /*exp_nokey*/
+ {R"([])"},
+ /*exp_emptykey*/
+ {R"([])"},
+ /*exp*/
+ {R"([])"}, 1000);
}
-TEST(AsofJoinTest, TestUnsupportedDatatype) {
- // Utf8 is unsupported
+TRACED_TEST_P(AsofJoinBasicTest, TestEmpty5, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty5_" +
std::get<1>(GetParam()));
+ BasicTest basic_test = GetEmptyTest5();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
+INSTANTIATE_TEST_SUITE_P(
+ AsofJoinNodeTest, AsofJoinBasicTest,
+ testing::Values(AsofJoinBasicParams(BasicTest::DoSingleByKey,
"SingleByKey"),
+ AsofJoinBasicParams(BasicTest::DoDoubleByKey,
"DoubleByKey"),
+ AsofJoinBasicParams(BasicTest::DoMutateByKey,
"MutateByKey"),
+ AsofJoinBasicParams(BasicTest::DoMutateNoKey,
"MutateNoKey"),
+ AsofJoinBasicParams(BasicTest::DoMutateNullKey,
"MutateNullKey"),
+ AsofJoinBasicParams(BasicTest::DoMutateEmptyKey,
"MutateEmptyKey")));
+
+TRACED_TEST(AsofJoinTest, TestUnsupportedOntype, {
+ DoRunInvalidTypeTest(schema({field("time", list(int32())), field("key",
int32()),
+ field("l_v0", float64())}),
+ schema({field("time", list(int32())), field("key",
int32()),
+ field("r0_v0", float32())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestUnsupportedBytype, {
+ DoRunInvalidTypeTest(schema({field("time", int64()), field("key",
list(int32())),
+ field("l_v0", float64())}),
+ schema({field("time", int64()), field("key",
list(int32())),
+ field("r0_v0", float32())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestUnsupportedDatatype, {
+ // List is unsupported
DoRunInvalidTypeTest(
schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())}),
- schema({field("time", int64()), field("key", int32()), field("r0_v0",
utf8())}));
-}
+ schema({field("time", int64()), field("key", int32()),
+ field("r0_v0", list(int32()))}));
+})
-TEST(AsofJoinTest, TestMissingKeys) {
- DoRunInvalidTypeTest(
+TRACED_TEST(AsofJoinTest, TestMissingKeys, {
+ DoRunMissingKeysTest(
schema({field("time1", int64()), field("key", int32()), field("l_v0",
float64())}),
schema(
{field("time1", int64()), field("key", int32()), field("r0_v0",
float64())}));
- DoRunInvalidTypeTest(
+ DoRunMissingKeysTest(
schema({field("time", int64()), field("key1", int32()), field("l_v0",
float64())}),
schema(
{field("time", int64()), field("key1", int32()), field("r0_v0",
float64())}));
-}
+})
+
+TRACED_TEST(AsofJoinTest, TestUnsupportedTolerance, {
+ // Utf8 is unsupported
+ DoRunInvalidToleranceTest(
+ schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())}),
+ schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestMissingOnKey, {
+ DoRunMissingOnKeyTest(
+ schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())}),
+ schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestMissingByKey, {
+ DoRunMissingByKeyTest(
+ schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())}),
+ schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestNestedOnKey, {
+ DoRunNestedOnKeyTest(
+ schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())}),
+ schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestNestedByKey, {
+ DoRunNestedByKeyTest(
+ schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())}),
+ schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestAmbiguousOnKey, {
+ DoRunAmbiguousOnKeyTest(
+ schema({field("time", int64()), field("time", int64()), field("key",
int32()),
+ field("l_v0", float64())}),
+ schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestAmbiguousByKey, {
+ DoRunAmbiguousByKeyTest(
+ schema({field("time", int64()), field("key", int64()), field("key",
int32()),
+ field("l_v0", float64())}),
+ schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestLeftUnorderedOnKey, {
+ DoRunUnorderedPlanTest(
+ /*l_unordered=*/true, /*r_unordered=*/false,
+ schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())}),
+ schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestRightUnorderedOnKey, {
+ DoRunUnorderedPlanTest(
+ /*l_unordered=*/false, /*r_unordered=*/true,
+ schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())}),
+ schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
+})
+
+TRACED_TEST(AsofJoinTest, TestUnorderedOnKey, {
+ DoRunUnorderedPlanTest(
+ /*l_unordered=*/true, /*r_unordered=*/true,
+ schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())}),
+ schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
+})
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/hash_join.cc
b/cpp/src/arrow/compute/exec/hash_join.cc
index 5cf66b3d09..da1710fe08 100644
--- a/cpp/src/arrow/compute/exec/hash_join.cc
+++ b/cpp/src/arrow/compute/exec/hash_join.cc
@@ -26,7 +26,6 @@
#include <vector>
#include "arrow/compute/exec/hash_join_dict.h"
-#include "arrow/compute/exec/key_hash.h"
#include "arrow/compute/exec/task_util.h"
#include "arrow/compute/kernels/row_encoder.h"
#include "arrow/compute/row/encode_internal.h"
diff --git a/cpp/src/arrow/compute/exec/options.h
b/cpp/src/arrow/compute/exec/options.h
index a8e8c1ee23..e0172bff7f 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -397,23 +397,25 @@ class ARROW_EXPORT HashJoinNodeOptions : public
ExecNodeOptions {
/// This node will output one row for each row in the left table.
class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions {
public:
- AsofJoinNodeOptions(FieldRef on_key, FieldRef by_key, int64_t tolerance)
- : on_key(std::move(on_key)), by_key(std::move(by_key)),
tolerance(tolerance) {}
+ AsofJoinNodeOptions(FieldRef on_key, std::vector<FieldRef> by_key, int64_t
tolerance)
+ : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {}
- /// \brief "on" key for the join. Each
+ /// \brief "on" key for the join.
///
- /// All inputs tables must be sorted by the "on" key. Inexact
- /// match is used on the "on" key. i.e., a row is considiered match iff
+ /// All inputs tables must be sorted by the "on" key. Must be a single field
of a common
+ /// type. Inexact match is used on the "on" key. i.e., a row is considered
match iff
/// left_on - tolerance <= right_on <= left_on.
- /// Currently, "on" key must be an int64 field
+ /// Currently, the "on" key must be of an integer, date, or timestamp type.
FieldRef on_key;
/// \brief "by" key for the join.
///
/// All input tables must have the "by" key. Exact equality
/// is used for the "by" key.
- /// Currently, the "by" key must be an int32 field
- FieldRef by_key;
- /// Tolerance for inexact "on" key matching
+ /// Currently, the "by" key must be of an integer, date, timestamp, or
base-binary type
+ std::vector<FieldRef> by_key;
+ /// \brief Tolerance for inexact "on" key matching. Must be non-negative.
+ ///
+ /// The tolerance is interpreted in the same units as the "on" key.
int64_t tolerance;
};
diff --git a/cpp/src/arrow/compute/light_array.cc
b/cpp/src/arrow/compute/light_array.cc
index a337d4f999..caa392319b 100644
--- a/cpp/src/arrow/compute/light_array.cc
+++ b/cpp/src/arrow/compute/light_array.cc
@@ -147,6 +147,12 @@ Result<KeyColumnArray> ColumnArrayFromArrayData(
const std::shared_ptr<ArrayData>& array_data, int64_t start_row, int64_t
num_rows) {
ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata metadata,
ColumnMetadataFromDataType(array_data->type));
+ return ColumnArrayFromArrayDataAndMetadata(array_data, metadata, start_row,
num_rows);
+}
+
+KeyColumnArray ColumnArrayFromArrayDataAndMetadata(
+ const std::shared_ptr<ArrayData>& array_data, const KeyColumnMetadata&
metadata,
+ int64_t start_row, int64_t num_rows) {
KeyColumnArray column_array = KeyColumnArray(
metadata, array_data->offset + start_row + num_rows,
array_data->buffers[0] != NULLPTR ? array_data->buffers[0]->data() :
nullptr,
diff --git a/cpp/src/arrow/compute/light_array.h
b/cpp/src/arrow/compute/light_array.h
index 0620f6d3eb..389b63cca4 100644
--- a/cpp/src/arrow/compute/light_array.h
+++ b/cpp/src/arrow/compute/light_array.h
@@ -135,7 +135,7 @@ class ARROW_EXPORT KeyColumnArray {
/// Only valid if this is a view into a varbinary type
uint32_t* mutable_offsets() {
DCHECK(!metadata_.is_fixed_length);
- DCHECK(metadata_.fixed_length == sizeof(uint32_t));
+ DCHECK_EQ(metadata_.fixed_length, sizeof(uint32_t));
return reinterpret_cast<uint32_t*>(mutable_data(kFixedLengthBuffer));
}
/// \brief Return a read-only version of the offsets buffer
@@ -143,7 +143,7 @@ class ARROW_EXPORT KeyColumnArray {
/// Only valid if this is a view into a varbinary type
const uint32_t* offsets() const {
DCHECK(!metadata_.is_fixed_length);
- DCHECK(metadata_.fixed_length == sizeof(uint32_t));
+ DCHECK_EQ(metadata_.fixed_length, sizeof(uint32_t));
return reinterpret_cast<const uint32_t*>(data(kFixedLengthBuffer));
}
/// \brief Return a mutable version of the large-offsets buffer
@@ -151,7 +151,7 @@ class ARROW_EXPORT KeyColumnArray {
/// Only valid if this is a view into a large varbinary type
uint64_t* mutable_large_offsets() {
DCHECK(!metadata_.is_fixed_length);
- DCHECK(metadata_.fixed_length == sizeof(uint64_t));
+ DCHECK_EQ(metadata_.fixed_length, sizeof(uint64_t));
return reinterpret_cast<uint64_t*>(mutable_data(kFixedLengthBuffer));
}
/// \brief Return a read-only version of the large-offsets buffer
@@ -159,7 +159,7 @@ class ARROW_EXPORT KeyColumnArray {
/// Only valid if this is a view into a large varbinary type
const uint64_t* large_offsets() const {
DCHECK(!metadata_.is_fixed_length);
- DCHECK(metadata_.fixed_length == sizeof(uint64_t));
+ DCHECK_EQ(metadata_.fixed_length, sizeof(uint64_t));
return reinterpret_cast<const uint64_t*>(data(kFixedLengthBuffer));
}
/// \brief Return the type metadata
@@ -205,6 +205,17 @@ ARROW_EXPORT Result<KeyColumnMetadata>
ColumnMetadataFromDataType(
ARROW_EXPORT Result<KeyColumnArray> ColumnArrayFromArrayData(
const std::shared_ptr<ArrayData>& array_data, int64_t start_row, int64_t
num_rows);
+/// \brief Create KeyColumnArray from ArrayData and KeyColumnMetadata
+///
+/// If `type` is a dictionary type then this will return the KeyColumnArray for
+/// the indices array
+///
+/// The caller should ensure this is only called on "key" columns.
+/// \see ColumnMetadataFromDataType for details
+ARROW_EXPORT KeyColumnArray ColumnArrayFromArrayDataAndMetadata(
+ const std::shared_ptr<ArrayData>& array_data, const KeyColumnMetadata&
metadata,
+ int64_t start_row, int64_t num_rows);
+
/// \brief Create KeyColumnMetadata instances from an ExecBatch
///
/// column_metadatas will be resized to fit
diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h
index 66da3cadcb..e2b74e865f 100644
--- a/cpp/src/arrow/type_traits.h
+++ b/cpp/src/arrow/type_traits.h
@@ -622,6 +622,13 @@ using is_fixed_size_binary_type =
std::is_base_of<FixedSizeBinaryType, T>;
template <typename T, typename R = void>
using enable_if_fixed_size_binary =
enable_if_t<is_fixed_size_binary_type<T>::value, R>;
+// This includes primitive, dictionary, and fixed-size-binary types
+template <typename T>
+using is_fixed_width_type = std::is_base_of<FixedWidthType, T>;
+
+template <typename T, typename R = void>
+using enable_if_fixed_width_type = enable_if_t<is_fixed_width_type<T>::value,
R>;
+
template <typename T>
using is_binary_like_type =
std::integral_constant<bool, (is_base_binary_type<T>::value &&