This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 9c109533afd5d14f7ad7f79564e5d59efeaca479
Author: Pindikura Ravindra <[email protected]>
AuthorDate: Wed Sep 26 12:09:10 2018 +0530

    [Gandiva] fix bug with nested if in condition
    
    - in if-else handling, do not re-use bitmap across conditions,
      or between a condition and the rest of the then-else.
    - the field ToString should include the field-name.
---
 cpp/src/gandiva/expr_decomposer.cc      | 36 ++++++++++---
 cpp/src/gandiva/expr_decomposer.h       | 17 ++++--
 cpp/src/gandiva/expr_decomposer_test.cc | 94 ++++++++++++++++++++++++++++++++-
 cpp/src/gandiva/tests/if_expr_test.cc   | 64 ++++++++++++++++++++++
 cpp/src/gandiva/tests/projector_test.cc | 25 +++++++++
 cpp/src/gandiva/tests/to_string_test.cc | 14 ++---
 6 files changed, 231 insertions(+), 19 deletions(-)

diff --git a/cpp/src/gandiva/expr_decomposer.cc 
b/cpp/src/gandiva/expr_decomposer.cc
index a5eede6..c0e994d 100644
--- a/cpp/src/gandiva/expr_decomposer.cc
+++ b/cpp/src/gandiva/expr_decomposer.cc
@@ -116,12 +116,14 @@ Status ExprDecomposer::Visit(const FunctionNode &in_node) 
{
 }
 
 // Decompose an IfNode
