This is an automated email from the ASF dual-hosted git repository. mblow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/asterixdb.git
commit e2a63b25ecf2f07381b9dc9aba811ae3fd59fe9d Author: Dmitry Lychagin <[email protected]> AuthorDate: Fri Mar 12 12:13:52 2021 -0800 [ASTERIXDB-2843][COMP] Fix type computer for scalar aggregates - user model changes: no - storage format changes: no - interface changes: no Details: - Align type computation for scalar aggregates with regular aggregates - Add testcase to verify it for all aggregate functions Change-Id: Iddd8075b490c83cb6f493d02b7bea1eedb4a4129 Reviewed-on: https://asterix-gerrit.ics.uci.edu/c/asterixdb/+/10483 Integration-Tests: Jenkins <[email protected]> Tested-by: Jenkins <[email protected]> Reviewed-by: Dmitry Lychagin <[email protected]> --- .../scalar_sum_type/scalar_sum_type.1.query.sqlpp | 28 +++ .../scalar_sum_type/scalar_sum_type.1.query.sqlpp | 28 +++ .../sum/scalar_sum_type/scalar_sum_type.1.adm | 1 + .../sum/scalar_sum_type/scalar_sum_type.1.adm | 1 + .../test/resources/runtimets/testsuite_sqlpp.xml | 10 + .../asterix/om/functions/BuiltinFunctions.java | 29 +-- ...puter.java => AggregateResultTypeComputer.java} | 28 +-- .../typecomputer/impl/MinMaxAggTypeComputer.java | 3 +- .../impl/NumericSumAggTypeComputer.java | 33 +-- .../impl/ScalarVersionOfAggregateResultType.java | 50 +++-- .../functions/ScalarAggregateTypeComputerTest.java | 239 +++++++++++++++++++++ 11 files changed, 372 insertions(+), 78 deletions(-) diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp b/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp new file mode 100644 index 0000000..932661c9 --- /dev/null +++ b/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Test that scalar sum() produces correct output type + */ + +select array_sum(array_reverse(lst)) +let lst = ( + from range(1, 3) r + select value int32(to_string(r)) +) \ No newline at end of file diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp b/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp new file mode 100644 index 0000000..361a59b --- /dev/null +++ b/asterixdb/asterix-app/src/test/resources/runtimets/queries_sqlpp/aggregate/sum/scalar_sum_type/scalar_sum_type.1.query.sqlpp @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Test that scalar sum() produces correct output type + */ + +select strict_sum(array_reverse(lst)) +let lst = ( + from range(1, 3) r + select value int32(to_string(r)) +) \ No newline at end of file diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.adm b/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.adm new file mode 100644 index 0000000..d9b1127 --- /dev/null +++ b/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate-sql/sum/scalar_sum_type/scalar_sum_type.1.adm @@ -0,0 +1 @@ +{ "$1": 6 } \ No newline at end of file diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate/sum/scalar_sum_type/scalar_sum_type.1.adm b/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate/sum/scalar_sum_type/scalar_sum_type.1.adm new file mode 100644 index 0000000..d9b1127 --- /dev/null +++ b/asterixdb/asterix-app/src/test/resources/runtimets/results/aggregate/sum/scalar_sum_type/scalar_sum_type.1.adm @@ -0,0 +1 @@ +{ "$1": 6 } \ No newline at end of file diff --git a/asterixdb/asterix-app/src/test/resources/runtimets/testsuite_sqlpp.xml b/asterixdb/asterix-app/src/test/resources/runtimets/testsuite_sqlpp.xml index f8164e8..0c3b15b 100644 --- a/asterixdb/asterix-app/src/test/resources/runtimets/testsuite_sqlpp.xml +++ b/asterixdb/asterix-app/src/test/resources/runtimets/testsuite_sqlpp.xml @@ -834,6 +834,11 @@ </compilation-unit> </test-case> <test-case FilePath="aggregate"> + <compilation-unit name="sum/scalar_sum_type"> + <output-dir compare="Text">sum/scalar_sum_type</output-dir> + </compilation-unit> + </test-case> + <test-case FilePath="aggregate"> <compilation-unit name="scalar_var"> <output-dir compare="Text">scalar_var</output-dir> </compilation-unit> @@ -2097,6 +2102,11 @@ </compilation-unit> </test-case> <test-case FilePath="aggregate-sql"> + <compilation-unit name="sum/scalar_sum_type"> + <output-dir compare="Text">sum/scalar_sum_type</output-dir> + </compilation-unit> + </test-case> + <test-case FilePath="aggregate-sql"> <compilation-unit name="scalar_var"> <output-dir compare="Text">scalar_var</output-dir> </compilation-unit> diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/functions/BuiltinFunctions.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/functions/BuiltinFunctions.java index 268c1f6..63573a1 100644 --- a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/functions/BuiltinFunctions.java +++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/functions/BuiltinFunctions.java @@ -1832,6 +1832,11 @@ public class BuiltinFunctions { addFunction(NEGINF_IF, DoubleIfTypeComputer.INSTANCE, true); // Aggregate Functions + ScalarVersionOfAggregateResultType scalarNumericSumTypeComputer = + new ScalarVersionOfAggregateResultType(NumericSumAggTypeComputer.INSTANCE); + ScalarVersionOfAggregateResultType scalarMinMaxTypeComputer = + new ScalarVersionOfAggregateResultType(MinMaxAggTypeComputer.INSTANCE); + addPrivateFunction(LISTIFY, OrderedListConstructorTypeComputer.INSTANCE, true); addFunction(SCALAR_ARRAYAGG, ScalarArrayAggTypeComputer.INSTANCE, true); addFunction(MAX, MinMaxAggTypeComputer.INSTANCE, true); @@ -1877,7 +1882,7 @@ public class BuiltinFunctions { // SUM addFunction(SUM, NumericSumAggTypeComputer.INSTANCE, true); - addFunction(SCALAR_SUM, ScalarVersionOfAggregateResultType.INSTANCE, true); + addFunction(SCALAR_SUM, scalarNumericSumTypeComputer, true); addPrivateFunction(LOCAL_SUM, NumericSumAggTypeComputer.INSTANCE, true); addPrivateFunction(INTERMEDIATE_SUM, NumericSumAggTypeComputer.INSTANCE, true); addPrivateFunction(GLOBAL_SUM, NumericSumAggTypeComputer.INSTANCE, true); @@ -1893,8 +1898,8 @@ public class BuiltinFunctions { addPrivateFunction(SERIAL_INTERMEDIATE_SQL_AVG, LocalAvgTypeComputer.INSTANCE, true); addFunction(SCALAR_AVG, NullableDoubleTypeComputer.INSTANCE, true); addFunction(SCALAR_COUNT, AInt64TypeComputer.INSTANCE, true); - addFunction(SCALAR_MAX, ScalarVersionOfAggregateResultType.INSTANCE, true); - addFunction(SCALAR_MIN, ScalarVersionOfAggregateResultType.INSTANCE, true); + addFunction(SCALAR_MAX, scalarMinMaxTypeComputer, true); + addFunction(SCALAR_MIN, scalarMinMaxTypeComputer, true); addPrivateFunction(INTERMEDIATE_AVG, LocalAvgTypeComputer.INSTANCE, true); addFunction(SCALAR_STDDEV_SAMP, NullableDoubleTypeComputer.INSTANCE, true); addPrivateFunction(INTERMEDIATE_STDDEV_SAMP, LocalSingleVarStatisticsTypeComputer.INSTANCE, true); @@ -1935,7 +1940,7 @@ public class BuiltinFunctions { // SQL SUM addFunction(SQL_SUM, NumericSumAggTypeComputer.INSTANCE, true); - addFunction(SCALAR_SQL_SUM, ScalarVersionOfAggregateResultType.INSTANCE, true); + addFunction(SCALAR_SQL_SUM, scalarNumericSumTypeComputer, true); addPrivateFunction(LOCAL_SQL_SUM, NumericSumAggTypeComputer.INSTANCE, true); addPrivateFunction(INTERMEDIATE_SQL_SUM, NumericSumAggTypeComputer.INSTANCE, true); addPrivateFunction(GLOBAL_SQL_SUM, NumericSumAggTypeComputer.INSTANCE, true); @@ -1959,8 +1964,8 @@ public class BuiltinFunctions { addPrivateFunction(GLOBAL_SQL_MIN, MinMaxAggTypeComputer.INSTANCE, true); addFunction(SCALAR_SQL_AVG, NullableDoubleTypeComputer.INSTANCE, true); addFunction(SCALAR_SQL_COUNT, AInt64TypeComputer.INSTANCE, true); - addFunction(SCALAR_SQL_MAX, ScalarVersionOfAggregateResultType.INSTANCE, true); - addFunction(SCALAR_SQL_MIN, ScalarVersionOfAggregateResultType.INSTANCE, true); + addFunction(SCALAR_SQL_MAX, scalarMinMaxTypeComputer, true); + addFunction(SCALAR_SQL_MIN, scalarMinMaxTypeComputer, true); addPrivateFunction(INTERMEDIATE_SQL_AVG, LocalAvgTypeComputer.INSTANCE, true); addFunction(SQL_STDDEV_SAMP, NullableDoubleTypeComputer.INSTANCE, true); addPrivateFunction(GLOBAL_SQL_STDDEV_SAMP, NullableDoubleTypeComputer.INSTANCE, true); @@ -2035,9 +2040,9 @@ public class BuiltinFunctions { addFunction(SCALAR_SQL_COUNT_DISTINCT, AInt64TypeComputer.INSTANCE, true); addFunction(SUM_DISTINCT, NumericSumAggTypeComputer.INSTANCE, true); - addFunction(SCALAR_SUM_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true); + addFunction(SCALAR_SUM_DISTINCT, scalarNumericSumTypeComputer, true); addFunction(SQL_SUM_DISTINCT, NumericSumAggTypeComputer.INSTANCE, true); - addFunction(SCALAR_SQL_SUM_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true); + addFunction(SCALAR_SQL_SUM_DISTINCT, scalarNumericSumTypeComputer, true); addFunction(AVG_DISTINCT, NullableDoubleTypeComputer.INSTANCE, true); addFunction(SCALAR_AVG_DISTINCT, NullableDoubleTypeComputer.INSTANCE, true); @@ -2045,14 +2050,14 @@ public class BuiltinFunctions { addFunction(SCALAR_SQL_AVG_DISTINCT, NullableDoubleTypeComputer.INSTANCE, true); addFunction(MAX_DISTINCT, MinMaxAggTypeComputer.INSTANCE, true); - addFunction(SCALAR_MAX_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true); + addFunction(SCALAR_MAX_DISTINCT, scalarMinMaxTypeComputer, true); addFunction(SQL_MAX_DISTINCT, MinMaxAggTypeComputer.INSTANCE, true); - addFunction(SCALAR_SQL_MAX_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true); + addFunction(SCALAR_SQL_MAX_DISTINCT, scalarMinMaxTypeComputer, true); addFunction(MIN_DISTINCT, MinMaxAggTypeComputer.INSTANCE, true); - addFunction(SCALAR_MIN_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true); + addFunction(SCALAR_MIN_DISTINCT, scalarMinMaxTypeComputer, true); addFunction(SQL_MIN_DISTINCT, MinMaxAggTypeComputer.INSTANCE, true); - addFunction(SCALAR_SQL_MIN_DISTINCT, ScalarVersionOfAggregateResultType.INSTANCE, true); + addFunction(SCALAR_SQL_MIN_DISTINCT, scalarMinMaxTypeComputer, true); addFunction(STDDEV_SAMP_DISTINCT, NullableDoubleTypeComputer.INSTANCE, true); addFunction(SCALAR_STDDEV_SAMP_DISTINCT, NullableDoubleTypeComputer.INSTANCE, true); diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/AggregateResultTypeComputer.java similarity index 57% copy from asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java copy to asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/AggregateResultTypeComputer.java index c34b5ed..8e663a7 100644 --- a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java +++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/AggregateResultTypeComputer.java @@ -16,30 +16,24 @@ * specific language governing permissions and limitations * under the License. */ + package org.apache.asterix.om.typecomputer.impl; -import org.apache.asterix.dataflow.data.common.ILogicalBinaryComparator; import org.apache.asterix.om.typecomputer.base.AbstractResultTypeComputer; -import org.apache.asterix.om.types.ATypeTag; -import org.apache.asterix.om.types.AUnionType; -import org.apache.asterix.om.types.BuiltinType; import org.apache.asterix.om.types.IAType; import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException; import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression; +import org.apache.hyracks.algebricks.core.algebra.functions.FunctionIdentifier; +import org.apache.hyracks.api.exceptions.SourceLocation; -public class MinMaxAggTypeComputer extends AbstractResultTypeComputer { - - public static final MinMaxAggTypeComputer INSTANCE = new MinMaxAggTypeComputer(); - - private MinMaxAggTypeComputer() { +public abstract class AggregateResultTypeComputer extends AbstractResultTypeComputer { + @Override + protected void checkArgType(FunctionIdentifier funcId, int argIndex, IAType type, SourceLocation sourceLoc) + throws AlgebricksException { + super.checkArgType(funcId, argIndex, type, sourceLoc); } @Override - protected IAType getResultType(ILogicalExpression expr, IAType... strippedInputTypes) throws AlgebricksException { - ATypeTag typeTag = strippedInputTypes[0].getTypeTag(); - if (ILogicalBinaryComparator.inequalityUndefined(typeTag)) { - return BuiltinType.ANULL; - } - return typeTag == ATypeTag.ANY ? BuiltinType.ANY : AUnionType.createUnknownableType(strippedInputTypes[0]); - } -} + protected abstract IAType getResultType(ILogicalExpression expr, IAType... strippedInputTypes) + throws AlgebricksException; +} \ No newline at end of file diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java index c34b5ed..fc1eee5 100644 --- a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java +++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/MinMaxAggTypeComputer.java @@ -19,7 +19,6 @@ package org.apache.asterix.om.typecomputer.impl; import org.apache.asterix.dataflow.data.common.ILogicalBinaryComparator; -import org.apache.asterix.om.typecomputer.base.AbstractResultTypeComputer; import org.apache.asterix.om.types.ATypeTag; import org.apache.asterix.om.types.AUnionType; import org.apache.asterix.om.types.BuiltinType; @@ -27,7 +26,7 @@ import org.apache.asterix.om.types.IAType; import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException; import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression; -public class MinMaxAggTypeComputer extends AbstractResultTypeComputer { +public class MinMaxAggTypeComputer extends AggregateResultTypeComputer { public static final MinMaxAggTypeComputer INSTANCE = new MinMaxAggTypeComputer(); diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/NumericSumAggTypeComputer.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/NumericSumAggTypeComputer.java index 1c67e56..a4b5e34 100644 --- a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/NumericSumAggTypeComputer.java +++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/NumericSumAggTypeComputer.java @@ -18,42 +18,20 @@ */ package org.apache.asterix.om.typecomputer.impl; -import org.apache.asterix.om.exceptions.UnsupportedTypeException; -import org.apache.asterix.om.typecomputer.base.AbstractResultTypeComputer; import org.apache.asterix.om.types.ATypeTag; import org.apache.asterix.om.types.AUnionType; import org.apache.asterix.om.types.BuiltinType; import org.apache.asterix.om.types.IAType; import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException; import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression; -import org.apache.hyracks.algebricks.core.algebra.functions.FunctionIdentifier; -import org.apache.hyracks.api.exceptions.SourceLocation; -public class NumericSumAggTypeComputer extends AbstractResultTypeComputer { +public class NumericSumAggTypeComputer extends AggregateResultTypeComputer { public static final NumericSumAggTypeComputer INSTANCE = new NumericSumAggTypeComputer(); private NumericSumAggTypeComputer() { } @Override - protected void checkArgType(FunctionIdentifier funcId, int argIndex, IAType type, SourceLocation sourceLoc) - throws AlgebricksException { - ATypeTag tag = type.getTypeTag(); - switch (tag) { - case DOUBLE: - case FLOAT: - case BIGINT: - case INTEGER: - case SMALLINT: - case TINYINT: - case ANY: - break; - default: - throw new UnsupportedTypeException(sourceLoc, funcId, tag); - } - } - - @Override protected IAType getResultType(ILogicalExpression expr, IAType... strippedInputTypes) throws AlgebricksException { ATypeTag tag = strippedInputTypes[0].getTypeTag(); switch (tag) { @@ -61,15 +39,12 @@ public class NumericSumAggTypeComputer extends AbstractResultTypeComputer { case SMALLINT: case INTEGER: case BIGINT: - IAType int64Type = BuiltinType.AINT64; - return AUnionType.createNullableType(int64Type, "AggResult"); + return AUnionType.createNullableType(BuiltinType.AINT64); case FLOAT: case DOUBLE: - IAType doubleType = BuiltinType.ADOUBLE; - return AUnionType.createNullableType(doubleType, "AggResult"); + return AUnionType.createNullableType(BuiltinType.ADOUBLE); case ANY: - IAType anyType = strippedInputTypes[0]; - return AUnionType.createNullableType(anyType, "AggResult"); + return BuiltinType.ANY; default: // All other possible cases. return BuiltinType.ANULL; diff --git a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/ScalarVersionOfAggregateResultType.java b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/ScalarVersionOfAggregateResultType.java index 5b90974..fda0285 100644 --- a/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/ScalarVersionOfAggregateResultType.java +++ b/asterixdb/asterix-om/src/main/java/org/apache/asterix/om/typecomputer/impl/ScalarVersionOfAggregateResultType.java @@ -18,9 +18,7 @@ */ package org.apache.asterix.om.typecomputer.impl; -import org.apache.asterix.om.exceptions.TypeMismatchException; import org.apache.asterix.om.typecomputer.base.AbstractResultTypeComputer; -import org.apache.asterix.om.types.ATypeTag; import org.apache.asterix.om.types.AUnionType; import org.apache.asterix.om.types.AbstractCollectionType; import org.apache.asterix.om.types.BuiltinType; @@ -32,32 +30,48 @@ import org.apache.hyracks.api.exceptions.SourceLocation; public class ScalarVersionOfAggregateResultType extends AbstractResultTypeComputer { - public static final ScalarVersionOfAggregateResultType INSTANCE = new ScalarVersionOfAggregateResultType(); + private final AggregateResultTypeComputer aggResultTypeComputer; - private ScalarVersionOfAggregateResultType() { + public ScalarVersionOfAggregateResultType(AggregateResultTypeComputer aggResultTypeComputer) { + this.aggResultTypeComputer = aggResultTypeComputer; } @Override - public void checkArgType(FunctionIdentifier funcId, int argIndex, IAType type, SourceLocation sourceLoc) + protected void checkArgType(FunctionIdentifier funcId, int argIndex, IAType argType, SourceLocation sourceLoc) throws AlgebricksException { - ATypeTag tag = type.getTypeTag(); - if (tag != ATypeTag.ANY && tag != ATypeTag.ARRAY && tag != ATypeTag.MULTISET) { - throw new TypeMismatchException(sourceLoc, funcId, argIndex, tag, ATypeTag.ARRAY, ATypeTag.MULTISET); + if (argIndex == 0) { + switch (argType.getTypeTag()) { + case ARRAY: + case MULTISET: + AbstractCollectionType act = (AbstractCollectionType) argType; + aggResultTypeComputer.checkArgType(funcId, argIndex, act.getItemType(), sourceLoc); + break; + } } } @Override protected IAType getResultType(ILogicalExpression expr, IAType... strippedInputTypes) throws AlgebricksException { - ATypeTag tag = strippedInputTypes[0].getTypeTag(); - if (tag == ATypeTag.ANY) { - return BuiltinType.ANY; + IAType argType = strippedInputTypes[0]; + switch (argType.getTypeTag()) { + case ARRAY: + case MULTISET: + AbstractCollectionType act = (AbstractCollectionType) argType; + IAType[] strippedInputTypes2 = strippedInputTypes.clone(); + strippedInputTypes2[0] = act.getItemType(); + IAType resultType = aggResultTypeComputer.getResultType(expr, strippedInputTypes2); + switch (resultType.getTypeTag()) { + case NULL: + case MISSING: + case ANY: + return resultType; + case UNION: + return AUnionType.createUnknownableType(((AUnionType) resultType).getActualType()); + default: + return AUnionType.createUnknownableType(resultType); + } + default: + return BuiltinType.ANY; } - if (tag != ATypeTag.ARRAY && tag != ATypeTag.MULTISET) { - // this condition being true would've thrown an exception above, no? - return strippedInputTypes[0]; - } - AbstractCollectionType act = (AbstractCollectionType) strippedInputTypes[0]; - IAType t = act.getItemType(); - return AUnionType.createUnknownableType(t); } } diff --git a/asterixdb/asterix-runtime/src/test/java/org/apache/asterix/runtime/functions/ScalarAggregateTypeComputerTest.java b/asterixdb/asterix-runtime/src/test/java/org/apache/asterix/runtime/functions/ScalarAggregateTypeComputerTest.java new file mode 100644 index 0000000..cbde36c --- /dev/null +++ b/asterixdb/asterix-runtime/src/test/java/org/apache/asterix/runtime/functions/ScalarAggregateTypeComputerTest.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.asterix.runtime.functions; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import org.apache.asterix.dataflow.data.common.ExpressionTypeComputer; +import org.apache.asterix.om.base.ABoolean; +import org.apache.asterix.om.base.ADate; +import org.apache.asterix.om.base.ADateTime; +import org.apache.asterix.om.base.ADayTimeDuration; +import org.apache.asterix.om.base.ADouble; +import org.apache.asterix.om.base.ADuration; +import org.apache.asterix.om.base.AFloat; +import org.apache.asterix.om.base.AInt16; +import org.apache.asterix.om.base.AInt32; +import org.apache.asterix.om.base.AInt64; +import org.apache.asterix.om.base.AInt8; +import org.apache.asterix.om.base.AInterval; +import org.apache.asterix.om.base.AOrderedList; +import org.apache.asterix.om.base.ARecord; +import org.apache.asterix.om.base.AString; +import org.apache.asterix.om.base.ATime; +import org.apache.asterix.om.base.AUnorderedList; +import org.apache.asterix.om.base.AYearMonthDuration; +import org.apache.asterix.om.base.IAObject; +import org.apache.asterix.om.constants.AsterixConstantValue; +import org.apache.asterix.om.exceptions.UnsupportedTypeException; +import org.apache.asterix.om.functions.BuiltinFunctionInfo; +import org.apache.asterix.om.functions.BuiltinFunctions; +import org.apache.asterix.om.functions.IFunctionDescriptorFactory; +import org.apache.asterix.om.types.AOrderedListType; +import org.apache.asterix.om.types.ARecordType; +import org.apache.asterix.om.types.ATypeTag; +import org.apache.asterix.om.types.AUnionType; +import org.apache.asterix.om.types.AUnorderedListType; +import org.apache.asterix.om.types.BuiltinType; +import org.apache.asterix.om.types.IAType; +import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException; +import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression; +import org.apache.hyracks.algebricks.core.algebra.base.LogicalVariable; +import org.apache.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression; +import org.apache.hyracks.algebricks.core.algebra.expressions.AggregateFunctionCallExpression; +import org.apache.hyracks.algebricks.core.algebra.expressions.ConstantExpression; +import org.apache.hyracks.algebricks.core.algebra.expressions.IVariableTypeEnvironment; +import org.apache.hyracks.algebricks.core.algebra.expressions.ScalarFunctionCallExpression; +import org.apache.hyracks.algebricks.core.algebra.functions.FunctionIdentifier; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** + * Test alignment of type computers between aggregate functions and their scalar versions + */ +@RunWith(Parameterized.class) +public class ScalarAggregateTypeComputerTest { + + private static final IAObject[] ITEMS = { + // + ABoolean.TRUE, + // + new AInt8((byte) 0), + // + new AInt16((short) 0), + // + new AInt32(0), + // + new AInt64(0), + // + new AFloat(0), + // + new ADouble(0), + // + new AString(""), + // + new ADate(0), + // + new ADateTime(0), + // + new ATime(0), + // + new ADuration(0, 0), + // + new AYearMonthDuration(0), + // + new ADayTimeDuration(0), + // + new AInterval(0, 0, ATypeTag.DATETIME.serialize()), + // + new AOrderedList(AOrderedListType.FULL_OPEN_ORDEREDLIST_TYPE, Collections.singletonList(new AString(""))), + // + new AUnorderedList(AUnorderedListType.FULLY_OPEN_UNORDEREDLIST_TYPE, + Collections.singletonList(new AString(""))), + // + new ARecord( + new ARecordType("record-type", new String[] { "a" }, new IAType[] { BuiltinType.ASTRING }, false), + new IAObject[] { new AString("") }) }; + + // Test parameters + @Parameterized.Parameter + public String testName; + + @Parameterized.Parameter(1) + public FunctionIdentifier scalarfid; + + @Parameterized.Parameter(2) + public FunctionIdentifier aggfid; + + @Parameterized.Parameter(3) + public IAObject item; + + @Parameterized.Parameters(name = "ScalarAggregateTypeComputerTest {index}: {0}({3})") + public static Collection<Object[]> tests() { + List<Object[]> tests = new ArrayList<>(); + + FunctionCollection fcoll = FunctionCollection.createDefaultFunctionCollection(); + for (IFunctionDescriptorFactory fdf : fcoll.getFunctionDescriptorFactories()) { + FunctionIdentifier fid = fdf.createFunctionDescriptor().getIdentifier(); + FunctionIdentifier aggfid = BuiltinFunctions.getAggregateFunction(fid); + if (aggfid == null) { + continue; + } + for (IAObject item : ITEMS) { + tests.add(new Object[] { fid.getName(), fid, aggfid, item }); + } + + } + return tests; + } + + @Test + public void test() throws Exception { + + AOrderedListType listType = new AOrderedListType(item.getType(), null); + AOrderedList list = new AOrderedList(listType, Collections.singletonList(item)); + ConstantExpression scalarArgExpr = new ConstantExpression(new AsterixConstantValue(list)); + BuiltinFunctionInfo scalarfi = BuiltinFunctions.getBuiltinFunctionInfo(scalarfid); + ScalarFunctionCallExpression scalarCallExpr = + new ScalarFunctionCallExpression(scalarfi, new MutableObject<>(scalarArgExpr)); + IAType scalarResultType = computeType(scalarCallExpr); + + ConstantExpression aggArgExpr = new ConstantExpression(new AsterixConstantValue(item)); + BuiltinFunctionInfo aggfi = BuiltinFunctions.getBuiltinFunctionInfo(aggfid); + AggregateFunctionCallExpression aggCallExpr = new AggregateFunctionCallExpression(aggfi, false, + Collections.singletonList(new MutableObject<>(aggArgExpr))); + IAType aggResultType = computeType(aggCallExpr); + + if (!compareResultTypes(scalarResultType, aggResultType)) { + Assert.fail(String.format("%s(%s) returns %s != %s(%s) returns %s", scalarfid.getName(), item.getType(), + formatResultType(scalarResultType), aggfid.getName(), item.getType(), + formatResultType(aggResultType))); + } + } + + private boolean compareResultTypes(IAType t1, IAType t2) { + // null means ERROR + if (t1 == null) { + // OK if both types are ERROR + return t2 == null; + } else if (t2 == null) { + return false; + } + boolean t1Union = false, t2Union = false; + if (t1.getTypeTag() == ATypeTag.UNION) { + t1Union = true; + t1 = ((AUnionType) t1).getActualType(); + } + if (t2.getTypeTag() == ATypeTag.UNION) { + t2Union = true; + t2 = ((AUnionType) t2).getActualType(); + } + return (t1Union == t2Union) && t1.deepEqual(t2); + } + + private String formatResultType(IAType t) { + return t == null ? "ERROR" : t.toString(); + } + + private IAType computeType(AbstractFunctionCallExpression callExpr) throws AlgebricksException { + try { + BuiltinFunctionInfo fi = Objects.requireNonNull((BuiltinFunctionInfo) callExpr.getFunctionInfo()); + return fi.getResultTypeComputer().computeType(callExpr, EMPTY_TYPE_ENV, null); + } catch (UnsupportedTypeException e) { + return null; + } + } + + private static final IVariableTypeEnvironment EMPTY_TYPE_ENV = new IVariableTypeEnvironment() { + + @Override + public boolean substituteProducedVariable(LogicalVariable v1, LogicalVariable v2) { + throw new IllegalStateException(); + } + + @Override + public void setVarType(LogicalVariable var, Object type) { + throw new IllegalStateException(); + } + + @Override + public Object getVarType(LogicalVariable var, List<LogicalVariable> nonNullVariables, + List<List<LogicalVariable>> correlatedNullableVariableLists) { + throw new IllegalStateException(); + } + + @Override + public Object getVarType(LogicalVariable var) { + throw new IllegalStateException(); + } + + @Override + public Object getType(ILogicalExpression expr) throws AlgebricksException { + return ExpressionTypeComputer.INSTANCE.getType(expr, null, this); + } + }; +}
