gortiz commented on code in PR #13573: URL: https://github.com/apache/pinot/pull/13573#discussion_r1674280241
########## pinot-core/src/test/java/org/apache/pinot/core/function/FunctionDefinitionRegistryTest.java: ########## @@ -0,0 +1,106 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.function; + +import java.util.EnumSet; +import org.apache.pinot.common.function.FunctionRegistry; +import org.apache.pinot.common.function.TransformFunctionType; +import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.spi.annotations.ScalarFunction; +import org.apache.pinot.sql.FilterKind; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + + +// NOTE: Keep this test in pinot-core to include all built-in scalar functions. +// TODO: Consider breaking this test into multiple tests. Review Comment: I think we should change this to one test per function before merging ########## pinot-common/src/main/java/org/apache/pinot/common/function/PinotScalarFunction.java: ########## @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function; + +import javax.annotation.Nullable; +import org.apache.pinot.common.function.sql.PinotSqlFunction; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +/** + * Provides finer control to the scalar functions annotated with {@link ScalarFunction}. + */ +public interface PinotScalarFunction { Review Comment: We should also add here the conditions imposed by FunctionRegistry to find implementations (be in a package that contains `.function.`, be public, etc. ########## pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java: ########## @@ -95,108 +153,152 @@ public static void init() { } /** - * Registers a method with the name of the method. + * Registers a {@link PinotScalarFunction} under the given canonical name. */ - public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder, - boolean isVarArg) { - registerFunction(method.getName(), method, nullableParameters, isPlaceholder, isVarArg); + private static void register(String canonicalName, PinotScalarFunction function, + Map<String, PinotScalarFunction> functionMap) { + Preconditions.checkState(functionMap.put(canonicalName, function) == null, "Function: %s is already registered", + canonicalName); } /** - * Registers a method with the given function name. + * Registers a {@link FunctionInfo} under the given canonical name. */ - public static void registerFunction(String functionName, Method method, boolean nullableParameters, - boolean isPlaceholder, boolean isVarArg) { - if (!isPlaceholder) { - registerFunctionInfoMap(functionName, method, nullableParameters, isVarArg); - } - registerCalciteNamedFunctionMap(functionName, method, nullableParameters, isVarArg); + private static void register(String canonicalName, FunctionInfo functionInfo, int numArguments, + Map<String, Map<Integer, FunctionInfo>> functionInfoMap) { + Preconditions.checkState( + functionInfoMap.computeIfAbsent(canonicalName, k -> new HashMap<>()).put(numArguments, functionInfo) == null, + "Function: %s with %s arguments is already registered", canonicalName, + numArguments == VAR_ARG_KEY ? "variable" : numArguments); } - private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters, - boolean isVarArg) { - FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters); - String canonicalName = canonicalize(functionName); - Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>()); - if (isVarArg) { - FunctionInfo existFunctionInfo = functionInfoMap.put(VAR_ARG_KEY, functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with variable number of parameters is already registered", functionName); - } else { - FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); - } - } - - private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters, - boolean isVarArg) { - if (method.getAnnotation(Deprecated.class) == null) { - FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method)); - } - } - - public static Map<String, List<Function>> getRegisteredCalciteFunctionMap() { - return FUNCTION_MAP.map(); + /** + * Returns {@code true} if the given canonical name is registered, {@code false} otherwise. + */ + public static boolean contains(String canonicalName) { + return FUNCTION_MAP.containsKey(canonicalName); } - public static Set<String> getRegisteredCalciteFunctionNames() { - return FUNCTION_MAP.map().keySet(); + @Deprecated + public static boolean containsFunction(String name) { + return contains(canonicalize(name)); } /** - * Returns {@code true} if the given function name is registered, {@code false} otherwise. + * Returns the {@link FunctionInfo} associated with the given canonical name and argument types, or {@code null} if + * there is no matching method. This method should be called after the FunctionRegistry is initialized and all methods + * are already registered. */ - public static boolean containsFunction(String functionName) { - return FUNCTION_INFO_MAP.containsKey(canonicalize(functionName)); + @Nullable + public static FunctionInfo lookupFunctionInfo(String canonicalName, ColumnDataType[] argumentTypes) { + PinotScalarFunction function = FUNCTION_MAP.get(canonicalName); + return function != null ? function.getFunctionInfo(argumentTypes) : null; } /** - * Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null} + * Returns the {@link FunctionInfo} associated with the given canonical name and number of arguments, or {@code null} * if there is no matching method. This method should be called after the FunctionRegistry is initialized and all * methods are already registered. + * TODO: Move all usages to {@link #lookupFunctionInfo(String, ColumnDataType[])}. */ @Nullable - public static FunctionInfo getFunctionInfo(String functionName, int numParameters) { - Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); - if (functionInfoMap != null) { - FunctionInfo functionInfo = functionInfoMap.get(numParameters); - if (functionInfo != null) { - return functionInfo; - } - return functionInfoMap.get(VAR_ARG_KEY); - } - return null; + public static FunctionInfo lookupFunctionInfo(String canonicalName, int numArguments) { + PinotScalarFunction function = FUNCTION_MAP.get(canonicalName); + return function != null ? function.getFunctionInfo(numArguments) : null; } - private static String canonicalize(String functionName) { - return StringUtils.remove(functionName, '_').toLowerCase(); + @Deprecated + @Nullable + public static FunctionInfo getFunctionInfo(String name, int numArguments) { + return lookupFunctionInfo(canonicalize(name), numArguments); } - /** - * Placeholders for scalar function, they register and represents the signature for transform and filter predicate - * so that v2 engine can understand and plan them correctly. - */ - private static class PlaceholderScalarFunctions { + public static String canonicalize(String name) { + return StringUtils.remove(name, '_').toLowerCase(); + } + + public static class ArgumentCountBasedScalarFunction implements PinotScalarFunction { + private final String _name; + private final Map<Integer, FunctionInfo> _functionInfoMap; + + private ArgumentCountBasedScalarFunction(String name, Map<Integer, FunctionInfo> functionInfoMap) { + _name = name.toUpperCase(); + _functionInfoMap = functionInfoMap; + } - @ScalarFunction(names = {"textContains", "text_contains"}, isPlaceholder = true) - public static boolean textContains(String text, String pattern) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + @Override + public String getName() { + return _name; } - @ScalarFunction(names = {"textMatch", "text_match"}, isPlaceholder = true) - public static boolean textMatch(String text, String pattern) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + @Override + public PinotSqlFunction toPinotSqlFunction() { + return new PinotSqlFunction(_name, getReturnTypeInference(), getOperandTypeChecker()); } - @ScalarFunction(names = {"jsonMatch", "json_match"}, isPlaceholder = true) - public static boolean jsonMatch(String text, String pattern) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + private SqlReturnTypeInference getReturnTypeInference() { + return opBinding -> { + int numArguments = opBinding.getOperandCount(); + FunctionInfo functionInfo = getFunctionInfo(numArguments); + Preconditions.checkState(functionInfo != null, "Failed to find function: %s with %s arguments", _name, + numArguments); + Method method = functionInfo.getMethod(); + Class<?> returnClass = method.getReturnType(); + RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + RelDataType returnType = returnClass == Object.class ? typeFactory.createSqlType(SqlTypeName.ANY) + : JavaTypeFactoryImpl.toSql(typeFactory, typeFactory.createJavaType(returnClass)); + + if (!functionInfo.hasNullableParameters()) { + // When any parameter is null, return is null + for (RelDataType type : opBinding.collectOperandTypes()) { + if (type.isNullable()) { + return typeFactory.createTypeWithNullability(returnType, true); + } + } + } + + return method.isAnnotationPresent(Nullable.class) ? typeFactory.createTypeWithNullability(returnType, true) + : returnType; + }; + } + + private SqlOperandTypeChecker getOperandTypeChecker() { + if (_functionInfoMap.containsKey(VAR_ARG_KEY)) { + return OperandTypes.VARIADIC; + } + if (_functionInfoMap.size() == 1) { + return getOperandTypeChecker(_functionInfoMap.values().iterator().next().getMethod()); + } + List<SqlOperandTypeChecker> operandTypeCheckers = new ArrayList<>(_functionInfoMap.size()); + for (FunctionInfo functionInfo : _functionInfoMap.values()) { + operandTypeCheckers.add(getOperandTypeChecker(functionInfo.getMethod())); + } + return OperandTypes.or(operandTypeCheckers.toArray(new SqlOperandTypeChecker[0])); + } + + private static SqlTypeFamily getSqlTypeFamily(Class<?> clazz) { + // NOTE: Pinot allows some non-standard type conversions such as Timestamp <-> long, boolean <-> int etc. Do not + // restrict the type family for now. We only restrict the type family for String so that cast can be added. + // Explicit cast is required to correctly convert boolean and Timestamp to String. + // TODO: Revisit this. + return clazz == String.class ? SqlTypeFamily.CHARACTER : SqlTypeFamily.ANY; + } + + private static SqlOperandTypeChecker getOperandTypeChecker(Method method) { + Class<?>[] parameterTypes = method.getParameterTypes(); + int length = parameterTypes.length; + SqlTypeFamily[] typeFamilies = new SqlTypeFamily[length]; + for (int i = 0; i < length; i++) { + typeFamilies[i] = getSqlTypeFamily(parameterTypes[i]); + } + return OperandTypes.family(typeFamilies); } - @ScalarFunction(names = {"vectorSimilarity", "vector_similarity"}, isPlaceholder = true) - public static boolean vectorSimilarity(float[] vector1, float[] vector2, int topk) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + @Nullable + @Override + public FunctionInfo getFunctionInfo(int numArguments) { + FunctionInfo functionInfo = _functionInfoMap.get(numArguments); + return functionInfo != null ? functionInfo : _functionInfoMap.get(VAR_ARG_KEY); Review Comment: I think it would be cool to have documented somewhere how functions can be registered. AFAIU that should be something like: - Using annotated methods: Simpler and shorted but less expressive. For example, you cannot support polymorphism. - Using annotated classes that implement PinotScalarFunction: More expressive. By @Jackie-Jiang comment here it looks like there is a third way that consist on registering the function explicitly in PinotOperatorTable. But function won't be usable in V1, am I right? ########## pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java: ########## @@ -46,127 +59,296 @@ * </ul> */ @SuppressWarnings("unused") // unused fields are accessed by reflection -public class PinotOperatorTable extends SqlStdOperatorTable { - - private static @MonotonicNonNull PinotOperatorTable _instance; - - // TODO: clean up lazy init by using Suppliers.memorized(this::computeInstance) and make getter wrapped around - // supplier instance. this should replace all lazy init static objects in the codebase - public static synchronized PinotOperatorTable instance() { - if (_instance == null) { - // Creates and initializes the standard operator table. - // Uses two-phase construction, because we can't initialize the - // table until the constructor of the sub-class has completed. - _instance = new PinotOperatorTable(); - _instance.initNoDuplicate(); - } - return _instance; +public class PinotOperatorTable implements SqlOperatorTable { + private static final Supplier<PinotOperatorTable> INSTANCE = Suppliers.memoize(PinotOperatorTable::new); + + public static PinotOperatorTable instance() { + return INSTANCE.get(); } /** - * Initialize without duplicate, e.g. when 2 duplicate operator is linked with the same op - * {@link org.apache.calcite.sql.SqlKind} it causes problem. - * - * <p>This is a direct copy of the {@link org.apache.calcite.sql.util.ReflectiveSqlOperatorTable} and can be hard to - * debug, suggest changing to a non-dynamic registration. Dynamic function support should happen via catalog. - * - * This also registers aggregation functions defined in {@link org.apache.pinot.segment.spi.AggregationFunctionType} - * which are multistage enabled. + * This list includes the supported standard {@link SqlOperator}s defined in {@link SqlStdOperatorTable}. + * NOTE: The operator order follows the same order as defined in {@link SqlStdOperatorTable} for easier search. + * Some operators are commented out and re-declared in {@link #STANDARD_OPERATORS_WITH_ALIASES}. + * TODO: Add more operators as needed. */ - public final void initNoDuplicate() { - // Pinot supports native COALESCE function, thus no need to create CASE WHEN conversion. - register(new PinotSqlCoalesceFunction()); - // Ensure ArrayValueConstructor is registered before ArrayQueryConstructor - register(ARRAY_VALUE_CONSTRUCTOR); - - // TODO: reflection based registration is not ideal, we should use a static list of operators and register them - // Use reflection to register the expressions stored in public fields. - for (Field field : getClass().getFields()) { - try { - if (SqlFunction.class.isAssignableFrom(field.getType())) { - SqlFunction op = (SqlFunction) field.get(this); - if (op != null && notRegistered(op)) { - register(op); - } - } else if (SqlOperator.class.isAssignableFrom(field.getType())) { - SqlOperator op = (SqlOperator) field.get(this); - if (op != null && notRegistered(op)) { - register(op); - } - } - } catch (IllegalArgumentException | IllegalAccessException e) { - throw Util.throwAsRuntime(Util.causeOrSelf(e)); + //@formatter:off + private static final List<SqlOperator> STANDARD_OPERATORS = List.of( + // SET OPERATORS + SqlStdOperatorTable.UNION, + SqlStdOperatorTable.UNION_ALL, + SqlStdOperatorTable.EXCEPT, + SqlStdOperatorTable.EXCEPT_ALL, + SqlStdOperatorTable.INTERSECT, + SqlStdOperatorTable.INTERSECT_ALL, + + // BINARY OPERATORS + SqlStdOperatorTable.AND, + SqlStdOperatorTable.AS, + SqlStdOperatorTable.FILTER, + SqlStdOperatorTable.WITHIN_GROUP, + SqlStdOperatorTable.WITHIN_DISTINCT, + SqlStdOperatorTable.CONCAT, + SqlStdOperatorTable.DIVIDE, + SqlStdOperatorTable.PERCENT_REMAINDER, + SqlStdOperatorTable.DOT, + SqlStdOperatorTable.EQUALS, + SqlStdOperatorTable.GREATER_THAN, + SqlStdOperatorTable.IS_DISTINCT_FROM, + SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, + SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, + SqlStdOperatorTable.IN, + SqlStdOperatorTable.NOT_IN, + SqlStdOperatorTable.SEARCH, + SqlStdOperatorTable.LESS_THAN, + SqlStdOperatorTable.LESS_THAN_OR_EQUAL, + SqlStdOperatorTable.MINUS, + SqlStdOperatorTable.MULTIPLY, + SqlStdOperatorTable.NOT_EQUALS, + SqlStdOperatorTable.OR, + SqlStdOperatorTable.PLUS, + SqlStdOperatorTable.INTERVAL, + + // POSTFIX OPERATORS + SqlStdOperatorTable.DESC, + SqlStdOperatorTable.NULLS_FIRST, + SqlStdOperatorTable.NULLS_LAST, + SqlStdOperatorTable.IS_NOT_NULL, + SqlStdOperatorTable.IS_NULL, + SqlStdOperatorTable.IS_NOT_TRUE, + SqlStdOperatorTable.IS_TRUE, + SqlStdOperatorTable.IS_NOT_FALSE, + SqlStdOperatorTable.IS_FALSE, + SqlStdOperatorTable.IS_NOT_UNKNOWN, + SqlStdOperatorTable.IS_UNKNOWN, + + // PREFIX OPERATORS + SqlStdOperatorTable.EXISTS, + SqlStdOperatorTable.NOT, + + // AGGREGATE OPERATORS + SqlStdOperatorTable.SUM, + SqlStdOperatorTable.COUNT, + SqlStdOperatorTable.MODE, + SqlStdOperatorTable.MIN, + SqlStdOperatorTable.MAX, + SqlStdOperatorTable.LAST_VALUE, + SqlStdOperatorTable.FIRST_VALUE, + SqlStdOperatorTable.LEAD, + SqlStdOperatorTable.LAG, + SqlStdOperatorTable.AVG, + SqlStdOperatorTable.STDDEV_POP, + SqlStdOperatorTable.COVAR_POP, + SqlStdOperatorTable.COVAR_SAMP, + SqlStdOperatorTable.STDDEV_SAMP, + SqlStdOperatorTable.VAR_POP, + SqlStdOperatorTable.VAR_SAMP, + SqlStdOperatorTable.SUM0, + + // WINDOW Rank Functions + SqlStdOperatorTable.DENSE_RANK, + SqlStdOperatorTable.RANK, + SqlStdOperatorTable.ROW_NUMBER, + + // SPECIAL OPERATORS + SqlStdOperatorTable.BETWEEN, + SqlStdOperatorTable.SYMMETRIC_BETWEEN, + SqlStdOperatorTable.NOT_BETWEEN, + SqlStdOperatorTable.SYMMETRIC_NOT_BETWEEN, + SqlStdOperatorTable.NOT_LIKE, + SqlStdOperatorTable.LIKE, +// SqlStdOperatorTable.CASE, + SqlStdOperatorTable.OVER, + + // FUNCTIONS + // String functions + SqlStdOperatorTable.SUBSTRING, + SqlStdOperatorTable.REPLACE, + SqlStdOperatorTable.TRIM, + SqlStdOperatorTable.UPPER, + SqlStdOperatorTable.LOWER, + // Arithmetic functions + SqlStdOperatorTable.POWER, + SqlStdOperatorTable.SQRT, + SqlStdOperatorTable.MOD, +// SqlStdOperatorTable.LN, + SqlStdOperatorTable.LOG10, + SqlStdOperatorTable.ABS, + SqlStdOperatorTable.ACOS, + SqlStdOperatorTable.ASIN, + SqlStdOperatorTable.ATAN, + SqlStdOperatorTable.ATAN2, + SqlStdOperatorTable.COS, + SqlStdOperatorTable.COT, + SqlStdOperatorTable.DEGREES, + SqlStdOperatorTable.EXP, + SqlStdOperatorTable.RADIANS, + SqlStdOperatorTable.ROUND, + SqlStdOperatorTable.SIGN, + SqlStdOperatorTable.SIN, + SqlStdOperatorTable.TAN, + SqlStdOperatorTable.TRUNCATE, + SqlStdOperatorTable.FLOOR, + SqlStdOperatorTable.CEIL, + SqlStdOperatorTable.TIMESTAMP_ADD, + SqlStdOperatorTable.TIMESTAMP_DIFF, + SqlStdOperatorTable.CAST, + + SqlStdOperatorTable.EXTRACT, + // TODO: The following operators are all rewritten to EXTRACT. Consider removing them because they are all + // supported without rewrite. + SqlStdOperatorTable.YEAR, + SqlStdOperatorTable.QUARTER, + SqlStdOperatorTable.MONTH, + SqlStdOperatorTable.WEEK, + SqlStdOperatorTable.DAYOFYEAR, + SqlStdOperatorTable.DAYOFMONTH, + SqlStdOperatorTable.DAYOFWEEK, + SqlStdOperatorTable.HOUR, + SqlStdOperatorTable.MINUTE, + SqlStdOperatorTable.SECOND, + + SqlStdOperatorTable.ITEM, + SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, + SqlStdOperatorTable.LISTAGG + ); + + private static final List<Pair<SqlOperator, List<String>>> STANDARD_OPERATORS_WITH_ALIASES = List.of( + Pair.of(SqlStdOperatorTable.CASE, List.of("CASE", "CASE_WHEN")), + Pair.of(SqlStdOperatorTable.LN, List.of("LN", "LOG")) + ); + + /** + * This list includes the customized {@link SqlOperator}s. + */ + private static final List<SqlOperator> PINOT_OPERATORS = List.of( + // Placeholder for special predicates + new PinotSqlFunction("TEXT_MATCH", ReturnTypes.BOOLEAN, OperandTypes.CHARACTER_CHARACTER), + new PinotSqlFunction("TEXT_CONTAINS", ReturnTypes.BOOLEAN, OperandTypes.CHARACTER_CHARACTER), + new PinotSqlFunction("JSON_MATCH", ReturnTypes.BOOLEAN, OperandTypes.CHARACTER_CHARACTER), + new PinotSqlFunction("VECTOR_SIMILARITY", ReturnTypes.BOOLEAN, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER), i -> i == 2)), + + // Placeholder for special functions to handle MV + new PinotSqlFunction("ARRAY_TO_MV", opBinding -> opBinding.getOperandType(0).getComponentType(), + OperandTypes.ARRAY), + + // SqlStdOperatorTable.COALESCE without rewrite + new SqlFunction("COALESCE", SqlKind.COALESCE, + ReturnTypes.LEAST_RESTRICTIVE.andThen(SqlTypeTransforms.LEAST_NULLABLE), null, OperandTypes.SAME_VARIADIC, + SqlFunctionCategory.SYSTEM), + + // The scalar function version returns long instead of Timestamp + // TODO: Consider unifying the return type to Timestamp + new PinotSqlFunction("FROM_DATE_TIME", ReturnTypes.TIMESTAMP_NULLABLE, OperandTypes.family( + List.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY), + i -> i > 1)) + ); + + private static final List<Pair<SqlOperator, List<String>>> PINOT_OPERATORS_WITH_ALIASES = List.of( + ); + //@formatter:on + + // Key is canonical name + private final Map<String, SqlOperator> _operatorMap; + private final List<SqlOperator> _operatorList; + + private PinotOperatorTable() { + Map<String, SqlOperator> operatorMap = new HashMap<>(); Review Comment: It is not clear to me the relation between FunctionRegistry and this class. I would have expected to iterate over all functions declared in FunctionRegistry and add them into calcite, but it looks like we are iterating over another static list. What am I missing? ########## pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateReduceFunctionsRule.java: ########## @@ -16,21 +16,27 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.calcite.sql.fun; +package org.apache.pinot.calcite.rel.rules; -import org.apache.calcite.sql.SqlCall; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.fun.SqlCoalesceFunction; -import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; +import org.apache.calcite.sql.SqlKind; /** - * Pinot supports native COALESCE function, thus no need to create CASE WHEN conversion. + * Pinot customized version of {@link AggregateReduceFunctionsRule} which only reduce on SUM and AVG. */ -public class PinotSqlCoalesceFunction extends SqlCoalesceFunction { +public class PinotAggregateReduceFunctionsRule extends AggregateReduceFunctionsRule { + public static final PinotAggregateReduceFunctionsRule INSTANCE = + new PinotAggregateReduceFunctionsRule(Config.DEFAULT); + + private PinotAggregateReduceFunctionsRule(Config config) { + super(config); + } @Override - public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { - return call; + public boolean canReduce(AggregateCall call) { + SqlKind kind = call.getAggregation().getKind(); + return kind == SqlKind.SUM || kind == SqlKind.AVG; Review Comment: Does this rule applies in the leaf stage or also in the intermediate stage? How we also merge data from different workers in the not simpler form? ########## pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java: ########## @@ -127,11 +125,7 @@ public void testNestedFunction() { } @Test - public void testStateSharedBetweenRowsForExecution() - throws Exception { - MyFunc myFunc = new MyFunc(); - Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class); - FunctionRegistry.registerFunction(method, false, false, false); Review Comment: The new code doesn't need to register the function because it is detected in the classpath, right? I think a commend saying so would be useful for future readers. Also, I think we should try to avoid this pattern. Otherwise one test class will be adding methods that could be seen by other test classes, which may end up being problematic ########## pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java: ########## @@ -46,127 +59,296 @@ * </ul> */ @SuppressWarnings("unused") // unused fields are accessed by reflection -public class PinotOperatorTable extends SqlStdOperatorTable { - - private static @MonotonicNonNull PinotOperatorTable _instance; - - // TODO: clean up lazy init by using Suppliers.memorized(this::computeInstance) and make getter wrapped around - // supplier instance. this should replace all lazy init static objects in the codebase - public static synchronized PinotOperatorTable instance() { - if (_instance == null) { - // Creates and initializes the standard operator table. - // Uses two-phase construction, because we can't initialize the - // table until the constructor of the sub-class has completed. - _instance = new PinotOperatorTable(); - _instance.initNoDuplicate(); - } - return _instance; +public class PinotOperatorTable implements SqlOperatorTable { Review Comment: This class is the one that is used at parsing time? ie here we enumerate operator signatures (name, types, etc) but not semantic, right? Can we also add that to the javadoc? ########## pinot-core/src/test/java/org/apache/pinot/queries/TimestampQueriesTest.java: ########## @@ -223,10 +223,8 @@ public void testQueries() { } } - @Test( - expectedExceptions = BadQueryRequestException.class, - expectedExceptionsMessageRegExp = ".*attimezone not found.*" - ) + @Test(expectedExceptions = BadQueryRequestException.class, expectedExceptionsMessageRegExp = "Unsupported function:" + + " attimezone") Review Comment: nit: I prefer the older code style with one property per line ########## pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java: ########## @@ -103,187 +98,119 @@ public enum TransformFunctionType { // date type conversion functions CAST("cast"), - // object type - ARRAY_TO_MV("arrayToMV", - ReturnTypes.cascade(opBinding -> positionalComponentReturnType(opBinding, 0), SqlTypeTransforms.FORCE_NULLABLE), - OperandTypes.family(SqlTypeFamily.ARRAY), "array_to_mv"), Review Comment: ~Is it safe to remove this function?~ ~This means we don't need ARRAY_TO_MV now? What is the impact on queries already written that use it?~ I've notice the operator is now declared in PinotOperatorTable, which IICU it includes only the signatures (name and type checks/inference), which do not include implementation. That is exactly what we want in ARRAY_TO_MV, right? ########## pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java: ########## @@ -95,108 +153,162 @@ public static void init() { } /** - * Registers a method with the name of the method. + * Registers a {@link PinotScalarFunction} under the given canonical name. */ - public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder, - boolean isVarArg) { - registerFunction(method.getName(), method, nullableParameters, isPlaceholder, isVarArg); + private static void register(String canonicalName, PinotScalarFunction function, + Map<String, PinotScalarFunction> functionMap) { + Preconditions.checkState(functionMap.put(canonicalName, function) == null, "Function: %s is already registered", + canonicalName); } /** - * Registers a method with the given function name. + * Registers a {@link FunctionInfo} under the given canonical name. */ - public static void registerFunction(String functionName, Method method, boolean nullableParameters, - boolean isPlaceholder, boolean isVarArg) { - if (!isPlaceholder) { - registerFunctionInfoMap(functionName, method, nullableParameters, isVarArg); - } - registerCalciteNamedFunctionMap(functionName, method, nullableParameters, isVarArg); - } - - private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters, - boolean isVarArg) { - FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters); - String canonicalName = canonicalize(functionName); - Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>()); - if (isVarArg) { - FunctionInfo existFunctionInfo = functionInfoMap.put(VAR_ARG_KEY, functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with variable number of parameters is already registered", functionName); - } else { - FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); - } - } - - private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters, - boolean isVarArg) { - if (method.getAnnotation(Deprecated.class) == null) { - FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method)); - } + private static void register(String canonicalName, FunctionInfo functionInfo, int numArguments, + Map<String, Map<Integer, FunctionInfo>> functionInfoMap) { + Preconditions.checkState( + functionInfoMap.computeIfAbsent(canonicalName, k -> new HashMap<>()).put(numArguments, functionInfo) == null, + "Function: %s with %s arguments is already registered", canonicalName, + numArguments == VAR_ARG_KEY ? "variable" : numArguments); } - public static Map<String, List<Function>> getRegisteredCalciteFunctionMap() { - return FUNCTION_MAP.map(); + /** + * Returns {@code true} if the given canonical name is registered, {@code false} otherwise. + */ + public static boolean contains(String canonicalName) { + return FUNCTION_MAP.containsKey(canonicalName); } - public static Set<String> getRegisteredCalciteFunctionNames() { - return FUNCTION_MAP.map().keySet(); + /** + * @deprecated For performance concern, use {@link #contains(String)} instead to avoid invoking + * {@link #canonicalize(String)} multiple times. + */ + @Deprecated + public static boolean containsFunction(String name) { + return contains(canonicalize(name)); } /** - * Returns {@code true} if the given function name is registered, {@code false} otherwise. + * Returns the {@link FunctionInfo} associated with the given canonical name and argument types, or {@code null} if + * there is no matching method. This method should be called after the FunctionRegistry is initialized and all methods + * are already registered. */ - public static boolean containsFunction(String functionName) { - return FUNCTION_INFO_MAP.containsKey(canonicalize(functionName)); + @Nullable + public static FunctionInfo lookupFunctionInfo(String canonicalName, ColumnDataType[] argumentTypes) { + PinotScalarFunction function = FUNCTION_MAP.get(canonicalName); + return function != null ? function.getFunctionInfo(argumentTypes) : null; } /** - * Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null} + * Returns the {@link FunctionInfo} associated with the given canonical name and number of arguments, or {@code null} * if there is no matching method. This method should be called after the FunctionRegistry is initialized and all * methods are already registered. + * TODO: Move all usages to {@link #lookupFunctionInfo(String, ColumnDataType[])}. */ @Nullable - public static FunctionInfo getFunctionInfo(String functionName, int numParameters) { - Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); - if (functionInfoMap != null) { - FunctionInfo functionInfo = functionInfoMap.get(numParameters); - if (functionInfo != null) { - return functionInfo; - } - return functionInfoMap.get(VAR_ARG_KEY); - } - return null; - } - - private static String canonicalize(String functionName) { - return StringUtils.remove(functionName, '_').toLowerCase(); + public static FunctionInfo lookupFunctionInfo(String canonicalName, int numArguments) { + PinotScalarFunction function = FUNCTION_MAP.get(canonicalName); + return function != null ? function.getFunctionInfo(numArguments) : null; } /** - * Placeholders for scalar function, they register and represents the signature for transform and filter predicate - * so that v2 engine can understand and plan them correctly. + * @deprecated For performance concern, use {@link #lookupFunctionInfo(String, int)} instead to avoid invoking + * {@link #canonicalize(String)} multiple times. */ - private static class PlaceholderScalarFunctions { + @Deprecated + @Nullable + public static FunctionInfo getFunctionInfo(String name, int numArguments) { + return lookupFunctionInfo(canonicalize(name), numArguments); + } + + public static String canonicalize(String name) { + return StringUtils.remove(name, '_').toLowerCase(); + } + + public static class ArgumentCountBasedScalarFunction implements PinotScalarFunction { + private final String _name; + private final Map<Integer, FunctionInfo> _functionInfoMap; - @ScalarFunction(names = {"textContains", "text_contains"}, isPlaceholder = true) - public static boolean textContains(String text, String pattern) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + private ArgumentCountBasedScalarFunction(String name, Map<Integer, FunctionInfo> functionInfoMap) { + _name = name; + _functionInfoMap = functionInfoMap; } - @ScalarFunction(names = {"textMatch", "text_match"}, isPlaceholder = true) - public static boolean textMatch(String text, String pattern) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + @Override + public String getName() { + return _name; + } + + @Override + public PinotSqlFunction toPinotSqlFunction() { + return new PinotSqlFunction(_name, getReturnTypeInference(), getOperandTypeChecker()); + } + + private SqlReturnTypeInference getReturnTypeInference() { + return opBinding -> { + int numArguments = opBinding.getOperandCount(); + FunctionInfo functionInfo = getFunctionInfo(numArguments); + Preconditions.checkState(functionInfo != null, "Failed to find function: %s with %s arguments", _name, + numArguments); + RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + Method method = functionInfo.getMethod(); + RelDataType returnType = FunctionUtils.getRelDataType(opBinding.getTypeFactory(), method.getReturnType()); + + if (!functionInfo.hasNullableParameters()) { + // When any parameter is null, return is null + for (RelDataType type : opBinding.collectOperandTypes()) { + if (type.isNullable()) { + return typeFactory.createTypeWithNullability(returnType, true); + } + } + } + + return method.isAnnotationPresent(Nullable.class) ? typeFactory.createTypeWithNullability(returnType, true) + : returnType; + }; + } + + private SqlOperandTypeChecker getOperandTypeChecker() { + if (_functionInfoMap.containsKey(VAR_ARG_KEY)) { + return OperandTypes.VARIADIC; + } + int numCheckers = _functionInfoMap.size(); + if (numCheckers == 1) { + return getOperandTypeChecker(_functionInfoMap.values().iterator().next().getMethod()); + } + SqlOperandTypeChecker[] operandTypeCheckers = new SqlOperandTypeChecker[numCheckers]; + int index = 0; + for (FunctionInfo functionInfo : _functionInfoMap.values()) { + operandTypeCheckers[index++] = getOperandTypeChecker(functionInfo.getMethod()); + } + return OperandTypes.or(operandTypeCheckers); + } + + private static SqlOperandTypeChecker getOperandTypeChecker(Method method) { + Class<?>[] parameterTypes = method.getParameterTypes(); + int length = parameterTypes.length; + SqlTypeFamily[] typeFamilies = new SqlTypeFamily[length]; + for (int i = 0; i < length; i++) { + typeFamilies[i] = getSqlTypeFamily(parameterTypes[i]); + } + return OperandTypes.family(typeFamilies); } - @ScalarFunction(names = {"jsonMatch", "json_match"}, isPlaceholder = true) - public static boolean jsonMatch(String text, String pattern) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + private static SqlTypeFamily getSqlTypeFamily(Class<?> clazz) { + // NOTE: Pinot allows some non-standard type conversions such as Timestamp <-> long, boolean <-> int etc. Do not + // restrict the type family for now. We only restrict the type family for String so that cast can be added. + // Explicit cast is required to correctly convert boolean and Timestamp to String. Without explicit case, + // BOOLEAN and TIMESTAMP type will be converted with their internal stored format which is INT and LONG + // respectively. E.g. true will be converted to "1", timestamp will be converted to long value string. + // TODO: Revisit this. Review Comment: We should be able to add these implicit type conversions into Calcite. In the past I tried to remove castings that already existed in Calcite (from String to varbinary) and I wasn't able to do so. I even asked in Calcite mailing list and by the answers I've got I have the impression that _removing implicit casts_ is not supported. But I think _adding_ new implicit cast is supported in Calcite ########## pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java: ########## @@ -95,108 +153,152 @@ public static void init() { } /** - * Registers a method with the name of the method. + * Registers a {@link PinotScalarFunction} under the given canonical name. */ - public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder, - boolean isVarArg) { - registerFunction(method.getName(), method, nullableParameters, isPlaceholder, isVarArg); + private static void register(String canonicalName, PinotScalarFunction function, + Map<String, PinotScalarFunction> functionMap) { + Preconditions.checkState(functionMap.put(canonicalName, function) == null, "Function: %s is already registered", + canonicalName); } /** - * Registers a method with the given function name. + * Registers a {@link FunctionInfo} under the given canonical name. */ - public static void registerFunction(String functionName, Method method, boolean nullableParameters, - boolean isPlaceholder, boolean isVarArg) { - if (!isPlaceholder) { - registerFunctionInfoMap(functionName, method, nullableParameters, isVarArg); - } - registerCalciteNamedFunctionMap(functionName, method, nullableParameters, isVarArg); + private static void register(String canonicalName, FunctionInfo functionInfo, int numArguments, + Map<String, Map<Integer, FunctionInfo>> functionInfoMap) { + Preconditions.checkState( + functionInfoMap.computeIfAbsent(canonicalName, k -> new HashMap<>()).put(numArguments, functionInfo) == null, + "Function: %s with %s arguments is already registered", canonicalName, + numArguments == VAR_ARG_KEY ? "variable" : numArguments); } - private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters, - boolean isVarArg) { - FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters); - String canonicalName = canonicalize(functionName); - Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>()); - if (isVarArg) { - FunctionInfo existFunctionInfo = functionInfoMap.put(VAR_ARG_KEY, functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with variable number of parameters is already registered", functionName); - } else { - FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); - } - } - - private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters, - boolean isVarArg) { - if (method.getAnnotation(Deprecated.class) == null) { - FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method)); - } - } - - public static Map<String, List<Function>> getRegisteredCalciteFunctionMap() { - return FUNCTION_MAP.map(); + /** + * Returns {@code true} if the given canonical name is registered, {@code false} otherwise. + */ + public static boolean contains(String canonicalName) { + return FUNCTION_MAP.containsKey(canonicalName); } - public static Set<String> getRegisteredCalciteFunctionNames() { - return FUNCTION_MAP.map().keySet(); + @Deprecated + public static boolean containsFunction(String name) { + return contains(canonicalize(name)); } /** - * Returns {@code true} if the given function name is registered, {@code false} otherwise. + * Returns the {@link FunctionInfo} associated with the given canonical name and argument types, or {@code null} if + * there is no matching method. This method should be called after the FunctionRegistry is initialized and all methods + * are already registered. */ - public static boolean containsFunction(String functionName) { - return FUNCTION_INFO_MAP.containsKey(canonicalize(functionName)); + @Nullable + public static FunctionInfo lookupFunctionInfo(String canonicalName, ColumnDataType[] argumentTypes) { + PinotScalarFunction function = FUNCTION_MAP.get(canonicalName); + return function != null ? function.getFunctionInfo(argumentTypes) : null; } /** - * Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null} + * Returns the {@link FunctionInfo} associated with the given canonical name and number of arguments, or {@code null} * if there is no matching method. This method should be called after the FunctionRegistry is initialized and all * methods are already registered. + * TODO: Move all usages to {@link #lookupFunctionInfo(String, ColumnDataType[])}. */ @Nullable - public static FunctionInfo getFunctionInfo(String functionName, int numParameters) { - Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); - if (functionInfoMap != null) { - FunctionInfo functionInfo = functionInfoMap.get(numParameters); - if (functionInfo != null) { - return functionInfo; - } - return functionInfoMap.get(VAR_ARG_KEY); - } - return null; + public static FunctionInfo lookupFunctionInfo(String canonicalName, int numArguments) { + PinotScalarFunction function = FUNCTION_MAP.get(canonicalName); + return function != null ? function.getFunctionInfo(numArguments) : null; } - private static String canonicalize(String functionName) { - return StringUtils.remove(functionName, '_').toLowerCase(); + @Deprecated + @Nullable + public static FunctionInfo getFunctionInfo(String name, int numArguments) { + return lookupFunctionInfo(canonicalize(name), numArguments); Review Comment: I would even recommend to create a class called `CannonicalName` that contains a String. Then use that class as input. We may have a static class that transforms Strings into CanonicalNames. This will let us: - Have type safe checks, so we Java don't let us call lookupFunctionInfo with non canonical names. - We can cache the CanonicalNames, so we don't need to allocate. ########## pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java: ########## @@ -294,42 +299,37 @@ private static List<AggregateCall> buildAggCalls(Aggregate aggRel, AggType aggTy // - argList is replaced with rexList private static AggregateCall buildAggCall(RelNode input, AggregateCall orgAggCall, List<RexNode> rexList, int numGroups, AggType aggType) { - String functionName = orgAggCall.getAggregation().getName(); + SqlAggFunction orgAggFunction = orgAggCall.getAggregation(); + String functionName = orgAggFunction.getName(); + SqlKind kind = orgAggFunction.getKind(); + SqlFunctionCategory functionCategory = orgAggFunction.getFunctionType(); if (orgAggCall.isDistinct()) { - if (functionName.equals("COUNT")) { + if (kind == SqlKind.COUNT) { functionName = "DISTINCTCOUNT"; - } else if (functionName.equals("LISTAGG")) { + kind = SqlKind.OTHER_FUNCTION; + functionCategory = SqlFunctionCategory.USER_DEFINED_FUNCTION; + } else if (kind == SqlKind.LISTAGG) { rexList.add(input.getCluster().getRexBuilder().makeLiteral(true)); } } - AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName); - SqlAggFunction sqlAggFunction; - switch (aggType) { - case DIRECT: - sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), - ReturnTypes.explicit(orgAggCall.getType()), null, functionType.getOperandTypeChecker(), - functionType.getSqlFunctionCategory()); - break; - case LEAF: - sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), - functionType.getIntermediateReturnTypeInference(), null, functionType.getOperandTypeChecker(), - functionType.getSqlFunctionCategory()); - break; - case INTERMEDIATE: - sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), - functionType.getIntermediateReturnTypeInference(), null, OperandTypes.ANY, - functionType.getSqlFunctionCategory()); - break; - case FINAL: - sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), - ReturnTypes.explicit(orgAggCall.getType()), null, OperandTypes.ANY, functionType.getSqlFunctionCategory()); - break; - default: - throw new IllegalStateException("Unsupported AggType: " + aggType); + SqlReturnTypeInference returnTypeInference = null; + RelDataType returnType = null; + // Override the intermediate result type inference if it is provided + if (aggType.isOutputIntermediateFormat()) { + AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName); + returnTypeInference = functionType.getIntermediateReturnTypeInference(); } + if (returnTypeInference == null) { + returnType = orgAggCall.getType(); + returnTypeInference = ReturnTypes.explicit(returnType); + } Review Comment: It would be ok to move this code to AggregationFunctionType.getIntermediateReturnTypeInference so it can never return null? Or would it break some other invariant? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
