This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 8407490053c [feature-wip](nereids) Support some spark-sql built-in 
functions when set dialect=spark_sql (#28531)
8407490053c is described below

commit 8407490053cdc4cea94cef661739b7852b2d550d
Author: Xiangyu Wang <[email protected]>
AuthorDate: Sat Dec 30 00:10:35 2023 +0800

    [feature-wip](nereids) Support some spark-sql built-in functions when set 
dialect=spark_sql (#28531)
---
 .../nereids/analyzer/PlaceholderExpression.java    | 31 ++++++--
 .../nereids/parser/AbstractFnCallTransformers.java |  6 +-
 .../nereids/parser/CommonFnCallTransformer.java    | 91 +++++++++++++---------
 ...nsformer.java => ComplexFnCallTransformer.java} |  6 +-
 .../DateTruncFnCallTransformer.java}               | 49 ++++++++----
 .../parser/spark/SparkSql3FnCallTransformers.java  | 56 +++++++++----
 .../parser/trino/DateDiffFnCallTransformer.java    |  8 +-
 .../parser/trino/TrinoFnCallTransformers.java      |  7 +-
 .../nereids/parser/spark/FnTransformTest.java      | 72 ++++++++++++-----
 9 files changed, 214 insertions(+), 112 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java
index b2acdddd64f..9b2dcde49bd 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/PlaceholderExpression.java
@@ -23,9 +23,11 @@ import 
org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 
 import java.util.List;
 import java.util.Objects;
+import java.util.Set;
 
 /**
  * Expression placeHolder, the expression in PlaceHolderExpression will be 
collected by
@@ -33,15 +35,25 @@ import java.util.Objects;
  * @see PlaceholderCollector
  */
 public class PlaceholderExpression extends Expression implements 
