This is an automated email from the ASF dual-hosted git repository.

praveenbingo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 8f5bda4  ARROW-11986: [C++][Gandiva] Implement IN expressions for 
doubles and floats
8f5bda4 is described below

commit 8f5bda4f785b0a73e5ef8d6786bea9867dee03d8
Author: frank400 <[email protected]>
AuthorDate: Thu May 20 10:44:17 2021 +0530

    ARROW-11986: [C++][Gandiva] Implement IN expressions for doubles and floats
    
    Add functions to process IN expressions for Arrows fields with double and 
float types.
    
    Closes #9724 from jvictorhuguenin/feature/add-float-double-decimal-in-expr 
and squashes the following commits:
    
    05c283d62 <frank400> Fix Expression validation error message
    182b340a8 <frank400> Fix checkstyle
    9088b09c4 <frank400> Fix checkstyle
    69745f61a <frank400> Fix checkstyle
    27a8adf84 <frank400> Add comments to java tests to make it clearer
    24d32cfe2 <frank400> Fix InNode constructor parameters
    ece7702f6 <frank400> Fix InNode constructor parameters
    50cfa1426 <frank400> Fix rebase errors
    1730fd50e <frank400> Fix problems with backward compatibility
    4be099baa <frank400> Fix lint problem after declaring float values 
explicitly
    6f2686791 <frank400> Fix test typing for float numbers
    7b41e9767 <frank400> Add test cases for -0.0, +inf and -inf and NaN
    d3af44db2 <frank400> Fix Lint problem within TestInFloat
    9adbd5692 <frank400> Fix array float tiping for build propose
    e744d5d3e <frank400> Add JNI functions and tests
    cb887ab69 <frank400> Fix Lint problems
    af2d70998 <frank400> Fix Lint problems
    a2b904b57 <frank400> Change the float_t and double_t to float and double
    570006d44 <frank400> Fix CI problem for truncating doubles to floats
    1a206c93e <frank400> Remove unnecessary call
    8cc34739e <frank400> Fix build problems with mingw
    6be7a60ae <frank400> Fix lint problems
    64905ddc4 <frank400> Fix CI problems
    efa4d1e2f <frank400> test the implemented in expressions
    94adfb8d3 <frank400> implements in expressions for floats and doubles
    94f111708 <frank400> fix wrong typed double_t and float_t at stub functions
    54baa9cc6 <frank400> Fix problems with backward compatibility
    36ed0b708 <frank400> Fix lint problem after declaring float values 
explicitly
    6b47c75d2 <frank400> Fix test typing for float numbers
    9f57c9549 <frank400> Add test cases for -0.0, +inf and -inf and NaN
    c7901663f <frank400> Fix Lint problem within TestInFloat
    074013cc3 <frank400> Fix jni register for double expressions
    9c1cea8df <frank400> Fix array float tiping for build propose
    2b464cae9 <frank400> Add JNI functions and tests
    7e226726c <frank400> Fix Lint problems
    195a129ba <frank400> Fix Lint problems
    b610afec6 <frank400> Change the float_t and double_t to float and double
    f8d7b6e8c <frank400> Fix CI problem for truncating doubles to floats
    8b448cfc3 <frank400> Remove unnecessary call
    89d822548 <frank400> Fix build problems with mingw
    a25558e5f <frank400> Fix lint problems
    f51ac3d01 <frank400> Fix CI problems
    f28ed9b56 <frank400> uncomment the implemented expressions
    3312ace71 <frank400> test the implemented in expressions
    8fbe192ed <frank400> implements in expressions for floats and doubles
    
    Authored-by: frank400 <[email protected]>
    Signed-off-by: Praveen <[email protected]>
