This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 675aef7  [AliasFunction] Add support for cast in alias function (#6754)
675aef7 is described below

commit 675aef7d75f2cfd6a25e588d1dc480c8cbba67ba
Author: qiye <[email protected]>
AuthorDate: Sun Oct 10 23:05:44 2021 +0800

    [AliasFunction] Add support for cast in alias function (#6754)
    
    support #6753
---
 .../Data Definition/create-function.md             |   9 +-
 .../Data Definition/create-function.md             |   9 +-
 fe/fe-core/src/main/cup/sql_parser.cup             |  22 ++++
 .../java/org/apache/doris/analysis/CastExpr.java   | 126 ++++++++++++++++++++-
 .../main/java/org/apache/doris/analysis/Expr.java  |   9 +-
 .../org/apache/doris/catalog/AliasFunction.java    |  45 +++++++-
 .../org/apache/doris/catalog/PrimitiveType.java    |   5 +-
 .../java/org/apache/doris/catalog/ScalarType.java  |  91 ++++++++++++++-
 .../main/java/org/apache/doris/catalog/Type.java   |   8 +-
 .../doris/rewrite/RewriteAliasFunctionRule.java    |   9 +-
 .../apache/doris/catalog/CreateFunctionTest.java   |  99 ++++++++++++++++
 gensrc/thrift/Types.thrift                         |   3 +-
 12 files changed, 415 insertions(+), 20 deletions(-)

diff --git a/docs/en/sql-reference/sql-statements/Data 
Definition/create-function.md b/docs/en/sql-reference/sql-statements/Data 
Definition/create-function.md
index 7eed4e4..417678c 100644
--- a/docs/en/sql-reference/sql-statements/Data Definition/create-function.md   
+++ b/docs/en/sql-reference/sql-statements/Data Definition/create-function.md   
@@ -47,8 +47,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
 > 
 > `Function_name`: To create the name of the function, you can include the 
 > name of the database. For example: `db1.my_func'.
 >
-> `arg_type`: The parameter type of the function is the same as the type 
defined at the time of table building. Variable-length parameters can be 
represented by `,...`. If it is a variable-length type, the type of the 
variable-length part of the parameters is the same as the last 
non-variable-length parameter type.
-> **NOTICE**: `ALIAS FUNCTION` variable-length parameters are not supported, 
and there is at least one parameter.
+> `arg_type`: The parameter type of the function is the same as the type 
defined at the time of table building. Variable-length parameters can be 
represented by `,...`. If it is a variable-length type, the type of the 
variable-length part of the parameters is the same as the last 
non-variable-length parameter type.  
+> **NOTICE**: `ALIAS FUNCTION` variable-length parameters are not supported, 
and there is at least one parameter. In particular, the type `ALL` refers to 
any data type and can only be used for `ALIAS FUNCTION`.
 > 
 > `ret_type`: Required for creating a new function. This parameter is not 
 > required if you are aliasing an existing function.
 >
@@ -130,8 +130,13 @@ If the `function_name` contains the database name, the 
custom function will be c
 5. Create a custom alias function
 
     ```
+    -- create a custom functional alias function
     CREATE ALIAS FUNCTION id_masking(INT) WITH PARAMETER(id) 
         AS CONCAT(LEFT(id, 3), '****', RIGHT(id, 4));
+
+    -- create a custom cast alias function
+    CREATE ALIAS FUNCTION decimal(ALL, INT, INT) WITH PARAMETER(col, 
precision, scale) 
+        AS CAST(col AS decimal(precision, scale));
     ```
 
 ## keyword
diff --git a/docs/zh-CN/sql-reference/sql-statements/Data 
Definition/create-function.md b/docs/zh-CN/sql-reference/sql-statements/Data 
Definition/create-function.md
index 8582626..cf6a4fe 100644
--- a/docs/zh-CN/sql-reference/sql-statements/Data 
Definition/create-function.md        
+++ b/docs/zh-CN/sql-reference/sql-statements/Data 
Definition/create-function.md        
@@ -47,8 +47,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
 > 
 > `function_name`: 要创建函数的名字, 可以包含数据库的名字。比如:`db1.my_func`。
 >
-> `arg_type`: 函数的参数类型,与建表时定义的类型一致。变长参数时可以使用`, 
...`来表示,如果是变长类型,那么变长部分参数的类型与最后一个非变长参数类型一致。
-> **注意**:`ALIAS FUNCTION` 不支持变长参数,且至少有一个参数。
+> `arg_type`: 函数的参数类型,与建表时定义的类型一致。变长参数时可以使用`, 
...`来表示,如果是变长类型,那么变长部分参数的类型与最后一个非变长参数类型一致。  
+> **注意**:`ALIAS FUNCTION` 不支持变长参数,且至少有一个参数。 特别地,`ALL` 类型指任一数据类型,只可以用于 `ALIAS 
FUNCTION`.
 > 
 > `ret_type`: 对创建新的函数来说,是必填项。如果是给已有函数取别名则可不用填写该参数。
 > 
@@ -131,8 +131,13 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
 5. 创建一个自定义别名函数
 
     ```
+    -- 创建自定义功能别名函数
     CREATE ALIAS FUNCTION id_masking(INT) WITH PARAMETER(id) 
         AS CONCAT(LEFT(id, 3), '****', RIGHT(id, 4));
+    
+    -- 创建自定义 CAST 别名函数
+    CREATE ALIAS FUNCTION decimal(ALL, INT, INT) WITH PARAMETER(col, 
precision, scale) 
+        AS CAST(col AS decimal(precision, scale));
     ```
 
 ## keyword
diff --git a/fe/fe-core/src/main/cup/sql_parser.cup 
b/fe/fe-core/src/main/cup/sql_parser.cup
index bfb936e..20e9c8e 100644
--- a/fe/fe-core/src/main/cup/sql_parser.cup
+++ b/fe/fe-core/src/main/cup/sql_parser.cup
@@ -4352,6 +4352,11 @@ type ::=
      type.setAssignedStrLenInColDefinition();
      RESULT = type;
   :}
+  | KW_VARCHAR LPAREN ident_or_text:lenStr RPAREN
+  {: ScalarType type = ScalarType.createVarcharType(lenStr);
+     type.setAssignedStrLenInColDefinition();
+     RESULT = type;
+  :}
   | KW_VARCHAR
   {: RESULT = ScalarType.createVarcharType(-1); :}
   | KW_ARRAY LESSTHAN type:value_type GREATERTHAN
@@ -4365,6 +4370,11 @@ type ::=
      type.setAssignedStrLenInColDefinition();
      RESULT = type;
   :}
+  | KW_CHAR LPAREN ident_or_text:lenStr RPAREN
+  {: ScalarType type = ScalarType.createCharType(lenStr);
+     type.setAssignedStrLenInColDefinition();
+     RESULT = type;
+  :}
   | KW_CHAR
   {: RESULT = ScalarType.createCharType(-1); :}
   | KW_DECIMAL LPAREN INTEGER_LITERAL:precision RPAREN
@@ -4373,11 +4383,17 @@ type ::=
   {: RESULT = ScalarType.createDecimalV2Type(precision.intValue(), 
scale.intValue()); :}
   | KW_DECIMAL
   {: RESULT = ScalarType.createDecimalV2Type(); :}
+  | KW_DECIMAL LPAREN ident_or_text:precision RPAREN
+  {: RESULT = ScalarType.createDecimalV2Type(precision); :}
+  | KW_DECIMAL LPAREN ident_or_text:precision COMMA ident_or_text:scale RPAREN
+  {: RESULT = ScalarType.createDecimalV2Type(precision, scale); :}
   | KW_HLL
   {: ScalarType type = ScalarType.createHllType();
      type.setAssignedStrLenInColDefinition();
      RESULT = type;
   :}
+  | KW_ALL
+  {: RESULT = Type.ALL; :}
   ;
 
 opt_field_length ::=
@@ -5180,6 +5196,8 @@ keyword ::=
     {: RESULT = id; :}
     | KW_CHAIN:id
     {: RESULT = id; :}
+    | KW_CHAR:id
+    {: RESULT = id; :}
     | KW_CHARSET:id
     {: RESULT = id; :}
     | KW_CHECK:id
@@ -5210,6 +5228,8 @@ keyword ::=
     {: RESULT = id; :}
     | KW_DATETIME:id
     {: RESULT = id; :}
+    | KW_DECIMAL:id
+    {: RESULT = id; :}
     | KW_DISTINCTPC:id
     {: RESULT = id; :}
     | KW_DISTINCTPCSA:id
@@ -5456,6 +5476,8 @@ keyword ::=
     {: RESULT = id; :}
     | KW_MAP:id
     {: RESULT = id; :}
+    | KW_VARCHAR:id
+    {: RESULT = id; :}
     ;
 
 // Identifier that contain keyword
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
index ca46025..f071e0f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
@@ -17,7 +17,11 @@
 
 package org.apache.doris.analysis;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 import java.util.Arrays;
+import java.util.List;
 import java.util.Map;
 
 import org.apache.doris.catalog.Catalog;
@@ -46,7 +50,7 @@ public class CastExpr extends Expr {
     private static final Logger LOG = LogManager.getLogger(CastExpr.class);
 
     // Only set for explicit casts. Null for implicit casts.
-    private final TypeDef targetTypeDef;
+    private TypeDef targetTypeDef;
 
     // True if this is a "pre-analyzed" implicit cast.
     private boolean isImplicit;
@@ -77,6 +81,11 @@ public class CastExpr extends Expr {
         }
     }
 
+    // only used restore from readFields.
+    public CastExpr() {
+
+    }
+
     public CastExpr(Type targetType, Expr e) {
         super();
         Preconditions.checkArgument(targetType.isValid());
@@ -120,6 +129,10 @@ public class CastExpr extends Expr {
         return "castTo" + targetType.getPrimitiveType().toString();
     }
 
+    public TypeDef getTargetTypeDef() {
+        return targetTypeDef;
+    }
+
     public static void initBuiltins(FunctionSet functionSet) {
         for (Type fromType : Type.getSupportedTypes()) {
             if (fromType.isNull()) {
@@ -206,6 +219,10 @@ public class CastExpr extends Expr {
     }
 
     public void analyze() throws AnalysisException {
+        // do not analyze ALL cast
+        if (type == Type.ALL) {
+            return;
+        }
         // cast was asked for in the query, check for validity of cast
         Type childType = getChild(0).getType();
 
@@ -327,4 +344,111 @@ public class CastExpr extends Expr {
         }
         return this;
     }
+
+    @Override
+    public void write(DataOutput out) throws IOException {
+        out.writeBoolean(isImplicit);
+        if (targetTypeDef.getType() instanceof ScalarType) {
+            ScalarType scalarType = (ScalarType) targetTypeDef.getType();
+            scalarType.write(out);
+        } else {
+            throw new IOException("Can not write type " + 
targetTypeDef.getType());
+        }
+        out.writeInt(children.size());
+        for (Expr expr : children) {
+            Expr.writeTo(expr, out);
+        }
+    }
+
+    public static CastExpr read(DataInput input) throws IOException {
+        CastExpr castExpr = new CastExpr();
+        castExpr.readFields(input);
+        return castExpr;
+    }
+
+    @Override
+    public void readFields(DataInput in) throws IOException {
+        isImplicit = in.readBoolean();
+        ScalarType scalarType = ScalarType.read(in);
+        targetTypeDef = new TypeDef(scalarType);
+        int counter = in.readInt();
+        for (int i = 0; i < counter; i++) {
+            children.add(Expr.readIn(in));
+        }
+    }
+
+    public CastExpr rewriteExpr(List<String> parameters, List<Expr> 
inputParamsExprs) throws AnalysisException {
+        // child
+        Expr child = this.getChild(0);
+        Expr newChild = null;
+        if (child instanceof SlotRef) {
+            String columnName = ((SlotRef) child).getColumnName();
+            int index = parameters.indexOf(columnName);
+            if (index != -1) {
+                newChild = inputParamsExprs.get(index);
+            }
+        }
+        // rewrite cast expr in children
+        if (child instanceof CastExpr) {
+            newChild = ((CastExpr) child).rewriteExpr(parameters, 
inputParamsExprs);
+        }
+
+        // type def
+        ScalarType targetType = (ScalarType) targetTypeDef.getType();
+        PrimitiveType primitiveType = targetType.getPrimitiveType();
+        ScalarType newTargetType = null;
+        switch (primitiveType) {
+            case DECIMALV2:
+                // normal decimal
+                if (targetType.getPrecision() != 0) {
+                    newTargetType = targetType;
+                    break;
+                }
+                int precision = getDigital(targetType.getScalarPrecisionStr(), 
parameters, inputParamsExprs);
+                int scale = getDigital(targetType.getScalarScaleStr(), 
parameters, inputParamsExprs);
+                if (precision != -1 && scale != -1) {
+                    newTargetType = ScalarType.createType(primitiveType, 0, 
precision, scale);
+                } else if (precision != -1 && scale == -1) {
+                    newTargetType = ScalarType.createType(primitiveType, 0, 
precision, ScalarType.DEFAULT_SCALE);
+                }
+                break;
+            case CHAR:
+            case VARCHAR:
+                // normal char/varchar
+                if (targetType.getLength() != -1) {
+                    newTargetType = targetType;
+                    break;
+                }
+                int len = getDigital(targetType.getLenStr(), parameters, 
inputParamsExprs);
+                if (len != -1) {
+                    newTargetType = ScalarType.createType(primitiveType, len, 
0, 0);
+                }
+                // default char/varchar, which len is -1
+                if (len == -1 && targetType.getLength() == -1) {
+                    newTargetType = targetType;
+                }
+                break;
+            default:
+                newTargetType = targetType;
+                break;
+        }
+
+        if (newTargetType != null && newChild != null) {
+            TypeDef typeDef = new TypeDef(newTargetType);
+            return new CastExpr(typeDef, newChild);
+        }
+
+        return this;
+    }
+
+    private int getDigital(String desc, List<String> parameters, List<Expr> 
inputParamsExprs) {
+        int index = parameters.indexOf(desc);
+        if (index != -1) {
+            Expr expr = inputParamsExprs.get(index);
+            if (expr.getType().isIntegerType()) {
+                return ((Long)((IntLiteral) expr).getRealValue()).intValue();
+            }
+        }
+        return -1;
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
index c935ab5..a7098a9 100755
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
@@ -1665,7 +1665,8 @@ abstract public class Expr extends TreeNode<Expr> 
implements ParseNode, Cloneabl
         MAX_LITERAL(10),
         BINARY_PREDICATE(11),
         FUNCTION_CALL(12),
-        ARRAY_LITERAL(13);
+        ARRAY_LITERAL(13),
+        CAST_EXPR(14);
 
         private static Map<Integer, ExprSerCode> codeMap = Maps.newHashMap();
 
@@ -1715,7 +1716,9 @@ abstract public class Expr extends TreeNode<Expr> 
implements ParseNode, Cloneabl
             output.writeInt(ExprSerCode.FUNCTION_CALL.getCode());
         } else if (expr instanceof ArrayLiteral) {
             output.writeInt(ExprSerCode.ARRAY_LITERAL.getCode());
-        } else {
+        } else if (expr instanceof CastExpr){
+            output.writeInt(ExprSerCode.CAST_EXPR.getCode());
+        }else {
             throw new IOException("Unknown class " + 
expr.getClass().getName());
         }
         expr.write(output);
@@ -1758,6 +1761,8 @@ abstract public class Expr extends TreeNode<Expr> 
implements ParseNode, Cloneabl
                 return FunctionCallExpr.read(in);
             case ARRAY_LITERAL:
                 return ArrayLiteral.read(in);
+            case CAST_EXPR:
+                return CastExpr.read(in);
             default:
                 throw new IOException("Unknown code: " + code);
         }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java
index 53476da..2e91f33 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.catalog;
 
+import org.apache.doris.analysis.CastExpr;
 import org.apache.doris.analysis.Expr;
 import org.apache.doris.analysis.FunctionCallExpr;
 import org.apache.doris.analysis.FunctionName;
@@ -24,12 +25,14 @@ import org.apache.doris.analysis.SelectStmt;
 import org.apache.doris.analysis.SlotRef;
 import org.apache.doris.analysis.SqlParser;
 import org.apache.doris.analysis.SqlScanner;
+import org.apache.doris.analysis.TypeDef;
 import org.apache.doris.common.AnalysisException;
 import org.apache.doris.common.io.Text;
 import org.apache.doris.common.util.SqlParserUtils;
 import org.apache.doris.qe.SqlModeHelper;
 import org.apache.doris.thrift.TFunctionBinaryType;
 
+import com.google.common.base.Strings;
 import com.google.common.collect.Lists;
 import com.google.gson.Gson;
 
@@ -59,6 +62,7 @@ public class AliasFunction extends Function {
 
     private Expr originFunction;
     private List<String> parameters = new ArrayList<>();
+    private List<String> typeDefParams = new ArrayList<>();
 
     // Only used for serialization
     protected AliasFunction() {
@@ -152,30 +156,63 @@ public class AliasFunction extends Function {
         if (parameters.size() != getArgs().length) {
             throw new AnalysisException("Alias function [" + functionName() + 
"] args number is not equal to parameters number");
         }
-        List<Expr> exprs = ((FunctionCallExpr) 
originFunction).getFnParams().exprs();
+        List<Expr> exprs;
+        if (originFunction instanceof FunctionCallExpr) {
+            exprs = ((FunctionCallExpr) originFunction).getFnParams().exprs();
+        } else if (originFunction instanceof CastExpr) {
+            exprs = originFunction.getChildren();
+            TypeDef targetTypeDef = ((CastExpr) 
originFunction).getTargetTypeDef();
+            if (targetTypeDef.getType().isScalarType()) {
+                ScalarType scalarType = (ScalarType) targetTypeDef.getType();
+                PrimitiveType primitiveType = scalarType.getPrimitiveType();
+                switch (primitiveType) {
+                    case DECIMALV2:
+                        if 
(!Strings.isNullOrEmpty(scalarType.getScalarPrecisionStr())) {
+                            
typeDefParams.add(scalarType.getScalarPrecisionStr());
+                        }
+                        if 
(!Strings.isNullOrEmpty(scalarType.getScalarScaleStr())) {
+                            typeDefParams.add(scalarType.getScalarScaleStr());
+                        }
+                        break;
+                    case CHAR:
+                    case VARCHAR:
+                        if (!Strings.isNullOrEmpty(scalarType.getLenStr())) {
+                            typeDefParams.add(scalarType.getLenStr());
+                        }
+                        break;
+                }
+            }
+        } else {
+            throw new AnalysisException("Not supported expr type: " + 
originFunction);
+        }
         Set<String> set = new HashSet<>();
         for (String str : parameters) {
             if (!set.add(str)) {
                 throw new AnalysisException("Alias function [" + 
functionName() + "] has duplicate parameter [" + str + "].");
             }
             boolean existFlag = false;
+            // check exprs
             for (Expr expr : exprs) {
                 existFlag |= checkParams(expr, str);
             }
+            // check targetTypeDef
+            for (String typeDefParam : typeDefParams) {
+                existFlag |= typeDefParam.equals(str);
+            }
             if (!existFlag) {
                 throw new AnalysisException("Alias function [" + 
functionName() + "]  do not contain parameter [" + str + "].");
             }
         }
     }
 
-    private boolean checkParams(Expr expr, String parma) {
+    private boolean checkParams(Expr expr, String param) {
         for (Expr e : expr.getChildren()) {
-            if (checkParams(e, parma)) {
+            if (checkParams(e, param)) {
                 return true;
             }
         }
         if (expr instanceof SlotRef) {
-            if (parma.equals(((SlotRef) expr).getColumnName())) {
+            if (param.equals(((SlotRef) expr).getColumnName())) {
                 return true;
             }
         }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java
index 992a9e0..7def685 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java
@@ -59,7 +59,8 @@ public enum PrimitiveType {
     STRUCT("MAP", 24, TPrimitiveType.STRUCT),
     STRING("STRING", 16, TPrimitiveType.STRING),
     // Unsupported scalar types.
-    BINARY("BINARY", -1, TPrimitiveType.BINARY);
+    BINARY("BINARY", -1, TPrimitiveType.BINARY),
+    ALL("ALL", -1, TPrimitiveType.INVALID_TYPE);
 
 
     private static final int DATE_INDEX_LEN = 3;
@@ -611,6 +612,8 @@ public enum PrimitiveType {
                 return MAP;
             case STRUCT:
                 return STRUCT;
+            case ALL:
+                return ALL;
             default:
                 return INVALID_TYPE;
         }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java
index b56d8a7..57c9034 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java
@@ -17,8 +17,13 @@
 
 package org.apache.doris.catalog;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 import java.util.Objects;
 
+import org.apache.doris.common.io.Text;
+import org.apache.doris.persist.gson.GsonUtils;
 import org.apache.doris.thrift.TColumnType;
 import org.apache.doris.thrift.TScalarType;
 import org.apache.doris.thrift.TTypeDesc;
@@ -85,6 +90,16 @@ public class ScalarType extends Type {
     @SerializedName(value = "scale")
     private int scale;
 
+    // Only used for alias function decimal
+    @SerializedName(value = "precisionStr")
+    private String precisionStr;
+    // Only used for alias function decimal
+    @SerializedName(value = "scaleStr")
+    private String scaleStr;
+    // Only used for alias function char/varchar
+    @SerializedName(value = "lenStr")
+    private String lenStr;
+
     protected ScalarType(PrimitiveType type) {
         this.type = type;
     }
@@ -144,6 +159,8 @@ public class ScalarType extends Type {
                 return DEFAULT_DECIMALV2;
             case LARGEINT:
                 return LARGEINT;
+            case ALL:
+                return ALL;
             default:
                 LOG.warn("type={}", type);
                 Preconditions.checkState(false);
@@ -207,6 +224,12 @@ public class ScalarType extends Type {
         return type;
     }
 
+    public static ScalarType createCharType(String lenStr) {
+        ScalarType type = new ScalarType(PrimitiveType.CHAR);
+        type.lenStr = lenStr;
+        return type;
+    }
+
     public static ScalarType createChar(int len) {
         ScalarType type = new ScalarType(PrimitiveType.CHAR);
         type.len = len;
@@ -230,6 +253,20 @@ public class ScalarType extends Type {
         return type;
     }
 
+    public static ScalarType createDecimalV2Type(String precisionStr) {
+        ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
+        type.precisionStr = precisionStr;
+        type.scaleStr = null;
+        return type;
+    }
+
+    public static ScalarType createDecimalV2Type(String precisionStr, String 
scaleStr) {
+        ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
+        type.precisionStr = precisionStr;
+        type.scaleStr = scaleStr;
+        return type;
+    }
+
     public static ScalarType createDecimalV2TypeInternal(int precision, int 
scale) {
         ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
         type.precision = Math.min(precision, MAX_PRECISION);
@@ -244,6 +281,13 @@ public class ScalarType extends Type {
         return type;
     }
 
+    public static ScalarType createVarcharType(String lenStr) {
+        // length checked in analysis
+        ScalarType type = new ScalarType(PrimitiveType.VARCHAR);
+        type.lenStr = lenStr;
+        return type;
+    }
+
     public static ScalarType createStringType() {
         // length checked in analysis
         ScalarType type = new ScalarType(PrimitiveType.STRING);
@@ -296,13 +340,27 @@ public class ScalarType extends Type {
         StringBuilder stringBuilder = new StringBuilder();
         switch (type) {
             case CHAR:
-                
stringBuilder.append("char").append("(").append(len).append(")");
+                if (Strings.isNullOrEmpty(lenStr)) {
+                    
stringBuilder.append("char").append("(").append(len).append(")");
+                } else {
+                    
stringBuilder.append("char").append("(`").append(lenStr).append("`)");
+                }
                 break;
             case VARCHAR:
-                
stringBuilder.append("varchar").append("(").append(len).append(")");
+                if (Strings.isNullOrEmpty(lenStr)) {
+                    
stringBuilder.append("varchar").append("(").append(len).append(")");
+                } else {
+                    
stringBuilder.append("varchar").append("(`").append(lenStr).append("`)");
+                }
                 break;
             case DECIMALV2:
-                
stringBuilder.append("decimal").append("(").append(precision).append(", 
").append(scale).append(")");
+                if (Strings.isNullOrEmpty(precisionStr)) {
+                    
stringBuilder.append("decimal").append("(").append(precision).append(", 
").append(scale).append(")");
+                } else if (!Strings.isNullOrEmpty(precisionStr) && 
!Strings.isNullOrEmpty(scaleStr)) {
+                    
stringBuilder.append("decimal").append("(`").append(precisionStr).append("`, 
`").append(scaleStr).append("`)");
+                } else {
+                    
stringBuilder.append("decimal").append("(`").append(precisionStr).append("`)");
+                }
                 break;
             case BOOLEAN:
                 return "boolean";
@@ -393,6 +451,18 @@ public class ScalarType extends Type {
     public int getScalarScale() { return scale; }
     public int getScalarPrecision() { return precision; }
 
+    public String getScalarPrecisionStr() {
+        return precisionStr;
+    }
+
+    public String getScalarScaleStr() {
+        return scaleStr;
+    }
+
+    public String getLenStr() {
+        return lenStr;
+    }
+
     @Override
     public boolean isWildcardDecimal() {
         return (type == PrimitiveType.DECIMALV2)
@@ -606,6 +676,11 @@ public class ScalarType extends Type {
             return INVALID;
         }
 
+        // for cast all type
+        if (t1.type == PrimitiveType.ALL || t2.type == PrimitiveType.ALL) {
+            return Type.ALL;
+        }
+
         if (t1.isStringType() || t2.isStringType()) {
             if (t1.type == PrimitiveType.STRING || t2.type == 
PrimitiveType.STRING) {
                 return createStringType();
@@ -708,4 +783,14 @@ public class ScalarType extends Type {
         result = 31 * result + scale;
         return result;
     }
+
+    public void write(DataOutput out) throws IOException {
+        String json = GsonUtils.GSON.toJson(this);
+        Text.writeString(out, json);
+    }
+
+    public static ScalarType read(DataInput input) throws IOException {
+        String json = Text.readString(input);
+        return GsonUtils.GSON.fromJson(json, ScalarType.class);
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java
index 7c74d34..7e2d2ab 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java
@@ -76,6 +76,8 @@ public abstract class Type {
     public static final ScalarType HLL = ScalarType.createHllType();
     public static final ScalarType CHAR = (ScalarType) 
ScalarType.createCharType(-1);
     public static final ScalarType BITMAP = new 
ScalarType(PrimitiveType.BITMAP);
+    // Only used for alias function, to represent any type in function args
+    public static final ScalarType ALL = new ScalarType(PrimitiveType.ALL);
     public static final MapType Map = new MapType();
 
     private static ArrayList<ScalarType> integerTypes;
@@ -944,9 +946,9 @@ public abstract class Type {
         compatibilityMatrix[TIME.ordinal()][TIME.ordinal()] = 
PrimitiveType.INVALID_TYPE;
 
         // Check all of the necessary entries that should be filled.
-        // ignore binary
-        for (int i = 0; i < PrimitiveType.values().length - 1; ++i) {
-            for (int j = i; j < PrimitiveType.values().length - 1; ++j) {
+        // ignore binary and all
+        for (int i = 0; i < PrimitiveType.values().length - 2; ++i) {
+            for (int j = i; j < PrimitiveType.values().length - 2; ++j) {
                 PrimitiveType t1 = PrimitiveType.values()[i];
                 PrimitiveType t2 = PrimitiveType.values()[j];
                 // DECIMAL, NULL, and INVALID_TYPE  are handled separately.
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java
index c3a0f3e..09e61fc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java
@@ -18,6 +18,7 @@
 package org.apache.doris.rewrite;
 
 import org.apache.doris.analysis.Analyzer;
+import org.apache.doris.analysis.CastExpr;
 import org.apache.doris.analysis.Expr;
 import org.apache.doris.analysis.FunctionCallExpr;
 import org.apache.doris.catalog.AliasFunction;
@@ -39,7 +40,13 @@ public class RewriteAliasFunctionRule implements 
ExprRewriteRule{
         if (expr instanceof FunctionCallExpr) {
             Function fn = expr.getFn();
             if (fn instanceof AliasFunction) {
-                return ((FunctionCallExpr) expr).rewriteExpr();
+                Expr originFn = ((AliasFunction) fn).getOriginFunction();
+                if (originFn instanceof FunctionCallExpr) {
+                    return ((FunctionCallExpr) expr).rewriteExpr();
+                } else if (originFn instanceof CastExpr) {
+                    return ((CastExpr) originFn).rewriteExpr(((AliasFunction) 
fn).getParameters(),
+                            ((FunctionCallExpr) expr).getParams().exprs());
+                }
             }
 
         }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
index d56856d..17ad666 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
@@ -19,8 +19,10 @@ package org.apache.doris.catalog;
 
 import org.apache.doris.analysis.CreateDbStmt;
 import org.apache.doris.analysis.CreateFunctionStmt;
+import org.apache.doris.analysis.CreateTableStmt;
 import org.apache.doris.analysis.Expr;
 import org.apache.doris.analysis.FunctionCallExpr;
+import org.apache.doris.analysis.StringLiteral;
 import org.apache.doris.common.FeConstants;
 import org.apache.doris.common.jmockit.Deencapsulation;
 import org.apache.doris.planner.PlanFragment;
@@ -29,6 +31,7 @@ import org.apache.doris.planner.UnionNode;
 import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.qe.QueryState;
 import org.apache.doris.qe.StmtExecutor;
+import org.apache.doris.utframe.DorisAssert;
 import org.apache.doris.utframe.UtFrameUtils;
 
 import org.junit.AfterClass;
@@ -48,11 +51,15 @@ import java.util.UUID;
 public class CreateFunctionTest {
 
     private static String runningDir = "fe/mocked/CreateFunctionTest/" + 
UUID.randomUUID().toString() + "/";
+    private static ConnectContext connectContext;
+    private static DorisAssert dorisAssert;
 
     @BeforeClass
     public static void setup() throws Exception {
         UtFrameUtils.createDorisCluster(runningDir);
         FeConstants.runningUnitTest = true;
+        // create connect context
+        connectContext = UtFrameUtils.createDefaultCtx();
     }
 
     @AfterClass
@@ -71,6 +78,14 @@ public class CreateFunctionTest {
         Catalog.getCurrentCatalog().createDb(createDbStmt);
         System.out.println(Catalog.getCurrentCatalog().getDbNames());
 
+        String createTblStmtStr = "create table db1.tbl1(k1 int, k2 bigint, k3 
varchar(10), k4 char(5)) duplicate key(k1) "
+                + "distributed by hash(k2) buckets 1 
properties('replication_num' = '1');";
+        CreateTableStmt createTableStmt = (CreateTableStmt) 
UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, connectContext);
+        Catalog.getCurrentCatalog().createTable(createTableStmt);
+
+        dorisAssert = new DorisAssert();
+        dorisAssert.useDatabase("db1");
+
         Database db = 
Catalog.getCurrentCatalog().getDbNullable("default_cluster:db1");
         Assert.assertNotNull(db);
 
@@ -126,5 +141,89 @@ public class CreateFunctionTest {
         Assert.assertEquals(1, constExprLists.size());
         Assert.assertEquals(1, constExprLists.get(0).size());
         Assert.assertTrue(constExprLists.get(0).get(0) instanceof 
FunctionCallExpr);
+
+        queryStr = "select db1.id_masking(k1) from db1.tbl1";
+        
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("concat(left(`k1`,
 3), '****', right(`k1`, 4))"));
+
+        // create alias function with cast
+        // cast any type to decimal with specific precision and scale
+        createFuncStr = "create alias function db1.decimal(all, int, int) with 
parameter(col, precision, scale)" +
+                " as cast(col as decimal(precision, scale));";
+        createFunctionStmt = (CreateFunctionStmt) 
UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
+        Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
+
+        functions = db.getFunctions();
+        Assert.assertEquals(3, functions.size());
+
+        queryStr = "select db1.decimal(333, 4, 1);";
+        ctx.getState().reset();
+        stmtExecutor = new StmtExecutor(ctx, queryStr);
+        stmtExecutor.execute();
+        Assert.assertNotEquals(QueryState.MysqlStateType.ERR, 
ctx.getState().getStateType());
+        planner = stmtExecutor.planner();
+        Assert.assertEquals(1, planner.getFragments().size());
+        fragment = planner.getFragments().get(0);
+        Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
+        unionNode =  (UnionNode)fragment.getPlanRoot();
+        constExprLists = Deencapsulation.getField(unionNode, 
"constExprLists_");
+        System.out.println(constExprLists.get(0).get(0));
+        Assert.assertTrue(constExprLists.get(0).get(0) instanceof 
StringLiteral);
+
+        queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;";
+        
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3`
 AS DECIMAL(4,1))"));
+
+        // cast any type to varchar with fixed length
+        createFuncStr = "create alias function db1.varchar(all, int) with 
parameter(text, length) as " +
+                "cast(text as varchar(length));";
+        createFunctionStmt = (CreateFunctionStmt) 
UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
+        Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
+
+        functions = db.getFunctions();
+        Assert.assertEquals(4, functions.size());
+
+        queryStr = "select db1.varchar(333, 4);";
+        ctx.getState().reset();
+        stmtExecutor = new StmtExecutor(ctx, queryStr);
+        stmtExecutor.execute();
+        Assert.assertNotEquals(QueryState.MysqlStateType.ERR, 
ctx.getState().getStateType());
+        planner = stmtExecutor.planner();
+        Assert.assertEquals(1, planner.getFragments().size());
+        fragment = planner.getFragments().get(0);
+        Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
+        unionNode =  (UnionNode)fragment.getPlanRoot();
+        constExprLists = Deencapsulation.getField(unionNode, 
"constExprLists_");
+        Assert.assertEquals(1, constExprLists.size());
+        Assert.assertEquals(1, constExprLists.get(0).size());
+        Assert.assertTrue(constExprLists.get(0).get(0) instanceof 
StringLiteral);
+
+        queryStr = "select db1.varchar(k1, 4) from db1.tbl1;";
+        
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1`
 AS CHARACTER)"));
+
+        // cast any type to char with fixed length
+        createFuncStr = "create alias function db1.char(all, int) with 
parameter(text, length) as " +
+                "cast(text as char(length));";
+        createFunctionStmt = (CreateFunctionStmt) 
UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
+        Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
+
+        functions = db.getFunctions();
+        Assert.assertEquals(5, functions.size());
+
+        queryStr = "select db1.char(333, 4);";
+        ctx.getState().reset();
+        stmtExecutor = new StmtExecutor(ctx, queryStr);
+        stmtExecutor.execute();
+        Assert.assertNotEquals(QueryState.MysqlStateType.ERR, 
ctx.getState().getStateType());
+        planner = stmtExecutor.planner();
+        Assert.assertEquals(1, planner.getFragments().size());
+        fragment = planner.getFragments().get(0);
+        Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
+        unionNode =  (UnionNode)fragment.getPlanRoot();
+        constExprLists = Deencapsulation.getField(unionNode, 
"constExprLists_");
+        Assert.assertEquals(1, constExprLists.size());
+        Assert.assertEquals(1, constExprLists.get(0).size());
+        Assert.assertTrue(constExprLists.get(0).get(0) instanceof 
StringLiteral);
+
+        queryStr = "select db1.char(k1, 4) from db1.tbl1;";
+        
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1`
 AS CHARACTER)"));
     }
 }
diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift
index 8c9a46f..efa657a 100644
--- a/gensrc/thrift/Types.thrift
+++ b/gensrc/thrift/Types.thrift
@@ -78,7 +78,8 @@ enum TPrimitiveType {
   ARRAY,
   MAP,
   STRUCT,
-  STRING
+  STRING,
+  ALL
 }
 
 enum TTypeNodeType {

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to