-Status ExprDecomposer::Visit(const IfNode& node) {
-  // Add a local bitmap to track the output validity.
+Status ExprDecomposer::Visit(const IfNode &node) {
+  PushConditionEntry(node);
   auto status = node.condition()->Accept(*this);
   GANDIVA_RETURN_NOT_OK(status);
   auto condition_vv = result();
+  PopConditionEntry(node);
 
+  // Add a local bitmap to track the output validity.
   int local_bitmap_idx = PushThenEntry(node);
   status = node.then_node()->Accept(*this);
   GANDIVA_RETURN_NOT_OK(status);
@@ -194,7 +196,8 @@ Status ExprDecomposer::Visit(const LiteralNode& node) {
 int ExprDecomposer::PushThenEntry(const IfNode& node) {
   int local_bitmap_idx;
 
-  if (!if_entries_stack_.empty() && !if_entries_stack_.top()->is_then_) {
+  if (!if_entries_stack_.empty() &&
+      if_entries_stack_.top()->entry_type_ == kStackEntryElse) {
     auto top = if_entries_stack_.top().get();
 
     // inside a nested else statement (i.e if-else-if). use the parent's 
bitmap.
@@ -209,7 +212,7 @@ int ExprDecomposer::PushThenEntry(const IfNode& node) {
 
   // push new entry to the stack.
   std::unique_ptr<IfStackEntry> entry(new IfStackEntry(
-      node, true /*is_then*/, false /*is_terminal_else*/, local_bitmap_idx));
+      node, kStackEntryThen, false /*is_terminal_else*/, local_bitmap_idx));
   if_entries_stack_.push(std::move(entry));
   return local_bitmap_idx;
 }
@@ -218,7 +221,8 @@ void ExprDecomposer::PopThenEntry(const IfNode& node) {
   DCHECK_EQ(if_entries_stack_.empty(), false) << "PopThenEntry: found empty 
stack";
 
   auto top = if_entries_stack_.top().get();
-  DCHECK_EQ(top->is_then_, true) << "PopThenEntry: found else, expected then";
+  DCHECK_EQ(top->entry_type_, kStackEntryThen)
+      << "PopThenEntry: found " << top->entry_type_ << " expected then";
   DCHECK_EQ(&top->if_node_, &node) << "PopThenEntry: found mismatched node";
 
   if_entries_stack_.pop();
@@ -226,7 +230,7 @@ void ExprDecomposer::PopThenEntry(const IfNode& node) {
 
 void ExprDecomposer::PushElseEntry(const IfNode& node, int local_bitmap_idx) {
   std::unique_ptr<IfStackEntry> entry(new IfStackEntry(
-      node, false /*is_then*/, true /*is_terminal_else*/, local_bitmap_idx));
+      node, kStackEntryElse, true /*is_terminal_else*/, local_bitmap_idx));
   if_entries_stack_.push(std::move(entry));
 }
 
@@ -234,12 +238,28 @@ bool ExprDecomposer::PopElseEntry(const IfNode& node) {
   DCHECK_EQ(if_entries_stack_.empty(), false) << "PopElseEntry: found empty 
stack";
 
   auto top = if_entries_stack_.top().get();
-  DCHECK_EQ(top->is_then_, false) << "PopElseEntry: found then, expected else";
-  DCHECK_EQ(&top->if_node_, &node) << "PopThenEntry: found mismatched node";
+  DCHECK_EQ(top->entry_type_, kStackEntryElse)
+      << "PopElseEntry: found " << top->entry_type_ << " expected else";
+  DCHECK_EQ(&top->if_node_, &node) << "PopElseEntry: found mismatched node";
   bool is_terminal_else = top->is_terminal_else_;
 
   if_entries_stack_.pop();
   return is_terminal_else;
 }
 
+void ExprDecomposer::PushConditionEntry(const IfNode &node) {
+  std::unique_ptr<IfStackEntry> entry(new IfStackEntry(node, 
kStackEntryCondition));
+  if_entries_stack_.push(std::move(entry));
+}
+
+void ExprDecomposer::PopConditionEntry(const IfNode &node) {
+  DCHECK_EQ(if_entries_stack_.empty(), false) << "PopConditionEntry: found 
empty stack";
+
+  auto top = if_entries_stack_.top().get();
+  DCHECK_EQ(top->entry_type_, kStackEntryCondition)
+      << "PopConditionEntry: found " << top->entry_type_ << " expected 
condition";
+  DCHECK_EQ(&top->if_node_, &node) << "PopConditionEntry: found mismatched 
node";
+  if_entries_stack_.pop();
+}
+
 }  // namespace gandiva
diff --git a/cpp/src/gandiva/expr_decomposer.h 
b/cpp/src/gandiva/expr_decomposer.h
index 97b242a..55e9d70 100644
--- a/cpp/src/gandiva/expr_decomposer.h
+++ b/cpp/src/gandiva/expr_decomposer.h
@@ -52,6 +52,7 @@ class ExprDecomposer : public NodeVisitor {
   FRIEND_TEST(TestExprDecomposer, TestNested);
   FRIEND_TEST(TestExprDecomposer, TestInternalIf);
   FRIEND_TEST(TestExprDecomposer, TestParallelIf);
+  FRIEND_TEST(TestExprDecomposer, TestIfInCondition);
 
   Status Visit(const FieldNode& node) override;
   Status Visit(const FunctionNode& node) override;
@@ -62,22 +63,30 @@ class ExprDecomposer : public NodeVisitor {
   // Optimize a function node, if possible.
   const FunctionNode TryOptimize(const FunctionNode &node);
 
+  enum StackEntryType { kStackEntryCondition, kStackEntryThen, kStackEntryElse 
};
+
   // stack of if nodes.
   class IfStackEntry {
    public:
-    IfStackEntry(const IfNode& if_node, bool is_then, bool is_terminal_else,
-                 int local_bitmap_idx)
+    IfStackEntry(const IfNode &if_node, StackEntryType entry_type,
+                 bool is_terminal_else = false, int local_bitmap_idx = 0)
         : if_node_(if_node),
-          is_then_(is_then),
+          entry_type_(entry_type),
           is_terminal_else_(is_terminal_else),
           local_bitmap_idx_(local_bitmap_idx) {}
 
     const IfNode& if_node_;
-    bool is_then_;
+    StackEntryType entry_type_;
     bool is_terminal_else_;
     int local_bitmap_idx_;
   };
 
+  // pop 'condition entry' into stack.
+  void PushConditionEntry(const IfNode &node);
+
+  // pop 'condition entry' from stack.
+  void PopConditionEntry(const IfNode &node);
+
   // push 'then entry' to stack. returns either a new local bitmap or the 
parent's
   // bitmap (in case of nested if-else).
   int PushThenEntry(const IfNode& node);
diff --git a/cpp/src/gandiva/expr_decomposer_test.cc 
b/cpp/src/gandiva/expr_decomposer_test.cc
index 2439047..0421dfe 100644
--- a/cpp/src/gandiva/expr_decomposer_test.cc
+++ b/cpp/src/gandiva/expr_decomposer_test.cc
@@ -43,6 +43,9 @@ TEST_F(TestExprDecomposer, TestStackSimple) {
   // else _
   IfNode node_a(nullptr, nullptr, nullptr, int32());
 
+  decomposer.PushConditionEntry(node_a);
+  decomposer.PopConditionEntry(node_a);
+
   int idx_a = decomposer.PushThenEntry(node_a);
   EXPECT_EQ(idx_a, 0);
   decomposer.PopThenEntry(node_a);
@@ -64,6 +67,9 @@ TEST_F(TestExprDecomposer, TestNested) {
   IfNode node_a(nullptr, nullptr, nullptr, int32());
   IfNode node_b(nullptr, nullptr, nullptr, int32());
 
+  decomposer.PushConditionEntry(node_a);
+  decomposer.PopConditionEntry(node_a);
+
   int idx_a = decomposer.PushThenEntry(node_a);
   EXPECT_EQ(idx_a, 0);
   decomposer.PopThenEntry(node_a);
@@ -71,6 +77,9 @@ TEST_F(TestExprDecomposer, TestNested) {
   decomposer.PushElseEntry(node_a, idx_a);
 
   {  // start b
+    decomposer.PushConditionEntry(node_b);
+    decomposer.PopConditionEntry(node_b);
+
     int idx_b = decomposer.PushThenEntry(node_b);
     EXPECT_EQ(idx_b, 0);  // must reuse bitmap.
     decomposer.PopThenEntry(node_b);
@@ -97,10 +106,16 @@ TEST_F(TestExprDecomposer, TestInternalIf) {
   IfNode node_a(nullptr, nullptr, nullptr, int32());
   IfNode node_b(nullptr, nullptr, nullptr, int32());
 
+  decomposer.PushConditionEntry(node_a);
+  decomposer.PopConditionEntry(node_a);
+
   int idx_a = decomposer.PushThenEntry(node_a);
   EXPECT_EQ(idx_a, 0);
 
   {  // start b
+    decomposer.PushConditionEntry(node_b);
+    decomposer.PopConditionEntry(node_b);
+
     int idx_b = decomposer.PushThenEntry(node_b);
     EXPECT_EQ(idx_b, 1);  // must not reuse bitmap.
     decomposer.PopThenEntry(node_b);
@@ -130,6 +145,9 @@ TEST_F(TestExprDecomposer, TestParallelIf) {
   IfNode node_a(nullptr, nullptr, nullptr, int32());
   IfNode node_b(nullptr, nullptr, nullptr, int32());
 
+  decomposer.PushConditionEntry(node_a);
+  decomposer.PopConditionEntry(node_a);
+
   int idx_a = decomposer.PushThenEntry(node_a);
   EXPECT_EQ(idx_a, 0);
 
@@ -140,6 +158,9 @@ TEST_F(TestExprDecomposer, TestParallelIf) {
   EXPECT_EQ(is_terminal_a, true);  // there was no nested if.
 
   // start b
+  decomposer.PushConditionEntry(node_b);
+  decomposer.PopConditionEntry(node_b);
+
   int idx_b = decomposer.PushThenEntry(node_b);
   EXPECT_EQ(idx_b, 1);  // must not reuse bitmap.
   decomposer.PopThenEntry(node_b);
@@ -151,7 +172,78 @@ TEST_F(TestExprDecomposer, TestParallelIf) {
   EXPECT_EQ(decomposer.if_entries_stack_.empty(), true);
 }
 
-int main(int argc, char** argv) {
+TEST_F(TestExprDecomposer, TestIfInCondition) {
+  Annotator annotator;
+  ExprDecomposer decomposer(registry_, annotator);
+
+  // if (if _ else _)   : a
+  //   -
+  // else
+  //   if (if _ else _)  : b
+  //    -
+  //   else
+  //    -
+  IfNode node_a(nullptr, nullptr, nullptr, int32());
+  IfNode node_b(nullptr, nullptr, nullptr, int32());
+  IfNode cond_node_a(nullptr, nullptr, nullptr, int32());
+  IfNode cond_node_b(nullptr, nullptr, nullptr, int32());
+
+  // start a
+  decomposer.PushConditionEntry(node_a);
+  {
+    // start cond_node_a
+    decomposer.PushConditionEntry(cond_node_a);
+    decomposer.PopConditionEntry(cond_node_a);
+
+    int idx_cond_a = decomposer.PushThenEntry(cond_node_a);
+    EXPECT_EQ(idx_cond_a, 0);
+    decomposer.PopThenEntry(cond_node_a);
+
+    decomposer.PushElseEntry(cond_node_a, idx_cond_a);
+    bool is_terminal = decomposer.PopElseEntry(cond_node_a);
+    EXPECT_EQ(is_terminal, true);  // there was no nested if.
+  }
+  decomposer.PopConditionEntry(node_a);
+
+  int idx_a = decomposer.PushThenEntry(node_a);
+  EXPECT_EQ(idx_a, 1);  // no re-use
+  decomposer.PopThenEntry(node_a);
+
+  decomposer.PushElseEntry(node_a, idx_a);
+
+  {  // start b
+    decomposer.PushConditionEntry(node_b);
+    {
+      // start cond_node_b
+      decomposer.PushConditionEntry(cond_node_b);
+      decomposer.PopConditionEntry(cond_node_b);
+
+      int idx_cond_b = decomposer.PushThenEntry(cond_node_b);
+      EXPECT_EQ(idx_cond_b, 2);  // no re-use
+      decomposer.PopThenEntry(cond_node_b);
+
+      decomposer.PushElseEntry(cond_node_b, idx_cond_b);
+      bool is_terminal = decomposer.PopElseEntry(cond_node_b);
+      EXPECT_EQ(is_terminal, true);  // there was no nested if.
+    }
+    decomposer.PopConditionEntry(node_b);
+
+    int idx_b = decomposer.PushThenEntry(node_b);
+    EXPECT_EQ(idx_b, 1);  // must reuse bitmap.
+    decomposer.PopThenEntry(node_b);
+
+    decomposer.PushElseEntry(node_b, idx_b);
+    bool is_terminal = decomposer.PopElseEntry(node_b);
+    EXPECT_EQ(is_terminal, true);
+  }  // end b
+
+  bool is_terminal_a = decomposer.PopElseEntry(node_a);
+  EXPECT_EQ(is_terminal_a, false);  // there was a nested if.
+
+  EXPECT_EQ(decomposer.if_entries_stack_.empty(), true);
+}
+
+int main(int argc, char **argv) {
   ::testing::InitGoogleTest(&argc, argv);
   return RUN_ALL_TESTS();
 }
diff --git a/cpp/src/gandiva/tests/if_expr_test.cc 
b/cpp/src/gandiva/tests/if_expr_test.cc
index 46bcb02..811137f 100644
--- a/cpp/src/gandiva/tests/if_expr_test.cc
+++ b/cpp/src/gandiva/tests/if_expr_test.cc
@@ -256,6 +256,70 @@ TEST_F(TestIfExpr, TestNestedInIf) {
   EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
 }
 
+TEST_F(TestIfExpr, TestNestedInCondition) {
+  // schema for input fields
+  auto fielda = field("a", int32());
+  auto fieldb = field("b", int32());
+  auto schema = arrow::schema({fielda, fieldb});
+
+  // output fields
+  auto field_result = field("res", int32());
+
+  // build expression.
+  // if (if (a > b) then true else if (a < b) false else null)
+  //   1
+  // else if !(if (a > b) then true else if (a < b) false else null)
+  //   2
+  // else
+  //   3
+  auto node_a = TreeExprBuilder::MakeField(fielda);
+  auto node_b = TreeExprBuilder::MakeField(fieldb);
+  auto literal_1 = TreeExprBuilder::MakeLiteral(1);
+  auto literal_2 = TreeExprBuilder::MakeLiteral(2);
+  auto literal_3 = TreeExprBuilder::MakeLiteral(3);
+  auto literal_true = TreeExprBuilder::MakeLiteral(true);
+  auto literal_false = TreeExprBuilder::MakeLiteral(false);
+  auto literal_null = TreeExprBuilder::MakeNull(boolean());
+
+  auto a_gt_b =
+      TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, 
boolean());
+  auto a_lt_b = TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, 
boolean());
+  auto cond_else =
+      TreeExprBuilder::MakeIf(a_lt_b, literal_false, literal_null, boolean());
+  auto cond_if = TreeExprBuilder::MakeIf(a_gt_b, literal_true, cond_else, 
boolean());
+  auto not_cond_if = TreeExprBuilder::MakeFunction("not", {cond_if}, 
boolean());
+
+  auto outer_else = TreeExprBuilder::MakeIf(not_cond_if, literal_2, literal_3, 
int32());
+  auto outer_if = TreeExprBuilder::MakeIf(cond_if, literal_1, outer_else, 
int32());
+  auto expr = TreeExprBuilder::MakeExpression(outer_if, field_result);
+
+  // Build a projector for the expressions.
+  std::shared_ptr<Projector> projector;
+  Status status = Projector::Make(schema, {expr}, &projector);
+  EXPECT_TRUE(status.ok());
+
+  // Create a row-batch with some sample data
+  int num_records = 6;
+  auto array_a =
+      MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, 
true, true});
+  auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19},
+                                     {true, true, true, false, false, false});
+  // expected output
+  auto exp =
+      MakeArrowArrayInt32({1, 2, 2, 3, 3, 3}, {true, true, true, true, true, 
true});
+
+  // prepare input record batch
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, 
array_b});
+
+  // Evaluate expression
+  arrow::ArrayVector outputs;
+  status = projector->Evaluate(*in_batch, pool_, &outputs);
+  EXPECT_TRUE(status.ok());
+
+  // Validate results
+  EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
+}
+
 TEST_F(TestIfExpr, TestBigNested) {
   // schema for input fields
   auto fielda = field("a", int32());
diff --git a/cpp/src/gandiva/tests/projector_test.cc 
b/cpp/src/gandiva/tests/projector_test.cc
index cda5b75..a7f71ec 100644
--- a/cpp/src/gandiva/tests/projector_test.cc
+++ b/cpp/src/gandiva/tests/projector_test.cc
@@ -77,6 +77,31 @@ TEST_F(TestProjector, TestProjectCache) {
   EXPECT_TRUE(cached_projector.get() != should_be_new_projector1.get());
 }
 
+TEST_F(TestProjector, TestProjectCacheFieldNames) {
+  // schema for input fields
+  auto field0 = field("f0", int32());
+  auto field1 = field("f1", int32());
+  auto field2 = field("f2", int32());
+  auto schema = arrow::schema({field0, field1, field2});
+
+  // output fields
+  auto sum_01 = field("sum_01", int32());
+  auto sum_12 = field("sum_12", int32());
+
+  auto sum_expr_01 = TreeExprBuilder::MakeExpression("add", {field0, field1}, 
sum_01);
+  std::shared_ptr<Projector> projector_01;
+  Status status = Projector::Make(schema, {sum_expr_01}, &projector_01);
+  EXPECT_TRUE(status.ok());
+
+  auto sum_expr_12 = TreeExprBuilder::MakeExpression("add", {field1, field2}, 
sum_12);
+  std::shared_ptr<Projector> projector_12;
+  status = Projector::Make(schema, {sum_expr_12}, &projector_12);
+  EXPECT_TRUE(status.ok());
+
+  // add(f0, f1) != add(f1, f2)
+  EXPECT_TRUE(projector_01.get() != projector_12.get());
+}
+
 TEST_F(TestProjector, TestProjectCacheDouble) {
   auto schema = arrow::schema({});
   auto res = field("result", arrow::float64());
diff --git a/cpp/src/gandiva/tests/to_string_test.cc 
b/cpp/src/gandiva/tests/to_string_test.cc
index 97919d4..55db6e9 100644
--- a/cpp/src/gandiva/tests/to_string_test.cc
+++ b/cpp/src/gandiva/tests/to_string_test.cc
@@ -67,9 +67,10 @@ TEST_F(TestToString, TestAll) {
 
   auto if_node = TreeExprBuilder::MakeIf(cond_node, then_node, else_node, 
int64());
   auto if_expr = TreeExprBuilder::MakeExpression(if_node, f1);
-  CHECK_EXPR_TO_STRING(
-      if_expr,
-      "if (bool lesser_than((double) f0, (const float) 0 raw(0))) { (int64) f1 
} else { (int64) f2 }");
+
+  CHECK_EXPR_TO_STRING(if_expr,
+                       "if (bool lesser_than((double) f0, (const float) 0 
raw(0))) { "
+                       "(int64) f1 } else { (int64) f2 }");
 
   auto f1_gt_100 =
       TreeExprBuilder::MakeFunction("greater_than", {f1_node, literal_node}, 
boolean());
@@ -78,9 +79,10 @@ TEST_F(TestToString, TestAll) {
   auto and_node = TreeExprBuilder::MakeAnd({f1_gt_100, f2_equals_100});
   auto and_expr =
       TreeExprBuilder::MakeExpression(and_node, arrow::field("f0", boolean()));
-  CHECK_EXPR_TO_STRING(
-      and_expr,
-      "bool greater_than((int64) f1, (const uint64) 100) && bool 
equals((int64) f2, (const uint64) 100)");
+
+  CHECK_EXPR_TO_STRING(and_expr,
+                       "bool greater_than((int64) f1, (const uint64) 100) && 
bool "
+                       "equals((int64) f2, (const uint64) 100)");
 }
 
 }  // namespace gandiva

Reply via email to