This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 405bac0619 [GLUTEN-10226][FLINK] Support decimal arithmetic operations 
(#10218)
405bac0619 is described below

commit 405bac06195fd6244bbfbc403b656b6fa1080cb2
Author: lgbo <[email protected]>
AuthorDate: Thu Jul 24 11:47:30 2025 +0800

    [GLUTEN-10226][FLINK] Support decimal arithmetic operations (#10218)
---
 .../functions/DecimalRexCallConverters.java        | 103 +++++++++++++++++++++
 .../rexnode/functions/RexCallConverterFactory.java |  19 +++-
 .../runtime/stream/custom/ScalarFunctionsTest.java |  43 +++++++++
 3 files changed, 162 insertions(+), 3 deletions(-)

diff --git 
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/DecimalRexCallConverters.java
 
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/DecimalRexCallConverters.java
new file mode 100644
index 0000000000..2b36a4705a
--- /dev/null
+++ 
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/DecimalRexCallConverters.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.rexnode.functions;
+
+import org.apache.gluten.rexnode.RexConversionContext;
+import org.apache.gluten.rexnode.RexNodeConverter;
+import org.apache.gluten.rexnode.ValidationResult;
+
+import io.github.zhztheplayer.velox4j.expression.CallTypedExpr;
+import io.github.zhztheplayer.velox4j.expression.CastTypedExpr;
+import io.github.zhztheplayer.velox4j.expression.TypedExpr;
+import io.github.zhztheplayer.velox4j.type.BigIntType;
+import io.github.zhztheplayer.velox4j.type.DecimalType;
+import io.github.zhztheplayer.velox4j.type.DoubleType;
+import io.github.zhztheplayer.velox4j.type.IntegerType;
+import io.github.zhztheplayer.velox4j.type.Type;
+
+import org.apache.flink.util.FlinkRuntimeException;
+
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexNode;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+class DecimalArithmeticOperatorRexCallConverters extends BaseRexCallConverter {
+
+  public DecimalArithmeticOperatorRexCallConverters(String functionName) {
+    super(functionName);
+  }
+
+  @Override
+  public ValidationResult isSuitable(RexCall callNode, RexConversionContext 
context) {
+    Type resultType = getResultType(callNode);
+    List<Type> operandsType =
+        callNode.getOperands().stream()
+            .map(RexNode::getType)
+            .map(RexNodeConverter::toType)
+            .collect(Collectors.toList());
+    // Check if the callNode is a decimal operation.
+    if (resultType instanceof DecimalType
+        || operandsType.stream().anyMatch(type -> type instanceof 
DecimalType)) {
+      return ValidationResult.success();
+    }
+    return ValidationResult.failure(
+        String.format(
+            "Decimal operation requires operands to be of decimal type, but 
found: %s",
+            getFunctionProtoTypeName(callNode)));
+  }
+
+  @Override
+  public TypedExpr toTypedExpr(RexCall callNode, RexConversionContext context) 
{
+    List<TypedExpr> params = getParams(callNode, context);
+    Type resultType = getResultType(callNode);
+    List<TypedExpr> castedParams =
+        params.stream()
+            .map(param -> castExprToDecimalType(param, resultType))
+            .collect(Collectors.toList());
+    return new CallTypedExpr(resultType, castedParams, functionName);
+  }
+
+  // If the type is not decimal, convert it to decimal type.
+  private TypedExpr castExprToDecimalType(TypedExpr expr, Type 
functionResultType) {
+    Type returnType = expr.getReturnType();
+
+    if (returnType instanceof IntegerType) {
+      // Cast BigInt to DecimalType.
+      return CastTypedExpr.create(new DecimalType(10, 0), expr, false);
+    } else if (returnType instanceof BigIntType) {
+      // Cast Integer to DecimalType
+      return CastTypedExpr.create(new DecimalType(19, 0), expr, false);
+    } else if (returnType instanceof DecimalType) {
+      if (functionResultType instanceof DecimalType) {
+        // If the return type is also DecimalType, no need to cast.
+        return expr;
+      } else if (functionResultType instanceof DoubleType) {
+        // The result is of type double when a decimal is operated with a 
double.
+        return CastTypedExpr.create(new DoubleType(), expr, false);
+      }
+      throw new FlinkRuntimeException(
+          "Not supported type for decimal conversion: " + 
functionResultType.getClass().getName());
+    } else if (returnType instanceof DoubleType) {
+      return expr;
+    } else {
+      throw new FlinkRuntimeException(
+          "Not supported type for decimal conversion: " + 
returnType.getClass().getName());
+    }
+  }
+}
diff --git 
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
 
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
index eac71dca7e..8cf2a4e4cf 100644
--- 
a/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
+++ 
b/gluten-flink/planner/src/main/java/org/apache/gluten/rexnode/functions/RexCallConverterFactory.java
@@ -57,9 +57,22 @@ public class RexCallConverterFactory {
                   () -> new StringCompareRexCallConverter("equalto"),
                   () -> new StringNumberCompareRexCallConverter("equalto"))),
           Map.entry(
-              "*", Arrays.asList(() -> new 
BasicArithmeticOperatorRexCallConverter("multiply"))),
-          Map.entry("-", Arrays.asList(() -> new SubtractRexCallConverter())),
-          Map.entry("+", Arrays.asList(() -> new 
BasicArithmeticOperatorRexCallConverter("add"))),
+              "/", Arrays.asList(() -> new 
DecimalArithmeticOperatorRexCallConverters("divide"))),
+          Map.entry(
+              "*",
+              Arrays.asList(
+                  () -> new 
BasicArithmeticOperatorRexCallConverter("multiply"),
+                  () -> new 
DecimalArithmeticOperatorRexCallConverters("multiply"))),
+          Map.entry(
+              "-",
+              Arrays.asList(
+                  () -> new SubtractRexCallConverter(),
+                  () -> new 
DecimalArithmeticOperatorRexCallConverters("subtract"))),
+          Map.entry(
+              "+",
+              Arrays.asList(
+                  () -> new BasicArithmeticOperatorRexCallConverter("add"),
+                  () -> new 
DecimalArithmeticOperatorRexCallConverters("add"))),
           Map.entry("MOD", Arrays.asList(() -> new ModRexCallConverter())),
           Map.entry("CAST", Arrays.asList(() -> new 
DefaultRexCallConverter("cast"))),
           Map.entry("CASE", Arrays.asList(() -> new 
DefaultRexCallConverter("if"))),
diff --git 
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java
 
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java
index 8f0371d70e..1bc5b09f17 100644
--- 
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java
+++ 
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScalarFunctionsTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.types.Row;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
+import java.math.BigDecimal;
 import java.util.Arrays;
 import java.util.List;
 