AlwaysNotNullable {
-    private final Class<? extends Expression> delegateClazz;
+
+    private final ImmutableSet<Class<? extends Expression>> delegateClazzSet;
     /**
-     * 1 based
+     * start from 1, set the index of this placeholderExpression in 
sourceFnTransformedArguments
+     * this placeholderExpression will be replaced later
      */
     private final int position;
 
     public PlaceholderExpression(List<Expression> children, Class<? extends 
Expression> delegateClazz, int position) {
         super(children);
-        this.delegateClazz = Objects.requireNonNull(delegateClazz, 
"delegateClazz should not be null");
+        this.delegateClazzSet = ImmutableSet.of(
+                Objects.requireNonNull(delegateClazz, "delegateClazz should 
not be null"));
+        this.position = position;
+    }
+
+    public PlaceholderExpression(List<Expression> children,
+                                 Set<Class<? extends Expression>> 
delegateClazzSet, int position) {
+        super(children);
+        this.delegateClazzSet = ImmutableSet.copyOf(delegateClazzSet);
         this.position = position;
     }
 
@@ -49,13 +61,18 @@ public class PlaceholderExpression extends Expression 
implements AlwaysNotNullab
         return new PlaceholderExpression(ImmutableList.of(), delegateClazz, 
position);
     }
 
+    public static PlaceholderExpression of(Set<Class<? extends Expression>> 
delegateClazzSet, int position) {
+        return new PlaceholderExpression(ImmutableList.of(), delegateClazzSet, 
position);
+    }
+
     @Override
     public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+        visitor.visitPlaceholderExpression(this, context);
         return visitor.visit(this, context);
     }
 
-    public Class<? extends Expression> getDelegateClazz() {
-        return delegateClazz;
+    public Set<Class<? extends Expression>> getDelegateClazzSet() {
+        return delegateClazzSet;
     }
 
     public int getPosition() {
@@ -74,11 +91,11 @@ public class PlaceholderExpression extends Expression 
implements AlwaysNotNullab
             return false;
         }
         PlaceholderExpression that = (PlaceholderExpression) o;
-        return position == that.position && Objects.equals(delegateClazz, 
that.delegateClazz);
+        return position == that.position && Objects.equals(delegateClazzSet, 
that.delegateClazzSet);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), delegateClazz, position);
+        return Objects.hash(super.hashCode(), delegateClazzSet, position);
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/AbstractFnCallTransformers.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/AbstractFnCallTransformers.java
index 75b5f87263c..4386123c428 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/AbstractFnCallTransformers.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/AbstractFnCallTransformers.java
@@ -77,17 +77,15 @@ public abstract class AbstractFnCallTransformers {
 
     protected void doRegister(
             String sourceFnNme,
-            int sourceFnArgumentsNum,
             String targetFnName,
-            List<? extends Expression> targetFnArguments,
-            boolean variableArgument) {
+            List<? extends Expression> targetFnArguments) {
 
         List<Expression> castedTargetFnArguments = targetFnArguments
                 .stream()
                 .map(each -> (Expression) each)
                 .collect(Collectors.toList());
         transformerBuilder.put(sourceFnNme, new CommonFnCallTransformer(new 
UnboundFunction(
-                targetFnName, castedTargetFnArguments), variableArgument, 
sourceFnArgumentsNum));
+                targetFnName, castedTargetFnArguments)));
     }
 
     protected void doRegister(
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/CommonFnCallTransformer.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/CommonFnCallTransformer.java
index 872c21a71e5..fa054160674 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/CommonFnCallTransformer.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/CommonFnCallTransformer.java
@@ -23,56 +23,76 @@ import 
org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.functions.Function;
 import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
 
+import com.google.common.collect.Lists;
+
 import java.util.ArrayList;
 import java.util.List;
 import java.util.stream.Collectors;
 
 /**
- * Trino function transformer
+ * Common function transformer,
+ * can transform functions which the size and type of target arguments are 
both the same with the source function,
+ * or source function is a variable-arguments function.
  */
 public class CommonFnCallTransformer extends AbstractFnCallTransformer {
     private final UnboundFunction targetFunction;
     private final List<PlaceholderExpression> targetArguments;
-    private final boolean variableArgument;
-    private final int sourceArgumentsNum;
+
+    // true means the arguments of this function is dynamic, for example:
+    // - named_struct('f1', 1, 'f2', 'a', 'f3', "abc")
+    // - struct(1, 'a', 'abc');
+    private final boolean variableArguments;
 
     /**
-     * Trino function transformer, mostly this handle common function.
+     * Common function transformer, mostly this handle common function.
      */
-    public CommonFnCallTransformer(UnboundFunction targetFunction,
-                                   boolean variableArgument,
-                                   int sourceArgumentsNum) {
+    public CommonFnCallTransformer(UnboundFunction targetFunction, boolean 
variableArguments) {
+        this.targetFunction = targetFunction;
+        PlaceholderCollector placeHolderCollector = new PlaceholderCollector();
+        placeHolderCollector.visit(targetFunction, null);
+        this.targetArguments = 
placeHolderCollector.getPlaceholderExpressions();
+        this.variableArguments = variableArguments;
+    }
+
+    public CommonFnCallTransformer(UnboundFunction targetFunction) {
         this.targetFunction = targetFunction;
-        this.variableArgument = variableArgument;
-        this.sourceArgumentsNum = sourceArgumentsNum;
-        PlaceholderCollector placeHolderCollector = new 
PlaceholderCollector(variableArgument);
+        PlaceholderCollector placeHolderCollector = new PlaceholderCollector();
         placeHolderCollector.visit(targetFunction, null);
         this.targetArguments = 
placeHolderCollector.getPlaceholderExpressions();
+        this.variableArguments = false;
     }
 
     @Override
     protected boolean check(String sourceFnName,
             List<Expression> sourceFnTransformedArguments,
             ParserContext context) {
+        // if variableArguments=true, we can not recognize if the type of all 
arguments is valid or not,
+        // because:
+        //     1. the argument size is not sure
+        //     2. there are some functions which can accept different types of 
arguments,
+        //        for example: struct(1, 'a', 'abc')
+        // so just return true here.
+        if (variableArguments) {
+            return true;
+        }
         List<Class<? extends Expression>> sourceFnTransformedArgClazz = 
sourceFnTransformedArguments.stream()
                 .map(Expression::getClass)
                 .collect(Collectors.toList());
-        if (variableArgument) {
-            if (targetArguments.isEmpty()) {
-                return false;
-            }
-            Class<? extends Expression> targetArgumentClazz = 
targetArguments.get(0).getDelegateClazz();
-            for (Expression argument : sourceFnTransformedArguments) {
-                if 
(!targetArgumentClazz.isAssignableFrom(argument.getClass())) {
-                    return false;
-                }
-            }
-        }
-        if (sourceFnTransformedArguments.size() != sourceArgumentsNum) {
+        if (sourceFnTransformedArguments.size() != targetArguments.size()) {
             return false;
         }
-        for (int i = 0; i < targetArguments.size(); i++) {
-            if 
(!targetArguments.get(i).getDelegateClazz().isAssignableFrom(sourceFnTransformedArgClazz.get(i)))
 {
+        for (PlaceholderExpression targetArgument : targetArguments) {
+            // replace the arguments of target function by the position of 
target argument
+            int position = targetArgument.getPosition();
+            Class<? extends Expression> sourceArgClazz = 
sourceFnTransformedArgClazz.get(position - 1);
+            boolean valid = false;
+            for (Class<? extends Expression> targetArgClazz : 
targetArgument.getDelegateClazzSet()) {
+                if (targetArgClazz.isAssignableFrom(sourceArgClazz)) {
+                    valid = true;
+                    break;
+                }
+            }
+            if (!valid) {
                 return false;
             }
         }
@@ -83,7 +103,16 @@ public class CommonFnCallTransformer extends 
AbstractFnCallTransformer {
     protected Function transform(String sourceFnName,
             List<Expression> sourceFnTransformedArguments,
             ParserContext context) {
-        return targetFunction.withChildren(sourceFnTransformedArguments);
+        if (variableArguments) {
+            // not support adjust the order of arguments when 
variableArguments=true
+            return targetFunction.withChildren(sourceFnTransformedArguments);
+        }
+        List<Expression> sourceFnTransformedArgumentsInorder = 
Lists.newArrayList();
+        for (PlaceholderExpression placeholderExpression : targetArguments) {
+            Expression expression = 
sourceFnTransformedArguments.get(placeholderExpression.getPosition() - 1);
+            sourceFnTransformedArgumentsInorder.add(expression);
+        }
+        return 
targetFunction.withChildren(sourceFnTransformedArgumentsInorder);
     }
 
     /**
@@ -93,20 +122,12 @@ public class CommonFnCallTransformer extends 
AbstractFnCallTransformer {
     public static final class PlaceholderCollector extends 
DefaultExpressionVisitor<Void, Void> {
 
         private final List<PlaceholderExpression> placeholderExpressions = new 
ArrayList<>();
-        private final boolean variableArgument;
 
-        public PlaceholderCollector(boolean variableArgument) {
-            this.variableArgument = variableArgument;
-        }
+        public PlaceholderCollector() {}
 
         @Override
         public Void visitPlaceholderExpression(PlaceholderExpression 
placeholderExpression, Void context) {
-
-            if (variableArgument) {
-                placeholderExpressions.add(placeholderExpression);
-                return null;
-            }
-            placeholderExpressions.set(placeholderExpression.getPosition() - 
1, placeholderExpression);
+            placeholderExpressions.add(placeholderExpression);
             return null;
         }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/ComplexTrinoFnCallTransformer.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/ComplexFnCallTransformer.java
similarity index 81%
rename from 
fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/ComplexTrinoFnCallTransformer.java
rename to 
fe/fe-core/src/main/java/org/apache/doris/nereids/parser/ComplexFnCallTransformer.java
index d3a687a289f..2583320b534 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/ComplexTrinoFnCallTransformer.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/ComplexFnCallTransformer.java
@@ -15,14 +15,12 @@
 // specific language governing permissions and limitations
 // under the License.
 
-package org.apache.doris.nereids.parser.trino;
-
-import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
+package org.apache.doris.nereids.parser;
 
 /**
  * Trino complex function transformer
  */
-public abstract class ComplexTrinoFnCallTransformer extends 
AbstractFnCallTransformer {
+public abstract class ComplexFnCallTransformer extends 
AbstractFnCallTransformer {
 
     protected abstract String getSourceFnName();
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/DateTruncFnCallTransformer.java
similarity index 51%
copy from 
fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
copy to 
fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/DateTruncFnCallTransformer.java
index b59a9327bde..1503c6db22c 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/DateTruncFnCallTransformer.java
@@ -15,52 +15,67 @@
 // specific language governing permissions and limitations
 // under the License.
 
-package org.apache.doris.nereids.parser.trino;
+package org.apache.doris.nereids.parser.spark;
 
 import org.apache.doris.nereids.analyzer.UnboundFunction;
+import org.apache.doris.nereids.parser.ComplexFnCallTransformer;
 import org.apache.doris.nereids.parser.ParserContext;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.functions.Function;
 import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 
 import java.util.List;
 
 /**
- * DateDiff complex function transformer
+ * DateTrunc complex function transformer
  */
-public class DateDiffFnCallTransformer extends ComplexTrinoFnCallTransformer {
+public class DateTruncFnCallTransformer extends ComplexFnCallTransformer {
 
-    private static final String SECOND = "second";
-    private static final String HOUR = "hour";
-    private static final String DAY = "day";
-    private static final String MILLI_SECOND = "millisecond";
+    // reference: https://spark.apache.org/docs/latest/api/sql/index.html#trunc
+    // spark-sql support YEAR/YYYY/YY for year, support MONTH/MON/MM for month
+    private static final ImmutableSet<String> YEAR = 
ImmutableSet.<String>builder()
+            .add("YEAR")
+            .add("YYYY")
+            .add("YY")
+            .build();
+
+    private static final ImmutableSet<String> MONTH = 
ImmutableSet.<String>builder()
+            .add("MONTH")
+            .add("MON")
+            .add("MM")
+            .build();
 
     @Override
     public String getSourceFnName() {
-        return "date_diff";
+        return "trunc";
     }
 
     @Override
     protected boolean check(String sourceFnName, List<Expression> 
sourceFnTransformedArguments,
             ParserContext context) {
-        return getSourceFnName().equalsIgnoreCase(sourceFnName);
+        return getSourceFnName().equalsIgnoreCase(sourceFnName) && 
(sourceFnTransformedArguments.size() == 2);
     }
 
     @Override
     protected Function transform(String sourceFnName, List<Expression> 
sourceFnTransformedArguments,
             ParserContext context) {
-        if (sourceFnTransformedArguments.size() != 3) {
-            return null;
+        VarcharLiteral fmtLiteral = (VarcharLiteral) 
sourceFnTransformedArguments.get(1);
+        if (YEAR.contains(fmtLiteral.getValue().toUpperCase())) {
+            return new UnboundFunction(
+                    "date_trunc",
+                    ImmutableList.of(sourceFnTransformedArguments.get(0), new 
VarcharLiteral("YEAR")));
         }
-        VarcharLiteral diffGranularity = (VarcharLiteral) 
sourceFnTransformedArguments.get(0);
-        if (SECOND.equals(diffGranularity.getValue())) {
+        if (MONTH.contains(fmtLiteral.getValue().toUpperCase())) {
             return new UnboundFunction(
-                    "seconds_diff",
-                    ImmutableList.of(sourceFnTransformedArguments.get(1), 
sourceFnTransformedArguments.get(2)));
+                    "date_trunc",
+                    ImmutableList.of(sourceFnTransformedArguments.get(0), new 
VarcharLiteral("MONTH")));
         }
-        // TODO: support other date diff granularity
-        return null;
+
+        return new UnboundFunction(
+                "date_trunc",
+                ImmutableList.of(sourceFnTransformedArguments.get(0), 
sourceFnTransformedArguments.get(1)));
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/SparkSql3FnCallTransformers.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/SparkSql3FnCallTransformers.java
index 5a6ec21fc9a..341a3d1951d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/SparkSql3FnCallTransformers.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/spark/SparkSql3FnCallTransformers.java
@@ -18,15 +18,13 @@
 package org.apache.doris.nereids.parser.spark;
 
 import org.apache.doris.nereids.analyzer.PlaceholderExpression;
-import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
 import org.apache.doris.nereids.parser.AbstractFnCallTransformers;
 import org.apache.doris.nereids.trees.expressions.Expression;
 
 import com.google.common.collect.Lists;
 
 /**
- * The builder and factory for spark-sql 3.x {@link AbstractFnCallTransformer},
- * and supply transform facade ability.
+ * The builder and factory for spark-sql 3.x FnCallTransformers, supply 
transform facade ability.
  */
 public class SparkSql3FnCallTransformers extends AbstractFnCallTransformers {
 
@@ -35,30 +33,54 @@ public class SparkSql3FnCallTransformers extends 
AbstractFnCallTransformers {
 
     @Override
     protected void registerTransformers() {
-        doRegister("get_json_object", 2, "json_extract",
-                Lists.newArrayList(
-                        PlaceholderExpression.of(Expression.class, 1),
-                        PlaceholderExpression.of(Expression.class, 2)), true);
+        // register json functions
+        registerJsonFunctionTransformers();
+        // register string functions
+        registerStringFunctionTransformers();
+        // register date functions
+        registerDateFunctionTransformers();
+        // register numeric functions
+        registerNumericFunctionTransformers();
+        // TODO: add other function transformer
+    }
+
+    @Override
+    protected void registerComplexTransformers() {
+        DateTruncFnCallTransformer dateTruncFnCallTransformer = new 
DateTruncFnCallTransformer();
+        doRegister(dateTruncFnCallTransformer.getSourceFnName(), 
dateTruncFnCallTransformer);
+        // TODO: add other complex function transformer
+    }
 
-        doRegister("get_json_object", 2, "json_extract",
+    private void registerJsonFunctionTransformers() {
+        doRegister("get_json_object", "json_extract",
                 Lists.newArrayList(
                         PlaceholderExpression.of(Expression.class, 1),
-                        PlaceholderExpression.of(Expression.class, 2)), false);
+                        PlaceholderExpression.of(Expression.class, 2)));
+    }
 
-        doRegister("split", 2, "split_by_string",
+    private void registerStringFunctionTransformers() {
+        doRegister("split", "split_by_string",
                 Lists.newArrayList(
                         PlaceholderExpression.of(Expression.class, 1),
-                        PlaceholderExpression.of(Expression.class, 2)), true);
-        doRegister("split", 2, "split_by_string",
+                        PlaceholderExpression.of(Expression.class, 2)));
+    }
+
+    private void registerDateFunctionTransformers() {
+        // spark-sql support to_date(date_str, fmt) function but doris only 
support to_date(date_str)
+        // here try to compat with this situation by using str_to_date(str, 
fmt),
+        // this function support the following three formats which can handle 
the mainly situations:
+        //  1. yyyyMMdd
+        //  2. yyyy-MM-dd
+        //  3. yyyy-MM-dd HH:mm:ss
+        doRegister("to_date", "str_to_date",
                 Lists.newArrayList(
                         PlaceholderExpression.of(Expression.class, 1),
-                        PlaceholderExpression.of(Expression.class, 2)), false);
-        // TODO: add other function transformer
+                        PlaceholderExpression.of(Expression.class, 2)));
     }
 
-    @Override
-    protected void registerComplexTransformers() {
-        // TODO: add other complex function transformer
+    private void registerNumericFunctionTransformers() {
+        doRegister("mean", "avg",
+                Lists.newArrayList(PlaceholderExpression.of(Expression.class, 
1)));
     }
 
     static class SingletonHolder {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
index b59a9327bde..986f9996f90 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/DateDiffFnCallTransformer.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.parser.trino;
 
 import org.apache.doris.nereids.analyzer.UnboundFunction;
+import org.apache.doris.nereids.parser.ComplexFnCallTransformer;
 import org.apache.doris.nereids.parser.ParserContext;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.functions.Function;
@@ -30,7 +31,7 @@ import java.util.List;
 /**
  * DateDiff complex function transformer
  */
-public class DateDiffFnCallTransformer extends ComplexTrinoFnCallTransformer {
+public class DateDiffFnCallTransformer extends ComplexFnCallTransformer {
 
     private static final String SECOND = "second";
     private static final String HOUR = "hour";
@@ -45,15 +46,12 @@ public class DateDiffFnCallTransformer extends 
ComplexTrinoFnCallTransformer {
     @Override
     protected boolean check(String sourceFnName, List<Expression> 
sourceFnTransformedArguments,
             ParserContext context) {
-        return getSourceFnName().equalsIgnoreCase(sourceFnName);
+        return getSourceFnName().equalsIgnoreCase(sourceFnName) && 
(sourceFnTransformedArguments.size() == 3);
     }
 
     @Override
     protected Function transform(String sourceFnName, List<Expression> 
sourceFnTransformedArguments,
             ParserContext context) {
-        if (sourceFnTransformedArguments.size() != 3) {
-            return null;
-        }
         VarcharLiteral diffGranularity = (VarcharLiteral) 
sourceFnTransformedArguments.get(0);
         if (SECOND.equals(diffGranularity.getValue())) {
             return new UnboundFunction(
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/TrinoFnCallTransformers.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/TrinoFnCallTransformers.java
index 883cb1cd132..662439cfc9d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/TrinoFnCallTransformers.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/trino/TrinoFnCallTransformers.java
@@ -18,14 +18,13 @@
 package org.apache.doris.nereids.parser.trino;
 
 import org.apache.doris.nereids.analyzer.PlaceholderExpression;
-import org.apache.doris.nereids.parser.AbstractFnCallTransformer;
 import org.apache.doris.nereids.parser.AbstractFnCallTransformers;
 import org.apache.doris.nereids.trees.expressions.Expression;
 
 import com.google.common.collect.Lists;
 
 /**
- * The builder and factory for trino {@link AbstractFnCallTransformer},
+ * The builder and factory for trino function call transformers,
  * and supply transform facade ability.
  */
 public class TrinoFnCallTransformers extends AbstractFnCallTransformers {
@@ -47,8 +46,8 @@ public class TrinoFnCallTransformers extends 
AbstractFnCallTransformers {
     }
 
     protected void registerStringFunctionTransformer() {
-        doRegister("codepoint", 1, "ascii",
-                Lists.newArrayList(PlaceholderExpression.of(Expression.class, 
1)), false);
+        doRegister("codepoint", "ascii",
+                Lists.newArrayList(PlaceholderExpression.of(Expression.class, 
1)));
         // TODO: add other string function transformer
     }
 
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/spark/FnTransformTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/spark/FnTransformTest.java
index f652e171280..a4e972c577a 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/spark/FnTransformTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/spark/FnTransformTest.java
@@ -21,6 +21,7 @@ import org.apache.doris.nereids.parser.NereidsParser;
 import org.apache.doris.nereids.parser.ParserTestBase;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 
+import org.apache.commons.lang3.StringUtils;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
@@ -30,26 +31,59 @@ import org.junit.jupiter.api.Test;
 public class FnTransformTest extends ParserTestBase {
 
     @Test
-    public void testCommonFnTransform() {
-        NereidsParser nereidsParser = new NereidsParser();
+    public void testCommonFnTransformers() {
+        // test json functions
+        testFunction("SELECT json_extract('{\"c1\": 1}', '$.c1') as b FROM t",
+                    "SELECT get_json_object('{\"c1\": 1}', '$.c1') as b FROM 
t",
+                    "json_extract('{\"c1\": 1}', '$.c1')");
 
-        String sql1 = "SELECT json_extract('{\"a\": 1}', '$.a') as b FROM t";
-        String dialectSql1 = "SELECT get_json_object('{\"a\": 1}', '$.a') as b 
FROM t";
-        LogicalPlan logicalPlan1 = nereidsParser.parseSingle(sql1);
-        LogicalPlan dialectLogicalPlan1 = 
nereidsParser.parseSingle(dialectSql1,
-                    new SparkSql3LogicalPlanBuilder());
-        Assertions.assertEquals(dialectLogicalPlan1, logicalPlan1);
-        
Assertions.assertTrue(dialectLogicalPlan1.child(0).toString().toLowerCase()
-                    .contains("json_extract('{\"a\": 1}', '$.a')"));
-
-        String sql2 = "SELECT json_extract(a, '$.a') as b FROM t";
-        String dialectSql2 = "SELECT get_json_object(a, '$.a') as b FROM t";
-        LogicalPlan logicalPlan2 = nereidsParser.parseSingle(sql2);
-        LogicalPlan dialectLogicalPlan2 = 
nereidsParser.parseSingle(dialectSql2,
-                new SparkSql3LogicalPlanBuilder());
-        Assertions.assertEquals(dialectLogicalPlan2, logicalPlan2);
-        
Assertions.assertTrue(dialectLogicalPlan2.child(0).toString().toLowerCase()
-                    .contains("json_extract('a, '$.a')"));
+        testFunction("SELECT json_extract(c1, '$.c1') as b FROM t",
+                    "SELECT get_json_object(c1, '$.c1') as b FROM t",
+                    "json_extract('c1, '$.c1')");
+
+        // test string functions
+        testFunction("SELECT str_to_date('2023-12-16', 'yyyy-MM-dd') as b FROM 
t",
+                    "SELECT to_date('2023-12-16', 'yyyy-MM-dd') as b FROM t",
+                    "str_to_date('2023-12-16', 'yyyy-MM-dd')");
+        testFunction("SELECT str_to_date(c1, 'yyyy-MM-dd') as b FROM t",
+                "SELECT to_date(c1, 'yyyy-MM-dd') as b FROM t",
+                "str_to_date('c1, 'yyyy-MM-dd')");
+
+        testFunction("SELECT date_trunc('2023-12-16', 'YEAR') as a FROM t",
+                    "SELECT trunc('2023-12-16', 'YEAR') as a FROM t",
+                    "date_trunc('2023-12-16', 'YEAR')");
+        testFunction("SELECT date_trunc(c1, 'YEAR') as a FROM t",
+                "SELECT trunc(c1, 'YEAR') as a FROM t",
+                "date_trunc('c1, 'YEAR')");
+
+        testFunction("SELECT date_trunc('2023-12-16', 'YEAR') as a FROM t",
+                    "SELECT trunc('2023-12-16', 'YY') as a FROM t",
+                    "date_trunc('2023-12-16', 'YEAR')");
+        testFunction("SELECT date_trunc(c1, 'YEAR') as a FROM t",
+                "SELECT trunc(c1, 'YY') as a FROM t",
+                "date_trunc('c1, 'YEAR')");
+
+        testFunction("SELECT date_trunc('2023-12-16', 'MONTH') as a FROM t",
+                    "SELECT trunc('2023-12-16', 'MON') as a FROM t",
+                    "date_trunc('2023-12-16', 'MONTH')");
+        testFunction("SELECT date_trunc(c1, 'MONTH') as a FROM t",
+                "SELECT trunc(c1, 'MON') as a FROM t",
+                "date_trunc('c1, 'MONTH')");
+
+        // test numeric functions
+        testFunction("SELECT avg(c1) as a from t",
+                    "SELECT mean(c1) as a from t",
+                    "avg('c1)");
     }
 
+    private void testFunction(String sql, String dialectSql, String 
expectLogicalPlanStr) {
+        NereidsParser nereidsParser = new NereidsParser();
+        LogicalPlan logicalPlan = nereidsParser.parseSingle(sql);
+        LogicalPlan dialectLogicalPlan = nereidsParser.parseSingle(dialectSql,
+                new SparkSql3LogicalPlanBuilder());
+        Assertions.assertEquals(dialectLogicalPlan, logicalPlan);
+        String dialectLogicalPlanStr = 
dialectLogicalPlan.child(0).toString().toLowerCase();
+        System.out.println("dialectLogicalPlanStr: " + dialectLogicalPlanStr);
+        
Assertions.assertTrue(StringUtils.containsIgnoreCase(dialectLogicalPlanStr, 
expectLogicalPlanStr));
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to