This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push:
new 675aef7 [AliasFunction] Add support for cast in alias function (#6754)
675aef7 is described below
commit 675aef7d75f2cfd6a25e588d1dc480c8cbba67ba
Author: qiye <[email protected]>
AuthorDate: Sun Oct 10 23:05:44 2021 +0800
[AliasFunction] Add support for cast in alias function (#6754)
support #6753
---
.../Data Definition/create-function.md | 9 +-
.../Data Definition/create-function.md | 9 +-
fe/fe-core/src/main/cup/sql_parser.cup | 22 ++++
.../java/org/apache/doris/analysis/CastExpr.java | 126 ++++++++++++++++++++-
.../main/java/org/apache/doris/analysis/Expr.java | 9 +-
.../org/apache/doris/catalog/AliasFunction.java | 45 +++++++-
.../org/apache/doris/catalog/PrimitiveType.java | 5 +-
.../java/org/apache/doris/catalog/ScalarType.java | 91 ++++++++++++++-
.../main/java/org/apache/doris/catalog/Type.java | 8 +-
.../doris/rewrite/RewriteAliasFunctionRule.java | 9 +-
.../apache/doris/catalog/CreateFunctionTest.java | 99 ++++++++++++++++
gensrc/thrift/Types.thrift | 3 +-
12 files changed, 415 insertions(+), 20 deletions(-)
diff --git a/docs/en/sql-reference/sql-statements/Data
Definition/create-function.md b/docs/en/sql-reference/sql-statements/Data
Definition/create-function.md
index 7eed4e4..417678c 100644
--- a/docs/en/sql-reference/sql-statements/Data Definition/create-function.md
+++ b/docs/en/sql-reference/sql-statements/Data Definition/create-function.md
@@ -47,8 +47,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
>
> `Function_name`: To create the name of the function, you can include the
> name of the database. For example: `db1.my_func'.
>
-> `arg_type`: The parameter type of the function is the same as the type
defined at the time of table building. Variable-length parameters can be
represented by `,...`. If it is a variable-length type, the type of the
variable-length part of the parameters is the same as the last
non-variable-length parameter type.
-> **NOTICE**: `ALIAS FUNCTION` variable-length parameters are not supported,
and there is at least one parameter.
+> `arg_type`: The parameter type of the function is the same as the type
defined at the time of table building. Variable-length parameters can be
represented by `,...`. If it is a variable-length type, the type of the
variable-length part of the parameters is the same as the last
non-variable-length parameter type.
+> **NOTICE**: `ALIAS FUNCTION` variable-length parameters are not supported,
and there is at least one parameter. In particular, the type `ALL` refers to
any data type and can only be used for `ALIAS FUNCTION`.
>
> `ret_type`: Required for creating a new function. This parameter is not
> required if you are aliasing an existing function.
>
@@ -130,8 +130,13 @@ If the `function_name` contains the database name, the
custom function will be c
5. Create a custom alias function
```
+ -- create a custom functional alias function
CREATE ALIAS FUNCTION id_masking(INT) WITH PARAMETER(id)
AS CONCAT(LEFT(id, 3), '****', RIGHT(id, 4));
+
+ -- create a custom cast alias function
+ CREATE ALIAS FUNCTION decimal(ALL, INT, INT) WITH PARAMETER(col,
precision, scale)
+ AS CAST(col AS decimal(precision, scale));
```
## keyword
diff --git a/docs/zh-CN/sql-reference/sql-statements/Data
Definition/create-function.md b/docs/zh-CN/sql-reference/sql-statements/Data
Definition/create-function.md
index 8582626..cf6a4fe 100644
--- a/docs/zh-CN/sql-reference/sql-statements/Data
Definition/create-function.md
+++ b/docs/zh-CN/sql-reference/sql-statements/Data
Definition/create-function.md
@@ -47,8 +47,8 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
>
> `function_name`: 要创建函数的名字, 可以包含数据库的名字。比如:`db1.my_func`。
>
-> `arg_type`: 函数的参数类型,与建表时定义的类型一致。变长参数时可以使用`,
...`来表示,如果是变长类型,那么变长部分参数的类型与最后一个非变长参数类型一致。
-> **注意**:`ALIAS FUNCTION` 不支持变长参数,且至少有一个参数。
+> `arg_type`: 函数的参数类型,与建表时定义的类型一致。变长参数时可以使用`,
...`来表示,如果是变长类型,那么变长部分参数的类型与最后一个非变长参数类型一致。
+> **注意**:`ALIAS FUNCTION` 不支持变长参数,且至少有一个参数。 特别地,`ALL` 类型指任一数据类型,只可以用于 `ALIAS
FUNCTION`.
>
> `ret_type`: 对创建新的函数来说,是必填项。如果是给已有函数取别名则可不用填写该参数。
>
@@ -131,8 +131,13 @@ CREATE [AGGREGATE] [ALIAS] FUNCTION function_name
5. 创建一个自定义别名函数
```
+ -- 创建自定义功能别名函数
CREATE ALIAS FUNCTION id_masking(INT) WITH PARAMETER(id)
AS CONCAT(LEFT(id, 3), '****', RIGHT(id, 4));
+
+ -- 创建自定义 CAST 别名函数
+ CREATE ALIAS FUNCTION decimal(ALL, INT, INT) WITH PARAMETER(col,
precision, scale)
+ AS CAST(col AS decimal(precision, scale));
```
## keyword
diff --git a/fe/fe-core/src/main/cup/sql_parser.cup
b/fe/fe-core/src/main/cup/sql_parser.cup
index bfb936e..20e9c8e 100644
--- a/fe/fe-core/src/main/cup/sql_parser.cup
+++ b/fe/fe-core/src/main/cup/sql_parser.cup
@@ -4352,6 +4352,11 @@ type ::=
type.setAssignedStrLenInColDefinition();
RESULT = type;
:}
+ | KW_VARCHAR LPAREN ident_or_text:lenStr RPAREN
+ {: ScalarType type = ScalarType.createVarcharType(lenStr);
+ type.setAssignedStrLenInColDefinition();
+ RESULT = type;
+ :}
| KW_VARCHAR
{: RESULT = ScalarType.createVarcharType(-1); :}
| KW_ARRAY LESSTHAN type:value_type GREATERTHAN
@@ -4365,6 +4370,11 @@ type ::=
type.setAssignedStrLenInColDefinition();
RESULT = type;
:}
+ | KW_CHAR LPAREN ident_or_text:lenStr RPAREN
+ {: ScalarType type = ScalarType.createCharType(lenStr);
+ type.setAssignedStrLenInColDefinition();
+ RESULT = type;
+ :}
| KW_CHAR
{: RESULT = ScalarType.createCharType(-1); :}
| KW_DECIMAL LPAREN INTEGER_LITERAL:precision RPAREN
@@ -4373,11 +4383,17 @@ type ::=
{: RESULT = ScalarType.createDecimalV2Type(precision.intValue(),
scale.intValue()); :}
| KW_DECIMAL
{: RESULT = ScalarType.createDecimalV2Type(); :}
+ | KW_DECIMAL LPAREN ident_or_text:precision RPAREN
+ {: RESULT = ScalarType.createDecimalV2Type(precision); :}
+ | KW_DECIMAL LPAREN ident_or_text:precision COMMA ident_or_text:scale RPAREN
+ {: RESULT = ScalarType.createDecimalV2Type(precision, scale); :}
| KW_HLL
{: ScalarType type = ScalarType.createHllType();
type.setAssignedStrLenInColDefinition();
RESULT = type;
:}
+ | KW_ALL
+ {: RESULT = Type.ALL; :}
;
opt_field_length ::=
@@ -5180,6 +5196,8 @@ keyword ::=
{: RESULT = id; :}
| KW_CHAIN:id
{: RESULT = id; :}
+ | KW_CHAR:id
+ {: RESULT = id; :}
| KW_CHARSET:id
{: RESULT = id; :}
| KW_CHECK:id
@@ -5210,6 +5228,8 @@ keyword ::=
{: RESULT = id; :}
| KW_DATETIME:id
{: RESULT = id; :}
+ | KW_DECIMAL:id
+ {: RESULT = id; :}
| KW_DISTINCTPC:id
{: RESULT = id; :}
| KW_DISTINCTPCSA:id
@@ -5456,6 +5476,8 @@ keyword ::=
{: RESULT = id; :}
| KW_MAP:id
{: RESULT = id; :}
+ | KW_VARCHAR:id
+ {: RESULT = id; :}
;
// Identifier that contain keyword
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
index ca46025..f071e0f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
@@ -17,7 +17,11 @@
package org.apache.doris.analysis;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
import java.util.Arrays;
+import java.util.List;
import java.util.Map;
import org.apache.doris.catalog.Catalog;
@@ -46,7 +50,7 @@ public class CastExpr extends Expr {
private static final Logger LOG = LogManager.getLogger(CastExpr.class);
// Only set for explicit casts. Null for implicit casts.
- private final TypeDef targetTypeDef;
+ private TypeDef targetTypeDef;
// True if this is a "pre-analyzed" implicit cast.
private boolean isImplicit;
@@ -77,6 +81,11 @@ public class CastExpr extends Expr {
}
}
+ // only used restore from readFields.
+ public CastExpr() {
+
+ }
+
public CastExpr(Type targetType, Expr e) {
super();
Preconditions.checkArgument(targetType.isValid());
@@ -120,6 +129,10 @@ public class CastExpr extends Expr {
return "castTo" + targetType.getPrimitiveType().toString();
}
+ public TypeDef getTargetTypeDef() {
+ return targetTypeDef;
+ }
+
public static void initBuiltins(FunctionSet functionSet) {
for (Type fromType : Type.getSupportedTypes()) {
if (fromType.isNull()) {
@@ -206,6 +219,10 @@ public class CastExpr extends Expr {
}
public void analyze() throws AnalysisException {
+ // do not analyze ALL cast
+ if (type == Type.ALL) {
+ return;
+ }
// cast was asked for in the query, check for validity of cast
Type childType = getChild(0).getType();
@@ -327,4 +344,111 @@ public class CastExpr extends Expr {
}
return this;
}
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeBoolean(isImplicit);
+ if (targetTypeDef.getType() instanceof ScalarType) {
+ ScalarType scalarType = (ScalarType) targetTypeDef.getType();
+ scalarType.write(out);
+ } else {
+ throw new IOException("Can not write type " +
targetTypeDef.getType());
+ }
+ out.writeInt(children.size());
+ for (Expr expr : children) {
+ Expr.writeTo(expr, out);
+ }
+ }
+
+ public static CastExpr read(DataInput input) throws IOException {
+ CastExpr castExpr = new CastExpr();
+ castExpr.readFields(input);
+ return castExpr;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ isImplicit = in.readBoolean();
+ ScalarType scalarType = ScalarType.read(in);
+ targetTypeDef = new TypeDef(scalarType);
+ int counter = in.readInt();
+ for (int i = 0; i < counter; i++) {
+ children.add(Expr.readIn(in));
+ }
+ }
+
+ public CastExpr rewriteExpr(List<String> parameters, List<Expr>
inputParamsExprs) throws AnalysisException {
+ // child
+ Expr child = this.getChild(0);
+ Expr newChild = null;
+ if (child instanceof SlotRef) {
+ String columnName = ((SlotRef) child).getColumnName();
+ int index = parameters.indexOf(columnName);
+ if (index != -1) {
+ newChild = inputParamsExprs.get(index);
+ }
+ }
+ // rewrite cast expr in children
+ if (child instanceof CastExpr) {
+ newChild = ((CastExpr) child).rewriteExpr(parameters,
inputParamsExprs);
+ }
+
+ // type def
+ ScalarType targetType = (ScalarType) targetTypeDef.getType();
+ PrimitiveType primitiveType = targetType.getPrimitiveType();
+ ScalarType newTargetType = null;
+ switch (primitiveType) {
+ case DECIMALV2:
+ // normal decimal
+ if (targetType.getPrecision() != 0) {
+ newTargetType = targetType;
+ break;
+ }
+ int precision = getDigital(targetType.getScalarPrecisionStr(),
parameters, inputParamsExprs);
+ int scale = getDigital(targetType.getScalarScaleStr(),
parameters, inputParamsExprs);
+ if (precision != -1 && scale != -1) {
+ newTargetType = ScalarType.createType(primitiveType, 0,
precision, scale);
+ } else if (precision != -1 && scale == -1) {
+ newTargetType = ScalarType.createType(primitiveType, 0,
precision, ScalarType.DEFAULT_SCALE);
+ }
+ break;
+ case CHAR:
+ case VARCHAR:
+ // normal char/varchar
+ if (targetType.getLength() != -1) {
+ newTargetType = targetType;
+ break;
+ }
+ int len = getDigital(targetType.getLenStr(), parameters,
inputParamsExprs);
+ if (len != -1) {
+ newTargetType = ScalarType.createType(primitiveType, len,
0, 0);
+ }
+ // default char/varchar, which len is -1
+ if (len == -1 && targetType.getLength() == -1) {
+ newTargetType = targetType;
+ }
+ break;
+ default:
+ newTargetType = targetType;
+ break;
+ }
+
+ if (newTargetType != null && newChild != null) {
+ TypeDef typeDef = new TypeDef(newTargetType);
+ return new CastExpr(typeDef, newChild);
+ }
+
+ return this;
+ }
+
+ private int getDigital(String desc, List<String> parameters, List<Expr>
inputParamsExprs) {
+ int index = parameters.indexOf(desc);
+ if (index != -1) {
+ Expr expr = inputParamsExprs.get(index);
+ if (expr.getType().isIntegerType()) {
+ return ((Long)((IntLiteral) expr).getRealValue()).intValue();
+ }
+ }
+ return -1;
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
index c935ab5..a7098a9 100755
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
@@ -1665,7 +1665,8 @@ abstract public class Expr extends TreeNode<Expr>
implements ParseNode, Cloneabl
MAX_LITERAL(10),
BINARY_PREDICATE(11),
FUNCTION_CALL(12),
- ARRAY_LITERAL(13);
+ ARRAY_LITERAL(13),
+ CAST_EXPR(14);
private static Map<Integer, ExprSerCode> codeMap = Maps.newHashMap();
@@ -1715,7 +1716,9 @@ abstract public class Expr extends TreeNode<Expr>
implements ParseNode, Cloneabl
output.writeInt(ExprSerCode.FUNCTION_CALL.getCode());
} else if (expr instanceof ArrayLiteral) {
output.writeInt(ExprSerCode.ARRAY_LITERAL.getCode());
- } else {
+ } else if (expr instanceof CastExpr){
+ output.writeInt(ExprSerCode.CAST_EXPR.getCode());
+ }else {
throw new IOException("Unknown class " +
expr.getClass().getName());
}
expr.write(output);
@@ -1758,6 +1761,8 @@ abstract public class Expr extends TreeNode<Expr>
implements ParseNode, Cloneabl
return FunctionCallExpr.read(in);
case ARRAY_LITERAL:
return ArrayLiteral.read(in);
+ case CAST_EXPR:
+ return CastExpr.read(in);
default:
throw new IOException("Unknown code: " + code);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java
index 53476da..2e91f33 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AliasFunction.java
@@ -17,6 +17,7 @@
package org.apache.doris.catalog;
+import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.FunctionName;
@@ -24,12 +25,14 @@ import org.apache.doris.analysis.SelectStmt;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.SqlParser;
import org.apache.doris.analysis.SqlScanner;
+import org.apache.doris.analysis.TypeDef;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.io.Text;
import org.apache.doris.common.util.SqlParserUtils;
import org.apache.doris.qe.SqlModeHelper;
import org.apache.doris.thrift.TFunctionBinaryType;
+import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.gson.Gson;
@@ -59,6 +62,7 @@ public class AliasFunction extends Function {
private Expr originFunction;
private List<String> parameters = new ArrayList<>();
+ private List<String> typeDefParams = new ArrayList<>();
// Only used for serialization
protected AliasFunction() {
@@ -152,30 +156,63 @@ public class AliasFunction extends Function {
if (parameters.size() != getArgs().length) {
throw new AnalysisException("Alias function [" + functionName() +
"] args number is not equal to parameters number");
}
- List<Expr> exprs = ((FunctionCallExpr)
originFunction).getFnParams().exprs();
+ List<Expr> exprs;
+ if (originFunction instanceof FunctionCallExpr) {
+ exprs = ((FunctionCallExpr) originFunction).getFnParams().exprs();
+ } else if (originFunction instanceof CastExpr) {
+ exprs = originFunction.getChildren();
+ TypeDef targetTypeDef = ((CastExpr)
originFunction).getTargetTypeDef();
+ if (targetTypeDef.getType().isScalarType()) {
+ ScalarType scalarType = (ScalarType) targetTypeDef.getType();
+ PrimitiveType primitiveType = scalarType.getPrimitiveType();
+ switch (primitiveType) {
+ case DECIMALV2:
+ if
(!Strings.isNullOrEmpty(scalarType.getScalarPrecisionStr())) {
+
typeDefParams.add(scalarType.getScalarPrecisionStr());
+ }
+ if
(!Strings.isNullOrEmpty(scalarType.getScalarScaleStr())) {
+ typeDefParams.add(scalarType.getScalarScaleStr());
+ }
+ break;
+ case CHAR:
+ case VARCHAR:
+ if (!Strings.isNullOrEmpty(scalarType.getLenStr())) {
+ typeDefParams.add(scalarType.getLenStr());
+ }
+ break;
+ }
+ }
+ } else {
+ throw new AnalysisException("Not supported expr type: " +
originFunction);
+ }
Set<String> set = new HashSet<>();
for (String str : parameters) {
if (!set.add(str)) {
throw new AnalysisException("Alias function [" +
functionName() + "] has duplicate parameter [" + str + "].");
}
boolean existFlag = false;
+ // check exprs
for (Expr expr : exprs) {
existFlag |= checkParams(expr, str);
}
+ // check targetTypeDef
+ for (String typeDefParam : typeDefParams) {
+ existFlag |= typeDefParam.equals(str);
+ }
if (!existFlag) {
throw new AnalysisException("Alias function [" +
functionName() + "] do not contain parameter [" + str + "].");
}
}
}
- private boolean checkParams(Expr expr, String parma) {
+ private boolean checkParams(Expr expr, String param) {
for (Expr e : expr.getChildren()) {
- if (checkParams(e, parma)) {
+ if (checkParams(e, param)) {
return true;
}
}
if (expr instanceof SlotRef) {
- if (parma.equals(((SlotRef) expr).getColumnName())) {
+ if (param.equals(((SlotRef) expr).getColumnName())) {
return true;
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java
index 992a9e0..7def685 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/PrimitiveType.java
@@ -59,7 +59,8 @@ public enum PrimitiveType {
STRUCT("MAP", 24, TPrimitiveType.STRUCT),
STRING("STRING", 16, TPrimitiveType.STRING),
// Unsupported scalar types.
- BINARY("BINARY", -1, TPrimitiveType.BINARY);
+ BINARY("BINARY", -1, TPrimitiveType.BINARY),
+ ALL("ALL", -1, TPrimitiveType.INVALID_TYPE);
private static final int DATE_INDEX_LEN = 3;
@@ -611,6 +612,8 @@ public enum PrimitiveType {
return MAP;
case STRUCT:
return STRUCT;
+ case ALL:
+ return ALL;
default:
return INVALID_TYPE;
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java
index b56d8a7..57c9034 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarType.java
@@ -17,8 +17,13 @@
package org.apache.doris.catalog;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
import java.util.Objects;
+import org.apache.doris.common.io.Text;
+import org.apache.doris.persist.gson.GsonUtils;
import org.apache.doris.thrift.TColumnType;
import org.apache.doris.thrift.TScalarType;
import org.apache.doris.thrift.TTypeDesc;
@@ -85,6 +90,16 @@ public class ScalarType extends Type {
@SerializedName(value = "scale")
private int scale;
+ // Only used for alias function decimal
+ @SerializedName(value = "precisionStr")
+ private String precisionStr;
+ // Only used for alias function decimal
+ @SerializedName(value = "scaleStr")
+ private String scaleStr;
+ // Only used for alias function char/varchar
+ @SerializedName(value = "lenStr")
+ private String lenStr;
+
protected ScalarType(PrimitiveType type) {
this.type = type;
}
@@ -144,6 +159,8 @@ public class ScalarType extends Type {
return DEFAULT_DECIMALV2;
case LARGEINT:
return LARGEINT;
+ case ALL:
+ return ALL;
default:
LOG.warn("type={}", type);
Preconditions.checkState(false);
@@ -207,6 +224,12 @@ public class ScalarType extends Type {
return type;
}
+ public static ScalarType createCharType(String lenStr) {
+ ScalarType type = new ScalarType(PrimitiveType.CHAR);
+ type.lenStr = lenStr;
+ return type;
+ }
+
public static ScalarType createChar(int len) {
ScalarType type = new ScalarType(PrimitiveType.CHAR);
type.len = len;
@@ -230,6 +253,20 @@ public class ScalarType extends Type {
return type;
}
+ public static ScalarType createDecimalV2Type(String precisionStr) {
+ ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
+ type.precisionStr = precisionStr;
+ type.scaleStr = null;
+ return type;
+ }
+
+ public static ScalarType createDecimalV2Type(String precisionStr, String
scaleStr) {
+ ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
+ type.precisionStr = precisionStr;
+ type.scaleStr = scaleStr;
+ return type;
+ }
+
public static ScalarType createDecimalV2TypeInternal(int precision, int
scale) {
ScalarType type = new ScalarType(PrimitiveType.DECIMALV2);
type.precision = Math.min(precision, MAX_PRECISION);
@@ -244,6 +281,13 @@ public class ScalarType extends Type {
return type;
}
+ public static ScalarType createVarcharType(String lenStr) {
+ // length checked in analysis
+ ScalarType type = new ScalarType(PrimitiveType.VARCHAR);
+ type.lenStr = lenStr;
+ return type;
+ }
+
public static ScalarType createStringType() {
// length checked in analysis
ScalarType type = new ScalarType(PrimitiveType.STRING);
@@ -296,13 +340,27 @@ public class ScalarType extends Type {
StringBuilder stringBuilder = new StringBuilder();
switch (type) {
case CHAR:
-
stringBuilder.append("char").append("(").append(len).append(")");
+ if (Strings.isNullOrEmpty(lenStr)) {
+
stringBuilder.append("char").append("(").append(len).append(")");
+ } else {
+
stringBuilder.append("char").append("(`").append(lenStr).append("`)");
+ }
break;
case VARCHAR:
-
stringBuilder.append("varchar").append("(").append(len).append(")");
+ if (Strings.isNullOrEmpty(lenStr)) {
+
stringBuilder.append("varchar").append("(").append(len).append(")");
+ } else {
+
stringBuilder.append("varchar").append("(`").append(lenStr).append("`)");
+ }
break;
case DECIMALV2:
-
stringBuilder.append("decimal").append("(").append(precision).append(",
").append(scale).append(")");
+ if (Strings.isNullOrEmpty(precisionStr)) {
+
stringBuilder.append("decimal").append("(").append(precision).append(",
").append(scale).append(")");
+ } else if (!Strings.isNullOrEmpty(precisionStr) &&
!Strings.isNullOrEmpty(scaleStr)) {
+
stringBuilder.append("decimal").append("(`").append(precisionStr).append("`,
`").append(scaleStr).append("`)");
+ } else {
+
stringBuilder.append("decimal").append("(`").append(precisionStr).append("`)");
+ }
break;
case BOOLEAN:
return "boolean";
@@ -393,6 +451,18 @@ public class ScalarType extends Type {
public int getScalarScale() { return scale; }
public int getScalarPrecision() { return precision; }
+ public String getScalarPrecisionStr() {
+ return precisionStr;
+ }
+
+ public String getScalarScaleStr() {
+ return scaleStr;
+ }
+
+ public String getLenStr() {
+ return lenStr;
+ }
+
@Override
public boolean isWildcardDecimal() {
return (type == PrimitiveType.DECIMALV2)
@@ -606,6 +676,11 @@ public class ScalarType extends Type {
return INVALID;
}
+ // for cast all type
+ if (t1.type == PrimitiveType.ALL || t2.type == PrimitiveType.ALL) {
+ return Type.ALL;
+ }
+
if (t1.isStringType() || t2.isStringType()) {
if (t1.type == PrimitiveType.STRING || t2.type ==
PrimitiveType.STRING) {
return createStringType();
@@ -708,4 +783,14 @@ public class ScalarType extends Type {
result = 31 * result + scale;
return result;
}
+
+ public void write(DataOutput out) throws IOException {
+ String json = GsonUtils.GSON.toJson(this);
+ Text.writeString(out, json);
+ }
+
+ public static ScalarType read(DataInput input) throws IOException {
+ String json = Text.readString(input);
+ return GsonUtils.GSON.fromJson(json, ScalarType.class);
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java
index 7c74d34..7e2d2ab 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java
@@ -76,6 +76,8 @@ public abstract class Type {
public static final ScalarType HLL = ScalarType.createHllType();
public static final ScalarType CHAR = (ScalarType)
ScalarType.createCharType(-1);
public static final ScalarType BITMAP = new
ScalarType(PrimitiveType.BITMAP);
+ // Only used for alias function, to represent any type in function args
+ public static final ScalarType ALL = new ScalarType(PrimitiveType.ALL);
public static final MapType Map = new MapType();
private static ArrayList<ScalarType> integerTypes;
@@ -944,9 +946,9 @@ public abstract class Type {
compatibilityMatrix[TIME.ordinal()][TIME.ordinal()] =
PrimitiveType.INVALID_TYPE;
// Check all of the necessary entries that should be filled.
- // ignore binary
- for (int i = 0; i < PrimitiveType.values().length - 1; ++i) {
- for (int j = i; j < PrimitiveType.values().length - 1; ++j) {
+ // ignore binary and all
+ for (int i = 0; i < PrimitiveType.values().length - 2; ++i) {
+ for (int j = i; j < PrimitiveType.values().length - 2; ++j) {
PrimitiveType t1 = PrimitiveType.values()[i];
PrimitiveType t2 = PrimitiveType.values()[j];
// DECIMAL, NULL, and INVALID_TYPE are handled separately.
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java
index c3a0f3e..09e61fc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteAliasFunctionRule.java
@@ -18,6 +18,7 @@
package org.apache.doris.rewrite;
import org.apache.doris.analysis.Analyzer;
+import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.catalog.AliasFunction;
@@ -39,7 +40,13 @@ public class RewriteAliasFunctionRule implements
ExprRewriteRule{
if (expr instanceof FunctionCallExpr) {
Function fn = expr.getFn();
if (fn instanceof AliasFunction) {
- return ((FunctionCallExpr) expr).rewriteExpr();
+ Expr originFn = ((AliasFunction) fn).getOriginFunction();
+ if (originFn instanceof FunctionCallExpr) {
+ return ((FunctionCallExpr) expr).rewriteExpr();
+ } else if (originFn instanceof CastExpr) {
+ return ((CastExpr) originFn).rewriteExpr(((AliasFunction)
fn).getParameters(),
+ ((FunctionCallExpr) expr).getParams().exprs());
+ }
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
index d56856d..17ad666 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
@@ -19,8 +19,10 @@ package org.apache.doris.catalog;
import org.apache.doris.analysis.CreateDbStmt;
import org.apache.doris.analysis.CreateFunctionStmt;
+import org.apache.doris.analysis.CreateTableStmt;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
+import org.apache.doris.analysis.StringLiteral;
import org.apache.doris.common.FeConstants;
import org.apache.doris.common.jmockit.Deencapsulation;
import org.apache.doris.planner.PlanFragment;
@@ -29,6 +31,7 @@ import org.apache.doris.planner.UnionNode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.QueryState;
import org.apache.doris.qe.StmtExecutor;
+import org.apache.doris.utframe.DorisAssert;
import org.apache.doris.utframe.UtFrameUtils;
import org.junit.AfterClass;
@@ -48,11 +51,15 @@ import java.util.UUID;
public class CreateFunctionTest {
private static String runningDir = "fe/mocked/CreateFunctionTest/" +
UUID.randomUUID().toString() + "/";
+ private static ConnectContext connectContext;
+ private static DorisAssert dorisAssert;
@BeforeClass
public static void setup() throws Exception {
UtFrameUtils.createDorisCluster(runningDir);
FeConstants.runningUnitTest = true;
+ // create connect context
+ connectContext = UtFrameUtils.createDefaultCtx();
}
@AfterClass
@@ -71,6 +78,14 @@ public class CreateFunctionTest {
Catalog.getCurrentCatalog().createDb(createDbStmt);
System.out.println(Catalog.getCurrentCatalog().getDbNames());
+ String createTblStmtStr = "create table db1.tbl1(k1 int, k2 bigint, k3
varchar(10), k4 char(5)) duplicate key(k1) "
+ + "distributed by hash(k2) buckets 1
properties('replication_num' = '1');";
+ CreateTableStmt createTableStmt = (CreateTableStmt)
UtFrameUtils.parseAndAnalyzeStmt(createTblStmtStr, connectContext);
+ Catalog.getCurrentCatalog().createTable(createTableStmt);
+
+ dorisAssert = new DorisAssert();
+ dorisAssert.useDatabase("db1");
+
Database db =
Catalog.getCurrentCatalog().getDbNullable("default_cluster:db1");
Assert.assertNotNull(db);
@@ -126,5 +141,89 @@ public class CreateFunctionTest {
Assert.assertEquals(1, constExprLists.size());
Assert.assertEquals(1, constExprLists.get(0).size());
Assert.assertTrue(constExprLists.get(0).get(0) instanceof
FunctionCallExpr);
+
+ queryStr = "select db1.id_masking(k1) from db1.tbl1";
+
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("concat(left(`k1`,
3), '****', right(`k1`, 4))"));
+
+ // create alias function with cast
+ // cast any type to decimal with specific precision and scale
+ createFuncStr = "create alias function db1.decimal(all, int, int) with
parameter(col, precision, scale)" +
+ " as cast(col as decimal(precision, scale));";
+ createFunctionStmt = (CreateFunctionStmt)
UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
+ Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
+
+ functions = db.getFunctions();
+ Assert.assertEquals(3, functions.size());
+
+ queryStr = "select db1.decimal(333, 4, 1);";
+ ctx.getState().reset();
+ stmtExecutor = new StmtExecutor(ctx, queryStr);
+ stmtExecutor.execute();
+ Assert.assertNotEquals(QueryState.MysqlStateType.ERR,
ctx.getState().getStateType());
+ planner = stmtExecutor.planner();
+ Assert.assertEquals(1, planner.getFragments().size());
+ fragment = planner.getFragments().get(0);
+ Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
+ unionNode = (UnionNode)fragment.getPlanRoot();
+ constExprLists = Deencapsulation.getField(unionNode,
"constExprLists_");
+ System.out.println(constExprLists.get(0).get(0));
+ Assert.assertTrue(constExprLists.get(0).get(0) instanceof
StringLiteral);
+
+ queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;";
+
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3`
AS DECIMAL(4,1))"));
+
+ // cast any type to varchar with fixed length
+ createFuncStr = "create alias function db1.varchar(all, int) with
parameter(text, length) as " +
+ "cast(text as varchar(length));";
+ createFunctionStmt = (CreateFunctionStmt)
UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
+ Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
+
+ functions = db.getFunctions();
+ Assert.assertEquals(4, functions.size());
+
+ queryStr = "select db1.varchar(333, 4);";
+ ctx.getState().reset();
+ stmtExecutor = new StmtExecutor(ctx, queryStr);
+ stmtExecutor.execute();
+ Assert.assertNotEquals(QueryState.MysqlStateType.ERR,
ctx.getState().getStateType());
+ planner = stmtExecutor.planner();
+ Assert.assertEquals(1, planner.getFragments().size());
+ fragment = planner.getFragments().get(0);
+ Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
+ unionNode = (UnionNode)fragment.getPlanRoot();
+ constExprLists = Deencapsulation.getField(unionNode,
"constExprLists_");
+ Assert.assertEquals(1, constExprLists.size());
+ Assert.assertEquals(1, constExprLists.get(0).size());
+ Assert.assertTrue(constExprLists.get(0).get(0) instanceof
StringLiteral);
+
+ queryStr = "select db1.varchar(k1, 4) from db1.tbl1;";
+
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1`
AS CHARACTER)"));
+
+ // cast any type to char with fixed length
+ createFuncStr = "create alias function db1.char(all, int) with
parameter(text, length) as " +
+ "cast(text as char(length));";
+ createFunctionStmt = (CreateFunctionStmt)
UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
+ Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
+
+ functions = db.getFunctions();
+ Assert.assertEquals(5, functions.size());
+
+ queryStr = "select db1.char(333, 4);";
+ ctx.getState().reset();
+ stmtExecutor = new StmtExecutor(ctx, queryStr);
+ stmtExecutor.execute();
+ Assert.assertNotEquals(QueryState.MysqlStateType.ERR,
ctx.getState().getStateType());
+ planner = stmtExecutor.planner();
+ Assert.assertEquals(1, planner.getFragments().size());
+ fragment = planner.getFragments().get(0);
+ Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
+ unionNode = (UnionNode)fragment.getPlanRoot();
+ constExprLists = Deencapsulation.getField(unionNode,
"constExprLists_");
+ Assert.assertEquals(1, constExprLists.size());
+ Assert.assertEquals(1, constExprLists.get(0).size());
+ Assert.assertTrue(constExprLists.get(0).get(0) instanceof
StringLiteral);
+
+ queryStr = "select db1.char(k1, 4) from db1.tbl1;";
+
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1`
AS CHARACTER)"));
}
}
diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift
index 8c9a46f..efa657a 100644
--- a/gensrc/thrift/Types.thrift
+++ b/gensrc/thrift/Types.thrift
@@ -78,7 +78,8 @@ enum TPrimitiveType {
ARRAY,
MAP,
STRUCT,
- STRING
+ STRING,
+ ALL
}
enum TTypeNodeType {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]