This is an automated email from the ASF dual-hosted git repository.

amashenkov pushed a commit to branch ignite-22988
in repository https://gitbox.apache.org/repos/asf/ignite-3.git

commit b1b68c83da372dacc1a817b7ff9b9daedf853bd8
Author: amashenkov <[email protected]>
AuthorDate: Tue Aug 27 17:51:38 2024 +0300

    WIP
---
 .../internal/sql/engine/ItAggregatesTest.java      | 23 ++++++-----------
 .../sql/engine/exec/exp/IgniteSqlFunctions.java    |  2 +-
 .../sql/engine/exec/exp/agg/Accumulators.java      | 30 ++++++++++++++++------
 .../sql/engine/rel/agg/MapReduceAggregates.java    | 11 ++++++--
 4 files changed, 40 insertions(+), 26 deletions(-)

diff --git 
a/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java
 
b/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java
index 82c13c341e..b2efca852e 100644
--- 
a/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java
+++ 
b/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java
@@ -24,6 +24,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
 
 import java.math.BigDecimal;
 import java.math.MathContext;
+import java.math.RoundingMode;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
@@ -38,7 +39,6 @@ import 
org.apache.ignite.internal.testframework.WithSystemProperty;
 import org.apache.ignite.lang.IgniteException;
 import org.junit.jupiter.api.Assumptions;
 import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.Arguments;
