This is an automated email from the ASF dual-hosted git repository.
philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 405bac0619 [GLUTEN-10226][FLINK] Support decimal arithmetic operations
(#10218)
405bac0619 is described below
commit 405bac06195fd6244bbfbc403b656b6fa1080cb2
Author: lgbo <[email protected]>
AuthorDate: Thu Jul 24 11:47:30 2025 +0800
[GLUTEN-10226][FLINK] Support decimal arithmetic operations (#10218)
---
.../functions/DecimalRexCallConverters.java | 103 +++++++++++++++++++++
.../rexnode/functions/RexCallConverterFactory.java | 19 +++-
.../runtime/stream/custom/ScalarFunctionsTest.java | 43 +++++++++
3 files changed, 162 insertions(+), 3 deletions(-)
diff --git
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/DecimalRexCallConverters.java
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/DecimalRexCallConverters.java
new file mode 100644
index 0000000000..2b36a4705a
--- /dev/null
+++
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/DecimalRexCallConverters.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.rexnode.functions;
+
+import org.apache.gluten.rexnode.RexConversionContext;
+import org.apache.gluten.rexnode.RexNodeConverter;
+import org.apache.gluten.rexnode.ValidationResult;
+
+import io.github.zhztheplayer.velox4j.expression.CallTypedExpr;
+import io.github.zhztheplayer.velox4j.expression.CastTypedExpr;
+import io.github.zhztheplayer.velox4j.expression.TypedExpr;
+import io.github.zhztheplayer.velox4j.type.BigIntType;
+import io.github.zhztheplayer.velox4j.type.DecimalType;
+import io.github.zhztheplayer.velox4j.type.DoubleType;
+import io.github.zhztheplayer.velox4j.type.IntegerType;
+import io.github.zhztheplayer.velox4j.type.Type;
+
+import org.apache.flink.util.FlinkRuntimeException;
+
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexNode;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+class DecimalArithmeticOperatorRexCallConverters extends BaseRexCallConverter {
+
+ public DecimalArithmeticOperatorRexCallConverters(String functionName) {
+ super(functionName);
+ }
+
+ @Override
+ public ValidationResult isSuitable(RexCall callNode, RexConversionContext
context) {
+ Type resultType = getResultType(callNode);
+ List<Type> operandsType =
+ callNode.getOperands().stream()
+ .map(RexNode::getType)
+ .map(RexNodeConverter::toType)
+ .collect(Collectors.toList());
+ // Check if the callNode is a decimal operation.
+ if (resultType instanceof DecimalType
+ || operandsType.stream().anyMatch(type -> type instanceof
DecimalType)) {
+ return ValidationResult.success();
+ }
+ return ValidationResult.failure(
+ String.format(
+ "Decimal operation requires operands to be of decimal type, but
found: %s",
+ getFunctionProtoTypeName(callNode)));
+ }
+
+ @Override
+ public TypedExpr toTypedExpr(RexCall callNode, RexConversionContext context)
{
+ List<TypedExpr> params = getParams(callNode, context);
+ Type resultType = getResultType(callNode);
+ List<TypedExpr> castedParams =
+ params.stream()
+ .map(param -> castExprToDecimalType(param, resultType))
+ .collect(Collectors.toList());
+ return new CallTypedExpr(resultType, castedParams, functionName);
+ }
+
+ // If the type is not decimal, convert it to decimal type.
+ private TypedExpr castExprToDecimalType(TypedExpr expr, Type
functionResultType) {
+ Type returnType = expr.getReturnType();
+
+ if (returnType instanceof IntegerType) {
+ // Cast BigInt to DecimalType.
+ return CastTypedExpr.create(new DecimalType(10, 0), expr, false);
+ } else if (returnType instanceof BigIntType) {
+ // Cast Integer to DecimalType
+ return CastTypedExpr.create(new DecimalType(19, 0), expr, false);
+ } else if (returnType instanceof DecimalType) {
+ if (functionResultType instanceof DecimalType) {
+ // If the return type is also DecimalType, no need to cast.
+ return expr;
+ } else if (functionResultType instanceof DoubleType) {
+ // The result is of type double when a decimal is operated with a
double.
+ return CastTypedExpr.create(new DoubleType(), expr, false);
+ }
+ throw new FlinkRuntimeException(
+ "Not supported type for decimal conversion: " +
functionResultType.getClass().getName());
+ } else if (returnType instanceof DoubleType) {
+ return expr;
+ } else {
+ throw new FlinkRuntimeException(
+ "Not supported type for decimal conversion: " +
returnType.getClass().getName());
+ }
+ }
+}
diff --git
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
index eac71dca7e..8cf2a4e4cf 100644
---
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
+++
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
@@ -57,9 +57,22 @@ public class RexCallConverterFactory {
() -> new StringCompareRexCallConverter("equalto"),
() -> new StringNumberCompareRexCallConverter("equalto"))),
Map.entry(
- "*", Arrays.asList(() -> new
BasicArithmeticOperatorRexCallConverter("multiply"))),
- Map.entry("-", Arrays.asList(() -> new SubtractRexCallConverter())),
- Map.entry("+", Arrays.asList(() -> new
BasicArithmeticOperatorRexCallConverter("add"))),
+ "/", Arrays.asList(() -> new
DecimalArithmeticOperatorRexCallConverters("divide"))),
+ Map.entry(
+ "*",
+ Arrays.asList(
+ () -> new
BasicArithmeticOperatorRexCallConverter("multiply"),
+ () -> new
DecimalArithmeticOperatorRexCallConverters("multiply"))),
+ Map.entry(
+ "-",
+ Arrays.asList(
+ () -> new SubtractRexCallConverter(),
+ () -> new
DecimalArithmeticOperatorRexCallConverters("subtract"))),
+ Map.entry(
+ "+",
+ Arrays.asList(
+ () -> new BasicArithmeticOperatorRexCallConverter("add"),
+ () -> new
DecimalArithmeticOperatorRexCallConverters("add"))),
Map.entry("MOD", Arrays.asList(() -> new ModRexCallConverter())),
Map.entry("CAST", Arrays.asList(() -> new
DefaultRexCallConverter("cast"))),
Map.entry("CASE", Arrays.asList(() -> new
DefaultRexCallConverter("if"))),
diff --git
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java
index 8f0371d70e..1bc5b09f17 100644
---
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java
+++
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.types.Row;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import java.math.BigDecimal;
import java.util.Arrays;
import java.util.List;
@@ -126,4 +127,46 @@ class ScalarFunctionsTest extends GlutenStreamingTestBase {
String query4 = "select c = d as x from tblEqual where a > 0";
runAndCheck(query4, Arrays.asList("+I[false]", "+I[true]", "+I[false]"));
}
+
+ @Test
+ void testDecimal() {
+ List<Row> rows =
+ Arrays.asList(
+ Row.of(1, new BigDecimal("1.0"), new BigDecimal("1.0"), 2L, 1.0),
+ Row.of(2, new BigDecimal("2.0"), new BigDecimal("2.0"), 3L, 3.0),
+ Row.of(3, new BigDecimal("3.0"), new BigDecimal("3.0"), 4L, 4.0));
+ createSimpleBoundedValuesTable(
+ "tblDecimal", "a int, b decimal(11, 2), c decimal(10, 3), d bigint, e
double", rows);
+ String query = "select b + c as x from tblDecimal where a > 0";
+ runAndCheck(query, Arrays.asList("+I[2.000]", "+I[4.000]", "+I[6.000]"));
+
+ query = "select b + a as x from tblDecimal where a > 0";
+ runAndCheck(query, Arrays.asList("+I[2.00]", "+I[4.00]", "+I[6.00]"));
+
+ query = "select b + d as x from tblDecimal where a > 0";
+ runAndCheck(query, Arrays.asList("+I[3.00]", "+I[5.00]", "+I[7.00]"));
+
+ query = "select b - c as x from tblDecimal where a > 0";
+ runAndCheck(query, Arrays.asList("+I[0.000]", "+I[0.000]", "+I[0.000]"));
+
+ query = "select b - a as x from tblDecimal where a > 0";
+ runAndCheck(query, Arrays.asList("+I[0.00]", "+I[0.00]", "+I[0.00]"));
+
+ query = "select b * c as x from tblDecimal where a > 0";
+ runAndCheck(query, Arrays.asList("+I[1.00000]", "+I[4.00000]",
"+I[9.00000]"));
+
+ query = "select b * d as x from tblDecimal where a > 0";
+ runAndCheck(query, Arrays.asList("+I[2.00]", "+I[6.00]", "+I[12.00]"));
+
+ query = "select b / c as x from tblDecimal where a > 0";
+ runAndCheck(
+ query, Arrays.asList("+I[1.0000000000000]", "+I[1.0000000000000]",
"+I[1.0000000000000]"));
+
+ query = "select b / a as x from tblDecimal where a > 0";
+ runAndCheck(
+ query, Arrays.asList("+I[1.0000000000000]", "+I[1.0000000000000]",
"+I[1.0000000000000]"));
+
+ query = "select b + e as x from tblDecimal where a > 0";
+ runAndCheck(query, Arrays.asList("+I[2.0]", "+I[5.0]", "+I[7.0]"));
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]