You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2023/01/08 20:04:31 UTC
[arrow] branch master updated: ARROW-18427: [C++] Support negative tolerance in `AsofJoinNode` (#14934)
This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 2acc51a7d5 ARROW-18427: [C++] Support negative tolerance in `AsofJoinNode` (#14934)
2acc51a7d5 is described below
commit 2acc51a7d5304c3fc6a432f1c09946547ca91d74
Author: rtpsw <rt...@hotmail.com>
AuthorDate: Sun Jan 8 22:04:25 2023 +0200
ARROW-18427: [C++] Support negative tolerance in `AsofJoinNode` (#14934)
See https://issues.apache.org/jira/browse/ARROW-18427
Lead-authored-by: Yaron Gvili <rt...@hotmail.com>
Co-authored-by: rtpsw <rt...@hotmail.com>
Signed-off-by: Weston Pace <we...@gmail.com>
---
cpp/src/arrow/compute/exec/asof_join_node.cc | 248 ++++++++++++++++-----
cpp/src/arrow/compute/exec/asof_join_node_test.cc | 256 ++++++++++++++++++++--
2 files changed, 437 insertions(+), 67 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc
index d071c0ce7f..a752cf800d 100644
--- a/cpp/src/arrow/compute/exec/asof_join_node.cc
+++ b/cpp/src/arrow/compute/exec/asof_join_node.cc
@@ -18,6 +18,7 @@
#include "arrow/compute/exec/asof_join_node.h"
#include <condition_variable>
+#include <limits>
#include <memory>
#include <mutex>
#include <optional>
@@ -64,6 +65,37 @@ typedef uint64_t ByType;
typedef uint64_t OnType;
typedef uint64_t HashType;
+/// A tolerance type with overflow-avoiding operations
+struct TolType {
+ constexpr static OnType kMinValue = std::numeric_limits<OnType>::lowest();
+ constexpr static OnType kMaxValue = std::numeric_limits<OnType>::max();
+
+ explicit TolType(int64_t tol)
+ : value(static_cast<uint64_t>(tol < 0 ? -tol : tol)), negative(tol < 0) {}
+
+ OnType value;
+ bool negative;
+
+ // an entry with a time below this threshold expires
+ inline OnType Expiry(OnType left_value) {
+ return negative ? left_value
+ : (left_value < kMinValue + value ? kMinValue : left_value - value);
+ }
+
+ // an entry with a time after this threshold is distant
+ inline OnType Horizon(OnType left_value) {
+ return negative ? (left_value > kMaxValue - value ? kMaxValue : left_value + value)
+ : left_value;
+ }
+
+ // true when the tolerance accepts the RHS time given the LHS one
+ inline bool Accepts(OnType left_value, OnType right_value) {
+ return negative
+ ? (left_value > right_value ? false : right_value - left_value <= value)
+ : (left_value < right_value ? false : left_value - right_value <= value);
+ }
+};
+
// Maximum number of tables that can be joined
#define MAX_JOIN_TABLES 64
typedef uint64_t row_index_t;
@@ -140,6 +172,17 @@ struct MemoStore {
// Stores last known values for all the keys
struct Entry {
+ Entry() = default;
+
+ Entry(OnType time, std::shared_ptr<arrow::RecordBatch> batch, row_index_t row)
+ : time(time), batch(batch), row(row) {}
+
+ void swap(Entry& other) {
+ std::swap(time, other.time);
+ std::swap(batch, other.batch);
+ std::swap(row, other.row);
+ }
+
// Timestamp associated with the entry
OnType time;
@@ -151,31 +194,103 @@ struct MemoStore {
row_index_t row;
};
+ explicit MemoStore(bool no_future)
+ : no_future_(no_future), current_time_(std::numeric_limits<OnType>::lowest()) {}
+
+ // true when there are no future entries, which is the case for the LHS table and the
+ // case for when the tolerance is positive. A regular non-negative-tolerance as-of-join
+ // operation requires memorizing only the most recently observed entry per key. OTOH, a
+ // negative-tolerance (future) as-of-join operation requires memorizing per-key queues
+ // of entries up to the tolerance's horizon and in particular distinguishes between the
+ // current (front-of-queue) and latest (back-of-queue) entries per key.
+ bool no_future_;
+ // the time of the current entry, defaulting to 0.
+ // when entries with a time less than T are removed, the current time is updated to the
+ // time of the next (by-time) and now-current entry or to T if no such entry exists.
+ OnType current_time_;
+ // current entry per key
std::unordered_map<ByType, Entry> entries_;
+ // future entries per key
+ std::unordered_map<ByType, std::queue<Entry>> future_entries_;
+ // current and future (distinct) times of existing entries
+ std::deque<OnType> times_;
+
+ void swap(MemoStore& memo) {
+ std::swap(no_future_, memo.no_future_);
+ std::swap(current_time_, memo.current_time_);
+ entries_.swap(memo.entries_);
+ future_entries_.swap(memo.future_entries_);
+ times_.swap(memo.times_);
+ }
- void Store(const std::shared_ptr<RecordBatch>& batch, row_index_t row, OnType time,
- ByType key) {
- auto& e = entries_[key];
- // that we can do this assignment optionally, is why we
- // can get array with using shared_ptr above (the batch
- // shouldn't change that often)
- if (e.batch != batch) e.batch = batch;
- e.row = row;
- e.time = time;
+ void Store(OnType for_time, const std::shared_ptr<RecordBatch>& batch, row_index_t row,
+ OnType time, ByType key) {
+ if (no_future_ || entries_.count(key) == 0) {
+ auto& e = entries_[key];
+ // that we can do this assignment optionally, is why we
+ // can get away with using shared_ptr above (the batch
+ // shouldn't change that often)
+ if (e.batch != batch) e.batch = batch;
+ e.row = row;
+ e.time = time;
+ } else {
+ future_entries_[key].emplace(time, batch, row);
+ }
+ if (!no_future_ || times_.empty() || times_.front() != time) {
+ times_.push_back(time);
+ } else {
+ times_.front() = time;
+ }
}
std::optional<const Entry*> GetEntryForKey(ByType key) const {
auto e = entries_.find(key);
- if (entries_.end() == e) return std::nullopt;
- return std::optional<const Entry*>(&e->second);
+ return entries_.end() == e ? std::nullopt : std::optional<const Entry*>(&e->second);
}
- void RemoveEntriesWithLesserTime(OnType ts) {
- for (auto e = entries_.begin(); e != entries_.end();)
- if (e->second.time < ts)
- e = entries_.erase(e);
- else
+ bool RemoveEntriesWithLesserTime(OnType ts) {
+ for (auto fe = future_entries_.begin(); fe != future_entries_.end();) {
+ auto& queue = fe->second;
+ while (!queue.empty() && queue.front().time < ts) queue.pop();
+ if (queue.empty()) {
+ fe = future_entries_.erase(fe);
+ } else {
+ ++fe;
+ }
+ }
+ for (auto e = entries_.begin(); e != entries_.end();) {
+ if (e->second.time < ts) {
+ auto fe = future_entries_.find(e->first);
+ if (fe != future_entries_.end() && !fe->second.empty()) {
+ auto& queue = fe->second;
+ e->second.swap(queue.front());
+ queue.pop();
+ ++e;
+ } else {
+ e = entries_.erase(e);
+ }
+ } else {
++e;
+ }
+ }
+ bool updated = false;
+ while (!times_.empty() && times_.front() < ts) {
+ current_time_ = times_.front();
+ times_.pop_front();
+ updated = true;
+ }
+ for (auto times_it = times_.begin(); times_it != times_.end(); times_it++) {
+ if (current_time_ < *times_it) {
+ current_time_ = *times_it;
+ updated = true;
+ }
+ if (*times_it > ts) break;
+ }
+ if (current_time_ < ts) {
+ current_time_ = ts;
+ updated = true;
+ }
+ return updated;
}
};
@@ -245,8 +360,8 @@ class InputState {
// turned into output record batches.
public:
- InputState(bool must_hash, bool may_rehash, KeyHasher* key_hasher,
- const std::shared_ptr<arrow::Schema>& schema,
+ InputState(size_t index, TolType tolerance, bool must_hash, bool may_rehash,
+ KeyHasher* key_hasher, const std::shared_ptr<arrow::Schema>& schema,
const col_index_t time_col_index,
const std::vector<col_index_t>& key_col_index)
: queue_(),
@@ -257,7 +372,9 @@ class InputState {
key_type_id_(key_col_index.size()),
key_hasher_(key_hasher),
must_hash_(must_hash),
- may_rehash_(may_rehash) {
+ may_rehash_(may_rehash),
+ tolerance_(tolerance),
+ memo_(/*no_future=*/index == 0 || !tolerance.negative) {
for (size_t k = 0; k < key_col_index_.size(); k++) {
key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id();
}
@@ -290,6 +407,20 @@ class InputState {
return queue_.Empty();
}
+ // true when the queue is empty and, when memo may have future entries (the case of a
+ // negative tolerance), when the memo is empty.
+ // used when checking whether RHS is up to date with LHS.
+ bool CurrentEmpty() const {
+ return memo_.no_future_ ? Empty() : memo_.times_.empty() && Empty();
+ }
+
+ // in case memo may not have future entries (the case of a non-negative tolerance),
+ // returns the latest time (which is current); otherwise, returns the current time.
+ // used when checking whether RHS is up to date with LHS.
+ OnType GetCurrentTime() const {
+ return memo_.no_future_ ? GetLatestTime() : memo_.current_time_;
+ }
+
int total_batches() const { return total_batches_; }
// Gets latest batch (precondition: must not be empty)
@@ -305,10 +436,10 @@ class InputState {
}
inline ByType GetLatestKey() const {
- return GetLatestKey(queue_.UnsyncFront().get(), latest_ref_row_);
+ return GetKey(GetLatestBatch().get(), latest_ref_row_);
}
- inline ByType GetLatestKey(const RecordBatch* batch, row_index_t row) const {
+ inline ByType GetKey(const RecordBatch* batch, row_index_t row) const {
if (must_hash_) {
return key_hasher_->HashesFor(batch)[row];
}
@@ -337,10 +468,10 @@ class InputState {
}
inline OnType GetLatestTime() const {
- return GetLatestTime(queue_.UnsyncFront().get(), latest_ref_row_);
+ return GetTime(GetLatestBatch().get(), latest_ref_row_);
}
- inline ByType GetLatestTime(const RecordBatch* batch, row_index_t row) const {
+ inline ByType GetTime(const RecordBatch* batch, row_index_t row) const {
auto data = batch->column_data(time_col_index_);
switch (time_type_id_) {
LATEST_VAL_CASE(INT8, time_value)
@@ -391,16 +522,18 @@ class InputState {
return have_active_batch;
}
- // Advance the data to be immediately past the specified timestamp, update
- // latest_time and latest_ref_row to the value that immediately pass the
- // specified timestamp.
+ // Advance the data to be immediately past the tolerance's horizon for the specified
+ // timestamp, update latest_time and latest_ref_row to the value that immediately pass
+ // the horizon. Update the memo-store with any entries or future entries so observed.
// Returns true if updates were made, false if not.
Result<bool> AdvanceAndMemoize(OnType ts) {
// Advance the right side row index until we reach the latest right row (for each key)
// for the given left timestamp.
// Check if already updated for TS (or if there is no latest)
- if (Empty()) return false; // can't advance if empty
+ if (Empty()) { // can't advance if empty and no future entries
+ return memo_.no_future_ ? false : memo_.RemoveEntriesWithLesserTime(ts);
+ }
// Not updated. Try to update and possibly advance.
bool advanced, updated = false;
@@ -410,8 +543,8 @@ class InputState {
// Keep advancing right table until we hit the latest row that has
// timestamp <= ts. This is because we only need the latest row for the
// match given a left ts.
- if (latest_time > ts) {
- break; // hit a future timestamp -- done updating for now
+ if (latest_time > tolerance_.Horizon(ts)) { // hit a distant timestamp
+ if (memo_.no_future_ || !memo_.times_.empty()) break; // no future entries
}
auto rb = GetLatestBatch();
if (may_rehash_ && rb->column_data(key_col_index_[0])->GetNullCount() > 0) {
@@ -419,20 +552,30 @@ class InputState {
may_rehash_ = false;
Rehash();
}
- memo_.Store(rb, latest_ref_row_, latest_time, GetLatestKey());
- updated = true;
+ memo_.Store(ts, rb, latest_ref_row_, latest_time, GetLatestKey());
+ updated = memo_.no_future_;
ARROW_ASSIGN_OR_RAISE(advanced, Advance());
} while (advanced);
+ if (!memo_.no_future_) { // "updated" was not modified in the loop; set it here
+ updated = memo_.RemoveEntriesWithLesserTime(ts);
+ }
return updated;
}
void Rehash() {
- MemoStore new_memo;
- for (const auto& entry : memo_.entries_) {
- const auto& e = entry.second;
- new_memo.Store(e.batch, e.row, e.time, GetLatestKey(e.batch.get(), e.row));
+ MemoStore new_memo(memo_.no_future_);
+ new_memo.current_time_ = memo_.current_time_;
+ for (auto e = memo_.entries_.begin(); e != memo_.entries_.end(); ++e) {
+ auto& entry = e->second;
+ auto new_key = GetKey(entry.batch.get(), entry.row);
+ new_memo.entries_[new_key].swap(entry);
+ auto fe = memo_.future_entries_.find(e->first);
+ if (fe != memo_.future_entries_.end()) {
+ new_memo.future_entries_[new_key].swap(fe->second);
+ }
}
- memo_ = new_memo;
+ memo_.times_.swap(new_memo.times_);
+ memo_.swap(new_memo);
}
Status Push(const std::shared_ptr<arrow::RecordBatch>& rb) {
@@ -492,6 +635,8 @@ class InputState {
bool must_hash_;
// True if by-key values may be rehashed
bool may_rehash_;
+ // Tolerance
+ TolType tolerance_;
// Index of the latest row reference within; if >0 then queue_ cannot be empty
// Must be < queue_.front()->num_rows() if queue_ is non-empty
row_index_t latest_ref_row_ = 0;
@@ -535,7 +680,7 @@ class CompositeReferenceTable {
// Adds the latest row from the input state as a new composite reference row
// - LHS must have a valid key,timestep,and latest rows
// - RHS must have valid data memo'ed for the key
- void Emplace(std::vector<std::unique_ptr<InputState>>& in, OnType tolerance) {
+ void Emplace(std::vector<std::unique_ptr<InputState>>& in, TolType tolerance) {
DCHECK_EQ(in.size(), n_tables_);
// Get the LHS key
@@ -566,7 +711,7 @@ class CompositeReferenceTable {
std::optional<const MemoStore::Entry*> opt_entry = in[i]->GetMemoEntryForKey(key);
if (opt_entry.has_value()) {
DCHECK(*opt_entry);
- if ((*opt_entry)->time + tolerance >= lhs_latest_time) {
+ if (tolerance.Accepts(lhs_latest_time, (*opt_entry)->time)) {
// Have a valid entry
const MemoStore::Entry* entry = *opt_entry;
row.refs[i].batch = entry->batch.get();
@@ -752,9 +897,9 @@ class AsofJoinNode : public ExecNode {
auto& rhs = *state_[i];
if (!rhs.Finished()) {
// If RHS is finished, then we know it's up to date
- if (rhs.Empty())
+ if (rhs.CurrentEmpty())
return false; // RHS isn't finished, but is empty --> not up to date
- if (lhs_ts >= rhs.GetLatestTime())
+ if (lhs_ts >= rhs.GetCurrentTime())
return false; // RHS isn't up to date (and not finished)
}
}
@@ -794,8 +939,9 @@ class AsofJoinNode : public ExecNode {
// Prune memo entries that have expired (to bound memory consumption)
if (!lhs.Empty()) {
for (size_t i = 1; i < state_.size(); ++i) {
- if (lhs.GetLatestTime() > tolerance_) {
- state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - tolerance_);
+ OnType ts = tolerance_.Expiry(lhs.GetLatestTime());
+ if (ts != TolType::kMinValue) {
+ state_[i]->RemoveMemoEntriesWithLesserTime(ts);
}
}
}
@@ -890,7 +1036,7 @@ class AsofJoinNode : public ExecNode {
AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector<std::string> input_labels,
const std::vector<col_index_t>& indices_of_on_key,
const std::vector<std::vector<col_index_t>>& indices_of_by_key,
- OnType tolerance, std::shared_ptr<Schema> output_schema,
+ TolType tolerance, std::shared_ptr<Schema> output_schema,
std::vector<std::unique_ptr<KeyHasher>> key_hashers, bool must_hash,
bool may_rehash);
@@ -900,8 +1046,8 @@ class AsofJoinNode : public ExecNode {
RETURN_NOT_OK(key_hashers_[i]->Init(plan()->query_context()->exec_context(),
output_schema()));
state_.push_back(std::make_unique<InputState>(
- must_hash_, may_rehash_, key_hashers_[i].get(), inputs[i]->output_schema(),
- indices_of_on_key_[i], indices_of_by_key_[i]));
+ i, tolerance_, must_hash_, may_rehash_, key_hashers_[i].get(),
+ inputs[i]->output_schema(), indices_of_on_key_[i], indices_of_by_key_[i]));
}
col_index_t dst_offset = 0;
@@ -1130,13 +1276,7 @@ class AsofJoinNode : public ExecNode {
static arrow::Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs";
-
const auto& join_options = checked_cast<const AsofJoinNodeOptions&>(options);
- if (join_options.tolerance < 0) {
- return Status::Invalid("AsOfJoin tolerance must be non-negative but is ",
- join_options.tolerance);
- }
-
ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(join_options.input_keys));
size_t n_input = inputs.size();
std::vector<std::string> input_labels(n_input);
@@ -1165,7 +1305,7 @@ class AsofJoinNode : public ExecNode {
bool may_rehash = n_by == 1 && !must_hash;
return plan->EmplaceNode<AsofJoinNode>(
plan, inputs, std::move(input_labels), std::move(indices_of_on_key),
- std::move(indices_of_by_key), time_value(join_options.tolerance),
+ std::move(indices_of_by_key), TolType(join_options.tolerance),
std::move(output_schema), std::move(key_hashers), must_hash, may_rehash);
}
@@ -1224,7 +1364,7 @@ class AsofJoinNode : public ExecNode {
// Each input state correponds to an input table
std::vector<std::unique_ptr<InputState>> state_;
std::mutex gate_;
- OnType tolerance_;
+ TolType tolerance_;
// Queue for triggering processing of a given input
// (a false value is a poison pill)
@@ -1240,7 +1380,7 @@ AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs,
std::vector<std::string> input_labels,
const std::vector<col_index_t>& indices_of_on_key,
const std::vector<std::vector<col_index_t>>& indices_of_by_key,
- OnType tolerance, std::shared_ptr<Schema> output_schema,
+ TolType tolerance, std::shared_ptr<Schema> output_schema,
std::vector<std::unique_ptr<KeyHasher>> key_hashers,
bool must_hash, bool may_rehash)
: ExecNode(plan, inputs, input_labels,
diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc
index e30e842095..6968aa03c9 100644
--- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc
+++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc
@@ -318,12 +318,6 @@ void DoRunInvalidTypeTest(const std::shared_ptr<Schema>& l_schema,
DoRunInvalidPlanTest(l_schema, r_schema, 0, "Unsupported type for ");
}
-void DoRunInvalidToleranceTest(const std::shared_ptr<Schema>& l_schema,
- const std::shared_ptr<Schema>& r_schema) {
- DoRunInvalidPlanTest(l_schema, r_schema, -1,
- "AsOfJoin tolerance must be non-negative but is ");
-}
-
void DoRunMissingKeysTest(const std::shared_ptr<Schema>& l_schema,
const std::shared_ptr<Schema>& r_schema) {
DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : No match");
@@ -668,6 +662,23 @@ TRACED_TEST_P(AsofJoinBasicTest, TestBasic1, {
runner(basic_test);
})
+BasicTest GetBasicTest1Negative() {
+ // Single key, single batch
+ return BasicTest(
+ /*l*/ {R"([[0, 1, 1], [1000, 1, 2]])"},
+ /*r0*/ {R"([[1000, 1, 11]])"},
+ /*r1*/ {R"([[2000, 1, 101]])"},
+ /*exp_nokey*/ {R"([[0, 0, 1, 11, null], [1000, 0, 2, 11, 101]])"},
+ /*exp_emptykey*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"},
+ /*exp*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, -1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic1Negative, {
+ BasicTest basic_test = GetBasicTest1Negative();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
BasicTest GetBasicTest2() {
// Single key, multiple batches
return BasicTest(
@@ -685,6 +696,23 @@ TRACED_TEST_P(AsofJoinBasicTest, TestBasic2, {
runner(basic_test);
})
+BasicTest GetBasicTest2Negative() {
+ // Single key, multiple batches
+ return BasicTest(
+ /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"},
+ /*r0*/ {R"([[500, 1, 11]])", R"([[1000, 1, 12]])"},
+ /*r1*/ {R"([[500, 1, 101]])", R"([[1000, 1, 102]])"},
+ /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"},
+ /*exp_emptykey*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"},
+ /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, -1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic2Negative, {
+ BasicTest basic_test = GetBasicTest2Negative();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
BasicTest GetBasicTest3() {
// Single key, multiple left batches, single right batches
return BasicTest(
@@ -703,6 +731,24 @@ TRACED_TEST_P(AsofJoinBasicTest, TestBasic3, {
runner(basic_test);
})
+BasicTest GetBasicTest3Negative() {
+ // Single key, multiple left batches, single right batches
+ return BasicTest(
+ /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"},
+ /*r0*/ {R"([[500, 1, 11], [1000, 1, 12]])"},
+ /*r1*/ {R"([[500, 1, 101], [1000, 1, 102]])"},
+ /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"},
+ /*exp_emptykey*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"},
+ /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, -1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic3Negative, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic3_" + std::get<1>(GetParam()));
+ BasicTest basic_test = GetBasicTest3Negative();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
BasicTest GetBasicTest4() {
// Multi key, multiple batches, misaligned batches
return BasicTest(
@@ -733,6 +779,36 @@ TRACED_TEST_P(AsofJoinBasicTest, TestBasic4, {
runner(basic_test);
})
+BasicTest GetBasicTest4Negative() {
+ // Multi key, multiple batches, misaligned batches
+ return BasicTest(
+ /*l*/
+ {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])",
+ R"([[2000, 1, 4], [2000, 2, 24]])"},
+ /*r0*/
+ {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])",
+ R"([[1600, 2, 32], [1900, 2, 33], [2100, 1, 13]])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1100, 1, 102], [1600, 2, 1002], [2100, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])",
+ R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"},
+ /*exp_emptykey*/
+ {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"},
+ /*exp*/
+ {R"([[0, 1, 1, 11, 101], [0, 2, 21, 31, 1001], [500, 1, 2, 12, 101], [1000, 2, 22, 32, 1002], [1500, 1, 3, 13, 103], [1500, 2, 23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, null, null]])"},
+ -1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic4Negative, {
+ BasicTest basic_test = GetBasicTest4Negative();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
BasicTest GetBasicTest5() {
// Multi key, multiple batches, misaligned batches, smaller tolerance
return BasicTest(/*l*/
@@ -763,6 +839,36 @@ TRACED_TEST_P(AsofJoinBasicTest, TestBasic5, {
runner(basic_test);
})
+BasicTest GetBasicTest5Negative() {
+ // Multi key, multiple batches, misaligned batches, smaller tolerance
+ return BasicTest(/*l*/
+ {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])",
+ R"([[2000, 1, 4], [2000, 2, 24]])"},
+ /*r0*/
+ {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])",
+ R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])",
+ R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"},
+ /*exp_emptykey*/
+ {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"},
+ /*exp*/
+ {R"([[0, 1, 1, 11, 101], [0, 2, 21, 31, 1001], [500, 1, 2, 12, 101], [1000, 2, 22, 32, 1002], [1500, 1, 3, 13, 103], [1500, 2, 23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 33, null]])"},
+ -500);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestBasic5Negative, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic5_" + std::get<1>(GetParam()));
+ BasicTest basic_test = GetBasicTest5Negative();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
BasicTest GetBasicTest6() {
// Multi key, multiple batches, misaligned batches, zero tolerance
return BasicTest(/*l*/
@@ -818,6 +924,31 @@ TRACED_TEST_P(AsofJoinBasicTest, TestEmpty1, {
runner(basic_test);
})
+BasicTest GetEmptyTest1Negative() {
+ // Empty left batch
+ return BasicTest(/*l*/
+ {R"([])", R"([[2000, 1, 4], [2000, 2, 24]])"},
+ /*r0*/
+ {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])",
+ R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"},
+ /*exp_emptykey*/
+ {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"},
+ /*exp*/
+ {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 33, null]])"}, -1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestEmpty1Negative, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty1Negative_" + std::get<1>(GetParam()));
+ BasicTest basic_test = GetEmptyTest1Negative();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
BasicTest GetEmptyTest2() {
// Empty left input
return BasicTest(/*l*/
@@ -843,6 +974,31 @@ TRACED_TEST_P(AsofJoinBasicTest, TestEmpty2, {
runner(basic_test);
})
+BasicTest GetEmptyTest2Negative() {
+ // Empty left input
+ return BasicTest(/*l*/
+ {R"([])"},
+ /*r0*/
+ {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])",
+ R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([])"},
+ /*exp_emptykey*/
+ {R"([])"},
+ /*exp*/
+ {R"([])"}, -1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestEmpty2Negative, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty2Negative_" + std::get<1>(GetParam()));
+ BasicTest basic_test = GetEmptyTest2Negative();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
BasicTest GetEmptyTest3() {
// Empty right batch
return BasicTest(/*l*/
@@ -872,6 +1028,35 @@ TRACED_TEST_P(AsofJoinBasicTest, TestEmpty3, {
runner(basic_test);
})
+BasicTest GetEmptyTest3Negative() {
+ // Empty right batch
+ return BasicTest(/*l*/
+ {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])",
+ R"([[2000, 1, 4], [2000, 2, 24]])"},
+ /*r0*/
+ {R"([])", R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, 32, 101], [1000, 0, 22, 32, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])",
+ R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"},
+ /*exp_emptykey*/
+ {R"([[0, 1, 1, null, 1001], [0, 2, 21, null, 1001], [500, 1, 2, 32, 101], [1000, 2, 22, 32, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"},
+ /*exp*/
+ {R"([[0, 1, 1, null, 101], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, 32, 1002], [1500, 1, 3, 13, 103], [1500, 2, 23, 32, 1002]])",
+ R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 33, null]])"},
+ -1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestEmpty3Negative, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty3Negative_" + std::get<1>(GetParam()));
+ BasicTest basic_test = GetEmptyTest3Negative();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
BasicTest GetEmptyTest4() {
// Empty right input
return BasicTest(/*l*/
@@ -901,6 +1086,35 @@ TRACED_TEST_P(AsofJoinBasicTest, TestEmpty4, {
runner(basic_test);
})
+BasicTest GetEmptyTest4Negative() {
+ // Empty right input
+ return BasicTest(/*l*/
+ {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])",
+ R"([[2000, 1, 4], [2000, 2, 24]])"},
+ /*r0*/
+ {R"([])"},
+ /*r1*/
+ {R"([[0, 2, 1001], [500, 1, 101]])",
+ R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"},
+ /*exp_nokey*/
+ {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, null, 1002], [1500, 0, 23, null, 1002]])",
+ R"([[2000, 0, 4, null, 103], [2000, 0, 24, null, 103]])"},
+ /*exp_emptykey*/
+ {R"([[0, 1, 1, null, 1001], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 102], [1500, 1, 3, null, 1002], [1500, 2, 23, null, 1002]])",
+ R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 103]])"},
+ /*exp*/
+ {R"([[0, 1, 1, null, 101], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1002], [1500, 1, 3, null, 103], [1500, 2, 23, null, 1002]])",
+ R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, null]])"},
+ -1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestEmpty4Negative, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty4Negative_" + std::get<1>(GetParam()));
+ BasicTest basic_test = GetEmptyTest4Negative();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
BasicTest GetEmptyTest5() {
// All empty
return BasicTest(/*l*/
@@ -924,6 +1138,29 @@ TRACED_TEST_P(AsofJoinBasicTest, TestEmpty5, {
runner(basic_test);
})
+BasicTest GetEmptyTest5Negative() {
+ // All empty
+ return BasicTest(/*l*/
+ {R"([])"},
+ /*r0*/
+ {R"([])"},
+ /*r1*/
+ {R"([])"},
+ /*exp_nokey*/
+ {R"([])"},
+ /*exp_emptykey*/
+ {R"([])"},
+ /*exp*/
+ {R"([])"}, -1000);
+}
+
+TRACED_TEST_P(AsofJoinBasicTest, TestEmpty5Negative, {
+ ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty5Negative_" + std::get<1>(GetParam()));
+ BasicTest basic_test = GetEmptyTest5Negative();
+ auto runner = std::get<0>(GetParam());
+ runner(basic_test);
+})
+
INSTANTIATE_TEST_SUITE_P(
AsofJoinNodeTest, AsofJoinBasicTest,
testing::Values(AsofJoinBasicParams(BasicTest::DoSingleByKey, "SingleByKey"),
@@ -967,13 +1204,6 @@ TRACED_TEST(AsofJoinTest, TestMissingKeys, {
{field("time", int64()), field("key1", int32()), field("r0_v0", float64())}));
})
-TRACED_TEST(AsofJoinTest, TestUnsupportedTolerance, {
- // Utf8 is unsupported
- DoRunInvalidToleranceTest(
- schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}),
- schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())}));
-})
-
TRACED_TEST(AsofJoinTest, TestMissingOnKey, {
DoRunMissingOnKeyTest(
schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}),