@@ -553,9 +553,6 @@ public class ItAggregatesTest extends 
BaseSqlIntegrationTest {
     @ParameterizedTest
     @MethodSource("provideRules")
     public void testAvg(String[] rules) {
-        Assumptions.assumeFalse(Arrays.stream(rules).filter(r -> 
r.startsWith("MapReduce")).count() == 1,
-                "need to be fixed after: 
https://issues.apache.org/jira/browse/IGNITE-22988";);
-
         sql("DELETE FROM numbers");
         sql("INSERT INTO numbers VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (2, 2, 
2, 2, 2, 2, 2, 2, 2, 2)");
 
@@ -564,7 +561,7 @@ public class ItAggregatesTest extends 
BaseSqlIntegrationTest {
                 + "AVG(float_col), AVG(double_col), AVG(dec2_col), 
AVG(dec4_2_col) "
                 + "FROM numbers")
                 .disableRules(rules)
-                .returns((byte) 1, (short) 1, 1, 1L, 1.5f, 1.5d, new 
BigDecimal("1.5"), new BigDecimal("1.50"))
+                .returns((byte) 2, (short) 2, 2, 2L, 1.5f, 1.5d, new 
BigDecimal("2"), new BigDecimal("1.50"))
                 .check();
 
         sql("DELETE FROM numbers");
@@ -580,7 +577,7 @@ public class ItAggregatesTest extends 
BaseSqlIntegrationTest {
 
         assertQuery("SELECT AVG(dec4_2_col) FROM numbers")
                 .disableRules(rules)
-                .returns(new BigDecimal("1.665"))
+                .returns(new BigDecimal("1.66"))
                 .check();
 
         sql("DELETE FROM numbers");
@@ -601,7 +598,6 @@ public class ItAggregatesTest extends 
BaseSqlIntegrationTest {
     }
 
     @Test
-    @Disabled("https://issues.apache.org/jira/browse/IGNITE-22988";)
     public void testAvgRandom() {
         long seed = System.nanoTime();
         Random random = new Random(seed);
@@ -617,12 +613,12 @@ public class ItAggregatesTest extends 
BaseSqlIntegrationTest {
             numbers.add(num);
 
             String query = "INSERT INTO numbers (id, int_col, dec10_2_col) 
VALUES(?, ?, ?)";
-            sql(query, i, num.intValue(), num);
+            sql(query, i, num.setScale(0, RoundingMode.HALF_UP).intValue(), 
num);
         }
 
         BigDecimal avg = numbers.stream()
                 .reduce(new BigDecimal("0.00"), BigDecimal::add)
-                .divide(BigDecimal.valueOf(numbers.size()), 
MathContext.DECIMAL64);
+                .divide(BigDecimal.valueOf(numbers.size()), 2, 
RoundingMode.HALF_UP);
 
         for (String[] rules : makePermutations(DISABLED_RULES)) {
             assertQuery("SELECT AVG(int_col), AVG(dec10_2_col) FROM numbers")
@@ -640,7 +636,7 @@ public class ItAggregatesTest extends 
BaseSqlIntegrationTest {
 
         assertQuery("SELECT AVG(int_col), AVG(dec4_2_col) FROM 
not_null_numbers")
                 .disableRules(rules)
-                .returns(1, new BigDecimal("1.50"))
+                .returns(2, new BigDecimal("1.50"))
                 .check();
 
         // Return type of an AVG aggregate can never be null.
@@ -689,9 +685,6 @@ public class ItAggregatesTest extends 
BaseSqlIntegrationTest {
     @ParameterizedTest
     @MethodSource("provideRules")
     public void testAvgFromLiterals(String[] rules) {
-        Assumptions.assumeFalse(Arrays.stream(rules).filter(r -> 
r.startsWith("MapReduce")).count() == 1,
-                "need to be fixed after: 
https://issues.apache.org/jira/browse/IGNITE-22988";);
-
         assertQuery("SELECT "
                 + "AVG(tinyint_col), AVG(smallint_col), AVG(int_col), 
AVG(bigint_col), "
                 + "AVG(float_col), AVG(double_col), AVG(dec2_col), 
AVG(dec4_2_col) "
@@ -701,7 +694,7 @@ public class ItAggregatesTest extends 
BaseSqlIntegrationTest {
                 + ") "
                 + "t(tinyint_col, smallint_col, int_col, bigint_col, 
float_col, double_col, dec2_col, dec4_2_col)")
                 .disableRules(rules)
-                .returns((byte) 1, (short) 1, 1, 1L, 1.5f, 1.5d, new 
BigDecimal("1.5"), new BigDecimal("1.50"))
+                .returns((byte) 2, (short) 2, 2, 2L, 1.5f, 1.5d, new 
BigDecimal("2"), new BigDecimal("1.50"))
                 .check();
 
         assertQuery("SELECT "
@@ -717,7 +710,7 @@ public class ItAggregatesTest extends 
BaseSqlIntegrationTest {
                 + "  UNION\n"
                 + "  SELECT 2::DECIMAL(2) as dec2_col, 3.00::DECIMAL(4,2) as 
dec4_2_col\n"
                 + ") as t")
-                .returns(new BigDecimal("1.5"), new BigDecimal("2.50"))
+                .returns(new BigDecimal("2"), new BigDecimal("2.50"))
                 .check();
     }
 
diff --git 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
index 1b32ca19e0..b7da271bcd 100644
--- 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
+++ 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java
@@ -361,7 +361,7 @@ public class IgniteSqlFunctions {
      * (see {@link IgniteSqlOperatorTable#DECIMAL_DIVIDE}, their values are 
ignored at runtime.
      */
     public static BigDecimal decimalDivide(BigDecimal sum, BigDecimal cnt, int 
p, int s) {
-        return sum.divide(cnt, MathContext.DECIMAL64);
+        return sum.divide(cnt, s, RoundingMode.HALF_EVEN); // TODO: HALF_UP is 
expected
     }
 
     private static BigDecimal processValueWithIntegralPart(Number value, int 
precision, int scale) {
diff --git 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java
 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java
index 7e0a78177d..e3cae00ddf 100644
--- 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java
+++ 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java
@@ -27,16 +27,17 @@ import static 
org.apache.calcite.sql.type.SqlTypeName.VARCHAR;
 import static org.apache.ignite.internal.util.ArrayUtils.nullOrEmpty;
 
 import java.math.BigDecimal;
-import java.math.MathContext;
 import java.util.Comparator;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+import java.util.function.IntFunction;
 import java.util.function.Supplier;
 import org.apache.calcite.avatica.util.ByteString;
 import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.ignite.internal.sql.engine.exec.exp.IgniteSqlFunctions;
 import org.apache.ignite.internal.sql.engine.type.IgniteCustomType;
 import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory;
 import org.apache.ignite.internal.util.ArrayUtils;
@@ -104,7 +105,7 @@ public class Accumulators {
         switch (call.type.getSqlTypeName()) {
             case BIGINT:
             case DECIMAL:
-                return DecimalAvg.FACTORY;
+                return () -> DecimalAvg.FACTORY.apply(call.type.getScale());
             case DOUBLE:
             case REAL:
             case FLOAT:
@@ -114,7 +115,8 @@ public class Accumulators {
                 if (call.type.getSqlTypeName() == ANY) {
                     throw unsupportedAggregateFunction(call);
                 }
-                return DoubleAvg.FACTORY;
+                RelDataType dataType = typeFactory.decimalOf(call.type);
+                return () -> DoubleAvg.FACTORY.apply(dataType.getScale());
         }
     }
 
@@ -305,12 +307,18 @@ public class Accumulators {
      * TODO Documentation https://issues.apache.org/jira/browse/IGNITE-15859
      */
     public static class DecimalAvg implements Accumulator {
-        public static final Supplier<Accumulator> FACTORY = DecimalAvg::new;
+        public static final IntFunction<Accumulator> FACTORY = DecimalAvg::new;
+
+        private final int scale;
 
         private BigDecimal sum = BigDecimal.ZERO;
 
         private BigDecimal cnt = BigDecimal.ZERO;
 
+        public DecimalAvg(int scale) {
+            this.scale = scale;
+        }
+
         /** {@inheritDoc} */
         @Override
         public void add(Object... args) {
@@ -327,7 +335,7 @@ public class Accumulators {
         /** {@inheritDoc} */
         @Override
         public Object end() {
-            return cnt.compareTo(BigDecimal.ZERO) == 0 ? null : 
sum.divide(cnt, MathContext.DECIMAL64);
+            return cnt.compareTo(BigDecimal.ZERO) == 0 ? null : 
IgniteSqlFunctions.decimalDivide(sum, cnt, -1, scale);
         }
 
         /** {@inheritDoc} */
@@ -339,7 +347,7 @@ public class Accumulators {
         /** {@inheritDoc} */
         @Override
         public RelDataType returnType(IgniteTypeFactory typeFactory) {
-            return 
typeFactory.createTypeWithNullability(typeFactory.createSqlType(DECIMAL), true);
+            return 
typeFactory.createTypeWithNullability(typeFactory.createSqlType(DECIMAL, 
RelDataType.PRECISION_NOT_SPECIFIED, scale), true);
         }
     }
 
@@ -348,12 +356,18 @@ public class Accumulators {
      * TODO Documentation https://issues.apache.org/jira/browse/IGNITE-15859
      */
     public static class DoubleAvg implements Accumulator {
-        public static final Supplier<Accumulator> FACTORY = DoubleAvg::new;
+        public static final IntFunction<Accumulator> FACTORY = DoubleAvg::new;
 
         private double sum;
 
         private long cnt;
 
+        private int scale;
+
+        public DoubleAvg(int scale) {
+            this.scale = scale;
+        }
+
         /** {@inheritDoc} */
         @Override
         public void add(Object... args) {
@@ -370,7 +384,7 @@ public class Accumulators {
         /** {@inheritDoc} */
         @Override
         public Object end() {
-            return cnt > 0 ? sum / cnt : null;
+            return cnt > 0 ? IgniteSqlFunctions.sround(sum / cnt, scale) : 
null;
         }
 
         /** {@inheritDoc} */
diff --git 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java
 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java
index ec8c2819e4..63258c9e2f 100644
--- 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java
+++ 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java
@@ -645,8 +645,15 @@ public class MapReduceAggregates {
 
                 sumDivCnt = 
rexBuilder.makeCall(IgniteSqlOperatorTable.DECIMAL_DIVIDE, numeratorRef, 
denominatorRef, p, s);
             } else {
-                RexNode divideRef = 
rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
-                sumDivCnt = rexBuilder.makeCast(call.getType(), divideRef, 
true, false);
+                RelDataType resultType = typeFactory.decimalOf(call.type);
+                int precision = resultType.getPrecision(); // not used.
+                int scale = resultType.getScale();
+
+                RexLiteral p = 
rexBuilder.makeExactLiteral(BigDecimal.valueOf(precision), 
tf.createSqlType(SqlTypeName.INTEGER));
+                RexLiteral s = 
rexBuilder.makeExactLiteral(BigDecimal.valueOf(scale), 
tf.createSqlType(SqlTypeName.INTEGER));
+
+                sumDivCnt = 
rexBuilder.makeCall(IgniteSqlOperatorTable.DECIMAL_DIVIDE, numeratorRef, 
denominatorRef, p, s);
+                sumDivCnt = rexBuilder.makeCast(call.getType(), sumDivCnt, 
true, false);
             }
 
             if (canBeNull) {

Reply via email to