This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 19b79f406c Polymorphic binary arithmetic scalar functions (#14089)
19b79f406c is described below
commit 19b79f406c36e8b075e9c1698a8752af5b4e6d23
Author: Yash Mayya <[email protected]>
AuthorDate: Tue Oct 1 01:32:38 2024 +0530
Polymorphic binary arithmetic scalar functions (#14089)
---
.../function/scalar/ArithmeticFunctions.java | 15 --
.../scalar/arithmetic/MinusScalarFunction.java | 66 +++++++++
.../scalar/arithmetic/MultScalarFunction.java | 66 +++++++++
.../scalar/arithmetic/PlusScalarFunction.java | 66 +++++++++
.../PolymorphicBinaryArithmeticScalarFunction.java | 67 +++++++++
.../scalar/comparison/EqualsScalarFunction.java | 4 +-
.../GreaterThanOrEqualScalarFunction.java | 9 +-
.../comparison/GreaterThanScalarFunction.java | 4 +-
.../comparison/LessThanOrEqualScalarFunction.java | 4 +-
.../scalar/comparison/LessThanScalarFunction.java | 4 +-
.../scalar/comparison/NotEqualsScalarFunction.java | 4 +-
.../pinot/sql/parsers/CalciteSqlCompilerTest.java | 24 ++++
.../PostAggregationFunctionTest.java | 4 +-
.../tests/OfflineClusterIntegrationTest.java | 151 +++++++++++++--------
.../pinot/calcite/sql/fun/PinotOperatorTable.java | 5 +-
.../resources/queries/LiteralEvaluationPlans.json | 4 +-
.../ExpressionTransformerTest.java | 2 +-
17 files changed, 406 insertions(+), 93 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
index 94489c92b1..27c4952b1f 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java
@@ -30,21 +30,6 @@ public class ArithmeticFunctions {
private ArithmeticFunctions() {
}
- @ScalarFunction(names = {"add", "plus"})
- public static double plus(double a, double b) {
- return a + b;
- }
-
- @ScalarFunction(names = {"sub", "minus"})
- public static double minus(double a, double b) {
- return a - b;
- }
-
- @ScalarFunction(names = {"mult", "times"})
- public static double times(double a, double b) {
- return a * b;
- }
-
@ScalarFunction(names = {"div", "divide"})
public static double divide(double a, double b) {
return a / b;
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MinusScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MinusScalarFunction.java
new file mode 100644
index 0000000000..61488e58e7
--- /dev/null
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MinusScalarFunction.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar.arithmetic;
+
+import java.util.EnumMap;
+import java.util.Map;
+import org.apache.pinot.common.function.FunctionInfo;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.spi.annotations.ScalarFunction;
+
+
+@ScalarFunction(names = {"sub", "minus"})
+public class MinusScalarFunction extends
PolymorphicBinaryArithmeticScalarFunction {
+
+ private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
+
+ static {
+ try {
+ TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG,
+ new FunctionInfo(MinusScalarFunction.class.getMethod("longMinus",
long.class, long.class),
+ MinusScalarFunction.class, false));
+ TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE,
+ new FunctionInfo(MinusScalarFunction.class.getMethod("doubleMinus",
double.class, double.class),
+ MinusScalarFunction.class, false));
+ } catch (NoSuchMethodException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ protected FunctionInfo functionInfoForType(ColumnDataType argumentType) {
+ FunctionInfo functionInfo = TYPE_FUNCTION_INFO_MAP.get(argumentType);
+
+ // Fall back to double based comparison by default
+ return functionInfo != null ? functionInfo :
TYPE_FUNCTION_INFO_MAP.get(ColumnDataType.DOUBLE);
+ }
+
+ @Override
+ public String getName() {
+ return "minus";
+ }
+
+ public static long longMinus(long a, long b) {
+ return a - b;
+ }
+
+ public static double doubleMinus(double a, double b) {
+ return a - b;
+ }
+}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MultScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MultScalarFunction.java
new file mode 100644
index 0000000000..a737045393
--- /dev/null
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/MultScalarFunction.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar.arithmetic;
+
+import java.util.EnumMap;
+import java.util.Map;
+import org.apache.pinot.common.function.FunctionInfo;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.spi.annotations.ScalarFunction;
+
+
+@ScalarFunction(names = {"mult", "times"})
+public class MultScalarFunction extends
PolymorphicBinaryArithmeticScalarFunction {
+
+ private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
+
+ static {
+ try {
+ TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG,
+ new FunctionInfo(MultScalarFunction.class.getMethod("longMult",
long.class, long.class),
+ MultScalarFunction.class, false));
+ TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE,
+ new FunctionInfo(MultScalarFunction.class.getMethod("doubleMult",
double.class, double.class),
+ MultScalarFunction.class, false));
+ } catch (NoSuchMethodException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ protected FunctionInfo functionInfoForType(ColumnDataType argumentType) {
+ FunctionInfo functionInfo = TYPE_FUNCTION_INFO_MAP.get(argumentType);
+
+ // Fall back to double based comparison by default
+ return functionInfo != null ? functionInfo :
TYPE_FUNCTION_INFO_MAP.get(ColumnDataType.DOUBLE);
+ }
+
+ @Override
+ public String getName() {
+ return "mult";
+ }
+
+ public static long longMult(long a, long b) {
+ return a * b;
+ }
+
+ public static double doubleMult(double a, double b) {
+ return a * b;
+ }
+}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PlusScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PlusScalarFunction.java
new file mode 100644
index 0000000000..5951afa527
--- /dev/null
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PlusScalarFunction.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar.arithmetic;
+
+import java.util.EnumMap;
+import java.util.Map;
+import org.apache.pinot.common.function.FunctionInfo;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.spi.annotations.ScalarFunction;
+
+
+@ScalarFunction(names = {"add", "plus"})
+public class PlusScalarFunction extends
PolymorphicBinaryArithmeticScalarFunction {
+
+ private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
+
+ static {
+ try {
+ TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.LONG,
+ new FunctionInfo(PlusScalarFunction.class.getMethod("longPlus",
long.class, long.class),
+ PlusScalarFunction.class, false));
+ TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.DOUBLE,
+ new FunctionInfo(PlusScalarFunction.class.getMethod("doublePlus",
double.class, double.class),
+ PlusScalarFunction.class, false));
+ } catch (NoSuchMethodException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ protected FunctionInfo functionInfoForType(ColumnDataType argumentType) {
+ FunctionInfo functionInfo = TYPE_FUNCTION_INFO_MAP.get(argumentType);
+
+ // Fall back to double based comparison by default
+ return functionInfo != null ? functionInfo :
TYPE_FUNCTION_INFO_MAP.get(ColumnDataType.DOUBLE);
+ }
+
+ @Override
+ public String getName() {
+ return "plus";
+ }
+
+ public static long longPlus(long a, long b) {
+ return a + b;
+ }
+
+ public static double doublePlus(double a, double b) {
+ return a + b;
+ }
+}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PolymorphicBinaryArithmeticScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PolymorphicBinaryArithmeticScalarFunction.java
new file mode 100644
index 0000000000..10167161f9
--- /dev/null
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/arithmetic/PolymorphicBinaryArithmeticScalarFunction.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar.arithmetic;
+
+import javax.annotation.Nullable;
+import org.apache.pinot.common.function.FunctionInfo;
+import org.apache.pinot.common.function.PinotScalarFunction;
+import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+
+
+/**
+ * Base class for polymorphic binary arithmetic scalar functions
+ */
+public abstract class PolymorphicBinaryArithmeticScalarFunction implements
PinotScalarFunction {
+
+ @Nullable
+ @Override
+ public FunctionInfo getFunctionInfo(ColumnDataType[] argumentTypes) {
+ if (argumentTypes.length != 2) {
+ return null;
+ }
+
+ return functionInfoForTypes(argumentTypes[0].getStoredType(),
argumentTypes[1].getStoredType());
+ }
+
+ @Nullable
+ @Override
+ public FunctionInfo getFunctionInfo(int numArguments) {
+ if (numArguments != 2) {
+ return null;
+ }
+
+ // For backward compatibility
+ return functionInfoForType(ColumnDataType.DOUBLE);
+ }
+
+ private FunctionInfo functionInfoForTypes(ColumnDataType argumentType1,
ColumnDataType argumentType2) {
+ if ((argumentType1 == ColumnDataType.LONG || argumentType1 ==
ColumnDataType.INT) && (
+ argumentType2 == ColumnDataType.LONG || argumentType2 ==
ColumnDataType.INT)) {
+ return functionInfoForType(ColumnDataType.LONG);
+ }
+
+ // Fall back to double based comparison by default
+ return functionInfoForType(ColumnDataType.DOUBLE);
+ }
+
+ /**
+ * Get the binary arithmetic scalar function's {@link FunctionInfo} for the
given argument type.
+ */
+ protected abstract FunctionInfo functionInfoForType(ColumnDataType
argumentType);
+}
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java
index 656722ccc8..0bc0fcb075 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/EqualsScalarFunction.java
@@ -20,7 +20,7 @@ package org.apache.pinot.common.function.scalar.comparison;
import java.math.BigDecimal;
import java.util.Arrays;
-import java.util.HashMap;
+import java.util.EnumMap;
import java.util.Map;
import java.util.Objects;
import org.apache.pinot.common.function.FunctionInfo;
@@ -33,7 +33,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
@ScalarFunction
public class EqualsScalarFunction extends PolymorphicComparisonScalarFunction {
- private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+ private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
private static final FunctionInfo DOUBLE_EQUALS_WITH_TOLERANCE;
static {
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java
index cdf27b0f5e..d7782cf7e7 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanOrEqualScalarFunction.java
@@ -19,7 +19,7 @@
package org.apache.pinot.common.function.scalar.comparison;
import java.math.BigDecimal;
-import java.util.HashMap;
+import java.util.EnumMap;
import java.util.Map;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
@ScalarFunction
public class GreaterThanOrEqualScalarFunction extends
PolymorphicComparisonScalarFunction {
- private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+ private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
static {
try {
@@ -51,9 +51,8 @@ public class GreaterThanOrEqualScalarFunction extends
PolymorphicComparisonScala
GreaterThanOrEqualScalarFunction.class.getMethod("doubleGreaterThanOrEqual",
double.class, double.class),
GreaterThanOrEqualScalarFunction.class, false));
TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.BIG_DECIMAL, new FunctionInfo(
-
GreaterThanOrEqualScalarFunction.class.getMethod("bigDecimalGreaterThanOrEqual",
- BigDecimal.class, BigDecimal.class),
- GreaterThanOrEqualScalarFunction.class, false));
+
GreaterThanOrEqualScalarFunction.class.getMethod("bigDecimalGreaterThanOrEqual",
BigDecimal.class,
+ BigDecimal.class), GreaterThanOrEqualScalarFunction.class,
false));
TYPE_FUNCTION_INFO_MAP.put(ColumnDataType.STRING, new FunctionInfo(
GreaterThanOrEqualScalarFunction.class.getMethod("stringGreaterThanOrEqual",
String.class, String.class),
GreaterThanOrEqualScalarFunction.class, false));
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java
index be8775f549..a41ddb6823 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/GreaterThanScalarFunction.java
@@ -19,7 +19,7 @@
package org.apache.pinot.common.function.scalar.comparison;
import java.math.BigDecimal;
-import java.util.HashMap;
+import java.util.EnumMap;
import java.util.Map;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
@ScalarFunction
public class GreaterThanScalarFunction extends
PolymorphicComparisonScalarFunction {
- private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+ private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
static {
try {
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java
index 941c1a6d56..7ff076744e 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanOrEqualScalarFunction.java
@@ -19,7 +19,7 @@
package org.apache.pinot.common.function.scalar.comparison;
import java.math.BigDecimal;
-import java.util.HashMap;
+import java.util.EnumMap;
import java.util.Map;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
@ScalarFunction
public class LessThanOrEqualScalarFunction extends
PolymorphicComparisonScalarFunction {
- private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+ private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
static {
try {
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java
index e9d722370e..d2d85d9bbf 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/LessThanScalarFunction.java
@@ -19,7 +19,7 @@
package org.apache.pinot.common.function.scalar.comparison;
import java.math.BigDecimal;
-import java.util.HashMap;
+import java.util.EnumMap;
import java.util.Map;
import org.apache.pinot.common.function.FunctionInfo;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
@@ -32,7 +32,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
@ScalarFunction
public class LessThanScalarFunction extends
PolymorphicComparisonScalarFunction {
- private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+ private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
static {
try {
diff --git
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java
index 7f63a1eb9e..8344514646 100644
---
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java
+++
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/comparison/NotEqualsScalarFunction.java
@@ -20,7 +20,7 @@ package org.apache.pinot.common.function.scalar.comparison;
import java.math.BigDecimal;
import java.util.Arrays;
-import java.util.HashMap;
+import java.util.EnumMap;
import java.util.Map;
import java.util.Objects;
import org.apache.pinot.common.function.FunctionInfo;
@@ -33,7 +33,7 @@ import org.apache.pinot.spi.annotations.ScalarFunction;
@ScalarFunction
public class NotEqualsScalarFunction extends
PolymorphicComparisonScalarFunction {
- private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new HashMap<>();
+ private static final Map<ColumnDataType, FunctionInfo>
TYPE_FUNCTION_INFO_MAP = new EnumMap<>(ColumnDataType.class);
private static final FunctionInfo DOUBLE_NOT_EQUALS_WITH_TOLERANCE;
static {
diff --git
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 35a625505a..34e2a6b5f5 100644
---
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -2339,6 +2339,30 @@ public class CalciteSqlCompilerTest {
long result = expression.getLiteral().getLongValue();
Assert.assertTrue(result >= lowerBound && result <= upperBound);
+ expression = compileToExpression("now() - 0");
+ Assert.assertNotNull(expression.getFunctionCall());
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+ Assert.assertNotNull(expression.getLiteral());
+ upperBound = System.currentTimeMillis();
+ result = expression.getLiteral().getLongValue();
+ Assert.assertTrue(result >= lowerBound && result <= upperBound);
+
+ expression = compileToExpression("now() + 0");
+ Assert.assertNotNull(expression.getFunctionCall());
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+ Assert.assertNotNull(expression.getLiteral());
+ upperBound = System.currentTimeMillis();
+ result = expression.getLiteral().getLongValue();
+ Assert.assertTrue(result >= lowerBound && result <= upperBound);
+
+ expression = compileToExpression("now() * 1");
+ Assert.assertNotNull(expression.getFunctionCall());
+ expression =
CompileTimeFunctionsInvoker.invokeCompileTimeFunctionExpression(expression);
+ Assert.assertNotNull(expression.getLiteral());
+ upperBound = System.currentTimeMillis();
+ result = expression.getLiteral().getLongValue();
+ Assert.assertTrue(result >= lowerBound && result <= upperBound);
+
lowerBound = TimeUnit.MILLISECONDS.toHours(System.currentTimeMillis()) + 1;
expression = compileToExpression("to_epoch_hours(now() + 3600000)");
Assert.assertNotNull(expression.getFunctionCall());
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
index 0c7b0e3e52..6f4cd02a29 100644
---
a/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunctionTest.java
@@ -37,8 +37,8 @@ public class PostAggregationFunctionTest {
// Plus
PostAggregationFunction function =
new PostAggregationFunction("plus", new
ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG});
- assertEquals(function.getResultType(), ColumnDataType.DOUBLE);
- assertEquals(function.invoke(new Object[]{1, 2L}), 3.0);
+ assertEquals(function.getResultType(), ColumnDataType.LONG);
+ assertEquals(function.invoke(new Object[]{1, 2L}), 3L);
// Minus
function = new PostAggregationFunction("MINUS", new
ColumnDataType[]{ColumnDataType.FLOAT, ColumnDataType.DOUBLE});
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
index be438702bf..2bcfcabef1 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java
@@ -2040,54 +2040,61 @@ public class OfflineClusterIntegrationTest extends
BaseClusterIntegrationTestSet
assertEquals(row.get(0).asLong(), 16138 * 24);
assertEquals(row.get(1).asLong(), 605);
- if (useMultiStageQueryEngine) {
- query = "SELECT add(DaysSinceEpoch,add(DaysSinceEpoch,15)), COUNT(*)
FROM mytable "
- + "GROUP BY add(DaysSinceEpoch,add(DaysSinceEpoch,15)) ORDER BY
COUNT(*) DESC";
- } else {
- query = "SELECT add(DaysSinceEpoch,DaysSinceEpoch,15), COUNT(*) FROM
mytable "
- + "GROUP BY add(DaysSinceEpoch,DaysSinceEpoch,15) ORDER BY COUNT(*)
DESC";
- }
+ query = "SELECT arrayLength(DivAirports), COUNT(*) FROM mytable "
+ + "GROUP BY arrayLength(DivAirports) ORDER BY COUNT(*) DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
- assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"DOUBLE\",\"LONG\"]");
+ assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"INT\",\"LONG\"]");
rows = resultTable.get("rows");
assertFalse(rows.isEmpty());
row = rows.get(0);
assertEquals(row.size(), 2);
- assertEquals(row.get(0).asDouble(), 16138.0 + 16138 + 15);
- assertEquals(row.get(1).asLong(), 605);
+ assertEquals(row.get(0).asInt(), 5);
+ assertEquals(row.get(1).asLong(), 115545);
- query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable "
- + "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC";
+ query = "SELECT arrayLength(valueIn(DivAirports,'DFW','ORD')), COUNT(*)
FROM mytable GROUP BY "
+ + "arrayLength(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*)
DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
- assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"DOUBLE\",\"LONG\"]");
+ assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"INT\",\"LONG\"]");
rows = resultTable.get("rows");
- assertFalse(rows.isEmpty());
+ assertEquals(rows.size(), 3);
row = rows.get(0);
assertEquals(row.size(), 2);
- assertEquals(row.get(0).asDouble(), 16138.0 - 25);
- assertEquals(row.get(1).asLong(), 605);
+ assertEquals(row.get(0).asInt(), 0);
+ assertEquals(row.get(1).asLong(), 114895);
+ row = rows.get(1);
+ assertEquals(row.size(), 2);
+ assertEquals(row.get(0).asInt(), 1);
+ assertEquals(row.get(1).asLong(), 648);
+ row = rows.get(2);
+ assertEquals(row.size(), 2);
+ assertEquals(row.get(0).asInt(), 2);
+ assertEquals(row.get(1).asLong(), 2);
- if (useMultiStageQueryEngine) {
- query = "SELECT mult(DaysSinceEpoch,mult(24,3600)), COUNT(*) FROM
mytable "
- + "GROUP BY mult(DaysSinceEpoch,mult(24,3600)) ORDER BY COUNT(*)
DESC";
+ if (useMultiStageQueryEngine()) {
+ query = "SELECT arrayToMV(valueIn(DivAirports,'DFW','ORD')), COUNT(*)
FROM mytable "
+ + "GROUP BY arrayToMV(valueIn(DivAirports,'DFW','ORD')) ORDER BY
COUNT(*) DESC";
} else {
- query = "SELECT mult(DaysSinceEpoch,24,3600), COUNT(*) FROM mytable "
- + "GROUP BY mult(DaysSinceEpoch,24,3600) ORDER BY COUNT(*) DESC";
+ query = "SELECT valueIn(DivAirports,'DFW','ORD'), COUNT(*) FROM mytable "
+ + "GROUP BY valueIn(DivAirports,'DFW','ORD') ORDER BY COUNT(*) DESC";
}
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
- assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"DOUBLE\",\"LONG\"]");
+ assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"STRING\",\"LONG\"]");
rows = resultTable.get("rows");
- assertFalse(rows.isEmpty());
+ assertEquals(rows.size(), 2);
row = rows.get(0);
assertEquals(row.size(), 2);
- assertEquals(row.get(0).asDouble(), 16138.0 * 24 * 3600);
- assertEquals(row.get(1).asLong(), 605);
+ assertEquals(row.get(0).asText(), "ORD");
+ assertEquals(row.get(1).asLong(), 336);
+ row = rows.get(1);
+ assertEquals(row.size(), 2);
+ assertEquals(row.get(0).asText(), "DFW");
+ assertEquals(row.get(1).asLong(), 316);
query = "SELECT div(DaysSinceEpoch,2), COUNT(*) FROM mytable "
+ "GROUP BY div(DaysSinceEpoch,2) ORDER BY COUNT(*) DESC";
@@ -2101,62 +2108,92 @@ public class OfflineClusterIntegrationTest extends
BaseClusterIntegrationTestSet
assertEquals(row.size(), 2);
assertEquals(row.get(0).asDouble(), 16138.0 / 2);
assertEquals(row.get(1).asLong(), 605);
+ }
- query = "SELECT arrayLength(DivAirports), COUNT(*) FROM mytable "
- + "GROUP BY arrayLength(DivAirports) ORDER BY COUNT(*) DESC";
+ @Test
+ public void testGroupByUDFV1() throws Exception {
+ setUseMultiStageQueryEngine(false);
+ String query = "SELECT add(DaysSinceEpoch,DaysSinceEpoch,15), COUNT(*)
FROM mytable "
+ + "GROUP BY add(DaysSinceEpoch,DaysSinceEpoch,15) ORDER BY COUNT(*)
DESC";
+ JsonNode response = postQuery(query);
+ JsonNode resultTable = response.get("resultTable");
+ JsonNode dataSchema = resultTable.get("dataSchema");
+ assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"DOUBLE\",\"LONG\"]");
+ JsonNode rows = resultTable.get("rows");
+ assertFalse(rows.isEmpty());
+ JsonNode row = rows.get(0);
+ assertEquals(row.size(), 2);
+ assertEquals(row.get(0).asDouble(), 16138.0 + 16138 + 15);
+ assertEquals(row.get(1).asLong(), 605);
+
+ query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable "
+ + "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
- assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"INT\",\"LONG\"]");
+ assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"DOUBLE\",\"LONG\"]");
rows = resultTable.get("rows");
assertFalse(rows.isEmpty());
row = rows.get(0);
assertEquals(row.size(), 2);
- assertEquals(row.get(0).asInt(), 5);
- assertEquals(row.get(1).asLong(), 115545);
+ assertEquals(row.get(0).asDouble(), 16138.0 - 25);
+ assertEquals(row.get(1).asLong(), 605);
- query = "SELECT arrayLength(valueIn(DivAirports,'DFW','ORD')), COUNT(*)
FROM mytable GROUP BY "
- + "arrayLength(valueIn(DivAirports,'DFW','ORD')) ORDER BY COUNT(*)
DESC";
+ query = "SELECT mult(DaysSinceEpoch,24,3600), COUNT(*) FROM mytable "
+ + "GROUP BY mult(DaysSinceEpoch,24,3600) ORDER BY COUNT(*) DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
- assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"INT\",\"LONG\"]");
+ assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"DOUBLE\",\"LONG\"]");
rows = resultTable.get("rows");
- assertEquals(rows.size(), 3);
+ assertFalse(rows.isEmpty());
row = rows.get(0);
assertEquals(row.size(), 2);
- assertEquals(row.get(0).asInt(), 0);
- assertEquals(row.get(1).asLong(), 114895);
- row = rows.get(1);
- assertEquals(row.size(), 2);
- assertEquals(row.get(0).asInt(), 1);
- assertEquals(row.get(1).asLong(), 648);
- row = rows.get(2);
+ assertEquals(row.get(0).asDouble(), 16138.0 * 24 * 3600);
+ assertEquals(row.get(1).asLong(), 605);
+ }
+
+ @Test
+ public void testGroupByUDFV2() throws Exception {
+ setUseMultiStageQueryEngine(true);
+ String query = "SELECT add(DaysSinceEpoch,add(DaysSinceEpoch,15)),
COUNT(*) FROM mytable "
+ + "GROUP BY add(DaysSinceEpoch,add(DaysSinceEpoch,15)) ORDER BY
COUNT(*) DESC";
+ JsonNode response = postQuery(query);
+ JsonNode resultTable = response.get("resultTable");
+ JsonNode dataSchema = resultTable.get("dataSchema");
+ assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"INT\",\"LONG\"]");
+ JsonNode rows = resultTable.get("rows");
+ assertFalse(rows.isEmpty());
+ JsonNode row = rows.get(0);
assertEquals(row.size(), 2);
- assertEquals(row.get(0).asInt(), 2);
- assertEquals(row.get(1).asLong(), 2);
+ assertEquals(row.get(0).asInt(), 16138 + 16138 + 15);
+ assertEquals(row.get(1).asLong(), 605);
- if (useMultiStageQueryEngine()) {
- query = "SELECT arrayToMV(valueIn(DivAirports,'DFW','ORD')), COUNT(*)
FROM mytable "
- + "GROUP BY arrayToMV(valueIn(DivAirports,'DFW','ORD')) ORDER BY
COUNT(*) DESC";
- } else {
- query = "SELECT valueIn(DivAirports,'DFW','ORD'), COUNT(*) FROM mytable "
- + "GROUP BY valueIn(DivAirports,'DFW','ORD') ORDER BY COUNT(*) DESC";
- }
+ query = "SELECT sub(DaysSinceEpoch,25), COUNT(*) FROM mytable "
+ + "GROUP BY sub(DaysSinceEpoch,25) ORDER BY COUNT(*) DESC";
response = postQuery(query);
resultTable = response.get("resultTable");
dataSchema = resultTable.get("dataSchema");
- assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"STRING\",\"LONG\"]");
+ assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"INT\",\"LONG\"]");
rows = resultTable.get("rows");
- assertEquals(rows.size(), 2);
+ assertFalse(rows.isEmpty());
row = rows.get(0);
assertEquals(row.size(), 2);
- assertEquals(row.get(0).asText(), "ORD");
- assertEquals(row.get(1).asLong(), 336);
- row = rows.get(1);
+ assertEquals(row.get(0).asInt(), 16138 - 25);
+ assertEquals(row.get(1).asLong(), 605);
+
+ query = "SELECT mult(DaysSinceEpoch,mult(24,3600)), COUNT(*) FROM mytable "
+ + "GROUP BY mult(DaysSinceEpoch,mult(24,3600)) ORDER BY COUNT(*) DESC";
+ response = postQuery(query);
+ resultTable = response.get("resultTable");
+ dataSchema = resultTable.get("dataSchema");
+ assertEquals(dataSchema.get("columnDataTypes").toString(),
"[\"INT\",\"LONG\"]");
+ rows = resultTable.get("rows");
+ assertFalse(rows.isEmpty());
+ row = rows.get(0);
assertEquals(row.size(), 2);
- assertEquals(row.get(0).asText(), "DFW");
- assertEquals(row.get(1).asLong(), 316);
+ assertEquals(row.get(0).asInt(), 16138 * 24 * 3600);
+ assertEquals(row.get(1).asLong(), 605);
}
@Test
diff --git
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
index 5e282544d2..0c1a8d8a48 100644
---
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
+++
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java
@@ -223,7 +223,10 @@ public class PinotOperatorTable implements
SqlOperatorTable {
Pair.of(SqlStdOperatorTable.GREATER_THAN, List.of("GREATER_THAN")),
Pair.of(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL,
List.of("GREATER_THAN_OR_EQUAL")),
Pair.of(SqlStdOperatorTable.LESS_THAN, List.of("LESS_THAN")),
- Pair.of(SqlStdOperatorTable.LESS_THAN_OR_EQUAL,
List.of("LESS_THAN_OR_EQUAL"))
+ Pair.of(SqlStdOperatorTable.LESS_THAN_OR_EQUAL,
List.of("LESS_THAN_OR_EQUAL")),
+ Pair.of(SqlStdOperatorTable.MINUS, List.of("SUB", "MINUS")),
+ Pair.of(SqlStdOperatorTable.PLUS, List.of("ADD", "PLUS")),
+ Pair.of(SqlStdOperatorTable.MULTIPLY, List.of("MULT", "TIMES"))
);
/**
diff --git
a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
index 6298709bf5..8e513b76fa 100644
--- a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
@@ -15,7 +15,7 @@
"sql": "EXPLAIN PLAN FOR SELECT 5*6,5+6 FROM d",
"output": [
"Execution Plan",
- "\nLogicalProject(EXPR$0=[30.0], EXPR$1=[11.0])",
+ "\nLogicalProject(EXPR$0=[30], EXPR$1=[11])",
"\n LogicalTableScan(table=[[default, d]])",
"\n"
]
@@ -175,7 +175,7 @@
"sql": "EXPLAIN PLAN FOR SELECT 1 +
ToEpochDays(fromDateTime('1970-01-15', 'yyyy-MM-dd')) FROM a",
"output": [
"Execution Plan",
- "\nLogicalProject(EXPR$0=[15.0:BIGINT])",
+ "\nLogicalProject(EXPR$0=[15:BIGINT])",
"\n LogicalTableScan(table=[[default, a]])",
"\n"
]
diff --git
a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java
b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java
index 55d8d7172f..58de9ec70c 100644
---
a/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java
+++
b/pinot-segment-local/src/test/java/org/apache/pinot/segment/local/recordtransformer/ExpressionTransformerTest.java
@@ -365,7 +365,7 @@ public class ExpressionTransformerTest {
expressionTransformer.transform(genericRow);
Assert.fail();
} catch (Exception e) {
- Assert.assertEquals(e.getCause().getMessage(), "Caught exception while
executing function: plus(x,'10')");
+ Assert.assertTrue(e.getCause().getMessage().contains("Caught exception
while executing function"));
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]