This is an automated email from the ASF dual-hosted git repository. amashenkov pushed a commit to branch ignite-22988 in repository https://gitbox.apache.org/repos/asf/ignite-3.git
commit b1b68c83da372dacc1a817b7ff9b9daedf853bd8 Author: amashenkov <[email protected]> AuthorDate: Tue Aug 27 17:51:38 2024 +0300 WIP --- .../internal/sql/engine/ItAggregatesTest.java | 23 ++++++----------- .../sql/engine/exec/exp/IgniteSqlFunctions.java | 2 +- .../sql/engine/exec/exp/agg/Accumulators.java | 30 ++++++++++++++++------ .../sql/engine/rel/agg/MapReduceAggregates.java | 11 ++++++-- 4 files changed, 40 insertions(+), 26 deletions(-) diff --git a/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java b/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java index 82c13c341e..b2efca852e 100644 --- a/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java +++ b/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItAggregatesTest.java @@ -24,6 +24,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.math.BigDecimal; import java.math.MathContext; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -38,7 +39,6 @@ import org.apache.ignite.internal.testframework.WithSystemProperty; import org.apache.ignite.lang.IgniteException; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -553,9 +553,6 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { @ParameterizedTest @MethodSource("provideRules") public void testAvg(String[] rules) { - Assumptions.assumeFalse(Arrays.stream(rules).filter(r -> r.startsWith("MapReduce")).count() == 1, - "need to be fixed after: https://issues.apache.org/jira/browse/IGNITE-22988"); - sql("DELETE FROM numbers"); sql("INSERT INTO numbers VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1), (2, 2, 2, 2, 2, 2, 2, 2, 2, 2)"); @@ -564,7 +561,7 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { + "AVG(float_col), AVG(double_col), AVG(dec2_col), AVG(dec4_2_col) " + "FROM numbers") .disableRules(rules) - .returns((byte) 1, (short) 1, 1, 1L, 1.5f, 1.5d, new BigDecimal("1.5"), new BigDecimal("1.50")) + .returns((byte) 2, (short) 2, 2, 2L, 1.5f, 1.5d, new BigDecimal("2"), new BigDecimal("1.50")) .check(); sql("DELETE FROM numbers"); @@ -580,7 +577,7 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { assertQuery("SELECT AVG(dec4_2_col) FROM numbers") .disableRules(rules) - .returns(new BigDecimal("1.665")) + .returns(new BigDecimal("1.66")) .check(); sql("DELETE FROM numbers"); @@ -601,7 +598,6 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { } @Test - @Disabled("https://issues.apache.org/jira/browse/IGNITE-22988") public void testAvgRandom() { long seed = System.nanoTime(); Random random = new Random(seed); @@ -617,12 +613,12 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { numbers.add(num); String query = "INSERT INTO numbers (id, int_col, dec10_2_col) VALUES(?, ?, ?)"; - sql(query, i, num.intValue(), num); + sql(query, i, num.setScale(0, RoundingMode.HALF_UP).intValue(), num); } BigDecimal avg = numbers.stream() .reduce(new BigDecimal("0.00"), BigDecimal::add) - .divide(BigDecimal.valueOf(numbers.size()), MathContext.DECIMAL64); + .divide(BigDecimal.valueOf(numbers.size()), 2, RoundingMode.HALF_UP); for (String[] rules : makePermutations(DISABLED_RULES)) { assertQuery("SELECT AVG(int_col), AVG(dec10_2_col) FROM numbers") @@ -640,7 +636,7 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { assertQuery("SELECT AVG(int_col), AVG(dec4_2_col) FROM not_null_numbers") .disableRules(rules) - .returns(1, new BigDecimal("1.50")) + .returns(2, new BigDecimal("1.50")) .check(); // Return type of an AVG aggregate can never be null. @@ -689,9 +685,6 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { @ParameterizedTest @MethodSource("provideRules") public void testAvgFromLiterals(String[] rules) { - Assumptions.assumeFalse(Arrays.stream(rules).filter(r -> r.startsWith("MapReduce")).count() == 1, - "need to be fixed after: https://issues.apache.org/jira/browse/IGNITE-22988"); - assertQuery("SELECT " + "AVG(tinyint_col), AVG(smallint_col), AVG(int_col), AVG(bigint_col), " + "AVG(float_col), AVG(double_col), AVG(dec2_col), AVG(dec4_2_col) " @@ -701,7 +694,7 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { + ") " + "t(tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, dec2_col, dec4_2_col)") .disableRules(rules) - .returns((byte) 1, (short) 1, 1, 1L, 1.5f, 1.5d, new BigDecimal("1.5"), new BigDecimal("1.50")) + .returns((byte) 2, (short) 2, 2, 2L, 1.5f, 1.5d, new BigDecimal("2"), new BigDecimal("1.50")) .check(); assertQuery("SELECT " @@ -717,7 +710,7 @@ public class ItAggregatesTest extends BaseSqlIntegrationTest { + " UNION\n" + " SELECT 2::DECIMAL(2) as dec2_col, 3.00::DECIMAL(4,2) as dec4_2_col\n" + ") as t") - .returns(new BigDecimal("1.5"), new BigDecimal("2.50")) + .returns(new BigDecimal("2"), new BigDecimal("2.50")) .check(); } diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java index 1b32ca19e0..b7da271bcd 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/IgniteSqlFunctions.java @@ -361,7 +361,7 @@ public class IgniteSqlFunctions { * (see {@link IgniteSqlOperatorTable#DECIMAL_DIVIDE}, their values are ignored at runtime. */ public static BigDecimal decimalDivide(BigDecimal sum, BigDecimal cnt, int p, int s) { - return sum.divide(cnt, MathContext.DECIMAL64); + return sum.divide(cnt, s, RoundingMode.HALF_EVEN); // TODO: HALF_UP is expected } private static BigDecimal processValueWithIntegralPart(Number value, int precision, int scale) { diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java index 7e0a78177d..e3cae00ddf 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/exp/agg/Accumulators.java @@ -27,16 +27,17 @@ import static org.apache.calcite.sql.type.SqlTypeName.VARCHAR; import static org.apache.ignite.internal.util.ArrayUtils.nullOrEmpty; import java.math.BigDecimal; -import java.math.MathContext; import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.function.IntFunction; import java.util.function.Supplier; import org.apache.calcite.avatica.util.ByteString; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.ignite.internal.sql.engine.exec.exp.IgniteSqlFunctions; import org.apache.ignite.internal.sql.engine.type.IgniteCustomType; import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory; import org.apache.ignite.internal.util.ArrayUtils; @@ -104,7 +105,7 @@ public class Accumulators { switch (call.type.getSqlTypeName()) { case BIGINT: case DECIMAL: - return DecimalAvg.FACTORY; + return () -> DecimalAvg.FACTORY.apply(call.type.getScale()); case DOUBLE: case REAL: case FLOAT: @@ -114,7 +115,8 @@ public class Accumulators { if (call.type.getSqlTypeName() == ANY) { throw unsupportedAggregateFunction(call); } - return DoubleAvg.FACTORY; + RelDataType dataType = typeFactory.decimalOf(call.type); + return () -> DoubleAvg.FACTORY.apply(dataType.getScale()); } } @@ -305,12 +307,18 @@ public class Accumulators { * TODO Documentation https://issues.apache.org/jira/browse/IGNITE-15859 */ public static class DecimalAvg implements Accumulator { - public static final Supplier<Accumulator> FACTORY = DecimalAvg::new; + public static final IntFunction<Accumulator> FACTORY = DecimalAvg::new; + + private final int scale; private BigDecimal sum = BigDecimal.ZERO; private BigDecimal cnt = BigDecimal.ZERO; + public DecimalAvg(int scale) { + this.scale = scale; + } + /** {@inheritDoc} */ @Override public void add(Object... args) { @@ -327,7 +335,7 @@ public class Accumulators { /** {@inheritDoc} */ @Override public Object end() { - return cnt.compareTo(BigDecimal.ZERO) == 0 ? null : sum.divide(cnt, MathContext.DECIMAL64); + return cnt.compareTo(BigDecimal.ZERO) == 0 ? null : IgniteSqlFunctions.decimalDivide(sum, cnt, -1, scale); } /** {@inheritDoc} */ @@ -339,7 +347,7 @@ public class Accumulators { /** {@inheritDoc} */ @Override public RelDataType returnType(IgniteTypeFactory typeFactory) { - return typeFactory.createTypeWithNullability(typeFactory.createSqlType(DECIMAL), true); + return typeFactory.createTypeWithNullability(typeFactory.createSqlType(DECIMAL, RelDataType.PRECISION_NOT_SPECIFIED, scale), true); } } @@ -348,12 +356,18 @@ public class Accumulators { * TODO Documentation https://issues.apache.org/jira/browse/IGNITE-15859 */ public static class DoubleAvg implements Accumulator { - public static final Supplier<Accumulator> FACTORY = DoubleAvg::new; + public static final IntFunction<Accumulator> FACTORY = DoubleAvg::new; private double sum; private long cnt; + private int scale; + + public DoubleAvg(int scale) { + this.scale = scale; + } + /** {@inheritDoc} */ @Override public void add(Object... args) { @@ -370,7 +384,7 @@ public class Accumulators { /** {@inheritDoc} */ @Override public Object end() { - return cnt > 0 ? sum / cnt : null; + return cnt > 0 ? IgniteSqlFunctions.sround(sum / cnt, scale) : null; } /** {@inheritDoc} */ diff --git a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java index ec8c2819e4..63258c9e2f 100644 --- a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java +++ b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rel/agg/MapReduceAggregates.java @@ -645,8 +645,15 @@ public class MapReduceAggregates { sumDivCnt = rexBuilder.makeCall(IgniteSqlOperatorTable.DECIMAL_DIVIDE, numeratorRef, denominatorRef, p, s); } else { - RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef); - sumDivCnt = rexBuilder.makeCast(call.getType(), divideRef, true, false); + RelDataType resultType = typeFactory.decimalOf(call.type); + int precision = resultType.getPrecision(); // not used. + int scale = resultType.getScale(); + + RexLiteral p = rexBuilder.makeExactLiteral(BigDecimal.valueOf(precision), tf.createSqlType(SqlTypeName.INTEGER)); + RexLiteral s = rexBuilder.makeExactLiteral(BigDecimal.valueOf(scale), tf.createSqlType(SqlTypeName.INTEGER)); + + sumDivCnt = rexBuilder.makeCall(IgniteSqlOperatorTable.DECIMAL_DIVIDE, numeratorRef, denominatorRef, p, s); + sumDivCnt = rexBuilder.makeCast(call.getType(), sumDivCnt, true, false); } if (canBeNull) {
