This is an automated email from the ASF dual-hosted git repository.

korlov pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/ignite-3.git


The following commit(s) were added to refs/heads/main by this push:
     new 51bbcf05e9 IGNITE-20311 Sql. Fix behaviour of ROUND function (#2690)
51bbcf05e9 is described below

commit 51bbcf05e9096821528b1201121bb0081d2dd588
Author: Max Zhuravkov <[email protected]>
AuthorDate: Mon Oct 23 19:16:52 2023 +0300

    IGNITE-20311 Sql. Fix behaviour of ROUND function (#2690)
---
 .../internal/sql/engine/ItDataTypesTest.java       |   2 +-
 .../internal/sql/engine/ItFunctionsTest.java       | 112 +++++++++++++++++++
 .../internal/sql/engine/ItSqlOperatorsTest.java    |   4 +-
 .../sql/types/decimal/test_decimal_ops.test        |  64 +++++------
 .../sql/engine/exec/exp/IgniteSqlFunctions.java    | 122 ++++++++++++++++++++-
 .../internal/sql/engine/exec/exp/RexImpTable.java  |   3 +-
 .../sql/engine/sql/fun/IgniteSqlOperatorTable.java |  34 +++++-
 .../internal/sql/engine/util/IgniteMethod.java     |  27 ++++-
 .../engine/exec/exp/IgniteSqlFunctionsTest.java    | 115 +++++++++++++++++++
 9 files changed, 442 insertions(+), 41 deletions(-)

diff --git 
a/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItDataTypesTest.java
 
b/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItDataTypesTest.java
index 092e5fca9f..1c05f2cc3a 100644
--- 
a/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItDataTypesTest.java
+++ 
b/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItDataTypesTest.java
@@ -322,7 +322,7 @@ public class ItDataTypesTest extends 
ClusterPerClassIntegrationTest {
                 "SELECT DECIMAL '0.09'  BETWEEN DECIMAL '0.06' AND DECIMAL 
'0.07'")
                 .returns(false).check();
 
-        assertQuery("SELECT ROUND(DECIMAL '10.000', 2)").returns(new 
BigDecimal("10.00")).check();
+        assertQuery("SELECT ROUND(DECIMAL '10.000', 2)").returns(new 
BigDecimal("10.000")).check();
         assertQuery("SELECT CAST(DECIMAL '10.000' AS 
VARCHAR)").returns("10.000").check();
         assertQuery("SELECT CAST(DECIMAL '10.000' AS 
INTEGER)").returns(10).check();
 
diff --git 
a/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItFunctionsTest.java
 
b/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItFunctionsTest.java
index 04e74d22c5..997f6274f8 100644
--- 
a/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItFunctionsTest.java
+++ 
b/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItFunctionsTest.java
@@ -29,14 +29,22 @@ import static org.junit.jupiter.api.Assertions.assertSame;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
+import java.math.BigDecimal;
 import java.time.LocalDate;
 import java.time.LocalDateTime;
 import java.time.LocalTime;
 import java.time.temporal.Temporal;
+import java.util.function.Function;
+import java.util.stream.Stream;
 import org.apache.calcite.sql.validate.SqlValidatorException;
+import org.apache.ignite.internal.sql.engine.util.MetadataMatcher;
 import org.apache.ignite.lang.ErrorGroups.Sql;
 import org.apache.ignite.lang.IgniteException;
+import org.apache.ignite.sql.ColumnType;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
 
 /**
  * Test Ignite SQL functions.
@@ -344,6 +352,110 @@ public class ItFunctionsTest extends 
ClusterPerClassIntegrationTest {
         assertThrowsWithCause(() -> sql("SELECT SUBSTR('1234567', 1, -3)"), 
IgniteException.class, "negative substring length");
     }
 
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("integralTypes")
+    public void testIntType(ParseNum parse, MetadataMatcher matcher) {
+        String v1 = parse.value("42");
+        String v2 = parse.value("45");
+        String v3 = parse.value("47");
+
+        assertQuery(format("SELECT ROUND({}), ROUND({}, 0)", v1, v1))
+                .returns(parse.apply("42"), parse.apply("42"))
+                .columnMetadata(matcher, matcher)
+                .check();
+
+        String query = format(
+                "SELECT ROUND({}, -2), ROUND({}, -1), ROUND({}, -1), ROUND({}, 
-1)",
+                v1, v1, v2, v3);
+
+        assertQuery(query)
+                .returns(parse.apply("0"), parse.apply("40"), 
parse.apply("50"), parse.apply("50"))
+                .columnMetadata(matcher, matcher, matcher, matcher)
+                .check();
+    }
+
+    private static Stream<Arguments> integralTypes() {
+        return Stream.of(
+                Arguments.of(new ParseNum("TINYINT", Byte::parseByte), new 
MetadataMatcher().type(ColumnType.INT8)),
+                Arguments.of(new ParseNum("SMALLINT", Short::parseShort), new 
MetadataMatcher().type(ColumnType.INT16)),
+                Arguments.of(new ParseNum("INTEGER", Integer::parseInt), new 
MetadataMatcher().type(ColumnType.INT32)),
+                Arguments.of(new ParseNum("BIGINT", Long::parseLong), new 
MetadataMatcher().type(ColumnType.INT64)),
+                Arguments.of(new ParseNum("DECIMAL(4)", BigDecimal::new), new 
MetadataMatcher().type(ColumnType.DECIMAL).precision(4))
+        );
+    }
+
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("nonIntegerTypes")
+    public void testRoundNonIntegralTypes(ParseNum parse, MetadataMatcher 
round1, MetadataMatcher round2) {
+        String v1 = parse.value("42.123");
+        String v2 = parse.value("45.123");
+        String v3 = parse.value("47.123");
+
+        assertQuery(format("SELECT ROUND({}), ROUND({}, 0)", v1, v1))
+                .returns(parse.apply("42"), parse.apply("42.000"))
+                .columnMetadata(round1, round2)
+                .check();
+
+        assertQuery(format("SELECT ROUND({}, -2), ROUND({}, -1), ROUND({}, 
-1),  ROUND({}, -1)", v1, v1, v2, v3))
+                .returns(parse.apply("0.000"), parse.apply("40.000"), 
parse.apply("50.000"), parse.apply("50.000"))
+                .columnMetadata(round2, round2, round2, round2)
+                .check();
+
+        String v4 = parse.value("1.123");
+
+        assertQuery(format("SELECT ROUND({}, s) FROM (VALUES (-2), (-1), (0), 
(1), (2), (3), (4), (100) ) t(s)", v4))
+                .returns(parse.apply("0.000"))
+                .returns(parse.apply("0.000"))
+                .returns(parse.apply("1.000"))
+                .returns(parse.apply("1.100"))
+                .returns(parse.apply("1.120"))
+                .returns(parse.apply("1.123"))
+                .returns(parse.apply("1.123"))
+                .returns(parse.apply("1.123"))
+                .columnMetadata(round2)
+                .check();
+    }
+
+    private static Stream<Arguments> nonIntegerTypes() {
+        MetadataMatcher matchFloat = new 
MetadataMatcher().type(ColumnType.FLOAT);
+        MetadataMatcher matchDouble = new 
MetadataMatcher().type(ColumnType.DOUBLE);
+
+        MetadataMatcher matchDecimal1 = new 
MetadataMatcher().type(ColumnType.DECIMAL).precision(5).scale(0);
+        MetadataMatcher matchDecimal2 = new 
MetadataMatcher().type(ColumnType.DECIMAL).precision(5).scale(3);
+
+        return Stream.of(
+                Arguments.of(new ParseNum("REAL", Float::parseFloat), 
matchFloat, matchFloat),
+                Arguments.of(new ParseNum("DOUBLE", Double::parseDouble), 
matchDouble, matchDouble),
+                Arguments.of(new ParseNum("DECIMAL(5, 3)", BigDecimal::new), 
matchDecimal1, matchDecimal2)
+        );
+    }
+
+    /** Numeric type parser. */
+    public static final class ParseNum {
+
+        private final String typeName;
+
+        private final Function<String, Object> func;
+
+        ParseNum(String typeName, Function<String, Object> func) {
+            this.typeName = typeName;
+            this.func = func;
+        }
+
+        public Object apply(String val) {
+            return func.apply(val);
+        }
+
+        public String value(String val) {
+            return val + "::" + typeName;
+        }
+
+        @Override
+        public String toString() {
+            return typeName;
+        }
+    }
+
     /**
      * An interface describing a clock reporting time in a specified temporal 
value.
      *
diff --git 
a/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItSqlOperatorsTest.java
 
b/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItSqlOperatorsTest.java
index b4b03f8c6c..9189220143 100644
--- 
a/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItSqlOperatorsTest.java
+++ 
b/modules/runner/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItSqlOperatorsTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.internal.sql.engine;
 
+import java.math.BigDecimal;
 import java.time.LocalDate;
 import java.time.LocalDateTime;
 import java.time.Period;
@@ -192,8 +193,7 @@ public class ItSqlOperatorsTest extends 
ClusterPerClassIntegrationTest {
         assertExpression("COT(1)").returns(1.0d / Math.tan(1)).check();
         assertExpression("DEGREES(1)").returns(Math.toDegrees(1)).check();
         assertExpression("RADIANS(1)").returns(Math.toRadians(1)).check();
-        // TODO https://issues.apache.org/jira/browse/IGNITE-20311
-        // 
assertExpression("ROUND(1.7)").returns(BigDecimal.valueOf(2)).check();
+        assertExpression("ROUND(1.7)").returns(BigDecimal.valueOf(2)).check();
         assertExpression("SIGN(-5)").returns(-1).check();
         assertExpression("SIN(1)").returns(Math.sin(1)).check();
         assertExpression("SINH(1)").returns(Math.sinh(1)).check();
diff --git 
a/modules/runner/src/integrationTest/sql/types/decimal/test_decimal_ops.test 
b/modules/runner/src/integrationTest/sql/types/decimal/test_decimal_ops.test
index 145edaf1ad..f4f913c87d 100644
--- a/modules/runner/src/integrationTest/sql/types/decimal/test_decimal_ops.test
+++ b/modules/runner/src/integrationTest/sql/types/decimal/test_decimal_ops.test
@@ -210,18 +210,18 @@ SELECT ROUND('100.3908147521'::DECIMAL(18,10), 
0)::VARCHAR,
           ROUND('100.3908147521'::DECIMAL(18,10), 20)::VARCHAR,
           ROUND(NULL::DECIMAL, 0)
 ----
-100
-100.4
-100.39
-100.391
-100.3908
-100.39081
-100.390815
-100.3908148
-100.39081475
-100.390814752
+100.0000000000
+100.4000000000
+100.3900000000
+100.3910000000
+100.3908000000
+100.3908100000
+100.3908150000
+100.3908148000
+100.3908147500
+100.3908147520
+100.3908147521
 100.3908147521
-100.39081475210000000000
 NULL
 
 # negative precision
@@ -248,27 +248,27 @@ SELECT ROUND('1049578239572094512.32415'::DECIMAL(30,10), 
0)::VARCHAR,
           ROUND('1049578239572094512.32415'::DECIMAL(30,10), -20)::VARCHAR,
           ROUND('1049578239572094512.32415'::DECIMAL(30,10), -19842)::VARCHAR
 ----
-1049578239572094512
-1049578239572094510
-1049578239572094500
-1049578239572095000
-1049578239572090000
-1049578239572100000
-1049578239572000000
-1049578239570000000
-1049578239600000000
-1049578240000000000
-1049578240000000000
-1049578200000000000
-1049578000000000000
-1049580000000000000
-1049600000000000000
-1050000000000000000
-1050000000000000000
-1000000000000000000
-0
-0
-0
+1049578239572094512.0000000000
+1049578239572094510.0000000000
+1049578239572094500.0000000000
+1049578239572095000.0000000000
+1049578239572090000.0000000000
+1049578239572100000.0000000000
+1049578239572000000.0000000000
+1049578239570000000.0000000000
+1049578239600000000.0000000000
+1049578240000000000.0000000000
+1049578240000000000.0000000000
+1049578200000000000.0000000000
+1049578000000000000.0000000000
+1049580000000000000.0000000000
+1049600000000000000.0000000000
+1050000000000000000.0000000000
+1050000000000000000.0000000000
+1000000000000000000.0000000000
+0.0000000000
+0.0000000000
+0.0000000000
 
 # use decimal in subquery
 query I
diff --git 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
index e1d4f7886f..0a3f8135e3 100644
--- 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
+++ 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
@@ -128,9 +128,98 @@ public class IgniteSqlFunctions {
         return b instanceof ByteString ? octetLength((ByteString) b) : 
charLength((String) b);
     }
 
-    private static BigDecimal setScale(int precision, int scale, BigDecimal 
decimal) {
-        return precision == 
IgniteTypeSystem.INSTANCE.getDefaultPrecision(SqlTypeName.DECIMAL)
-            ? decimal : decimal.setScale(scale, RoundingMode.HALF_UP);
+    // SQL ROUND function
+
+    /** SQL {@code ROUND} operator applied to byte values. */
+    public static byte sround(byte b0) {
+        return (byte) sround(b0, 0);
+    }
+
+    /** SQL {@code ROUND} operator applied to byte values. */
+    public static byte sround(byte b0, int b1) {
+        return (byte) sround((int) b0, b1);
+    }
+
+    /** SQL {@code ROUND} operator applied to short values. */
+    public static byte sround(short b0) {
+        return (byte) sround(b0, 0);
+    }
+
+    /** SQL {@code ROUND} operator applied to short values. */
+    public static short sround(short b0, int b1) {
+        return (short) sround((int) b0, b1);
+    }
+
+    /** SQL {@code ROUND} operator applied to int values. */
+    public static int sround(int b0) {
+        return sround(b0, 0);
+    }
+
+    /** SQL {@code ROUND} operator applied to int values. */
+    public static int sround(int b0, int b1) {
+        if (b1 == 0) {
+            return b0;
+        } else if (b1 > 0) {
+            return b0;
+        } else {
+            return (int) sround((long) b0, b1);
+        }
+    }
+
+    /** SQL {@code ROUND} operator applied to long values. */
+    public static long sround(long b0) {
+        return sround(b0, 0);
+    }
+
+    /** SQL {@code ROUND} operator applied to long values. */
+    public static long sround(long b0, int b1) {
+        if (b1 == 0) {
+            return b0;
+        } else if (b1 > 0) {
+            return b0;
+        } else {
+            long abs = (long) Math.pow(10, Math.abs(b1));
+            return divide(b0, abs, RoundingMode.HALF_UP) * abs;
+        }
+    }
+
+    /** SQL {@code ROUND} operator applied to double values. */
+    public static double sround(double b0) {
+        return sround(BigDecimal.valueOf(b0)).doubleValue();
+    }
+
+    /** SQL {@code ROUND} operator applied to double values. */
+    public static double sround(double b0, int b1) {
+        return sround(BigDecimal.valueOf(b0), b1).doubleValue();
+    }
+
+    /** SQL {@code ROUND} operator applied to float values. */
+    public static float sround(float b0) {
+        return sround(BigDecimal.valueOf(b0)).floatValue();
+    }
+
+    /** SQL {@code ROUND} operator applied to float values. */
+    public static float sround(float b0, int b1) {
+        return sround(BigDecimal.valueOf(b0), b1).floatValue();
+    }
+
+    /** SQL {@code ROUND} operator applied to BigDecimal values. */
+    public static BigDecimal sround(BigDecimal b0) {
+        return b0.setScale(0, RoundingMode.HALF_UP);
+    }
+
+    /** SQL {@code ROUND} operator applied to BigDecimal values. */
+    public static BigDecimal sround(BigDecimal b0, int b1) {
+        // b0.movePointRight(b1).setScale(0, 
RoundingMode.DOWN).movePointLeft(b1);
+        int originalScale = b0.scale();
+
+        if (b1 >= originalScale) {
+            return b0;
+        }
+
+        BigDecimal roundedValue = b0.setScale(b1, RoundingMode.HALF_UP);
+        // Pad with zeros to match the original scale
+        return roundedValue.setScale(originalScale, RoundingMode.UNNECESSARY);
     }
 
     /** CAST(DOUBLE AS DECIMAL). */
@@ -434,4 +523,31 @@ public class IgniteSqlFunctions {
             return true;
         }
     }
+
+    private static long divide(long p, long q, RoundingMode mode) {
+        // Stripped down version of guava's LongMath::divide.
+
+        long div = p / q; // throws if q == 0
+        long rem = p - q * div; // equals p % q
+
+        int signum = 1 | (int) ((p ^ q) >> (Long.SIZE - 1));
+        boolean increment;
+        switch (mode) {
+            case HALF_DOWN:
+            case HALF_UP:
+                long absRem = Math.abs(rem);
+                long cmpRemToHalfDivisor = absRem - (Math.abs(q) - absRem);
+
+                if (cmpRemToHalfDivisor == 0) { // exactly on the half mark
+                    increment = mode == RoundingMode.HALF_UP;
+                } else {
+                    increment = cmpRemToHalfDivisor > 0; // closer to the UP 
value
+                }
+                break;
+            default:
+                throw new AssertionError();
+        }
+
+        return increment ? div + signum : div;
+    }
 }
