This is an automated email from the ASF dual-hosted git repository.
praveenbingo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 8f5bda4 ARROW-11986: [C++][Gandiva] Implement IN expressions for
doubles and floats
8f5bda4 is described below
commit 8f5bda4f785b0a73e5ef8d6786bea9867dee03d8
Author: frank400 <[email protected]>
AuthorDate: Thu May 20 10:44:17 2021 +0530
ARROW-11986: [C++][Gandiva] Implement IN expressions for doubles and floats
Add functions to process IN expressions for Arrows fields with double and
float types.
Closes #9724 from jvictorhuguenin/feature/add-float-double-decimal-in-expr
and squashes the following commits:
05c283d62 <frank400> Fix Expression validation error message
182b340a8 <frank400> Fix checkstyle
9088b09c4 <frank400> Fix checkstyle
69745f61a <frank400> Fix checkstyle
27a8adf84 <frank400> Add comments to java tests to make it clearer
24d32cfe2 <frank400> Fix InNode constructor parameters
ece7702f6 <frank400> Fix InNode constructor parameters
50cfa1426 <frank400> Fix rebase errors
1730fd50e <frank400> Fix problems with backward compatibility
4be099baa <frank400> Fix lint problem after declaring float values
explicitly
6f2686791 <frank400> Fix test typing for float numbers
7b41e9767 <frank400> Add test cases for -0.0, +inf and -inf and NaN
d3af44db2 <frank400> Fix Lint problem within TestInFloat
9adbd5692 <frank400> Fix array float tiping for build propose
e744d5d3e <frank400> Add JNI functions and tests
cb887ab69 <frank400> Fix Lint problems
af2d70998 <frank400> Fix Lint problems
a2b904b57 <frank400> Change the float_t and double_t to float and double
570006d44 <frank400> Fix CI problem for truncating doubles to floats
1a206c93e <frank400> Remove unnecessary call
8cc34739e <frank400> Fix build problems with mingw
6be7a60ae <frank400> Fix lint problems
64905ddc4 <frank400> Fix CI problems
efa4d1e2f <frank400> test the implemented in expressions
94adfb8d3 <frank400> implements in expressions for floats and doubles
94f111708 <frank400> fix wrong typed double_t and float_t at stub functions
54baa9cc6 <frank400> Fix problems with backward compatibility
36ed0b708 <frank400> Fix lint problem after declaring float values
explicitly
6b47c75d2 <frank400> Fix test typing for float numbers
9f57c9549 <frank400> Add test cases for -0.0, +inf and -inf and NaN
c7901663f <frank400> Fix Lint problem within TestInFloat
074013cc3 <frank400> Fix jni register for double expressions
9c1cea8df <frank400> Fix array float tiping for build propose
2b464cae9 <frank400> Add JNI functions and tests
7e226726c <frank400> Fix Lint problems
195a129ba <frank400> Fix Lint problems
b610afec6 <frank400> Change the float_t and double_t to float and double
f8d7b6e8c <frank400> Fix CI problem for truncating doubles to floats
8b448cfc3 <frank400> Remove unnecessary call
89d822548 <frank400> Fix build problems with mingw
a25558e5f <frank400> Fix lint problems
f51ac3d01 <frank400> Fix CI problems
f28ed9b56 <frank400> uncomment the implemented expressions
3312ace71 <frank400> test the implemented in expressions
8fbe192ed <frank400> implements in expressions for floats and doubles
Authored-by: frank400 <[email protected]>
Signed-off-by: Praveen <[email protected]>
---
cpp/src/gandiva/dex.h | 18 +++++
cpp/src/gandiva/dex_visitor.h | 5 ++
cpp/src/gandiva/expr_decomposer.cc | 2 +
cpp/src/gandiva/expr_decomposer.h | 3 +
cpp/src/gandiva/expr_validator.cc | 17 +++--
cpp/src/gandiva/expr_validator.h | 2 +
cpp/src/gandiva/gdv_function_stubs.cc | 31 ++++++++
cpp/src/gandiva/jni/jni_common.cc | 16 +++++
cpp/src/gandiva/llvm_generator.cc | 7 ++
cpp/src/gandiva/llvm_generator.h | 2 +
cpp/src/gandiva/node_visitor.h | 3 +
cpp/src/gandiva/proto/Types.proto | 10 +++
cpp/src/gandiva/tests/in_expr_test.cc | 82 ++++++++++++++++++++++
cpp/src/gandiva/tree_expr_builder.cc | 2 +
cpp/src/gandiva/tree_expr_builder.h | 9 +++
.../apache/arrow/gandiva/expression/InNode.java | 39 ++++++++--
.../arrow/gandiva/expression/TreeBuilder.java | 10 +++
.../arrow/gandiva/evaluator/ProjectorTest.java | 57 ++++++++++++++-
18 files changed, 303 insertions(+), 12 deletions(-)
diff --git a/cpp/src/gandiva/dex.h b/cpp/src/gandiva/dex.h
index 3920f82..d1115c0 100644
--- a/cpp/src/gandiva/dex.h
+++ b/cpp/src/gandiva/dex.h
@@ -354,6 +354,24 @@ class InExprDex<int64_t> : public InExprDexBase<int64_t> {
};
template <>
+class InExprDex<float> : public InExprDexBase<float> {
+ public:
+ InExprDex(const ValueValidityPairVector& args, const
std::unordered_set<float>& values)
+ : InExprDexBase(args, values) {
+ runtime_function_ = "gdv_fn_in_expr_lookup_float";
+ }
+};
+
+template <>
+class InExprDex<double> : public InExprDexBase<double> {
+ public:
+ InExprDex(const ValueValidityPairVector& args, const
std::unordered_set<double>& values)
+ : InExprDexBase(args, values) {
+ runtime_function_ = "gdv_fn_in_expr_lookup_double";
+ }
+};
+
+template <>
class InExprDex<gandiva::DecimalScalar128>
: public InExprDexBase<gandiva::DecimalScalar128> {
public:
diff --git a/cpp/src/gandiva/dex_visitor.h b/cpp/src/gandiva/dex_visitor.h
index ba5de97..5d160bb 100644
--- a/cpp/src/gandiva/dex_visitor.h
+++ b/cpp/src/gandiva/dex_visitor.h
@@ -17,6 +17,7 @@
#pragma once
+#include <cmath>
#include <string>
#include "arrow/util/logging.h"
@@ -61,6 +62,8 @@ class GANDIVA_EXPORT DexVisitor {
virtual void Visit(const BooleanOrDex& dex) = 0;
virtual void Visit(const InExprDexBase<int32_t>& dex) = 0;
virtual void Visit(const InExprDexBase<int64_t>& dex) = 0;
+ virtual void Visit(const InExprDexBase<float>& dex) = 0;
+ virtual void Visit(const InExprDexBase<double>& dex) = 0;
virtual void Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) = 0;
virtual void Visit(const InExprDexBase<std::string>& dex) = 0;
};
@@ -85,6 +88,8 @@ class GANDIVA_EXPORT DexDefaultVisitor : public DexVisitor {
VISIT_DCHECK(BooleanOrDex)
VISIT_DCHECK(InExprDexBase<int32_t>)
VISIT_DCHECK(InExprDexBase<int64_t>)
+ VISIT_DCHECK(InExprDexBase<float>)
+ VISIT_DCHECK(InExprDexBase<double>)
VISIT_DCHECK(InExprDexBase<gandiva::DecimalScalar128>)
VISIT_DCHECK(InExprDexBase<std::string>)
};
diff --git a/cpp/src/gandiva/expr_decomposer.cc
b/cpp/src/gandiva/expr_decomposer.cc
index 07252b4..1c09d28 100644
--- a/cpp/src/gandiva/expr_decomposer.cc
+++ b/cpp/src/gandiva/expr_decomposer.cc
@@ -209,6 +209,8 @@ Status ExprDecomposer::Visit(const
InExpressionNode<gandiva::DecimalScalar128>&
MAKE_VISIT_IN(int32_t);
MAKE_VISIT_IN(int64_t);
+MAKE_VISIT_IN(float);
+MAKE_VISIT_IN(double);
MAKE_VISIT_IN(std::string);
Status ExprDecomposer::Visit(const LiteralNode& node) {
diff --git a/cpp/src/gandiva/expr_decomposer.h
b/cpp/src/gandiva/expr_decomposer.h
index 3e8e67d..f68b8a8 100644
--- a/cpp/src/gandiva/expr_decomposer.h
+++ b/cpp/src/gandiva/expr_decomposer.h
@@ -17,6 +17,7 @@
#pragma once
+#include <cmath>
#include <memory>
#include <stack>
#include <string>
@@ -66,6 +67,8 @@ class GANDIVA_EXPORT ExprDecomposer : public NodeVisitor {
Status Visit(const BooleanNode& node) override;
Status Visit(const InExpressionNode<int32_t>& node) override;
Status Visit(const InExpressionNode<int64_t>& node) override;
+ Status Visit(const InExpressionNode<float>& node) override;
+ Status Visit(const InExpressionNode<double>& node) override;
Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& node)
override;
Status Visit(const InExpressionNode<std::string>& node) override;
diff --git a/cpp/src/gandiva/expr_validator.cc
b/cpp/src/gandiva/expr_validator.cc
index fd46c28..c3c784c 100644
--- a/cpp/src/gandiva/expr_validator.cc
+++ b/cpp/src/gandiva/expr_validator.cc
@@ -156,6 +156,14 @@ Status ExprValidator::Visit(const
InExpressionNode<int64_t>& node) {
return ValidateInExpression(node.values().size(),
node.eval_expr()->return_type(),
arrow::int64());
}
+Status ExprValidator::Visit(const InExpressionNode<float>& node) {
+ return ValidateInExpression(node.values().size(),
node.eval_expr()->return_type(),
+ arrow::float32());
+}
+Status ExprValidator::Visit(const InExpressionNode<double>& node) {
+ return ValidateInExpression(node.values().size(),
node.eval_expr()->return_type(),
+ arrow::float64());
+}
Status ExprValidator::Visit(const InExpressionNode<gandiva::DecimalScalar128>&
node) {
return ValidateInExpression(node.values().size(),
node.eval_expr()->return_type(),
@@ -173,10 +181,11 @@ Status ExprValidator::ValidateInExpression(size_t
number_of_values,
ARROW_RETURN_IF(number_of_values == 0,
Status::ExpressionValidationError(
"IN Expression needs a non-empty constant list to
match."));
- ARROW_RETURN_IF(!in_expr_return_type->Equals(type_of_values),
- Status::ExpressionValidationError(
- "Evaluation expression for IN clause returns ",
in_expr_return_type,
- " values are of type", type_of_values));
+ ARROW_RETURN_IF(
+ !in_expr_return_type->Equals(type_of_values),
+ Status::ExpressionValidationError(
+ "Evaluation expression for IN clause returns ",
in_expr_return_type->ToString(),
+ " values are of type", type_of_values->ToString()));
return Status::OK();
}
diff --git a/cpp/src/gandiva/expr_validator.h b/cpp/src/gandiva/expr_validator.h
index e25afe5..daaf508 100644
--- a/cpp/src/gandiva/expr_validator.h
+++ b/cpp/src/gandiva/expr_validator.h
@@ -60,6 +60,8 @@ class ExprValidator : public NodeVisitor {
Status Visit(const BooleanNode& node) override;
Status Visit(const InExpressionNode<int32_t>& node) override;
Status Visit(const InExpressionNode<int64_t>& node) override;
+ Status Visit(const InExpressionNode<float>& node) override;
+ Status Visit(const InExpressionNode<double>& node) override;
Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& node)
override;
Status Visit(const InExpressionNode<std::string>& node) override;
Status ValidateInExpression(size_t number_of_values, DataTypePtr
in_expr_return_type,
diff --git a/cpp/src/gandiva/gdv_function_stubs.cc
b/cpp/src/gandiva/gdv_function_stubs.cc
index 832eebc..acf3f56 100644
--- a/cpp/src/gandiva/gdv_function_stubs.cc
+++ b/cpp/src/gandiva/gdv_function_stubs.cc
@@ -102,6 +102,22 @@ bool gdv_fn_in_expr_lookup_decimal(int64_t ptr, int64_t
value_high, int64_t valu
return holder->HasValue(value);
}
+bool gdv_fn_in_expr_lookup_float(int64_t ptr, float value, bool in_validity) {
+ if (!in_validity) {
+ return false;
+ }
+ gandiva::InHolder<float>* holder =
reinterpret_cast<gandiva::InHolder<float>*>(ptr);
+ return holder->HasValue(value);
+}
+
+bool gdv_fn_in_expr_lookup_double(int64_t ptr, double value, bool in_validity)
{
+ if (!in_validity) {
+ return false;
+ }
+ gandiva::InHolder<double>* holder =
reinterpret_cast<gandiva::InHolder<double>*>(ptr);
+ return holder->HasValue(value);
+}
+
bool gdv_fn_in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len,
bool in_validity) {
if (!in_validity) {
@@ -504,7 +520,22 @@ void ExportedStubFunctions::AddMappings(Engine* engine)
const {
engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_utf8",
types->i1_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_in_expr_lookup_utf8));
+ // gdv_fn_in_expr_lookup_float
+ args = {types->i64_type(), // int64_t in holder ptr
+ types->float_type(), // float value
+ types->i1_type()}; // bool in_validity
+
+ engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_float",
+ types->i1_type() /*return_type*/, args,
+
reinterpret_cast<void*>(gdv_fn_in_expr_lookup_float));
+ // gdv_fn_in_expr_lookup_double
+ args = {types->i64_type(), // int64_t in holder ptr
+ types->double_type(), // double value
+ types->i1_type()}; // bool in_validity
+ engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_double",
+ types->i1_type() /*return_type*/, args,
+
reinterpret_cast<void*>(gdv_fn_in_expr_lookup_double));
// gdv_fn_populate_varlen_vector
args = {types->i64_type(), // int64_t execution_context
types->i8_ptr_type(), // int8_t* data ptr
diff --git a/cpp/src/gandiva/jni/jni_common.cc
b/cpp/src/gandiva/jni/jni_common.cc
index 0495330..5a4cbb0 100644
--- a/cpp/src/gandiva/jni/jni_common.cc
+++ b/cpp/src/gandiva/jni/jni_common.cc
@@ -380,6 +380,22 @@ NodePtr ProtoTypeToInNode(const types::InNode& node) {
return TreeExprBuilder::MakeInExpressionDecimal(field, decimal_values);
}
+ if (node.has_floatvalues()) {
+ std::unordered_set<float> float_values;
+ for (int i = 0; i < node.floatvalues().floatvalues_size(); i++) {
+ float_values.insert(node.floatvalues().floatvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionFloat(field, float_values);
+ }
+
+ if (node.has_doublevalues()) {
+ std::unordered_set<double> double_values;
+ for (int i = 0; i < node.doublevalues().doublevalues_size(); i++) {
+ double_values.insert(node.doublevalues().doublevalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionDouble(field, double_values);
+ }
+
if (node.has_stringvalues()) {
std::unordered_set<std::string> stringvalues;
for (int i = 0; i < node.stringvalues().stringvalues_size(); i++) {
diff --git a/cpp/src/gandiva/llvm_generator.cc
b/cpp/src/gandiva/llvm_generator.cc
index 1a80f1e..77feb99 100644
--- a/cpp/src/gandiva/llvm_generator.cc
+++ b/cpp/src/gandiva/llvm_generator.cc
@@ -1084,6 +1084,13 @@ void LLVMGenerator::Visitor::Visit(const
InExprDexBase<int64_t>& dex) {
VisitInExpression<int64_t>(dex);
}
+void LLVMGenerator::Visitor::Visit(const InExprDexBase<float>& dex) {
+ VisitInExpression<float>(dex);
+}
+void LLVMGenerator::Visitor::Visit(const InExprDexBase<double>& dex) {
+ VisitInExpression<double>(dex);
+}
+
void LLVMGenerator::Visitor::Visit(const
InExprDexBase<gandiva::DecimalScalar128>& dex) {
VisitInExpression<gandiva::DecimalScalar128>(dex);
}
diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h
index 8ff9711..ff6d846 100644
--- a/cpp/src/gandiva/llvm_generator.h
+++ b/cpp/src/gandiva/llvm_generator.h
@@ -108,6 +108,8 @@ class GANDIVA_EXPORT LLVMGenerator {
void Visit(const BooleanOrDex& dex) override;
void Visit(const InExprDexBase<int32_t>& dex) override;
void Visit(const InExprDexBase<int64_t>& dex) override;
+ void Visit(const InExprDexBase<float>& dex) override;
+ void Visit(const InExprDexBase<double>& dex) override;
void Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) override;
void Visit(const InExprDexBase<std::string>& dex) override;
template <typename Type>
diff --git a/cpp/src/gandiva/node_visitor.h b/cpp/src/gandiva/node_visitor.h
index b118e49..8f233f5 100644
--- a/cpp/src/gandiva/node_visitor.h
+++ b/cpp/src/gandiva/node_visitor.h
@@ -17,6 +17,7 @@
#pragma once
+#include <cmath>
#include <string>
#include "arrow/status.h"
@@ -46,6 +47,8 @@ class GANDIVA_EXPORT NodeVisitor {
virtual Status Visit(const BooleanNode& node) = 0;
virtual Status Visit(const InExpressionNode<int32_t>& node) = 0;
virtual Status Visit(const InExpressionNode<int64_t>& node) = 0;
+ virtual Status Visit(const InExpressionNode<float>& node) = 0;
+ virtual Status Visit(const InExpressionNode<double>& node) = 0;
virtual Status Visit(const InExpressionNode<gandiva::DecimalScalar128>&
node) = 0;
virtual Status Visit(const InExpressionNode<std::string>& node) = 0;
};
diff --git a/cpp/src/gandiva/proto/Types.proto
b/cpp/src/gandiva/proto/Types.proto
index 7c0c49f..eb0d996 100644
--- a/cpp/src/gandiva/proto/Types.proto
+++ b/cpp/src/gandiva/proto/Types.proto
@@ -222,6 +222,8 @@ message InNode {
optional StringConstants stringValues = 4;
optional BinaryConstants binaryValues = 5;
optional DecimalConstants decimalValues = 6;
+ optional FloatConstants floatValues = 7;
+ optional DoubleConstants doubleValues = 8;
}
message IntConstants {
@@ -236,6 +238,14 @@ message DecimalConstants {
repeated DecimalNode decimalValues = 1;
}
+message FloatConstants {
+ repeated FloatNode floatValues = 1;
+}
+
+message DoubleConstants {
+ repeated DoubleNode doubleValues = 1;
+}
+
message StringConstants {
repeated StringNode stringValues = 1;
}
diff --git a/cpp/src/gandiva/tests/in_expr_test.cc
b/cpp/src/gandiva/tests/in_expr_test.cc
index 6a31b1c..fc1a8a7 100644
--- a/cpp/src/gandiva/tests/in_expr_test.cc
+++ b/cpp/src/gandiva/tests/in_expr_test.cc
@@ -16,6 +16,7 @@
// under the License.
#include <gtest/gtest.h>
+#include <cmath>
#include "arrow/memory_pool.h"
#include "gandiva/filter.h"
@@ -26,6 +27,7 @@ namespace gandiva {
using arrow::boolean;
using arrow::float32;
+using arrow::float64;
using arrow::int32;
class TestIn : public ::testing::Test {
@@ -91,6 +93,86 @@ TEST_F(TestIn, TestInSimple) {
EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
}
+TEST_F(TestIn, TestInFloat) {
+ // schema for input fields
+ auto field0 = field("f0", float32());
+ auto schema = arrow::schema({field0});
+
+ // Build In f0 + f1 in (6, 11)
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+
+ std::unordered_set<float> in_constants({6.5f, 12.0f, 11.5f});
+ auto in_expr = TreeExprBuilder::MakeInExpressionFloat(node_f0, in_constants);
+ auto condition = TreeExprBuilder::MakeCondition(in_expr);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 =
+ MakeArrowArrayFloat32({6.5f, 11.5f, 4, 3.15f, 6}, {true, true, false,
true, true});
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint16({0, 1});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
+TEST_F(TestIn, TestInDouble) {
+ // schema for input fields
+ auto field0 = field("double0", float64());
+ auto field1 = field("double1", float64());
+ auto schema = arrow::schema({field0, field1});
+
+ auto node_f0 = TreeExprBuilder::MakeField(field0);
+ auto node_f1 = TreeExprBuilder::MakeField(field1);
+ auto sum_func =
+ TreeExprBuilder::MakeFunction("add", {node_f0, node_f1},
arrow::float64());
+ std::unordered_set<double> in_constants({3.14159265359, 15.5555555});
+ auto in_expr = TreeExprBuilder::MakeInExpressionDouble(sum_func,
in_constants);
+ auto condition = TreeExprBuilder::MakeCondition(in_expr);
+
+ std::shared_ptr<Filter> filter;
+ auto status = Filter::Make(schema, condition, TestConfiguration(), &filter);
+ EXPECT_TRUE(status.ok());
+
+ // Create a row-batch with some sample data
+ int num_records = 5;
+ auto array0 = MakeArrowArrayFloat64({1, 2, 3, 4, 11}, {true, true, true,
false, false});
+ auto array1 = MakeArrowArrayFloat64({5, 9, 0.14159265359, 17, 4.5555555},
+ {true, true, true, true, true});
+
+ // expected output (indices for which condition matches)
+ auto exp = MakeArrowArrayUint16({2});
+
+ // prepare input record batch
+ auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0,
array1});
+
+ std::shared_ptr<SelectionVector> selection_vector;
+ status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Evaluate expression
+ status = filter->Evaluate(*in_batch, selection_vector);
+ EXPECT_TRUE(status.ok());
+
+ // Validate results
+ EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray());
+}
+
TEST_F(TestIn, TestInDecimal) {
int32_t precision = 38;
int32_t scale = 5;
diff --git a/cpp/src/gandiva/tree_expr_builder.cc
b/cpp/src/gandiva/tree_expr_builder.cc
index b27b920..de8e344 100644
--- a/cpp/src/gandiva/tree_expr_builder.cc
+++ b/cpp/src/gandiva/tree_expr_builder.cc
@@ -215,6 +215,8 @@ MAKE_IN(Date64, int64_t);
MAKE_IN(TimeStamp, int64_t);
MAKE_IN(Time32, int32_t);
MAKE_IN(Time64, int64_t);
+MAKE_IN(Float, float);
+MAKE_IN(Double, double);
MAKE_IN(String, std::string);
MAKE_IN(Binary, std::string);
diff --git a/cpp/src/gandiva/tree_expr_builder.h
b/cpp/src/gandiva/tree_expr_builder.h
index 9c24fb9..94a4a17 100644
--- a/cpp/src/gandiva/tree_expr_builder.h
+++ b/cpp/src/gandiva/tree_expr_builder.h
@@ -17,6 +17,7 @@
#pragma once
+#include <cmath>
#include <memory>
#include <string>
#include <unordered_set>
@@ -106,6 +107,14 @@ class GANDIVA_EXPORT TreeExprBuilder {
static NodePtr MakeInExpressionBinary(NodePtr node,
const std::unordered_set<std::string>&
constants);
+ /// \brief creates an in expression for float
+ static NodePtr MakeInExpressionFloat(NodePtr node,
+ const std::unordered_set<float>&
constants);
+
+ /// \brief creates an in expression for double
+ static NodePtr MakeInExpressionDouble(NodePtr node,
+ const std::unordered_set<double>&
constants);
+
/// \brief Date as s/millis since epoch.
static NodePtr MakeInExpressionDate32(NodePtr node,
const std::unordered_set<int32_t>&
constants);
diff --git
a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
index 08ef7f0..fef8e31 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
@@ -34,6 +34,8 @@ public class InNode implements TreeNode {
private final Set<Integer> intValues;
private final Set<Long> longValues;
+ private final Set<Float> floatValues;
+ private final Set<Double> doubleValues;
private final Set<BigDecimal> decimalValues;
private final Set<String> stringValues;
private final Set<byte[]> binaryValues;
@@ -43,7 +45,8 @@ public class InNode implements TreeNode {
private final Integer scale;
private InNode(Set<Integer> values, Set<Long> longValues, Set<String>
stringValues, Set<byte[]>
- binaryValues, Set<BigDecimal> decimalValues, Integer precision,
Integer scale, TreeNode node) {
+ binaryValues, Set<BigDecimal> decimalValues, Integer precision,
Integer scale,
+ Set<Float> floatValues, Set<Double> doubleValues, TreeNode
node) {
this.intValues = values;
this.longValues = longValues;
this.decimalValues = decimalValues;
@@ -51,33 +54,47 @@ public class InNode implements TreeNode {
this.scale = scale;
this.stringValues = stringValues;
this.binaryValues = binaryValues;
+ this.floatValues = floatValues;
+ this.doubleValues = doubleValues;
this.input = node;
}
public static InNode makeIntInExpr(TreeNode node, Set<Integer> intValues) {
return new InNode(intValues,
- null, null, null, null, null, null, node);
+ null, null, null, null, null, null, null,
+ null, node);
}
public static InNode makeLongInExpr(TreeNode node, Set<Long> longValues) {
return new InNode(null, longValues,
- null, null, null, null, null, node);
+ null, null, null, null, null, null,
+ null, node);
+ }
+
+ public static InNode makeFloatInExpr(TreeNode node, Set<Float> floatValues) {
+ return new InNode(null, null, null, null, null, null,
+ null, floatValues, null, node);
+ }
+
+ public static InNode makeDoubleInExpr(TreeNode node, Set<Double>
doubleValues) {
+ return new InNode(null, null, null, null, null,
+ null, null, null, doubleValues, node);
}
public static InNode makeDecimalInExpr(TreeNode node, Set<BigDecimal>
decimalValues,
Integer precision, Integer scale) {
return new InNode(null, null, null, null,
- decimalValues, precision, scale, node);
+ decimalValues, precision, scale, null, null, node);
}
public static InNode makeStringInExpr(TreeNode node, Set<String>
stringValues) {
return new InNode(null, null, stringValues, null,
- null, null, null, node);
+ null, null, null, null, null, node);
}
public static InNode makeBinaryInExpr(TreeNode node, Set<byte[]>
binaryValues) {
return new InNode(null, null, null, binaryValues,
- null, null, null, node);
+ null, null, null, null, null, node);
}
@Override
@@ -96,6 +113,16 @@ public class InNode implements TreeNode {
longValues.stream().forEach(val ->
longConstants.addLongValues(GandivaTypes.LongNode.newBuilder()
.setValue(val).build()));
inNode.setLongValues(longConstants.build());
+ } else if (floatValues != null) {
+ GandivaTypes.FloatConstants.Builder floatConstants =
GandivaTypes.FloatConstants.newBuilder();
+ floatValues.stream().forEach(val ->
floatConstants.addFloatValues(GandivaTypes.FloatNode.newBuilder()
+ .setValue(val).build()));
+ inNode.setFloatValues(floatConstants.build());
+ } else if (doubleValues != null) {
+ GandivaTypes.DoubleConstants.Builder doubleConstants =
GandivaTypes.DoubleConstants.newBuilder();
+ doubleValues.stream().forEach(val ->
doubleConstants.addDoubleValues(GandivaTypes.DoubleNode.newBuilder()
+ .setValue(val).build()));
+ inNode.setDoubleValues(doubleConstants.build());
} else if (decimalValues != null) {
GandivaTypes.DecimalConstants.Builder decimalConstants =
GandivaTypes.DecimalConstants.newBuilder();
decimalValues.stream().forEach(val ->
decimalConstants.addDecimalValues(GandivaTypes.DecimalNode.newBuilder()
diff --git
a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
index 067715c..8656e88 100644
---
a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
+++
b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
@@ -208,6 +208,16 @@ public class TreeBuilder {
return InNode.makeDecimalInExpr(resultNode, decimalValues, precision,
scale);
}
+ public static TreeNode makeInExpressionFloat(TreeNode resultNode,
+ Set<Float> floatValues) {
+ return InNode.makeFloatInExpr(resultNode, floatValues);
+ }
+
+ public static TreeNode makeInExpressionDouble(TreeNode resultNode,
+ Set<Double> doubleValues) {
+ return InNode.makeDoubleInExpr(resultNode, doubleValues);
+ }
+
public static TreeNode makeInExpressionString(TreeNode resultNode,
Set<String> stringValues) {
return InNode.makeStringInExpr(resultNode, stringValues);
diff --git
a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
index 606c1a9..e51f458 100644
---
a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
+++
b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
@@ -1220,10 +1220,10 @@ public class ProjectorTest extends BaseEvaluatorTest {
output.add(bitVector);
eval.evaluate(batch, output);
- for (int i = 1; i < 5; i++) {
+ for (int i = 0; i < 4; i++) {
assertTrue(bitVector.getObject(i).booleanValue());
}
- for (int i = 5; i < 16; i++) {
+ for (int i = 4; i < 16; i++) {
assertFalse(bitVector.getObject(i).booleanValue());
}
@@ -1252,7 +1252,9 @@ public class ProjectorTest extends BaseEvaluatorTest {
Schema schema = new Schema(Lists.newArrayList(c1));
Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+ // Create a row-batch with some sample data to look for
int numRows = 16;
+ // Only the first 8 values will be valid.
byte[] validity = new byte[]{(byte) 255, 0};
String[] c1Values =
new String[]{"1", "2", "3", "4", "-0.0", "6", "7", "8", "9", "10",
"11", "12", "13", "14",
@@ -1276,6 +1278,57 @@ public class ProjectorTest extends BaseEvaluatorTest {
output.add(bitVector);
eval.evaluate(batch, output);
+ // The first four values in the vector must match the expression, but not
the other ones.
+ for (int i = 0; i < 4; i++) {
+ assertTrue(bitVector.getObject(i).booleanValue());
+ }
+ for (int i = 4; i < 16; i++) {
+ assertFalse(bitVector.getObject(i).booleanValue());
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testInExprDouble() throws GandivaException, Exception {
+ Field c1 = Field.nullable("c1", float64);
+
+ TreeNode inExpr =
+ TreeBuilder.makeInExpressionDouble(TreeBuilder.makeField(c1),
+ Sets.newHashSet(1.0, -0.0, 3.0, 4.0, Double.NaN,
+ Double.POSITIVE_INFINITY,
Double.NEGATIVE_INFINITY));
+ ExpressionTree expr = TreeBuilder.makeExpression(inExpr,
Field.nullable("result", boolType));
+ Schema schema = new Schema(Lists.newArrayList(c1));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ // Create a row-batch with some sample data to look for
+ int numRows = 16;
+ // Only the first 8 values will be valid.
+ byte[] validity = new byte[]{(byte) 255, 0};
+ double[] c1Values = new double[]{1, -0.0, Double.NEGATIVE_INFINITY ,
Double.POSITIVE_INFINITY, Double.NaN,
+ 6, 7, 8, 9, 10, 11, 12, 13, 14, 4 , 3};
+
+ ArrowBuf c1Validity = buf(validity);
+ ArrowBuf c1Data = doubleBuf(c1Values);
+ ArrowBuf c2Validity = buf(validity);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(c1Validity, c1Data, c2Validity));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ // The first five values in the vector must match the expression, but not
the other ones.
for (int i = 1; i < 5; i++) {
assertTrue(bitVector.getObject(i).booleanValue());
}