This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 394caaf6850 [Feature](nereids) support median agg function using
percentile (#46377)
394caaf6850 is described below
commit 394caaf6850ca989fa09384abf47c4c60191a920
Author: shee <[email protected]>
AuthorDate: Thu Jan 9 17:14:04 2025 +0800
[Feature](nereids) support median agg function using percentile (#46377)
### What problem does this PR solve?
reference :
https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/median
---------
Co-authored-by: garenshi <[email protected]>
---
.../doris/catalog/BuiltinAggregateFunctions.java | 2 +
.../rules/expression/ExpressionNormalization.java | 2 +
.../rules/expression/ExpressionRuleType.java | 1 +
.../rules/expression/rules/MedianConvert.java | 46 +++++++++
.../trees/expressions/functions/agg/Median.java | 106 +++++++++++++++++++++
.../visitor/AggregateFunctionVisitor.java | 5 +
.../test_convert_median_to_percentile.out | 21 ++++
.../test_convert_median_to_percentile.groovy | 78 +++++++++++++++
8 files changed, 261 insertions(+)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
index 655a3dfce29..5e2e7291987 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java
@@ -50,6 +50,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.LinearHistogram;
import org.apache.doris.nereids.trees.expressions.functions.agg.MapAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.MaxBy;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Median;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.MinBy;
import
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
@@ -131,6 +132,7 @@ public class BuiltinAggregateFunctions implements
FunctionHelper {
agg(MapAgg.class, "map_agg"),
agg(Max.class, "max"),
agg(MaxBy.class, "max_by"),
+ agg(Median.class, "median"),
agg(Min.class, "min"),
agg(MinBy.class, "min_by"),
agg(MultiDistinctCount.class, "multi_distinct_count"),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
index 0e52f2aaaf6..b4430d33087 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
@@ -23,6 +23,7 @@ import
org.apache.doris.nereids.rules.expression.rules.DigitalMaskingConvert;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup;
import
org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule;
+import org.apache.doris.nereids.rules.expression.rules.MedianConvert;
import org.apache.doris.nereids.rules.expression.rules.MergeDateTrunc;
import
org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule;
import
org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticComparisonRule;
@@ -52,6 +53,7 @@ public class ExpressionNormalization extends
ExpressionRewrite {
FoldConstantRule.INSTANCE,
SimplifyCastRule.INSTANCE,
DigitalMaskingConvert.INSTANCE,
+ MedianConvert.INSTANCE,
SimplifyArithmeticComparisonRule.INSTANCE,
ConvertAggStateCast.INSTANCE,
MergeDateTrunc.INSTANCE,
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
index 16881a61add..bc12c0459ee 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
@@ -38,6 +38,7 @@ public enum ExpressionRuleType {
IN_PREDICATE_TO_EQUAL_TO,
LIKE_TO_EQUAL,
MERGE_DATE_TRUNC,
+ MEDIAN_CONVERT,
NORMALIZE_BINARY_PREDICATES,
NULL_SAFE_EQUAL_TO_EQUAL,
REPLACE_VARIABLE_BY_LITERAL,
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/MedianConvert.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/MedianConvert.java
new file mode 100644
index 00000000000..e6e295ef98c
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/MedianConvert.java
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
+import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
+import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Median;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Percentile;
+import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * median(col) -> percentile(col, 0.5)
+ */
+public class MedianConvert implements ExpressionPatternRuleFactory {
+ public static MedianConvert INSTANCE = new MedianConvert();
+
+ @Override
+ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
+ return ImmutableList.of(
+ matchesType(Median.class).then(median ->
+ new Percentile(median.child(0), DoubleLiteral.of(0.5))
+ ).toRule(ExpressionRuleType.MEDIAN_CONVERT)
+ );
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Median.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Median.java
new file mode 100644
index 00000000000..342604f526c
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Median.java
@@ -0,0 +1,106 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.agg;
+
+import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.FloatType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.SmallIntType;
+import org.apache.doris.nereids.types.TinyIntType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * AggregateFunction 'median'. This class is generated by GenerateFunction.
+ */
+public class Median extends NullableAggregateFunction
+ implements UnaryExpression, ExplicitlyCastableSignature {
+
+ public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE)
+
+ );
+
+ /**
+ * constructor with 2 arguments.
+ */
+ public Median(Expression arg) {
+ this(false, arg);
+ }
+
+ /**
+ * constructor with 2 arguments.
+ */
+ public Median(boolean distinct, Expression arg) {
+ this(distinct, false, arg);
+ }
+
+ public Median(boolean distinct, boolean alwaysNullable, Expression arg) {
+ super("median", distinct, alwaysNullable, arg);
+ }
+
+ @Override
+ public void checkLegalityBeforeTypeCoercion() {
+ DataType argType = child().getDataType();
+ if (((!argType.isNumericType() && !argType.isNullType()) ||
argType.isOnlyMetricType())) {
+ throw new AnalysisException("median requires a numeric parameter:
" + toSql());
+ }
+ }
+
+ /**
+ * withDistinctAndChildren.
+ */
+ @Override
+ public Median withDistinctAndChildren(boolean distinct, List<Expression>
children) {
+ Preconditions.checkArgument(children.size() == 1);
+ return new Median(distinct, alwaysNullable, children.get(0));
+ }
+
+ @Override
+ public NullableAggregateFunction withAlwaysNullable(boolean
alwaysNullable) {
+ return new Median(distinct, alwaysNullable, children.get(0));
+ }
+
+ @Override
+ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ return visitor.visitMedian(this, context);
+ }
+
+ @Override
+ public List<FunctionSignature> getSignatures() {
+ return SIGNATURES;
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
index 50ca233cdbc..1918eb1b7b4 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
@@ -51,6 +51,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.LinearHistogram;
import org.apache.doris.nereids.trees.expressions.functions.agg.MapAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.MaxBy;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Median;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.MinBy;
import
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
@@ -261,6 +262,10 @@ public interface AggregateFunctionVisitor<R, C> {
return visitAggregateFunction(function, context);
}
+ default R visitMedian(Median median, C context) {
+ return visitNullableAggregateFunction(median, context);
+ }
+
default R visitPercentile(Percentile percentile, C context) {
return visitNullableAggregateFunction(percentile, context);
}
diff --git
a/regression-test/data/nereids_rules_p0/expression/test_convert_median_to_percentile.out
b/regression-test/data/nereids_rules_p0/expression/test_convert_median_to_percentile.out
new file mode 100644
index 00000000000..3cccc3d6fb0
--- /dev/null
+++
b/regression-test/data/nereids_rules_p0/expression/test_convert_median_to_percentile.out
@@ -0,0 +1,21 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !select_1 --
+101
+
+-- !select_2 --
+101
+
+-- !select_3 --
+2000 100
+2001 125.5
+
+-- !select_4 --
+2000 100
+2001 125.5
+
+-- !select_5 --
+2001 125.5
+
+-- !select_6 --
+2001 125.5
+
diff --git
a/regression-test/suites/nereids_rules_p0/expression/test_convert_median_to_percentile.groovy
b/regression-test/suites/nereids_rules_p0/expression/test_convert_median_to_percentile.groovy
new file mode 100644
index 00000000000..dd5b8032df6
--- /dev/null
+++
b/regression-test/suites/nereids_rules_p0/expression/test_convert_median_to_percentile.groovy
@@ -0,0 +1,78 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_convert_median_to_percentile") {
+ sql "SET enable_nereids_planner=true"
+ sql "SET enable_fallback_to_original_planner=false"
+
+ sql "create database if not exists test_convert_median_to_percentile"
+ sql "use test_convert_median_to_percentile"
+
+ sql "DROP TABLE IF EXISTS sales"
+ sql """
+ CREATE TABLE sales (
+ year INT,
+ country STRING,
+ product STRING,
+ profit INT
+ )
+ DISTRIBUTED BY HASH(`year`) BUCKETS 1
+ PROPERTIES (
+ "replication_num" = "1"
+ )
+ """
+ sql """
+ INSERT INTO sales VALUES
+ (2000,'Finland','Computer',1501),
+ (2000,'Finland','Phone',100),
+ (2001,'Finland','Phone',10),
+ (2000,'India','Calculator',75),
+ (2000,'India','Calculator',76),
+ (2000,'India','Computer',1201),
+ (2000,'USA','Calculator',77),
+ (2000,'USA','Computer',1502),
+ (2001,'USA','Calculator',50),
+ (2001,'USA','Computer',1503),
+ (2001,'USA','Computer',1202),
+ (2001,'USA','TV',150),
+ (2001,'USA','TV',101);
+ """
+
+ def sql1 = "select median(profit) from sales"
+ def sql2 = "select percentile(profit, 0.5) from sales"
+ def explainStr1 = sql """ explain ${sql1} """
+ assertTrue(explainStr1.toString().contains("percentile(profit, 0.5)"))
+ qt_select_1 "${sql1}"
+ qt_select_2 "${sql2}"
+
+ def sql3 = "select year, median(profit) from sales group by year order by
year"
+ def sql4 = "select year, percentile(profit, 0.5) from sales group by year
order by year"
+ def explainStr3 = sql """ explain ${sql3} """
+ assertTrue(explainStr3.toString().contains("percentile(profit, 0.5)"))
+ qt_select_3 "${sql3}"
+ qt_select_4 "${sql4}"
+
+ def sql5 = "select year, median(profit) from sales group by year having
median(profit) > 100"
+ def sql6 = "select year, percentile(profit, 0.5) from sales group by year
having percentile(profit, 0.5) > 100"
+ def explainStr5 = sql """ explain ${sql5} """
+ assertTrue(explainStr5.toString().contains("percentile(profit, 0.5)"))
+ qt_select_5 "${sql5}"
+ qt_select_6 "${sql6}"
+
+ sql "DROP TABLE if exists sales"
+ sql "DROP DATABASE if exists test_convert_median_to_percentile"
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]