This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 8407490053c [feature-wip](nereids) Support some spark-sql built-in
functions when set dialect=spark_sql (#28531)
8407490053c is described below
commit 8407490053cdc4cea94cef661739b7852b2d550d
Author: Xiangyu Wang <[email protected]>
AuthorDate: Sat Dec 30 00:10:35 2023 +0800
[feature-wip](nereids) Support some spark-sql built-in functions when set
dialect=spark_sql (#28531)
---
.../nereids/analyzer/PlaceholderExpression.java | 31 ++++++--
.../nereids/parser/AbstractFnCallTransformers.java | 6 +-
.../nereids/parser/CommonFnCallTransformer.java | 91 +++++++++++++---------
...nsformer.java => ComplexFnCallTransformer.java} | 6 +-
.../DateTruncFnCallTransformer.java} | 49 ++++++++----
.../parser/spark/SparkSql3FnCallTransformers.java | 56 +++++++++----
.../parser/trino/DateDiffFnCallTransformer.java | 8 +-
.../parser/trino/TrinoFnCallTransformers.java | 7 +-
.../nereids/parser/spark/FnTransformTest.java | 72 ++++++++++++-----
9 files changed, 214 insertions(+), 112 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java
index b2acdddd64f..9b2dcde49bd 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java
@@ -23,9 +23,11 @@ import
org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Objects;
+import java.util.Set;
/**
* Expression placeHolder, the expression in PlaceHolderExpression will be
collected by
@@ -33,15 +35,25 @@ import java.util.Objects;
* @see PlaceholderCollector
*/
public class PlaceholderExpression extends Expression implements
AlwaysNotNullable {
- private final Class<? extends Expression> delegateClazz;
+
+ private final ImmutableSet<Class<? extends Expression>> delegateClazzSet;
/**
- * 1 based
+ * start from 1, set the index of this placeholderExpression in
sourceFnTransformedArguments
+ * this placeholderExpression will be replaced later
*/
private final int position;
public PlaceholderExpression(List<Expression> children, Class<? extends
Expression> delegateClazz, int position) {
super(children);
- this.delegateClazz = Objects.requireNonNull(delegateClazz,
"delegateClazz should not be null");
+ this.delegateClazzSet = ImmutableSet.of(
+ Objects.requireNonNull(delegateClazz, "delegateClazz should
not be null"));
+ this.position = position;
+ }
+
+ public PlaceholderExpression(List<Expression> children,
+ Set<Class<? extends Expression>>
delegateClazzSet, int position) {
+ super(children);
+ this.delegateClazzSet = ImmutableSet.copyOf(delegateClazzSet);
this.position = position;
}
@@ -49,13 +61,18 @@ public class PlaceholderExpression extends Expression
implements AlwaysNotNullab
return new PlaceholderExpression(ImmutableList.of(), delegateClazz,
position);
}
+ public static PlaceholderExpression of(Set<Class<? extends Expression>>
delegateClazzSet, int position) {
+ return new PlaceholderExpression(ImmutableList.of(), delegateClazzSet,
position);
+ }
+
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ visitor.visitPlaceholderExpression(this, context);
return visitor.visit(this, context);
}
- public Class<? extends Expression> getDelegateClazz() {
- return delegateClazz;
+ public Set<Class<? extends Expression>> getDelegateClazzSet() {
+ return delegateClazzSet;
}
public int getPosition() {
@@ -74,11 +91,11 @@ public class PlaceholderExpression extends Expression
implements AlwaysNotNullab
return false;
}
PlaceholderExpression that = (PlaceholderExpression) o;
- return position == that.position && Objects.equals(delegateClazz,
that.delegateClazz);
+ return position == that.position && Objects.equals(delegateClazzSet,
that.delegateClazzSet);
}
@Override
public int hashCode() {
- return Objects.hash(super.hashCode(), delegateClazz, position);
+ return Objects.hash(super.hashCode(), delegateClazzSet, position);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/AbstractFnCallTransformers.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/AbstractFnCallTransformers.java
index 75b5f87263c..4386123c428 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/AbstractFnCallTransformers.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/AbstractFnCallTransformers.java
@@ -77,17 +77,15 @@ public abstract class AbstractFnCallTransformers {
protected void doRegister(
String sourceFnNme,
- int sourceFnArgumentsNum,
String targetFnName,
- List<? extends Expression> targetFnArguments,
- boolean variableArgument) {
+ List<? extends Expression> targetFnArguments) {
List<Expression> castedTargetFnArguments = targetFnArguments
.stream()
.map(each -> (Expression) each)
.collect(Collectors.toList());
transformerBuilder.put(sourceFnNme, new CommonFnCallTransformer(new
UnboundFunction(
- targetFnName, castedTargetFnArguments), variableArgument,
sourceFnArgumentsNum));
+ targetFnName, castedTargetFnArguments)));
}
protected void doRegister(
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/CommonFnCallTransformer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/CommonFnCallTransformer.java
index 872c21a71e5..fa054160674 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/CommonFnCallTransformer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/CommonFnCallTransformer.java
@@ -23,56 +23,76 @@ import
org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
+import com.google.common.collect.Lists;
+
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
- * Trino function transformer
+ * Common function transformer,
+ * can transform functions which the size and type of target arguments are
both the same with the source function,
+ * or source function is a variable-arguments function.
*/
public class CommonFnCallTransformer extends AbstractFnCallTransformer {
private final UnboundFunction targetFunction;
private final List<PlaceholderExpression> targetArguments;
- private final boolean variableArgument;
- private final int sourceArgumentsNum;
+
+ // true means the arguments of this function is dynamic, for example:
+ // - named_struct('f1', 1, 'f2', 'a', 'f3', "abc")
+ // - struct(1, 'a', 'abc');
+ private final boolean variableArguments;
/**
- * Trino function transformer, mostly this handle common function.
+ * Common function transformer, mostly this handle common function.
*/
- public CommonFnCallTransformer(UnboundFunction targetFunction,
- boolean variableArgument,
- int sourceArgumentsNum) {
+ public CommonFnCallTransformer(UnboundFunction targetFunction, boolean
variableArguments) {
+ this.targetFunction = targetFunction;
+ PlaceholderCollector placeHolderCollector = new PlaceholderCollector();
+ placeHolderCollector.visit(targetFunction, null);
+ this.targetArguments =
placeHolderCollector.getPlaceholderExpressions();
+ this.variableArguments = variableArguments;
+ }
+
+ public CommonFnCallTransformer(UnboundFunction targetFunction) {
this.targetFunction = targetFunction;
- this.variableArgument = variableArgument;
- this.sourceArgumentsNum = sourceArgumentsNum;
- PlaceholderCollector placeHolderCollector = new
PlaceholderCollector(variableArgument);
+ PlaceholderCollector placeHolderCollector = new PlaceholderCollector();
placeHolderCollector.visit(targetFunction, null);
this.targetArguments =
placeHolderCollector.getPlaceholderExpressions();
+ this.variableArguments = false;
}
@Override
protected boolean check(String sourceFnName,
List<Expression> sourceFnTransformedArguments,
ParserContext context) {
+ // if variableArguments=true, we can not recognize if the type of all
arguments is valid or not,
+ // because:
+ // 1. the argument size is not sure
+ // 2. there are some functions which can accept different types of
arguments,
+ // for example: struct(1, 'a', 'abc')
+ // so just return true here.
+ if (variableArguments) {
+ return true;
+ }
List<Class<? extends Expression>> sourceFnTransformedArgClazz =
sourceFnTransformedArguments.stream()
.map(Expression::getClass)
.collect(Collectors.toList());
- if (variableArgument) {
- if (targetArguments.isEmpty()) {
- return false;
- }
- Class<? extends Expression> targetArgumentClazz =
targetArguments.get(0).getDelegateClazz();
- for (Expression argument : sourceFnTransformedArguments) {
- if
(!targetArgumentClazz.isAssignableFrom(argument.getClass())) {
- return false;
- }
- }
- }
- if (sourceFnTransformedArguments.size() != sourceArgumentsNum) {
+ if (sourceFnTransformedArguments.size() != targetArguments.size()) {
return false;
}
- for (int i = 0; i < targetArguments.size(); i++) {
- if
(!targetArguments.get(i).getDelegateClazz().isAssignableFrom(sourceFnTransformedArgClazz.get(i)))
{
+ for (PlaceholderExpression targetArgument : targetArguments) {
+ // replace the arguments of target function by the position of
target argument
+ int position = targetArgument.getPosition();
+ Class<? extends Expression> sourceArgClazz =
sourceFnTransformedArgClazz.get(position - 1);
+ boolean valid = false;
+ for (Class<? extends Expression> targetArgClazz :
targetArgument.getDelegateClazzSet()) {
+ if (targetArgClazz.isAssignableFrom(sourceArgClazz)) {
+ valid = true;
+ break;
+ }
+ }
+ if (!valid) {
return false;
}
}
@@ -83,7 +103,16 @@ public class CommonFnCallTransformer extends
AbstractFnCallTransformer {
protected Function transform(String sourceFnName,
List<Expression> sourceFnTransformedArguments,
ParserContext context) {
- return targetFunction.withChildren(sourceFnTransformedArguments);
+ if (variableArguments) {
+ // not support adjust the order of arguments when
variableArguments=true
+ return targetFunction.withChildren(sourceFnTransformedArguments);
+ }
+ List<Expression> sourceFnTransformedArgumentsInorder =
Lists.newArrayList();
+ for (PlaceholderExpression placeholderExpression : targetArguments) {
+ Expression expression =
sourceFnTransformedArguments.get(placeholderExpression.getPosition() - 1);
+ sourceFnTransformedArgumentsInorder.add(expression);
+ }
+ return
targetFunction.withChildren(sourceFnTransformedArgumentsInorder);
}
/**
@@ -93,20 +122,12 @@ public class CommonFnCallTransformer extends
AbstractFnCallTransformer {
public static final class PlaceholderCollector extends
DefaultExpressionVisitor<Void, Void> {
private final List<PlaceholderExpression> placeholderExpressions = new
ArrayList<>();
- private final boolean variableArgument;
- public PlaceholderCollector(boolean variableArgument) {
- this.variableArgument = variableArgument;
- }
+ public PlaceholderCollector() {}
@Override
public Void visitPlaceholderExpression(PlaceholderExpression
placeholderExpression, Void context) {
-
- if (variableArgument) {
- placeholderExpressions.add(placeholderExpression);
- return null;
- }
- placeholderExpressions.set(placeholderExpression.getPosition() -
1, placeholderExpression);
+ placeholderExpressions.add(placeholderExpression);
return null;
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/ComplexTrinoFnCallTransformer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/ComplexFnCallTransformer.java
similarity index 81%
rename from
fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/ComplexTrinoFnCallTransformer.java
rename to
fe/fe-core/src/main/java/org/apache/doris/nereids/parser/ComplexFnCallTransformer.java
index d3a687a289f..2583320b534 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/ComplexTrinoFnCallTransformer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/ComplexFnCallTransformer.java
@@ -15,14 +15,12 @@
// specific language governing permissions and limitations
// under the License.
-package org.apache.doris.nereids.parser.trino;
-
-import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
+package org.apache.doris.nereids.parser;
/**
* Trino complex function transformer
*/
-public abstract class ComplexTrinoFnCallTransformer extends
AbstractFnCallTransformer {
+public abstract class ComplexFnCallTransformer extends
AbstractFnCallTransformer {
protected abstract String getSourceFnName();
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/DateTruncFnCallTransformer.java
similarity index 51%
copy from
fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
copy to
fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/DateTruncFnCallTransformer.java
index b59a9327bde..1503c6db22c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/DateTruncFnCallTransformer.java
@@ -15,52 +15,67 @@
// specific language governing permissions and limitations
// under the License.
-package org.apache.doris.nereids.parser.trino;
+package org.apache.doris.nereids.parser.spark;
import org.apache.doris.nereids.analyzer.UnboundFunction;
+import org.apache.doris.nereids.parser.ComplexFnCallTransformer;
import org.apache.doris.nereids.parser.ParserContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
import java.util.List;
/**
- * DateDiff complex function transformer
+ * DateTrunc complex function transformer
*/
-public class DateDiffFnCallTransformer extends ComplexTrinoFnCallTransformer {
+public class DateTruncFnCallTransformer extends ComplexFnCallTransformer {
- private static final String SECOND = "second";
- private static final String HOUR = "hour";
- private static final String DAY = "day";
- private static final String MILLI_SECOND = "millisecond";
+ // reference: https://spark.apache.org/docs/latest/api/sql/index.html#trunc
+ // spark-sql support YEAR/YYYY/YY for year, support MONTH/MON/MM for month
+ private static final ImmutableSet<String> YEAR =
ImmutableSet.<String>builder()
+ .add("YEAR")
+ .add("YYYY")
+ .add("YY")
+ .build();
+
+ private static final ImmutableSet<String> MONTH =
ImmutableSet.<String>builder()
+ .add("MONTH")
+ .add("MON")
+ .add("MM")
+ .build();
@Override
public String getSourceFnName() {
- return "date_diff";
+ return "trunc";
}
@Override
protected boolean check(String sourceFnName, List<Expression>
sourceFnTransformedArguments,
ParserContext context) {
- return getSourceFnName().equalsIgnoreCase(sourceFnName);
+ return getSourceFnName().equalsIgnoreCase(sourceFnName) &&
(sourceFnTransformedArguments.size() == 2);
}
@Override
protected Function transform(String sourceFnName, List<Expression>
sourceFnTransformedArguments,
ParserContext context) {
- if (sourceFnTransformedArguments.size() != 3) {
- return null;
+ VarcharLiteral fmtLiteral = (VarcharLiteral)
sourceFnTransformedArguments.get(1);
+ if (YEAR.contains(fmtLiteral.getValue().toUpperCase())) {
+ return new UnboundFunction(
+ "date_trunc",
+ ImmutableList.of(sourceFnTransformedArguments.get(0), new
VarcharLiteral("YEAR")));
}
- VarcharLiteral diffGranularity = (VarcharLiteral)
sourceFnTransformedArguments.get(0);
- if (SECOND.equals(diffGranularity.getValue())) {
+ if (MONTH.contains(fmtLiteral.getValue().toUpperCase())) {
return new UnboundFunction(
- "seconds_diff",
- ImmutableList.of(sourceFnTransformedArguments.get(1),
sourceFnTransformedArguments.get(2)));
+ "date_trunc",
+ ImmutableList.of(sourceFnTransformedArguments.get(0), new
VarcharLiteral("MONTH")));
}
- // TODO: support other date diff granularity
- return null;
+
+ return new UnboundFunction(
+ "date_trunc",
+ ImmutableList.of(sourceFnTransformedArguments.get(0),
sourceFnTransformedArguments.get(1)));
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/SparkSql3FnCallTransformers.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/SparkSql3FnCallTransformers.java
index 5a6ec21fc9a..341a3d1951d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/SparkSql3FnCallTransformers.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/SparkSql3FnCallTransformers.java
@@ -18,15 +18,13 @@
package org.apache.doris.nereids.parser.spark;
import org.apache.doris.nereids.analyzer.PlaceholderExpression;
-import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
import org.apache.doris.nereids.parser.AbstractFnCallTransformers;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.Lists;
/**
- * The builder and factory for spark-sql 3.x {@link AbstractFnCallTransformer},
- * and supply transform facade ability.
+ * The builder and factory for spark-sql 3.x FnCallTransformers, supply
transform facade ability.
*/
public class SparkSql3FnCallTransformers extends AbstractFnCallTransformers {
@@ -35,30 +33,54 @@ public class SparkSql3FnCallTransformers extends
AbstractFnCallTransformers {
@Override
protected void registerTransformers() {
- doRegister("get_json_object", 2, "json_extract",
- Lists.newArrayList(
- PlaceholderExpression.of(Expression.class, 1),
- PlaceholderExpression.of(Expression.class, 2)), true);
+ // register json functions
+ registerJsonFunctionTransformers();
+ // register string functions
+ registerStringFunctionTransformers();
+ // register date functions
+ registerDateFunctionTransformers();
+ // register numeric functions
+ registerNumericFunctionTransformers();
+ // TODO: add other function transformer
+ }
+
+ @Override
+ protected void registerComplexTransformers() {
+ DateTruncFnCallTransformer dateTruncFnCallTransformer = new
DateTruncFnCallTransformer();
+ doRegister(dateTruncFnCallTransformer.getSourceFnName(),
dateTruncFnCallTransformer);
+ // TODO: add other complex function transformer
+ }
- doRegister("get_json_object", 2, "json_extract",
+ private void registerJsonFunctionTransformers() {
+ doRegister("get_json_object", "json_extract",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
- PlaceholderExpression.of(Expression.class, 2)), false);
+ PlaceholderExpression.of(Expression.class, 2)));
+ }
- doRegister("split", 2, "split_by_string",
+ private void registerStringFunctionTransformers() {
+ doRegister("split", "split_by_string",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
- PlaceholderExpression.of(Expression.class, 2)), true);
- doRegister("split", 2, "split_by_string",
+ PlaceholderExpression.of(Expression.class, 2)));
+ }
+
+ private void registerDateFunctionTransformers() {
+ // spark-sql support to_date(date_str, fmt) function but doris only
support to_date(date_str)
+ // here try to compat with this situation by using str_to_date(str,
fmt),
+ // this function support the following three formats which can handle
the mainly situations:
+ // 1. yyyyMMdd
+ // 2. yyyy-MM-dd
+ // 3. yyyy-MM-dd HH:mm:ss
+ doRegister("to_date", "str_to_date",
Lists.newArrayList(
PlaceholderExpression.of(Expression.class, 1),
- PlaceholderExpression.of(Expression.class, 2)), false);
- // TODO: add other function transformer
+ PlaceholderExpression.of(Expression.class, 2)));
}
- @Override
- protected void registerComplexTransformers() {
- // TODO: add other complex function transformer
+ private void registerNumericFunctionTransformers() {
+ doRegister("mean", "avg",
+ Lists.newArrayList(PlaceholderExpression.of(Expression.class,
1)));
}
static class SingletonHolder {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
index b59a9327bde..986f9996f90 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.parser.trino;
import org.apache.doris.nereids.analyzer.UnboundFunction;
+import org.apache.doris.nereids.parser.ComplexFnCallTransformer;
import org.apache.doris.nereids.parser.ParserContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
@@ -30,7 +31,7 @@ import java.util.List;
/**
* DateDiff complex function transformer
*/
-public class DateDiffFnCallTransformer extends ComplexTrinoFnCallTransformer {
+public class DateDiffFnCallTransformer extends ComplexFnCallTransformer {
private static final String SECOND = "second";
private static final String HOUR = "hour";
@@ -45,15 +46,12 @@ public class DateDiffFnCallTransformer extends
ComplexTrinoFnCallTransformer {
@Override
protected boolean check(String sourceFnName, List<Expression>
sourceFnTransformedArguments,
ParserContext context) {
- return getSourceFnName().equalsIgnoreCase(sourceFnName);
+ return getSourceFnName().equalsIgnoreCase(sourceFnName) &&
(sourceFnTransformedArguments.size() == 3);
}
@Override
protected Function transform(String sourceFnName, List<Expression>
sourceFnTransformedArguments,
ParserContext context) {
- if (sourceFnTransformedArguments.size() != 3) {
- return null;
- }
VarcharLiteral diffGranularity = (VarcharLiteral)
sourceFnTransformedArguments.get(0);
if (SECOND.equals(diffGranularity.getValue())) {
return new UnboundFunction(
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/TrinoFnCallTransformers.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/TrinoFnCallTransformers.java
index 883cb1cd132..662439cfc9d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/TrinoFnCallTransformers.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/TrinoFnCallTransformers.java
@@ -18,14 +18,13 @@
package org.apache.doris.nereids.parser.trino;
import org.apache.doris.nereids.analyzer.PlaceholderExpression;
-import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
import org.apache.doris.nereids.parser.AbstractFnCallTransformers;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.Lists;
/**
- * The builder and factory for trino {@link AbstractFnCallTransformer},
+ * The builder and factory for trino function call transformers,
* and supply transform facade ability.
*/
public class TrinoFnCallTransformers extends AbstractFnCallTransformers {
@@ -47,8 +46,8 @@ public class TrinoFnCallTransformers extends
AbstractFnCallTransformers {
}
protected void registerStringFunctionTransformer() {
- doRegister("codepoint", 1, "ascii",
- Lists.newArrayList(PlaceholderExpression.of(Expression.class,
1)), false);
+ doRegister("codepoint", "ascii",
+ Lists.newArrayList(PlaceholderExpression.of(Expression.class,
1)));
// TODO: add other string function transformer
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/spark/FnTransformTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/spark/FnTransformTest.java
index f652e171280..a4e972c577a 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/spark/FnTransformTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/spark/FnTransformTest.java
@@ -21,6 +21,7 @@ import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.parser.ParserTestBase;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@@ -30,26 +31,59 @@ import org.junit.jupiter.api.Test;
public class FnTransformTest extends ParserTestBase {
@Test
- public void testCommonFnTransform() {
- NereidsParser nereidsParser = new NereidsParser();
+ public void testCommonFnTransformers() {
+ // test json functions
+ testFunction("SELECT json_extract('{\"c1\": 1}', '$.c1') as b FROM t",
+ "SELECT get_json_object('{\"c1\": 1}', '$.c1') as b FROM
t",
+ "json_extract('{\"c1\": 1}', '$.c1')");
- String sql1 = "SELECT json_extract('{\"a\": 1}', '$.a') as b FROM t";
- String dialectSql1 = "SELECT get_json_object('{\"a\": 1}', '$.a') as b
FROM t";
- LogicalPlan logicalPlan1 = nereidsParser.parseSingle(sql1);
- LogicalPlan dialectLogicalPlan1 =
nereidsParser.parseSingle(dialectSql1,
- new SparkSql3LogicalPlanBuilder());
- Assertions.assertEquals(dialectLogicalPlan1, logicalPlan1);
-
Assertions.assertTrue(dialectLogicalPlan1.child(0).toString().toLowerCase()
- .contains("json_extract('{\"a\": 1}', '$.a')"));
-
- String sql2 = "SELECT json_extract(a, '$.a') as b FROM t";
- String dialectSql2 = "SELECT get_json_object(a, '$.a') as b FROM t";
- LogicalPlan logicalPlan2 = nereidsParser.parseSingle(sql2);
- LogicalPlan dialectLogicalPlan2 =
nereidsParser.parseSingle(dialectSql2,
- new SparkSql3LogicalPlanBuilder());
- Assertions.assertEquals(dialectLogicalPlan2, logicalPlan2);
-
Assertions.assertTrue(dialectLogicalPlan2.child(0).toString().toLowerCase()
- .contains("json_extract('a, '$.a')"));
+ testFunction("SELECT json_extract(c1, '$.c1') as b FROM t",
+ "SELECT get_json_object(c1, '$.c1') as b FROM t",
+ "json_extract('c1, '$.c1')");
+
+ // test string functions
+ testFunction("SELECT str_to_date('2023-12-16', 'yyyy-MM-dd') as b FROM
t",
+ "SELECT to_date('2023-12-16', 'yyyy-MM-dd') as b FROM t",
+ "str_to_date('2023-12-16', 'yyyy-MM-dd')");
+ testFunction("SELECT str_to_date(c1, 'yyyy-MM-dd') as b FROM t",
+ "SELECT to_date(c1, 'yyyy-MM-dd') as b FROM t",
+ "str_to_date('c1, 'yyyy-MM-dd')");
+
+ testFunction("SELECT date_trunc('2023-12-16', 'YEAR') as a FROM t",
+ "SELECT trunc('2023-12-16', 'YEAR') as a FROM t",
+ "date_trunc('2023-12-16', 'YEAR')");
+ testFunction("SELECT date_trunc(c1, 'YEAR') as a FROM t",
+ "SELECT trunc(c1, 'YEAR') as a FROM t",
+ "date_trunc('c1, 'YEAR')");
+
+ testFunction("SELECT date_trunc('2023-12-16', 'YEAR') as a FROM t",
+ "SELECT trunc('2023-12-16', 'YY') as a FROM t",
+ "date_trunc('2023-12-16', 'YEAR')");
+ testFunction("SELECT date_trunc(c1, 'YEAR') as a FROM t",
+ "SELECT trunc(c1, 'YY') as a FROM t",
+ "date_trunc('c1, 'YEAR')");
+
+ testFunction("SELECT date_trunc('2023-12-16', 'MONTH') as a FROM t",
+ "SELECT trunc('2023-12-16', 'MON') as a FROM t",
+ "date_trunc('2023-12-16', 'MONTH')");
+ testFunction("SELECT date_trunc(c1, 'MONTH') as a FROM t",
+ "SELECT trunc(c1, 'MON') as a FROM t",
+ "date_trunc('c1, 'MONTH')");
+
+ // test numeric functions
+ testFunction("SELECT avg(c1) as a from t",
+ "SELECT mean(c1) as a from t",
+ "avg('c1)");
}
+ private void testFunction(String sql, String dialectSql, String
expectLogicalPlanStr) {
+ NereidsParser nereidsParser = new NereidsParser();
+ LogicalPlan logicalPlan = nereidsParser.parseSingle(sql);
+ LogicalPlan dialectLogicalPlan = nereidsParser.parseSingle(dialectSql,
+ new SparkSql3LogicalPlanBuilder());
+ Assertions.assertEquals(dialectLogicalPlan, logicalPlan);
+ String dialectLogicalPlanStr =
dialectLogicalPlan.child(0).toString().toLowerCase();
+ System.out.println("dialectLogicalPlanStr: " + dialectLogicalPlanStr);
+
Assertions.assertTrue(StringUtils.containsIgnoreCase(dialectLogicalPlanStr,
expectLogicalPlanStr));
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]