---
 cpp/src/gandiva/dex.h                              | 18 +++++
 cpp/src/gandiva/dex_visitor.h                      |  5 ++
 cpp/src/gandiva/expr_decomposer.cc                 |  2 +
 cpp/src/gandiva/expr_decomposer.h                  |  3 +
 cpp/src/gandiva/expr_validator.cc                  | 17 +++--
 cpp/src/gandiva/expr_validator.h                   |  2 +
 cpp/src/gandiva/gdv_function_stubs.cc              | 31 ++++++++
 cpp/src/gandiva/jni/jni_common.cc                  | 16 +++++
 cpp/src/gandiva/llvm_generator.cc                  |  7 ++
 cpp/src/gandiva/llvm_generator.h                   |  2 +
 cpp/src/gandiva/node_visitor.h                     |  3 +
 cpp/src/gandiva/proto/Types.proto                  | 10 +++
 cpp/src/gandiva/tests/in_expr_test.cc              | 82 ++++++++++++++++++++++
 cpp/src/gandiva/tree_expr_builder.cc               |  2 +
 cpp/src/gandiva/tree_expr_builder.h                |  9 +++
 .../apache/arrow/gandiva/expression/InNode.java    | 39 ++++++++--
 .../arrow/gandiva/expression/TreeBuilder.java      | 10 +++
 .../arrow/gandiva/evaluator/ProjectorTest.java     | 57 ++++++++++++++-
 18 files changed, 303 insertions(+), 12 deletions(-)

diff --git a/cpp/src/gandiva/dex.h b/cpp/src/gandiva/dex.h
index 3920f82..d1115c0 100644
--- a/cpp/src/gandiva/dex.h
+++ b/cpp/src/gandiva/dex.h
@@ -354,6 +354,24 @@ class InExprDex<int64_t> : public InExprDexBase<int64_t> {
 };
 
 template <>
+class InExprDex<float> : public InExprDexBase<float> {
+ public:
+  InExprDex(const ValueValidityPairVector& args, const 
std::unordered_set<float>& values)
+      : InExprDexBase(args, values) {
+    runtime_function_ = "gdv_fn_in_expr_lookup_float";
+  }
+};
+
+template <>
+class InExprDex<double> : public InExprDexBase<double> {
+ public:
+  InExprDex(const ValueValidityPairVector& args, const 
std::unordered_set<double>& values)
+      : InExprDexBase(args, values) {
+    runtime_function_ = "gdv_fn_in_expr_lookup_double";
+  }
+};
+
+template <>
 class InExprDex<gandiva::DecimalScalar128>
     : public InExprDexBase<gandiva::DecimalScalar128> {
  public:
diff --git a/cpp/src/gandiva/dex_visitor.h b/cpp/src/gandiva/dex_visitor.h
index ba5de97..5d160bb 100644
--- a/cpp/src/gandiva/dex_visitor.h
+++ b/cpp/src/gandiva/dex_visitor.h
@@ -17,6 +17,7 @@
 
 #pragma once
 
+#include <cmath>
 #include <string>
 
 #include "arrow/util/logging.h"
@@ -61,6 +62,8 @@ class GANDIVA_EXPORT DexVisitor {
   virtual void Visit(const BooleanOrDex& dex) = 0;
   virtual void Visit(const InExprDexBase<int32_t>& dex) = 0;
   virtual void Visit(const InExprDexBase<int64_t>& dex) = 0;
+  virtual void Visit(const InExprDexBase<float>& dex) = 0;
+  virtual void Visit(const InExprDexBase<double>& dex) = 0;
   virtual void Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) = 0;
   virtual void Visit(const InExprDexBase<std::string>& dex) = 0;
 };
@@ -85,6 +88,8 @@ class GANDIVA_EXPORT DexDefaultVisitor : public DexVisitor {
   VISIT_DCHECK(BooleanOrDex)
   VISIT_DCHECK(InExprDexBase<int32_t>)
   VISIT_DCHECK(InExprDexBase<int64_t>)
+  VISIT_DCHECK(InExprDexBase<float>)
+  VISIT_DCHECK(InExprDexBase<double>)
   VISIT_DCHECK(InExprDexBase<gandiva::DecimalScalar128>)
   VISIT_DCHECK(InExprDexBase<std::string>)
 };