@@ -126,4 +127,46 @@ class ScalarFunctionsTest extends GlutenStreamingTestBase {
     String query4 = "select c = d as x from tblEqual where a > 0";
     runAndCheck(query4, Arrays.asList("+I[false]", "+I[true]", "+I[false]"));
   }
+
+  @Test
+  void testDecimal() {
+    List<Row> rows =
+        Arrays.asList(
+            Row.of(1, new BigDecimal("1.0"), new BigDecimal("1.0"), 2L, 1.0),
+            Row.of(2, new BigDecimal("2.0"), new BigDecimal("2.0"), 3L, 3.0),
+            Row.of(3, new BigDecimal("3.0"), new BigDecimal("3.0"), 4L, 4.0));
+    createSimpleBoundedValuesTable(
+        "tblDecimal", "a int, b decimal(11, 2), c decimal(10, 3), d bigint, e 
double", rows);
+    String query = "select b + c as x from tblDecimal where a > 0";
+    runAndCheck(query, Arrays.asList("+I[2.000]", "+I[4.000]", "+I[6.000]"));
+
+    query = "select b + a as x from tblDecimal where a > 0";
+    runAndCheck(query, Arrays.asList("+I[2.00]", "+I[4.00]", "+I[6.00]"));
+
+    query = "select b + d as x from tblDecimal where a > 0";
+    runAndCheck(query, Arrays.asList("+I[3.00]", "+I[5.00]", "+I[7.00]"));
+
+    query = "select b - c as x from tblDecimal where a > 0";
+    runAndCheck(query, Arrays.asList("+I[0.000]", "+I[0.000]", "+I[0.000]"));
+
+    query = "select b - a as x from tblDecimal where a > 0";
+    runAndCheck(query, Arrays.asList("+I[0.00]", "+I[0.00]", "+I[0.00]"));
+
+    query = "select b * c as x from tblDecimal where a > 0";
+    runAndCheck(query, Arrays.asList("+I[1.00000]", "+I[4.00000]", 
"+I[9.00000]"));
+
+    query = "select b * d as x from tblDecimal where a > 0";
+    runAndCheck(query, Arrays.asList("+I[2.00]", "+I[6.00]", "+I[12.00]"));
+
+    query = "select b / c as x from tblDecimal where a > 0";
+    runAndCheck(
+        query, Arrays.asList("+I[1.0000000000000]", "+I[1.0000000000000]", 
"+I[1.0000000000000]"));
+
+    query = "select b / a as x from tblDecimal where a > 0";
+    runAndCheck(
+        query, Arrays.asList("+I[1.0000000000000]", "+I[1.0000000000000]", 
"+I[1.0000000000000]"));
+
+    query = "select b + e as x from tblDecimal where a > 0";
+    runAndCheck(query, Arrays.asList("+I[2.0]", "+I[5.0]", "+I[7.0]"));
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to