This is an automated email from the ASF dual-hosted git repository.
mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 141dfd9cfd [CALCITE-6684] Support Arrow filter pushdown conditions
that have subexpressions of a decimal type
141dfd9cfd is described below
commit 141dfd9cfd8b32102ce6942e076503041bf245f9
Author: Cancai Cai <[email protected]>
AuthorDate: Mon Nov 11 00:08:45 2024 +0800
[CALCITE-6684] Support Arrow filter pushdown conditions that have
subexpressions of a decimal type
---
.../apache/calcite/adapter/arrow/ArrowTable.java | 19 +++++---
.../calcite/adapter/arrow/ArrowTranslator.java | 29 ++++++------
.../calcite/adapter/arrow/ArrowAdapterTest.java | 54 ++++++++++++++++++++++
3 files changed, 80 insertions(+), 22 deletions(-)
diff --git
a/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowTable.java
b/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowTable.java
index ec9fe405e8..eddf83c63d 100644
--- a/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowTable.java
+++ b/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowTable.java
@@ -185,18 +185,23 @@ public class ArrowTable extends AbstractTable
}
private static TreeNode makeLiteralNode(String literal, String type) {
- switch (type) {
- case "integer":
+ if (type.startsWith("decimal")) {
+ String[] typeParts =
+ type.substring(type.indexOf('(') + 1, type.indexOf(')')).split(",");
+ int precision = parseInt(typeParts[0]);
+ int scale = parseInt(typeParts[1]);
+ return TreeBuilder.makeDecimalLiteral(literal, precision, scale);
+ } else if (type.equals("integer")) {
return TreeBuilder.makeLiteral(parseInt(literal));
- case "long":
+ } else if (type.equals("long")) {
return TreeBuilder.makeLiteral(parseLong(literal));
- case "float":
+ } else if (type.equals("float")) {
return TreeBuilder.makeLiteral(parseFloat(literal));
- case "double":
+ } else if (type.equals("double")) {
return TreeBuilder.makeLiteral(parseDouble(literal));
- case "string":
+ } else if (type.equals("string")) {
return TreeBuilder.makeStringLiteral(literal.substring(1,
literal.length() - 1));
- default:
+ } else {
throw new IllegalArgumentException("Invalid literal " + literal
+ ", type " + type);
}
diff --git
a/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowTranslator.java
b/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowTranslator.java
index b2e067db4f..1651974cb4 100644
--- a/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowTranslator.java
+++ b/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowTranslator.java
@@ -31,7 +31,6 @@ import org.apache.calcite.util.DateString;
import org.checkerframework.checker.nullness.qual.Nullable;
-import java.math.BigDecimal;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.List;
@@ -194,7 +193,7 @@ class ArrowTranslator {
private String translateOp2(String op, String name, RexLiteral right) {
Object value = literalValue(right);
String valueString = value.toString();
- String valueType = getLiteralType(value);
+ String valueType = getLiteralType(right.getType());
if (value instanceof String) {
final RelDataTypeField field = requireNonNull(rowType.getField(name,
true, false), "field");
@@ -234,20 +233,20 @@ class ArrowTranslator {
return name + " " + op;
}
- private static String getLiteralType(Object literal) {
- if (literal instanceof BigDecimal) {
- BigDecimal bigDecimalLiteral = (BigDecimal) literal;
- int scale = bigDecimalLiteral.scale();
- if (scale == 0) {
- return "integer";
- } else if (scale > 0) {
- return "float";
- }
- } else if (String.class.equals(literal.getClass())) {
- return "string";
- } else if (literal instanceof Double) {
+ private static String getLiteralType(RelDataType type) {
+ if (type.getSqlTypeName() == SqlTypeName.DECIMAL) {
+ return "decimal" + "(" + type.getPrecision() + "," + type.getScale() +
")";
+ } else if (type.getSqlTypeName() == SqlTypeName.REAL) {
return "float";
+ } else if (type.getSqlTypeName() == SqlTypeName.DOUBLE) {
+ return "double";
+ } else if (type.getSqlTypeName() == SqlTypeName.INTEGER) {
+ return "integer";
+ } else if (type.getSqlTypeName() == SqlTypeName.VARCHAR
+ || type.getSqlTypeName() == SqlTypeName.CHAR) {
+ return "string";
+ } else {
+ throw new UnsupportedOperationException("Unsupported type " + type);
}
- throw new UnsupportedOperationException("Unsupported literal " + literal);
}
}
diff --git
a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterTest.java
b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterTest.java
index 9e60bca2a4..39aa328019 100644
--- a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterTest.java
+++ b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterTest.java
@@ -961,4 +961,58 @@ class ArrowAdapterTest {
.returns(result)
.explainContains(plan);
}
+
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6684">[CALCITE-6684]
+ * Arrow adapter should supports filter conditions of Decimal type</a>. */
+ @Test void testArrowProjectFieldsWithDecimalFilter() {
+ String sql = "select \"decimalField\"\n"
+ + "from arrowdatatype\n"
+ + "where \"decimalField\" = 1.00";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(decimalField=[$8])\n"
+ + " ArrowFilter(condition=[=($8, 1)])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "decimalField=1.00\n";
+
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .returns(result)
+ .explainContains(plan);
+ }
+
+ @Test void testArrowProjectFieldsWithDoubleFilter() {
+ String sql = "select \"doubleField\"\n"
+ + "from arrowdatatype\n"
+ + "where \"doubleField\" = 1.00";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(doubleField=[$6])\n"
+ + " ArrowFilter(condition=[=($6, 1.0E0)])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "doubleField=1.0\n";
+
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .returns(result)
+ .explainContains(plan);
+ }
+
+ @Test void testArrowProjectFieldsWithStringFilter() {
+ String sql = "select \"stringField\"\n"
+ + "from arrowdatatype\n"
+ + "where \"stringField\" = '1'";
+ String plan = "PLAN=ArrowToEnumerableConverter\n"
+ + " ArrowProject(stringField=[$3])\n"
+ + " ArrowFilter(condition=[=($3, '1')])\n"
+ + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1,
2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ String result = "stringField=1\n";
+
+ CalciteAssert.that()
+ .with(arrow)
+ .query(sql)
+ .returns(result)
+ .explainContains(plan);
+ }
}