diff --git 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java
 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java
index 6a82dd4ee1..146a760a07 100644
--- 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java
+++ 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/RexImpTable.java
@@ -599,7 +599,7 @@ public class RexImpTable {
       defineMethod(DEGREES, "degrees", NullPolicy.STRICT);
       defineMethod(POW, "power", NullPolicy.STRICT);
       defineMethod(RADIANS, "radians", NullPolicy.STRICT);
-      defineMethod(ROUND, "sround", NullPolicy.STRICT);
+//      defineMethod(ROUND, "sround", NullPolicy.STRICT);
       defineMethod(SEC, "sec", NullPolicy.STRICT);
       defineMethod(SECH, "sech", NullPolicy.STRICT);
       defineMethod(SIGN, "sign", NullPolicy.STRICT);
@@ -994,6 +994,7 @@ public class RexImpTable {
       defineMethod(IgniteSqlOperatorTable.LEAST2, LEAST2.method(), 
NullPolicy.NONE);
       defineMethod(IgniteSqlOperatorTable.LENGTH, LENGTH.method(), 
NullPolicy.STRICT);
       defineMethod(SUBSTR, IgniteMethod.SUBSTR.method(), NullPolicy.STRICT);
+      defineMethod(ROUND, IgniteMethod.ROUND.method(), NullPolicy.STRICT);
 
       map.put(TYPEOF, systemFunctionImplementor);
       map.put(NULL_BOUND, systemFunctionImplementor);
diff --git 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java
 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java
index 79409d5f38..0025a5ea69 100644
--- 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java
+++ 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/sql/fun/IgniteSqlOperatorTable.java
@@ -17,19 +17,26 @@
 
 package org.apache.ignite.internal.sql.engine.sql.fun;
 
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.sql.SqlBasicFunction;
 import org.apache.calcite.sql.SqlFunction;
 import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlOperatorBinding;
 import org.apache.calcite.sql.fun.SqlInternalOperators;
 import org.apache.calcite.sql.fun.SqlLibraryOperators;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.fun.SqlSubstringFunction;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlReturnTypeInference;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.type.SqlTypeTransforms;
 import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable;
+import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory;
 import org.apache.ignite.internal.sql.engine.type.UuidType;
+import org.apache.ignite.internal.sql.engine.util.Commons;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 /**
  * Operator table that contains only Ignite-specific functions and operators.
@@ -160,6 +167,31 @@ public class IgniteSqlOperatorTable extends 
ReflectiveSqlOperatorTable {
                 }
             };
 
+    /** The {@code ROUND(numeric [, numeric])} function. */
+    public static final SqlFunction ROUND = SqlBasicFunction.create("ROUND",
+            new SqlReturnTypeInference() {
+                @Override
+                public @Nullable RelDataType 
inferReturnType(SqlOperatorBinding opBinding) {
+                    RelDataType operandType = opBinding.getOperandType(0);
+
+                    // If there is only one argument and it supports precision 
and scale, set scale 0.
+                    if (opBinding.getOperandCount() == 1 && 
operandType.getSqlTypeName().allowsPrecScale(true, true)) {
+                        int precision = operandType.getPrecision();
+                        IgniteTypeFactory typeFactory = Commons.typeFactory();
+
+                        RelDataType returnType = 
typeFactory.createSqlType(operandType.getSqlTypeName(), precision, 0);
+                        // Preserve nullability
+                        boolean nullable = operandType.isNullable();
+
+                        return 
typeFactory.createTypeWithNullability(returnType, nullable);
+                    } else {
+                        return operandType;
+                    }
+                }
+            },
+            OperandTypes.NUMERIC_OPTIONAL_INTEGER,
+            SqlFunctionCategory.NUMERIC);
+
     /** Singleton instance. */
     public static final IgniteSqlOperatorTable INSTANCE = new 
IgniteSqlOperatorTable();
 
@@ -290,7 +322,7 @@ public class IgniteSqlOperatorTable extends 
ReflectiveSqlOperatorTable {
         register(SqlStdOperatorTable.COT); // Cotangent.
         register(SqlStdOperatorTable.DEGREES); // Radians to degrees.
         register(SqlStdOperatorTable.RADIANS); // Degrees to radians.
-        register(SqlStdOperatorTable.ROUND);
+        register(ROUND); // Fixes return type scale.
         register(SqlStdOperatorTable.SIGN);
         register(SqlStdOperatorTable.SIN); // Sine.
         register(SqlLibraryOperators.SINH); // Hyperbolic sine.
diff --git 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java
 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java
index b661146f09..48bf6bd87c 100644
--- 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java
+++ 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/IgniteMethod.java
@@ -17,8 +17,11 @@
 
 package org.apache.ignite.internal.sql.engine.util;
 
+import static org.apache.ignite.internal.lang.IgniteStringFormatter.format;
+
 import java.lang.reflect.Method;
 import java.lang.reflect.Type;
+import java.util.Arrays;
 import java.util.Objects;
 import java.util.UUID;
 import org.apache.calcite.DataContext;
@@ -107,7 +110,10 @@ public enum IgniteMethod {
     /** See {@link IgniteSqlFunctions#consumeFirstArgument(Object, Object)}. 
**/
     CONSUME_FIRST_ARGUMENT(IgniteSqlFunctions.class, "consumeFirstArgument", 
Object.class, Object.class),
 
-    SUBSTR(SqlFunctions.class, "substring", String.class, int.class, 
int.class);
+    SUBSTR(SqlFunctions.class, "substring", String.class, int.class, 
int.class),
+
+    /** ROUND function. See {@link IgniteSqlFunctions#sround(double)}, {@link 
IgniteSqlFunctions#sround(double, int)} and variants. */
+    ROUND(IgniteSqlFunctions.class, "sround", true);
 
     private final Method method;
 
@@ -122,6 +128,25 @@ public enum IgniteMethod {
         method = Types.lookupMethod(clazz, methodName, argumentTypes);
     }
 
+    /**
+     * Constructor that allows to specify overloaded methods as SQL function 
implementations.
+     *
+     * @param clazz Class where to lookup method.
+     * @param methodName Method name.
+     * @param overloadedMethod If {@code true} to looks up overloaded methods, 
otherwise looks up a method w/o parameters.
+     */
+    IgniteMethod(Class<?> clazz, String methodName, boolean overloadedMethod) {
+        if (overloadedMethod) {
+            // Allow calcite to select appropriate method at a call site.
+            this.method = Arrays.stream(clazz.getMethods())
+                    .filter(m -> m.getName().equals(methodName))
+                    .findFirst()
+                    .orElseThrow(() -> new 
IllegalArgumentException(format("Public static method {} is not defined", 
methodName)));
+        } else {
+            this.method = Types.lookupMethod(clazz, methodName);
+        }
+    }
+
     /**
      * Get method.
      */
diff --git 
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java
 
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java
index df38db058b..85435c0883 100644
--- 
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java
+++ 
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctionsTest.java
@@ -242,4 +242,119 @@ public class IgniteSqlFunctionsTest {
             assertThrowsSqlException(Sql.RUNTIME_ERR, "Numeric field 
overflow", convert::get);
         }
     }
+
+    // ROUND
+
+    /** Tests for ROUND(x) function. */
+    @Test
+    public void testRound() {
+        assertEquals(new BigDecimal("1"), IgniteSqlFunctions.sround(new 
BigDecimal("1.000")));
+        assertEquals(new BigDecimal("1"), IgniteSqlFunctions.sround(new 
BigDecimal("1.123")));
+        assertEquals(1, IgniteSqlFunctions.sround(1), "int");
+        assertEquals(1L, IgniteSqlFunctions.sround(1L), "long");
+        assertEquals(1.0d, IgniteSqlFunctions.sround(1.123d), "double");
+    }
+
+    /** Tests for ROUND(x, s) function, where x is a BigDecimal value. */
+    @ParameterizedTest
+    @CsvSource({
+            "1.123, -1, 0.000",
+            "1.123, 0, 1.000",
+            "1.123, 1, 1.100",
+            "1.123, 2, 1.120",
+            "1.127, 2, 1.130",
+            "1.123, 3, 1.123",
+            "1.123, 4, 1.123",
+            "10.123, 0, 10.000",
+            "10.123, -1, 10.000",
+            "10.123, -2, 0.000",
+            "10.123, 3, 10.123",
+            "10.123, 4, 10.123",
+    })
+    public void testRound2Decimal(String input, int scale, String result) {
+        assertEquals(new BigDecimal(result), IgniteSqlFunctions.sround(new 
BigDecimal(result), scale));
+    }
+
+    /** Tests for ROUND(x, s) function, where x is a double value. */
+    @ParameterizedTest
+    @CsvSource({
+            "1.123, 3, 1.123",
+            "1.123, 2, 1.12",
+            "1.127, 2, 1.13",
+            "1.245, 1, 1.2",
+            "1.123, 0, 1.0",
+            "1.123, -1, 0.0",
+            "10.123, 0, 10.000",
+            "10.123, -1, 10.000",
+            "10.123, -2, 0.000",
+            "10.123, 3, 10.123",
+            "10.123, 4, 10.123",
+    })
+    public void testRound2Double(double input, int scale, double result) {
+        assertEquals(result, IgniteSqlFunctions.sround(input, scale));
+    }
+
+    /** Tests for ROUND(x, s) function, where x is an byte. */
+    @ParameterizedTest
+    @CsvSource({
+            "42, -2, 0",
+            "42, -1, 40",
+            "47, -1, 50",
+            "42, 0, 42",
+            "42, 1, 42",
+            "42, 2, 42",
+            "-42, -1, -40",
+            "-47, -1, -50",
+    })
+    public void testRound2ByteType(byte input, int scale, byte result) {
+        assertEquals(result, IgniteSqlFunctions.sround(input, scale));
+    }
+
+    /** Tests for ROUND(x, s) function, where x is an short. */
+    @ParameterizedTest
+    @CsvSource({
+            "42, -2, 0",
+            "42, -1, 40",
+            "47, -1, 50",
+            "42, 0, 42",
+            "42, 1, 42",
+            "42, 2, 42",
+            "-42, -1, -40",
+            "-47, -1, -50",
+    })
+    public void testRound2ShortType(short input, int scale, short result) {
+        assertEquals(result, IgniteSqlFunctions.sround(input, scale));
+    }
+
+    /** Tests for ROUND(x, s) function, where x is an int. */
+    @ParameterizedTest
+    @CsvSource({
+            "42, -2, 0",
+            "42, -1, 40",
+            "47, -1, 50",
+            "42, 0, 42",
+            "42, 1, 42",
+            "42, 2, 42",
+            "-42, -1, -40",
+            "-47, -1, -50",
+    })
+    public void testRound2IntType(int input, int scale, int result) {
+        assertEquals(result, IgniteSqlFunctions.sround(input, scale));
+    }
+
+    /** Tests for ROUND(x, s) function, where x is a long. */
+    @ParameterizedTest
+    @CsvSource({
+            "42, -2, 0",
+            "42, -1, 40",
+            "47, -1, 50",
+            "42, 0, 42",
+            "42, 1, 42",
+            "42, 2, 42",
+            "-42, -1, -40",
+            "-47, -1, -50",
+    })
+    public void testRound2LongType(long input, int scale, long result) {
+        assertEquals(result, IgniteSqlFunctions.sround(input, scale));
+    }
 }

Reply via email to