You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@quickstep.apache.org by zu...@apache.org on 2018/05/04 19:59:28 UTC
incubator-quickstep git commit: Refactored ScalarCaseExpression.
Repository: incubator-quickstep
Updated Branches:
refs/heads/master 77287a788 -> 666102fff
Refactored ScalarCaseExpression.
Project: http://git-wip-us.apache.org/repos/asf/incubator-quickstep/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-quickstep/commit/666102ff
Tree: http://git-wip-us.apache.org/repos/asf/incubator-quickstep/tree/666102ff
Diff: http://git-wip-us.apache.org/repos/asf/incubator-quickstep/diff/666102ff
Branch: refs/heads/master
Commit: 666102fff32d258a5a2c85e33bf8d8ebb5d3a9cf
Parents: 77287a7
Author: Zuyu Zhang <zu...@cs.wisc.edu>
Authored: Wed May 2 16:06:59 2018 -0500
Committer: Zuyu Zhang <zu...@cs.wisc.edu>
Committed: Fri May 4 14:40:28 2018 -0500
----------------------------------------------------------------------
expressions/scalar/ScalarCaseExpression.cpp | 265 ++++++++--------
expressions/scalar/ScalarCaseExpression.hpp | 26 --
.../tests/ScalarCaseExpression_unittest.cpp | 307 +++++++++++++++++++
3 files changed, 450 insertions(+), 148 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/666102ff/expressions/scalar/ScalarCaseExpression.cpp
----------------------------------------------------------------------
diff --git a/expressions/scalar/ScalarCaseExpression.cpp b/expressions/scalar/ScalarCaseExpression.cpp
index 6847425..c2af83b 100644
--- a/expressions/scalar/ScalarCaseExpression.cpp
+++ b/expressions/scalar/ScalarCaseExpression.cpp
@@ -41,6 +41,102 @@
namespace quickstep {
+namespace {
+
+// Merge the values in the NativeColumnVector 'case_result' into '*output' at
+// the positions specified by 'case_matches'. If '*source_sequence' is
+// non-NULL, it indicates which positions actually have tuples in the input,
+// otherwise it is assumed that there are no holes in the input.
+void MultiplexNativeColumnVector(
+ const TupleIdSequence *source_sequence,
+ const TupleIdSequence &case_matches,
+ const NativeColumnVector &case_result,
+ NativeColumnVector *output) {
+ if (source_sequence == nullptr) {
+ if (case_result.typeIsNullable()) {
+ TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
+ for (std::size_t input_pos = 0;
+ input_pos < case_result.size();
+ ++input_pos, ++output_pos_it) {
+ const void *value = case_result.getUntypedValue<true>(input_pos);
+ if (value) {
+ output->positionalWriteUntypedValue(*output_pos_it, value);
+ } else {
+ output->positionalWriteNullValue(*output_pos_it);
+ }
+ }
+ } else {
+ TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
+ for (std::size_t input_pos = 0;
+ input_pos < case_result.size();
+ ++input_pos, ++output_pos_it) {
+ output->positionalWriteUntypedValue(*output_pos_it,
+ case_result.getUntypedValue<false>(input_pos));
+ }
+ }
+ } else {
+ if (case_result.typeIsNullable()) {
+ std::size_t input_pos = 0;
+ TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
+ for (std::size_t output_pos = 0;
+ output_pos < output->size();
+ ++output_pos, ++source_sequence_it) {
+ if (case_matches.get(*source_sequence_it)) {
+ const void *value = case_result.getUntypedValue<true>(input_pos++);
+ if (value) {
+ output->positionalWriteUntypedValue(output_pos, value);
+ } else {
+ output->positionalWriteNullValue(output_pos);
+ }
+ }
+ }
+ } else {
+ std::size_t input_pos = 0;
+ TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
+ for (std::size_t output_pos = 0;
+ output_pos < output->size();
+ ++output_pos, ++source_sequence_it) {
+ if (case_matches.get(*source_sequence_it)) {
+ output->positionalWriteUntypedValue(output_pos,
+ case_result.getUntypedValue<false>(input_pos++));
+ }
+ }
+ }
+ }
+}
+
+// Same as MultiplexNativeColumnVector(), but works on IndirectColumnVectors
+// instead of NativeColumnVectors.
+void MultiplexIndirectColumnVector(
+ const TupleIdSequence *source_sequence,
+ const TupleIdSequence &case_matches,
+ const IndirectColumnVector &case_result,
+ IndirectColumnVector *output) {
+ if (source_sequence == nullptr) {
+ TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
+ for (std::size_t input_pos = 0;
+ input_pos < case_result.size();
+ ++input_pos, ++output_pos_it) {
+ output->positionalWriteTypedValue(*output_pos_it,
+ case_result.getTypedValue(input_pos));
+ }
+ } else {
+ std::size_t input_pos = 0;
+ TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
+ for (std::size_t output_pos = 0;
+ output_pos < output->size();
+ ++output_pos, ++source_sequence_it) {
+ if (case_matches.get(*source_sequence_it)) {
+ output->positionalWriteTypedValue(output_pos,
+ case_result.getTypedValue(input_pos++));
+ }
+ }
+ }
+}
+
+} // namespace
+
+
ScalarCaseExpression::ScalarCaseExpression(
const Type &result_type,
std::vector<std::unique_ptr<Predicate>> &&when_predicates,
@@ -96,17 +192,17 @@ serialization::Scalar ScalarCaseExpression::getProto() const {
serialization::Scalar proto;
proto.set_data_source(serialization::Scalar::CASE_EXPRESSION);
proto.MutableExtension(serialization::ScalarCaseExpression::result_type)
- ->CopyFrom(type_.getProto());
+ ->MergeFrom(type_.getProto());
for (const std::unique_ptr<Predicate> &when_pred : when_predicates_) {
proto.AddExtension(serialization::ScalarCaseExpression::when_predicate)
- ->CopyFrom(when_pred->getProto());
+ ->MergeFrom(when_pred->getProto());
}
for (const std::unique_ptr<Scalar> &result_expr : result_expressions_) {
proto.AddExtension(serialization::ScalarCaseExpression::result_expression)
- ->CopyFrom(result_expr->getProto());
+ ->MergeFrom(result_expr->getProto());
}
proto.MutableExtension(serialization::ScalarCaseExpression::else_result_expression)
- ->CopyFrom(else_result_expression_->getProto());
+ ->MergeFrom(else_result_expression_->getProto());
return proto;
}
@@ -137,16 +233,16 @@ TypedValue ScalarCaseExpression::getValueForSingleTuple(
return static_value_.makeReferenceToThis();
} else if (fixed_result_expression_ != nullptr) {
return fixed_result_expression_->getValueForSingleTuple(accessor, tuple);
- } else {
- for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
- case_idx < when_predicates_.size();
- ++case_idx) {
- if (when_predicates_[case_idx]->matchesForSingleTuple(accessor, tuple)) {
- return result_expressions_[case_idx]->getValueForSingleTuple(accessor, tuple);
- }
+ }
+
+ for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
+ case_idx < when_predicates_.size();
+ ++case_idx) {
+ if (when_predicates_[case_idx]->matchesForSingleTuple(accessor, tuple)) {
+ return result_expressions_[case_idx]->getValueForSingleTuple(accessor, tuple);
}
- return else_result_expression_->getValueForSingleTuple(accessor, tuple);
}
+ return else_result_expression_->getValueForSingleTuple(accessor, tuple);
}
TypedValue ScalarCaseExpression::getValueForJoinedTuples(
@@ -165,33 +261,33 @@ TypedValue ScalarCaseExpression::getValueForJoinedTuples(
right_accessor,
right_relation_id,
right_tuple_id);
- } else {
- for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
- case_idx < when_predicates_.size();
- ++case_idx) {
- if (when_predicates_[case_idx]->matchesForJoinedTuples(left_accessor,
- left_relation_id,
- left_tuple_id,
- right_accessor,
- right_relation_id,
- right_tuple_id)) {
- return result_expressions_[case_idx]->getValueForJoinedTuples(
- left_accessor,
- left_relation_id,
- left_tuple_id,
- right_accessor,
- right_relation_id,
- right_tuple_id);
- }
+ }
+
+ for (std::vector<std::unique_ptr<Predicate>>::size_type case_idx = 0;
+ case_idx < when_predicates_.size();
+ ++case_idx) {
+ if (when_predicates_[case_idx]->matchesForJoinedTuples(left_accessor,
+ left_relation_id,
+ left_tuple_id,
+ right_accessor,
+ right_relation_id,
+ right_tuple_id)) {
+ return result_expressions_[case_idx]->getValueForJoinedTuples(
+ left_accessor,
+ left_relation_id,
+ left_tuple_id,
+ right_accessor,
+ right_relation_id,
+ right_tuple_id);
}
- return else_result_expression_->getValueForJoinedTuples(
- left_accessor,
- left_relation_id,
- left_tuple_id,
- right_accessor,
- right_relation_id,
- right_tuple_id);
}
+ return else_result_expression_->getValueForJoinedTuples(
+ left_accessor,
+ left_relation_id,
+ left_tuple_id,
+ right_accessor,
+ right_relation_id,
+ right_tuple_id);
}
ColumnVectorPtr ScalarCaseExpression::getAllValues(
@@ -280,6 +376,16 @@ ColumnVectorPtr ScalarCaseExpression::getAllValuesForJoin(
ValueAccessor *right_accessor,
const std::vector<std::pair<tuple_id, tuple_id>> &joined_tuple_ids,
ColumnVectorCache *cv_cache) const {
+ if (has_static_value_) {
+ return ColumnVectorPtr(
+ ColumnVector::MakeVectorOfValue(type_, static_value_, joined_tuple_ids.size()));
+ } else if (fixed_result_expression_) {
+ return fixed_result_expression_->getAllValuesForJoin(
+ left_relation_id, left_accessor,
+ right_relation_id, right_accessor,
+ joined_tuple_ids, cv_cache);
+ }
+
// Slice 'joined_tuple_ids' apart by case.
//
// NOTE(chasseur): We use TupleIdSequence to keep track of the positions in
@@ -368,91 +474,6 @@ ColumnVectorPtr ScalarCaseExpression::getAllValuesForJoin(
else_results);
}
-void ScalarCaseExpression::MultiplexNativeColumnVector(
- const TupleIdSequence *source_sequence,
- const TupleIdSequence &case_matches,
- const NativeColumnVector &case_result,
- NativeColumnVector *output) {
- if (source_sequence == nullptr) {
- if (case_result.typeIsNullable()) {
- TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
- for (std::size_t input_pos = 0;
- input_pos < case_result.size();
- ++input_pos, ++output_pos_it) {
- const void *value = case_result.getUntypedValue<true>(input_pos);
- if (value == nullptr) {
- output->positionalWriteNullValue(*output_pos_it);
- } else {
- output->positionalWriteUntypedValue(*output_pos_it, value);
- }
- }
- } else {
- TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
- for (std::size_t input_pos = 0;
- input_pos < case_result.size();
- ++input_pos, ++output_pos_it) {
- output->positionalWriteUntypedValue(*output_pos_it,
- case_result.getUntypedValue<false>(input_pos));
- }
- }
- } else {
- if (case_result.typeIsNullable()) {
- std::size_t input_pos = 0;
- TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
- for (std::size_t output_pos = 0;
- output_pos < output->size();
- ++output_pos, ++source_sequence_it) {
- if (case_matches.get(*source_sequence_it)) {
- const void *value = case_result.getUntypedValue<true>(input_pos++);
- if (value == nullptr) {
- output->positionalWriteNullValue(output_pos);
- } else {
- output->positionalWriteUntypedValue(output_pos, value);
- }
- }
- }
- } else {
- std::size_t input_pos = 0;
- TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
- for (std::size_t output_pos = 0;
- output_pos < output->size();
- ++output_pos, ++source_sequence_it) {
- if (case_matches.get(*source_sequence_it)) {
- output->positionalWriteUntypedValue(output_pos,
- case_result.getUntypedValue<false>(input_pos++));
- }
- }
- }
- }
-}
-
-void ScalarCaseExpression::MultiplexIndirectColumnVector(
- const TupleIdSequence *source_sequence,
- const TupleIdSequence &case_matches,
- const IndirectColumnVector &case_result,
- IndirectColumnVector *output) {
- if (source_sequence == nullptr) {
- TupleIdSequence::const_iterator output_pos_it = case_matches.begin();
- for (std::size_t input_pos = 0;
- input_pos < case_result.size();
- ++input_pos, ++output_pos_it) {
- output->positionalWriteTypedValue(*output_pos_it,
- case_result.getTypedValue(input_pos));
- }
- } else {
- std::size_t input_pos = 0;
- TupleIdSequence::const_iterator source_sequence_it = source_sequence->begin();
- for (std::size_t output_pos = 0;
- output_pos < output->size();
- ++output_pos, ++source_sequence_it) {
- if (case_matches.get(*source_sequence_it)) {
- output->positionalWriteTypedValue(output_pos,
- case_result.getTypedValue(input_pos++));
- }
- }
- }
-}
-
ColumnVectorPtr ScalarCaseExpression::multiplexColumnVectors(
const std::size_t output_size,
const TupleIdSequence *source_sequence,
http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/666102ff/expressions/scalar/ScalarCaseExpression.hpp
----------------------------------------------------------------------
diff --git a/expressions/scalar/ScalarCaseExpression.hpp b/expressions/scalar/ScalarCaseExpression.hpp
index 3d0ed71..22acfa8 100644
--- a/expressions/scalar/ScalarCaseExpression.hpp
+++ b/expressions/scalar/ScalarCaseExpression.hpp
@@ -124,14 +124,6 @@ class ScalarCaseExpression : public Scalar {
}
}
- relation_id getRelationIdForValueAccessor() const override {
- if (fixed_result_expression_ != nullptr) {
- return fixed_result_expression_->getRelationIdForValueAccessor();
- } else {
- return -1;
- }
- }
-
ColumnVectorPtr getAllValues(ValueAccessor *accessor,
const SubBlocksReference *sub_blocks_ref,
ColumnVectorCache *cv_cache) const override;
@@ -154,24 +146,6 @@ class ScalarCaseExpression : public Scalar {
std::vector<std::vector<const Expression*>> *container_child_fields) const override;
private:
- // Merge the values in the NativeColumnVector 'case_result' into '*output' at
- // the positions specified by 'case_matches'. If '*source_sequence' is
- // non-NULL, it indicates which positions actually have tuples in the input,
- // otherwise it is assumed that there are no holes in the input.
- static void MultiplexNativeColumnVector(
- const TupleIdSequence *source_sequence,
- const TupleIdSequence &case_matches,
- const NativeColumnVector &case_result,
- NativeColumnVector *output);
-
- // Same as MultiplexNativeColumnVector(), but works on IndirectColumnVectors
- // instead of NativeColumnVectors.
- static void MultiplexIndirectColumnVector(
- const TupleIdSequence *source_sequence,
- const TupleIdSequence &case_matches,
- const IndirectColumnVector &case_result,
- IndirectColumnVector *output);
-
// Create and return a new ColumnVector by multiplexing the ColumnVectors
// containing results for individual CASE branches at the appropriate
// positions. 'output_size' is the total number of values in the output.
http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/666102ff/expressions/scalar/tests/ScalarCaseExpression_unittest.cpp
----------------------------------------------------------------------
diff --git a/expressions/scalar/tests/ScalarCaseExpression_unittest.cpp b/expressions/scalar/tests/ScalarCaseExpression_unittest.cpp
index 7182642..f385b74 100644
--- a/expressions/scalar/tests/ScalarCaseExpression_unittest.cpp
+++ b/expressions/scalar/tests/ScalarCaseExpression_unittest.cpp
@@ -875,6 +875,313 @@ TEST_F(ScalarCaseExpressionTest,
}
}
+// Test CASE evaluation over joins, which that always goes to the same branch
+// on a constant.
+TEST_F(ScalarCaseExpressionTest, JoinStaticBranchConstantTest) {
+ // Simulate a join with another relation.
+ CatalogRelation other_relation(nullptr, "other", 1);
+ other_relation.addAttribute(new CatalogAttribute(&other_relation,
+ "other_double",
+ TypeFactory::GetType(kDouble, false)));
+ other_relation.addAttribute(new CatalogAttribute(&other_relation,
+ "other_int",
+ TypeFactory::GetType(kInt, false)));
+
+ static const double kOtherDoubleValues[] = {-250.0, -750.0};
+ std::unique_ptr<NativeColumnVector> other_double_column(
+ new NativeColumnVector(TypeFactory::GetType(kDouble, false), 2));
+ other_double_column->appendUntypedValue(kOtherDoubleValues);
+ other_double_column->appendUntypedValue(kOtherDoubleValues + 1);
+
+ static const int kOtherIntValues[] = {10, -1};
+ std::unique_ptr<NativeColumnVector> other_int_column(
+ new NativeColumnVector(TypeFactory::GetType(kInt, false), 2));
+ other_int_column->appendUntypedValue(kOtherIntValues);
+ other_int_column->appendUntypedValue(kOtherIntValues + 1);
+
+ ColumnVectorsValueAccessor other_accessor;
+ other_accessor.addColumn(other_double_column.release());
+ other_accessor.addColumn(other_int_column.release());
+
+ const Type &int_type = TypeFactory::GetType(kInt);
+
+ // Setup expression.
+ std::vector<std::unique_ptr<Predicate>> when_predicates;
+ std::vector<std::unique_ptr<Scalar>> result_expressions;
+
+ // WHEN 1 > 2 THEN int_attr + other_int
+ when_predicates.emplace_back(new ComparisonPredicate(
+ ComparisonFactory::GetComparison(ComparisonID::kGreater),
+ new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+ new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+ result_expressions.emplace_back(new ScalarBinaryExpression(
+ BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd),
+ new ScalarAttribute(*sample_relation_->getAttributeById(0)),
+ new ScalarAttribute(*other_relation.getAttributeById(1))));
+
+ const int kConstant = 72;
+ // WHEN 1 < 2 THEN kConstant
+ when_predicates.emplace_back(new ComparisonPredicate(
+ ComparisonFactory::GetComparison(ComparisonID::kLess),
+ new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+ new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+ result_expressions.emplace_back(
+ new ScalarLiteral(TypedValue(kConstant), int_type));
+
+ // WHEN double_attr = other_double THEN 0
+ when_predicates.emplace_back(new ComparisonPredicate(
+ ComparisonFactory::GetComparison(ComparisonID::kEqual),
+ new ScalarAttribute(*sample_relation_->getAttributeById(1)),
+ new ScalarAttribute(*other_relation.getAttributeById(0))));
+ result_expressions.emplace_back(new ScalarLiteral(TypedValue(0), TypeFactory::GetType(kInt)));
+
+ const Type &int_nullable_type = TypeFactory::GetType(kInt, true);
+
+ // ELSE NULL
+ ScalarCaseExpression case_expr(
+ int_nullable_type,
+ std::move(when_predicates),
+ std::move(result_expressions),
+ new ScalarLiteral(TypedValue(kInt), int_nullable_type));
+
+ // Create a list of joined tuple-id pairs (just the cross-product of tuples).
+ std::vector<std::pair<tuple_id, tuple_id>> joined_tuple_ids;
+ for (std::size_t tuple_num = 0; tuple_num < kNumSampleTuples; ++tuple_num) {
+ joined_tuple_ids.emplace_back(tuple_num, 0);
+ joined_tuple_ids.emplace_back(tuple_num, 1);
+ }
+
+ ColumnVectorPtr result_cv(case_expr.getAllValuesForJoin(
+ 0,
+ &sample_data_value_accessor_,
+ 1,
+ &other_accessor,
+ joined_tuple_ids,
+ nullptr /* cv_cache */));
+ ASSERT_TRUE(result_cv->isNative());
+ const NativeColumnVector &native_result_cv
+ = static_cast<const NativeColumnVector&>(*result_cv);
+ EXPECT_EQ(kNumSampleTuples * 2, native_result_cv.size());
+
+ for (std::size_t result_num = 0;
+ result_num < native_result_cv.size();
+ ++result_num) {
+ EXPECT_EQ(kConstant,
+ *static_cast<const int*>(native_result_cv.getUntypedValue(result_num)));
+ }
+}
+
+// Test CASE evaluation over joins, which that always goes to the same branch
+// of ScalarAttribute.
+TEST_F(ScalarCaseExpressionTest, JoinStaticBranchOnScalarAttributeTest) {
+ // Simulate a join with another relation.
+ CatalogRelation other_relation(nullptr, "other", 1);
+ other_relation.addAttribute(new CatalogAttribute(&other_relation,
+ "other_double",
+ TypeFactory::GetType(kDouble, false)));
+ other_relation.addAttribute(new CatalogAttribute(&other_relation,
+ "other_int",
+ TypeFactory::GetType(kInt, false)));
+
+ static const double kOtherDoubleValues[] = {-250.0, -750.0};
+ std::unique_ptr<NativeColumnVector> other_double_column(
+ new NativeColumnVector(TypeFactory::GetType(kDouble, false), 2));
+ other_double_column->appendUntypedValue(kOtherDoubleValues);
+ other_double_column->appendUntypedValue(kOtherDoubleValues + 1);
+
+ static const int kOtherIntValues[] = {10, -1};
+ std::unique_ptr<NativeColumnVector> other_int_column(
+ new NativeColumnVector(TypeFactory::GetType(kInt, false), 2));
+ other_int_column->appendUntypedValue(kOtherIntValues);
+ other_int_column->appendUntypedValue(kOtherIntValues + 1);
+
+ ColumnVectorsValueAccessor other_accessor;
+ other_accessor.addColumn(other_double_column.release());
+ other_accessor.addColumn(other_int_column.release());
+
+ const Type &int_type = TypeFactory::GetType(kInt);
+
+ // Setup expression.
+ std::vector<std::unique_ptr<Predicate>> when_predicates;
+ std::vector<std::unique_ptr<Scalar>> result_expressions;
+
+ // WHEN 1 > 2 THEN int_attr + other_int
+ when_predicates.emplace_back(new ComparisonPredicate(
+ ComparisonFactory::GetComparison(ComparisonID::kGreater),
+ new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+ new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+ result_expressions.emplace_back(new ScalarBinaryExpression(
+ BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd),
+ new ScalarAttribute(*sample_relation_->getAttributeById(0)),
+ new ScalarAttribute(*other_relation.getAttributeById(1))));
+
+ // WHEN 1 < 2 THEN int_attr
+ when_predicates.emplace_back(new ComparisonPredicate(
+ ComparisonFactory::GetComparison(ComparisonID::kLess),
+ new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+ new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+ result_expressions.emplace_back(
+ new ScalarAttribute(*sample_relation_->getAttributeById(0)));
+
+ // WHEN double_attr = other_double THEN 0
+ when_predicates.emplace_back(new ComparisonPredicate(
+ ComparisonFactory::GetComparison(ComparisonID::kEqual),
+ new ScalarAttribute(*sample_relation_->getAttributeById(1)),
+ new ScalarAttribute(*other_relation.getAttributeById(0))));
+ result_expressions.emplace_back(new ScalarLiteral(TypedValue(0), TypeFactory::GetType(kInt)));
+
+ const Type &int_nullable_type = TypeFactory::GetType(kInt, true);
+
+ // ELSE NULL
+ ScalarCaseExpression case_expr(
+ int_nullable_type,
+ std::move(when_predicates),
+ std::move(result_expressions),
+ new ScalarLiteral(TypedValue(kInt), int_nullable_type));
+
+ // Create a list of joined tuple-id pairs (just the cross-product of tuples).
+ std::vector<std::pair<tuple_id, tuple_id>> joined_tuple_ids;
+ for (std::size_t tuple_num = 0; tuple_num < kNumSampleTuples; ++tuple_num) {
+ joined_tuple_ids.emplace_back(tuple_num, 0);
+ joined_tuple_ids.emplace_back(tuple_num, 1);
+ }
+
+ ColumnVectorPtr result_cv(case_expr.getAllValuesForJoin(
+ 0,
+ &sample_data_value_accessor_,
+ 1,
+ &other_accessor,
+ joined_tuple_ids,
+ nullptr /* cv_cache */));
+ ASSERT_TRUE(result_cv->isNative());
+ const NativeColumnVector &native_result_cv
+ = static_cast<const NativeColumnVector&>(*result_cv);
+ EXPECT_EQ(kNumSampleTuples * 2, native_result_cv.size());
+
+ for (std::size_t result_num = 0;
+ result_num < native_result_cv.size();
+ ++result_num) {
+ // For convenience, calculate expected tuple values here.
+ const bool sample_int_null = ((result_num >> 1) % 10 == 0);
+ const int sample_int = result_num >> 1;
+
+ if (sample_int_null) {
+ EXPECT_EQ(nullptr, native_result_cv.getUntypedValue(result_num));
+ } else {
+ ASSERT_NE(nullptr, native_result_cv.getUntypedValue(result_num));
+ EXPECT_EQ(sample_int,
+ *static_cast<const int*>(native_result_cv.getUntypedValue(result_num)));
+ }
+ }
+}
+
+// Test CASE evaluation over joins, which that always goes to the same branch
+// of ScalarBinaryExpression.
+TEST_F(ScalarCaseExpressionTest, JoinStaticBranchTest) {
+ // Simulate a join with another relation.
+ CatalogRelation other_relation(nullptr, "other", 1);
+ other_relation.addAttribute(new CatalogAttribute(&other_relation,
+ "other_double",
+ TypeFactory::GetType(kDouble, false)));
+ other_relation.addAttribute(new CatalogAttribute(&other_relation,
+ "other_int",
+ TypeFactory::GetType(kInt, false)));
+
+ static const double kOtherDoubleValues[] = {-250.0, -750.0};
+ std::unique_ptr<NativeColumnVector> other_double_column(
+ new NativeColumnVector(TypeFactory::GetType(kDouble, false), 2));
+ other_double_column->appendUntypedValue(kOtherDoubleValues);
+ other_double_column->appendUntypedValue(kOtherDoubleValues + 1);
+
+ static const int kOtherIntValues[] = {10, -1};
+ std::unique_ptr<NativeColumnVector> other_int_column(
+ new NativeColumnVector(TypeFactory::GetType(kInt, false), 2));
+ other_int_column->appendUntypedValue(kOtherIntValues);
+ other_int_column->appendUntypedValue(kOtherIntValues + 1);
+
+ ColumnVectorsValueAccessor other_accessor;
+ other_accessor.addColumn(other_double_column.release());
+ other_accessor.addColumn(other_int_column.release());
+
+ const Type &int_type = TypeFactory::GetType(kInt);
+
+ // Setup expression.
+ std::vector<std::unique_ptr<Predicate>> when_predicates;
+ std::vector<std::unique_ptr<Scalar>> result_expressions;
+
+ // WHEN 1 > 2 THEN int_attr + other_int
+ when_predicates.emplace_back(new ComparisonPredicate(
+ ComparisonFactory::GetComparison(ComparisonID::kGreater),
+ new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+ new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+ result_expressions.emplace_back(new ScalarBinaryExpression(
+ BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd),
+ new ScalarAttribute(*sample_relation_->getAttributeById(0)),
+ new ScalarAttribute(*other_relation.getAttributeById(1))));
+
+ // WHEN 1 < 2 THEN int_attr * other_int
+ when_predicates.emplace_back(new ComparisonPredicate(
+ ComparisonFactory::GetComparison(ComparisonID::kLess),
+ new ScalarLiteral(TypedValue(static_cast<int>(1)), int_type),
+ new ScalarLiteral(TypedValue(static_cast<int>(2)), int_type)));
+ result_expressions.emplace_back(new ScalarBinaryExpression(
+ BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kMultiply),
+ new ScalarAttribute(*sample_relation_->getAttributeById(0)),
+ new ScalarAttribute(*other_relation.getAttributeById(1))));
+
+ // WHEN double_attr = other_double THEN 0
+ when_predicates.emplace_back(new ComparisonPredicate(
+ ComparisonFactory::GetComparison(ComparisonID::kEqual),
+ new ScalarAttribute(*sample_relation_->getAttributeById(1)),
+ new ScalarAttribute(*other_relation.getAttributeById(0))));
+ result_expressions.emplace_back(new ScalarLiteral(TypedValue(0), TypeFactory::GetType(kInt)));
+
+ const Type &int_nullable_type = TypeFactory::GetType(kInt, true);
+
+ // ELSE NULL
+ ScalarCaseExpression case_expr(
+ int_nullable_type,
+ std::move(when_predicates),
+ std::move(result_expressions),
+ new ScalarLiteral(TypedValue(kInt), int_nullable_type));
+
+ // Create a list of joined tuple-id pairs (just the cross-product of tuples).
+ std::vector<std::pair<tuple_id, tuple_id>> joined_tuple_ids;
+ for (std::size_t tuple_num = 0; tuple_num < kNumSampleTuples; ++tuple_num) {
+ joined_tuple_ids.emplace_back(tuple_num, 0);
+ joined_tuple_ids.emplace_back(tuple_num, 1);
+ }
+
+ ColumnVectorPtr result_cv(case_expr.getAllValuesForJoin(
+ 0,
+ &sample_data_value_accessor_,
+ 1,
+ &other_accessor,
+ joined_tuple_ids,
+ nullptr /* cv_cache */));
+ ASSERT_TRUE(result_cv->isNative());
+ const NativeColumnVector &native_result_cv
+ = static_cast<const NativeColumnVector&>(*result_cv);
+ EXPECT_EQ(kNumSampleTuples * 2, native_result_cv.size());
+
+ for (std::size_t result_num = 0;
+ result_num < native_result_cv.size();
+ ++result_num) {
+ // For convenience, calculate expected tuple values here.
+ const bool sample_int_null = ((result_num >> 1) % 10 == 0);
+ const int sample_int = result_num >> 1;
+ const int other_int = kOtherIntValues[result_num & 0x1];
+
+ if (sample_int_null) {
+ EXPECT_EQ(nullptr, native_result_cv.getUntypedValue(result_num));
+ } else {
+ ASSERT_NE(nullptr, native_result_cv.getUntypedValue(result_num));
+ EXPECT_EQ(sample_int * other_int,
+ *static_cast<const int*>(native_result_cv.getUntypedValue(result_num)));
+ }
+ }
+}
+
// Test CASE evaluation over joins, with both WHEN predicates and THEN
// expressions referencing attributes in both relations.
TEST_F(ScalarCaseExpressionTest, JoinTest) {