diff --git a/cpp/src/gandiva/expr_decomposer.cc 
b/cpp/src/gandiva/expr_decomposer.cc
index 07252b4..1c09d28 100644
--- a/cpp/src/gandiva/expr_decomposer.cc
+++ b/cpp/src/gandiva/expr_decomposer.cc
@@ -209,6 +209,8 @@ Status ExprDecomposer::Visit(const 
InExpressionNode<gandiva::DecimalScalar128>&
 
 MAKE_VISIT_IN(int32_t);
 MAKE_VISIT_IN(int64_t);
+MAKE_VISIT_IN(float);
+MAKE_VISIT_IN(double);
 MAKE_VISIT_IN(std::string);
 
 Status ExprDecomposer::Visit(const LiteralNode& node) {
diff --git a/cpp/src/gandiva/expr_decomposer.h 
b/cpp/src/gandiva/expr_decomposer.h
index 3e8e67d..f68b8a8 100644
--- a/cpp/src/gandiva/expr_decomposer.h
+++ b/cpp/src/gandiva/expr_decomposer.h
@@ -17,6 +17,7 @@
 
 #pragma once
 
+#include <cmath>
 #include <memory>
 #include <stack>
 #include <string>
@@ -66,6 +67,8 @@ class GANDIVA_EXPORT ExprDecomposer : public NodeVisitor {
   Status Visit(const BooleanNode& node) override;
   Status Visit(const InExpressionNode<int32_t>& node) override;
   Status Visit(const InExpressionNode<int64_t>& node) override;
+  Status Visit(const InExpressionNode<float>& node) override;
+  Status Visit(const InExpressionNode<double>& node) override;
   Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) 
override;
   Status Visit(const InExpressionNode<std::string>& node) override;
 
diff --git a/cpp/src/gandiva/expr_validator.cc 
b/cpp/src/gandiva/expr_validator.cc
index fd46c28..c3c784c 100644
--- a/cpp/src/gandiva/expr_validator.cc
+++ b/cpp/src/gandiva/expr_validator.cc
@@ -156,6 +156,14 @@ Status ExprValidator::Visit(const 
InExpressionNode<int64_t>& node) {
   return ValidateInExpression(node.values().size(), 
node.eval_expr()->return_type(),
                               arrow::int64());
 }
+Status ExprValidator::Visit(const InExpressionNode<float>& node) {
+  return ValidateInExpression(node.values().size(), 
node.eval_expr()->return_type(),
+                              arrow::float32());
+}
+Status ExprValidator::Visit(const InExpressionNode<double>& node) {
+  return ValidateInExpression(node.values().size(), 
node.eval_expr()->return_type(),
+                              arrow::float64());
+}
 
 Status ExprValidator::Visit(const InExpressionNode<gandiva::DecimalScalar128>& 
node) {
   return ValidateInExpression(node.values().size(), 
node.eval_expr()->return_type(),
@@ -173,10 +181,11 @@ Status ExprValidator::ValidateInExpression(size_t 
number_of_values,
   ARROW_RETURN_IF(number_of_values == 0,
                   Status::ExpressionValidationError(
                       "IN Expression needs a non-empty constant list to 
match."));
-  ARROW_RETURN_IF(!in_expr_return_type->Equals(type_of_values),
-                  Status::ExpressionValidationError(
-                      "Evaluation expression for IN clause returns ", 
in_expr_return_type,
-                      " values are of type", type_of_values));
+  ARROW_RETURN_IF(
+      !in_expr_return_type->Equals(type_of_values),
+      Status::ExpressionValidationError(
+          "Evaluation expression for IN clause returns ", 
in_expr_return_type->ToString(),
+          " values are of type", type_of_values->ToString()));
 
   return Status::OK();
 }
diff --git a/cpp/src/gandiva/expr_validator.h b/cpp/src/gandiva/expr_validator.h
index e25afe5..daaf508 100644
--- a/cpp/src/gandiva/expr_validator.h
+++ b/cpp/src/gandiva/expr_validator.h
@@ -60,6 +60,8 @@ class ExprValidator : public NodeVisitor {
   Status Visit(const BooleanNode& node) override;
   Status Visit(const InExpressionNode<int32_t>& node) override;
   Status Visit(const InExpressionNode<int64_t>& node) override;
+  Status Visit(const InExpressionNode<float>& node) override;
+  Status Visit(const InExpressionNode<double>& node) override;
   Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) 
override;
   Status Visit(const InExpressionNode<std::string>& node) override;
   Status ValidateInExpression(size_t number_of_values, DataTypePtr 
in_expr_return_type,
diff --git a/cpp/src/gandiva/gdv_function_stubs.cc 
b/cpp/src/gandiva/gdv_function_stubs.cc
index 832eebc..acf3f56 100644
--- a/cpp/src/gandiva/gdv_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_function_stubs.cc
@@ -102,6 +102,22 @@ bool gdv_fn_in_expr_lookup_decimal(int64_t ptr, int64_t 
value_high, int64_t valu
   return holder->HasValue(value);
 }
 
+bool gdv_fn_in_expr_lookup_float(int64_t ptr, float value, bool in_validity) {
+  if (!in_validity) {
+    return false;
+  }
+  gandiva::InHolder<float>* holder = 
reinterpret_cast<gandiva::InHolder<float>*>(ptr);
+  return holder->HasValue(value);
+}
+
+bool gdv_fn_in_expr_lookup_double(int64_t ptr, double value, bool in_validity) 
{
+  if (!in_validity) {
+    return false;
+  }
+  gandiva::InHolder<double>* holder = 
reinterpret_cast<gandiva::InHolder<double>*>(ptr);
+  return holder->HasValue(value);
+}
+
 bool gdv_fn_in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len,
                                 bool in_validity) {
   if (!in_validity) {
@@ -504,7 +520,22 @@ void ExportedStubFunctions::AddMappings(Engine* engine) 
const {
   engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_utf8",
                                   types->i1_type() /*return_type*/, args,
                                   
reinterpret_cast<void*>(gdv_fn_in_expr_lookup_utf8));
+  // gdv_fn_in_expr_lookup_float
+  args = {types->i64_type(),    // int64_t in holder ptr
+          types->float_type(),  // float value
+          types->i1_type()};    // bool in_validity
+
+  engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_float",
+                                  types->i1_type() /*return_type*/, args,
+                                  
reinterpret_cast<void*>(gdv_fn_in_expr_lookup_float));
+  // gdv_fn_in_expr_lookup_double
+  args = {types->i64_type(),     // int64_t in holder ptr
+          types->double_type(),  // double value
+          types->i1_type()};     // bool in_validity
 
+  engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_double",
+                                  types->i1_type() /*return_type*/, args,
+                                  
reinterpret_cast<void*>(gdv_fn_in_expr_lookup_double));
   // gdv_fn_populate_varlen_vector
   args = {types->i64_type(),      // int64_t execution_context
           types->i8_ptr_type(),   // int8_t* data ptr
diff --git a/cpp/src/gandiva/jni/jni_common.cc 
b/cpp/src/gandiva/jni/jni_common.cc
index 0495330..5a4cbb0 100644
--- a/cpp/src/gandiva/jni/jni_common.cc
+++ b/cpp/src/gandiva/jni/jni_common.cc
@@ -380,6 +380,22 @@ NodePtr ProtoTypeToInNode(const types::InNode& node) {
     return TreeExprBuilder::MakeInExpressionDecimal(field, decimal_values);
   }
 
+  if (node.has_floatvalues()) {
+    std::unordered_set<float> float_values;
+    for (int i = 0; i < node.floatvalues().floatvalues_size(); i++) {
+      float_values.insert(node.floatvalues().floatvalues(i).value());
+    }
+    return TreeExprBuilder::MakeInExpressionFloat(field, float_values);
+  }
+
+  if (node.has_doublevalues()) {
+    std::unordered_set<double> double_values;
+    for (int i = 0; i < node.doublevalues().doublevalues_size(); i++) {
+      double_values.insert(node.doublevalues().doublevalues(i).value());
+    }
+    return TreeExprBuilder::MakeInExpressionDouble(field, double_values);
+  }
+
   if (node.has_stringvalues()) {
     std::unordered_set<std::string> stringvalues;
     for (int i = 0; i < node.stringvalues().stringvalues_size(); i++) {
diff --git a/cpp/src/gandiva/llvm_generator.cc 
b/cpp/src/gandiva/llvm_generator.cc
index 1a80f1e..77feb99 100644
--- a/cpp/src/gandiva/llvm_generator.cc
+++ b/cpp/src/gandiva/llvm_generator.cc
@@ -1084,6 +1084,13 @@ void LLVMGenerator::Visitor::Visit(const 
InExprDexBase<int64_t>& dex) {
   VisitInExpression<int64_t>(dex);
 }
 
+void LLVMGenerator::Visitor::Visit(const InExprDexBase<float>& dex) {
+  VisitInExpression<float>(dex);
+}
+void LLVMGenerator::Visitor::Visit(const InExprDexBase<double>& dex) {
+  VisitInExpression<double>(dex);
+}
+
 void LLVMGenerator::Visitor::Visit(const 
InExprDexBase<gandiva::DecimalScalar128>& dex) {
   VisitInExpression<gandiva::DecimalScalar128>(dex);
 }
diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h
index 8ff9711..ff6d846 100644
--- a/cpp/src/gandiva/llvm_generator.h
+++ b/cpp/src/gandiva/llvm_generator.h
@@ -108,6 +108,8 @@ class GANDIVA_EXPORT LLVMGenerator {
     void Visit(const BooleanOrDex& dex) override;
     void Visit(const InExprDexBase<int32_t>& dex) override;
     void Visit(const InExprDexBase<int64_t>& dex) override;
+    void Visit(const InExprDexBase<float>& dex) override;
+    void Visit(const InExprDexBase<double>& dex) override;
     void Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) override;
     void Visit(const InExprDexBase<std::string>& dex) override;
     template <typename Type>
diff --git a/cpp/src/gandiva/node_visitor.h b/cpp/src/gandiva/node_visitor.h
index b118e49..8f233f5 100644
--- a/cpp/src/gandiva/node_visitor.h
+++ b/cpp/src/gandiva/node_visitor.h
@@ -17,6 +17,7 @@
 
 #pragma once
 
+#include <cmath>
 #include <string>
 
 #include "arrow/status.h"
@@ -46,6 +47,8 @@ class GANDIVA_EXPORT NodeVisitor {
   virtual Status Visit(const BooleanNode& node) = 0;
   virtual Status Visit(const InExpressionNode<int32_t>& node) = 0;
   virtual Status Visit(const InExpressionNode<int64_t>& node) = 0;
+  virtual Status Visit(const InExpressionNode<float>& node) = 0;
+  virtual Status Visit(const InExpressionNode<double>& node) = 0;
   virtual Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& 
node) = 0;
   virtual Status Visit(const InExpressionNode<std::string>& node) = 0;
 };
diff --git a/cpp/src/gandiva/proto/Types.proto 
b/cpp/src/gandiva/proto/Types.proto
index 7c0c49f..eb0d996 100644
--- a/cpp/src/gandiva/proto/Types.proto
+++ b/cpp/src/gandiva/proto/Types.proto
@@ -222,6 +222,8 @@ message InNode {
   optional StringConstants stringValues = 4;
   optional BinaryConstants binaryValues = 5;
   optional DecimalConstants decimalValues = 6;
+  optional FloatConstants floatValues = 7;
+  optional DoubleConstants doubleValues = 8;
 }
 
 message IntConstants {
@@ -236,6 +238,14 @@ message DecimalConstants {
   repeated DecimalNode decimalValues = 1;
 }
 
+message FloatConstants {
+  repeated FloatNode floatValues = 1;
+}
+
+message DoubleConstants {
+  repeated DoubleNode doubleValues = 1;
+}
+
 message StringConstants {
   repeated StringNode stringValues = 1;
 }
diff --git a/cpp/src/gandiva/tests/in_expr_test.cc 
b/cpp/src/gandiva/tests/in_expr_test.cc
index 6a31b1c..fc1a8a7 100644
--- a/cpp/src/gandiva/tests/in_expr_test.cc
+++ b/cpp/src/gandiva/tests/in_expr_test.cc
@@ -16,6 +16,7 @@
 // under the License.
 
 #include <gtest/gtest.h>
+#include <cmath>
 
 #include "arrow/memory_pool.h"
 #include "gandiva/filter.h"
@@ -26,6 +27,7 @@ namespace gandiva {
 
 using arrow::boolean;
 using arrow::float32;
+using arrow::float64;
 using arrow::int32;
 
 class TestIn : public ::testing::Test {
@@ -91,6 +93,86 @@ TEST_F(TestIn, TestInSimple) {
   EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
 }
 
+TEST_F(TestIn, TestInFloat) {
+  // schema for input fields
+  auto field0 = field("f0", float32());
+  auto schema = arrow::schema({field0});
+
+  // Build In f0 + f1 in (6, 11)
+  auto node_f0 = TreeExprBuilder::MakeField(field0);
+
+  std::unordered_set<float> in_constants({6.5f, 12.0f, 11.5f});
+  auto in_expr = TreeExprBuilder::MakeInExpressionFloat(node_f0, in_constants);
+  auto condition = TreeExprBuilder::MakeCondition(in_expr);
+
+  std::shared_ptr<Filter> filter;
+  auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+  EXPECT_TRUE(status.ok());
+
+  // Create a row-batch with some sample data
+  int num_records = 5;
+  auto array0 =
+      MakeArrowArrayFloat32({6.5f, 11.5f, 4, 3.15f, 6}, {true, true, false, 
true, true});
+  // expected output (indices for which condition matches)
+  auto exp = MakeArrowArrayUint16({0, 1});
+
+  // prepare input record batch
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+  std::shared_ptr<SelectionVector> selection_vector;
+  status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+  EXPECT_TRUE(status.ok());
+
+  // Evaluate expression
+  status = filter->Evaluate(*in_batch, selection_vector);
+  EXPECT_TRUE(status.ok());
+
+  // Validate results
+  EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestIn, TestInDouble) {
+  // schema for input fields
+  auto field0 = field("double0", float64());
+  auto field1 = field("double1", float64());
+  auto schema = arrow::schema({field0, field1});
+
+  auto node_f0 = TreeExprBuilder::MakeField(field0);
+  auto node_f1 = TreeExprBuilder::MakeField(field1);
+  auto sum_func =
+      TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, 
arrow::float64());
+  std::unordered_set<double> in_constants({3.14159265359, 15.5555555});
+  auto in_expr = TreeExprBuilder::MakeInExpressionDouble(sum_func, 
in_constants);
+  auto condition = TreeExprBuilder::MakeCondition(in_expr);
+
+  std::shared_ptr<Filter> filter;
+  auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+  EXPECT_TRUE(status.ok());
+
+  // Create a row-batch with some sample data
+  int num_records = 5;
+  auto array0 = MakeArrowArrayFloat64({1, 2, 3, 4, 11}, {true, true, true, 
false, false});
+  auto array1 = MakeArrowArrayFloat64({5, 9, 0.14159265359, 17, 4.5555555},
+                                      {true, true, true, true, true});
+
+  // expected output (indices for which condition matches)
+  auto exp = MakeArrowArrayUint16({2});
+
+  // prepare input record batch
+  auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, 
array1});
+
+  std::shared_ptr<SelectionVector> selection_vector;
+  status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+  EXPECT_TRUE(status.ok());
+
+  // Evaluate expression
+  status = filter->Evaluate(*in_batch, selection_vector);
+  EXPECT_TRUE(status.ok());
+
+  // Validate results
+  EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
 TEST_F(TestIn, TestInDecimal) {
   int32_t precision = 38;
   int32_t scale = 5;
diff --git a/cpp/src/gandiva/tree_expr_builder.cc 
b/cpp/src/gandiva/tree_expr_builder.cc
index b27b920..de8e344 100644
--- a/cpp/src/gandiva/tree_expr_builder.cc
+++ b/cpp/src/gandiva/tree_expr_builder.cc
@@ -215,6 +215,8 @@ MAKE_IN(Date64, int64_t);
 MAKE_IN(TimeStamp, int64_t);
 MAKE_IN(Time32, int32_t);
 MAKE_IN(Time64, int64_t);
+MAKE_IN(Float, float);
+MAKE_IN(Double, double);
 MAKE_IN(String, std::string);
 MAKE_IN(Binary, std::string);
 
diff --git a/cpp/src/gandiva/tree_expr_builder.h 
b/cpp/src/gandiva/tree_expr_builder.h
index 9c24fb9..94a4a17 100644
--- a/cpp/src/gandiva/tree_expr_builder.h
+++ b/cpp/src/gandiva/tree_expr_builder.h
@@ -17,6 +17,7 @@
 
 #pragma once
 
+#include <cmath>
 #include <memory>
 #include <string>
 #include <unordered_set>
@@ -106,6 +107,14 @@ class GANDIVA_EXPORT TreeExprBuilder {
   static NodePtr MakeInExpressionBinary(NodePtr node,
                                         const std::unordered_set<std::string>& 
constants);
 
+  /// \brief creates an in expression for float
+  static NodePtr MakeInExpressionFloat(NodePtr node,
+                                       const std::unordered_set<float>& 
constants);
+
+  /// \brief creates an in expression for double
+  static NodePtr MakeInExpressionDouble(NodePtr node,
+                                        const std::unordered_set<double>& 
constants);
+
   /// \brief Date as s/millis since epoch.
   static NodePtr MakeInExpressionDate32(NodePtr node,
                                         const std::unordered_set<int32_t>& 
constants);
diff --git 
a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java 
b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
index 08ef7f0..fef8e31 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
@@ -34,6 +34,8 @@ public class InNode implements TreeNode {
 
   private final Set<Integer> intValues;
   private final Set<Long> longValues;
+  private final Set<Float> floatValues;
+  private final Set<Double> doubleValues;
   private final Set<BigDecimal> decimalValues;
   private final Set<String> stringValues;
   private final Set<byte[]> binaryValues;
@@ -43,7 +45,8 @@ public class InNode implements TreeNode {
   private final Integer scale;
 
   private InNode(Set<Integer> values, Set<Long> longValues, Set<String> 
stringValues, Set<byte[]>
-          binaryValues, Set<BigDecimal> decimalValues, Integer precision, 
Integer scale, TreeNode node) {
+          binaryValues, Set<BigDecimal> decimalValues, Integer precision, 
Integer scale,
+                 Set<Float> floatValues, Set<Double> doubleValues, TreeNode 
node) {
     this.intValues = values;
     this.longValues = longValues;
     this.decimalValues = decimalValues;
@@ -51,33 +54,47 @@ public class InNode implements TreeNode {
     this.scale = scale;
     this.stringValues = stringValues;
     this.binaryValues = binaryValues;
+    this.floatValues = floatValues;
+    this.doubleValues = doubleValues;
     this.input = node;
   }
 
   public static InNode makeIntInExpr(TreeNode node, Set<Integer> intValues) {
     return new InNode(intValues,
-            null, null, null, null, null, null, node);
+            null, null, null, null, null, null, null,
+            null, node);
   }
 
   public static InNode makeLongInExpr(TreeNode node, Set<Long> longValues) {
     return new InNode(null, longValues,
-            null, null, null, null, null, node);
+            null, null, null, null, null, null,
+            null, node);
+  }
+
+  public static InNode makeFloatInExpr(TreeNode node, Set<Float> floatValues) {
+    return new InNode(null, null, null, null, null, null,
+            null, floatValues, null, node);
+  }
+
+  public static InNode makeDoubleInExpr(TreeNode node, Set<Double> 
doubleValues) {
+    return new InNode(null, null, null, null, null,
+            null, null, null, doubleValues, node);
   }
 
   public static InNode makeDecimalInExpr(TreeNode node, Set<BigDecimal> 
decimalValues,
                                          Integer precision, Integer scale) {
     return new InNode(null, null, null, null,
-            decimalValues, precision, scale, node);
+            decimalValues, precision, scale, null, null, node);
   }
 
   public static InNode makeStringInExpr(TreeNode node, Set<String> 
stringValues) {
     return new InNode(null, null, stringValues, null,
-            null, null, null, node);
+            null, null, null, null, null, node);
   }
 
   public static InNode makeBinaryInExpr(TreeNode node, Set<byte[]> 
binaryValues) {
     return new InNode(null, null, null, binaryValues,
-            null, null, null, node);
+            null, null, null, null, null, node);
   }
 
   @Override
@@ -96,6 +113,16 @@ public class InNode implements TreeNode {
       longValues.stream().forEach(val -> 
longConstants.addLongValues(GandivaTypes.LongNode.newBuilder()
               .setValue(val).build()));
       inNode.setLongValues(longConstants.build());
+    } else if (floatValues != null) {
+      GandivaTypes.FloatConstants.Builder floatConstants = 
GandivaTypes.FloatConstants.newBuilder();
+      floatValues.stream().forEach(val -> 
floatConstants.addFloatValues(GandivaTypes.FloatNode.newBuilder()
+              .setValue(val).build()));
+      inNode.setFloatValues(floatConstants.build());
+    } else if (doubleValues != null) {
+      GandivaTypes.DoubleConstants.Builder doubleConstants = 
GandivaTypes.DoubleConstants.newBuilder();
+      doubleValues.stream().forEach(val -> 
doubleConstants.addDoubleValues(GandivaTypes.DoubleNode.newBuilder()
+              .setValue(val).build()));
+      inNode.setDoubleValues(doubleConstants.build());
     } else if (decimalValues != null) {
       GandivaTypes.DecimalConstants.Builder decimalConstants = 
GandivaTypes.DecimalConstants.newBuilder();
       decimalValues.stream().forEach(val -> 
decimalConstants.addDecimalValues(GandivaTypes.DecimalNode.newBuilder()
diff --git 
a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
 
b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
index 067715c..8656e88 100644
--- 
a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
+++ 
b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
@@ -208,6 +208,16 @@ public class TreeBuilder {
     return InNode.makeDecimalInExpr(resultNode, decimalValues, precision, 
scale);
   }
 
+  public static TreeNode makeInExpressionFloat(TreeNode resultNode,
+                                                Set<Float> floatValues) {
+    return InNode.makeFloatInExpr(resultNode, floatValues);
+  }
+
+  public static TreeNode makeInExpressionDouble(TreeNode resultNode,
+                                                Set<Double> doubleValues) {
+    return InNode.makeDoubleInExpr(resultNode, doubleValues);
+  }
+
   public static TreeNode makeInExpressionString(TreeNode resultNode,
                                                 Set<String> stringValues) {
     return InNode.makeStringInExpr(resultNode, stringValues);
diff --git 
a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
 
b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
index 606c1a9..e51f458 100644
--- 
a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
+++ 
b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
@@ -1220,10 +1220,10 @@ public class ProjectorTest extends BaseEvaluatorTest {
     output.add(bitVector);
     eval.evaluate(batch, output);
 
-    for (int i = 1; i < 5; i++) {
+    for (int i = 0; i < 4; i++) {
       assertTrue(bitVector.getObject(i).booleanValue());
     }
-    for (int i = 5; i < 16; i++) {
+    for (int i = 4; i < 16; i++) {
       assertFalse(bitVector.getObject(i).booleanValue());
     }
 
@@ -1252,7 +1252,9 @@ public class ProjectorTest extends BaseEvaluatorTest {
     Schema schema = new Schema(Lists.newArrayList(c1));
     Projector eval = Projector.make(schema, Lists.newArrayList(expr));
 
+    // Create a row-batch with some sample data to look for
     int numRows = 16;
+    // Only the first 8 values will be valid.
     byte[] validity = new byte[]{(byte) 255, 0};
     String[] c1Values =
             new String[]{"1", "2", "3", "4", "-0.0", "6", "7", "8", "9", "10", 
"11", "12", "13", "14",
@@ -1276,6 +1278,57 @@ public class ProjectorTest extends BaseEvaluatorTest {
     output.add(bitVector);
     eval.evaluate(batch, output);
 
+    // The first four values in the vector must match the expression, but not 
the other ones.
+    for (int i = 0; i < 4; i++) {
+      assertTrue(bitVector.getObject(i).booleanValue());
+    }
+    for (int i = 4; i < 16; i++) {
+      assertFalse(bitVector.getObject(i).booleanValue());
+    }
+
+    releaseRecordBatch(batch);
+    releaseValueVectors(output);
+    eval.close();
+  }
+
+  @Test
+  public void testInExprDouble() throws GandivaException, Exception {
+    Field c1 = Field.nullable("c1", float64);
+
+    TreeNode inExpr =
+            TreeBuilder.makeInExpressionDouble(TreeBuilder.makeField(c1),
+                    Sets.newHashSet(1.0, -0.0, 3.0, 4.0, Double.NaN,
+                            Double.POSITIVE_INFINITY, 
Double.NEGATIVE_INFINITY));
+    ExpressionTree expr = TreeBuilder.makeExpression(inExpr, 
Field.nullable("result", boolType));
+    Schema schema = new Schema(Lists.newArrayList(c1));
+    Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+    // Create a row-batch with some sample data to look for
+    int numRows = 16;
+    // Only the first 8 values will be valid.
+    byte[] validity = new byte[]{(byte) 255, 0};
+    double[] c1Values = new double[]{1, -0.0, Double.NEGATIVE_INFINITY , 
Double.POSITIVE_INFINITY, Double.NaN,
+        6, 7, 8, 9, 10, 11, 12, 13, 14, 4 , 3};
+
+    ArrowBuf c1Validity = buf(validity);
+    ArrowBuf c1Data = doubleBuf(c1Values);
+    ArrowBuf c2Validity = buf(validity);
+
+    ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+    ArrowRecordBatch batch =
+            new ArrowRecordBatch(
+                    numRows,
+                    Lists.newArrayList(fieldNode, fieldNode),
+                    Lists.newArrayList(c1Validity, c1Data, c2Validity));
+
+    BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+    bitVector.allocateNew(numRows);
+
+    List<ValueVector> output = new ArrayList<ValueVector>();
+    output.add(bitVector);
+    eval.evaluate(batch, output);
+
+    // The first five values in the vector must match the expression, but not 
the other ones.
     for (int i = 1; i < 5; i++) {
       assertTrue(bitVector.getObject(i).booleanValue());
     }

Reply via email to