This is an automated email from the ASF dual-hosted git repository.
kxiao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 93a934e7753 [Improve](map) support map cast with map literal and
implicate nested scala cast (#26126)
93a934e7753 is described below
commit 93a934e7753dbf54eed84c311fcd3027880cf71c
Author: amory <[email protected]>
AuthorDate: Thu Nov 2 22:56:42 2023 +0800
[Improve](map) support map cast with map literal and implicate nested scala
cast (#26126)
---
be/src/vec/functions/function_cast.h | 58 ++++++++++++++++++---
.../java/org/apache/doris/catalog/MapType.java | 6 ++-
.../main/java/org/apache/doris/catalog/Type.java | 11 ++++
.../java/org/apache/doris/analysis/CastExpr.java | 2 +-
.../main/java/org/apache/doris/analysis/Expr.java | 7 +++
.../cast_function/test_cast_map_function.out | 48 +++++++++++++++++
.../cast_function/test_cast_map_function.out | 48 +++++++++++++++++
.../cast_function/test_cast_map_function.groovy | 60 ++++++++++++++++++++++
.../cast_function/test_cast_map_function.groovy | 60 ++++++++++++++++++++++
9 files changed, 291 insertions(+), 9 deletions(-)
diff --git a/be/src/vec/functions/function_cast.h
b/be/src/vec/functions/function_cast.h
index 8872f373b44..e2bc77f0cc3 100644
--- a/be/src/vec/functions/function_cast.h
+++ b/be/src/vec/functions/function_cast.h
@@ -51,6 +51,7 @@
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
+#include "vec/columns/column_map.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_string.h"
#include "vec/columns/column_struct.h"
@@ -1899,13 +1900,57 @@ private:
}
//TODO(Amory) . Need support more cast for key , value for map
- WrapperType create_map_wrapper(const DataTypePtr& from_type, const
DataTypeMap& to_type) const {
- switch (from_type->get_type_id()) {
- case TypeIndex::String:
+ WrapperType create_map_wrapper(FunctionContext* context, const
DataTypePtr& from_type,
+ const DataTypeMap& to_type) const {
+ if (from_type->get_type_id() == TypeIndex::String) {
return &ConvertImplGenericFromString::execute;
- default:
- return create_unsupport_wrapper(from_type->get_name(),
to_type.get_name());
}
+ auto from = check_and_get_data_type<DataTypeMap>(from_type.get());
+ if (!from) {
+ return create_unsupport_wrapper(
+ fmt::format("CAST AS Map can only be performed between Map
types or from "
+ "String. from type: {}, to type: {}",
+ from_type->get_name(), to_type.get_name()));
+ }
+ DataTypes from_kv_types;
+ DataTypes to_kv_types;
+ from_kv_types.reserve(2);
+ to_kv_types.reserve(2);
+ from_kv_types.push_back(from->get_key_type());
+ from_kv_types.push_back(from->get_value_type());
+ to_kv_types.push_back(to_type.get_key_type());
+ to_kv_types.push_back(to_type.get_value_type());
+
+ auto kv_wrappers = get_element_wrappers(context, from_kv_types,
to_kv_types);
+ return [kv_wrappers, from_kv_types, to_kv_types](
+ FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ const size_t result, size_t /*input_rows_count*/) ->
Status {
+ auto& from_column =
block.get_by_position(arguments.front()).column;
+ auto from_col_map =
check_and_get_column<ColumnMap>(from_column.get());
+ if (!from_col_map) {
+ return Status::RuntimeError("Illegal column {} for function
CAST AS MAP",
+ from_column->get_name());
+ }
+
+ Columns converted_columns(2);
+ ColumnsWithTypeAndName columnsWithTypeAndName(2);
+ columnsWithTypeAndName[0] = {from_col_map->get_keys_ptr(),
from_kv_types[0], ""};
+ columnsWithTypeAndName[1] = {from_col_map->get_values_ptr(),
from_kv_types[1], ""};
+
+ for (size_t i = 0; i < 2; ++i) {
+ ColumnNumbers element_arguments {block.columns()};
+ block.insert(columnsWithTypeAndName[i]);
+ size_t element_result = block.columns();
+ block.insert({to_kv_types[i], ""});
+ RETURN_IF_ERROR(kv_wrappers[i](context, block,
element_arguments, element_result,
+
columnsWithTypeAndName[i].column->size()));
+ converted_columns[i] =
block.get_by_position(element_result).column;
+ }
+
+ block.get_by_position(result).column = ColumnMap::create(
+ converted_columns[0], converted_columns[1],
from_col_map->get_offsets_ptr());
+ return Status::OK();
+ };
}
ElementWrappers get_element_wrappers(FunctionContext* context,
@@ -2166,7 +2211,8 @@ private:
return create_struct_wrapper(context, from_type,
static_cast<const
DataTypeStruct&>(*to_type));
case TypeIndex::Map:
- return create_map_wrapper(from_type, static_cast<const
DataTypeMap&>(*to_type));
+ return create_map_wrapper(context, from_type,
+ static_cast<const
DataTypeMap&>(*to_type));
case TypeIndex::HLL:
return create_hll_wrapper(context, from_type,
static_cast<const
DataTypeHLL&>(*to_type));
diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java
b/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java
index e9efc83a8fc..691c4d7e04d 100644
--- a/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java
+++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/MapType.java
@@ -192,8 +192,10 @@ public class MapType extends Type {
}
public static boolean canCastTo(MapType type, MapType targetType) {
- return Type.canCastTo(type.getKeyType(), targetType.getKeyType())
- && Type.canCastTo(type.getValueType(), targetType.getValueType());
+ return (targetType.getKeyType().isStringType() &&
type.getKeyType().isStringType()
+ || Type.canCastTo(type.getKeyType(), targetType.getKeyType()))
+ && (Type.canCastTo(type.getValueType(), targetType.getValueType())
+ || targetType.getValueType().isStringType() &&
type.getValueType().isStringType());
}
@Override
diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
index 36d35f9e2a0..0d2c96ca855 100644
--- a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
+++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
@@ -2169,6 +2169,17 @@ public abstract class Type {
return false;
}
return matchExactType(((ArrayType) type2).getItemType(),
((ArrayType) type1).getItemType());
+ } else if (type2.isMapType()) {
+ // For types array, we also need to check contains null for
case like
+ // cast(array<not_null(int)> as array<int>)
+ if (!((MapType) type2).getIsKeyContainsNull() == ((MapType)
type1).getIsKeyContainsNull()) {
+ return false;
+ }
+ if (!((MapType) type2).getIsValueContainsNull() == ((MapType)
type1).getIsValueContainsNull()) {
+ return false;
+ }
+ return matchExactType(((MapType) type2).getKeyType(),
((MapType) type1).getKeyType())
+ && matchExactType(((MapType) type2).getValueType(),
((MapType) type1).getValueType());
} else {
return true;
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
index bf5c75c58f0..590c87427a0 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java
@@ -268,7 +268,7 @@ public class CastExpr extends Expr {
} else if (type.isMapType()) {
fn = ScalarFunction.createBuiltin(getFnName(Type.MAP),
type, Function.NullableMode.ALWAYS_NULLABLE,
- Lists.newArrayList(Type.VARCHAR), false,
+
Lists.newArrayList(getActualArgTypes(collectChildReturnTypes())[0]), false,
"doris::CastFunctions::cast_to_map_val", null, null, true);
} else if (type.isStructType()) {
fn = ScalarFunction.createBuiltin(getFnName(Type.STRUCT),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
index fe03afd02c3..1b32033ea88 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
@@ -28,6 +28,7 @@ import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.Function.NullableMode;
import org.apache.doris.catalog.FunctionSet;
+import org.apache.doris.catalog.MapType;
import org.apache.doris.catalog.MaterializedIndexMeta;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarFunction;
@@ -2516,6 +2517,8 @@ public abstract class Expr extends TreeNode<Expr>
implements ParseNode, Cloneabl
return getActualScalarType(originType);
} else if (originType.getPrimitiveType() == PrimitiveType.ARRAY) {
return getActualArrayType((ArrayType) originType);
+ } else if (originType.getPrimitiveType().isMapType()) {
+ return getActualMapType((MapType) originType);
} else {
return originType;
}
@@ -2550,6 +2553,10 @@ public abstract class Expr extends TreeNode<Expr>
implements ParseNode, Cloneabl
return
Arrays.stream(originType).map(this::getActualType).toArray(Type[]::new);
}
+ private MapType getActualMapType(MapType originMapType) {
+ return new MapType(getActualType(originMapType.getKeyType()),
getActualType(originMapType.getValueType()));
+ }
+
private ArrayType getActualArrayType(ArrayType originArrayType) {
return new ArrayType(getActualType(originArrayType.getItemType()));
}
diff --git
a/regression-test/data/nereids_function_p0/cast_function/test_cast_map_function.out
b/regression-test/data/nereids_function_p0/cast_function/test_cast_map_function.out
new file mode 100644
index 00000000000..c9da5a1c286
--- /dev/null
+++
b/regression-test/data/nereids_function_p0/cast_function/test_cast_map_function.out
@@ -0,0 +1,48 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !select --
+1 {"aa":1, "b":2, "1234567":77}
+2 {"b":12, "123":7777}
+
+-- !select --
+{}
+
+-- !select --
+{}
+
+-- !sql1 --
+\N
+
+-- !sql2 --
+{"":NULL}
+
+-- !sql3 --
+{"1":2}
+
+-- !sql4 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":7777}
+
+-- !sql5 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":7777}
+
+-- !sql6 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":97}
+
+-- !sql7 --
+{"aa":"1", "b":"2", "1234567":"77"}
+{"b":"12", "123":"7777"}
+
+-- !sql8 --
+{NULL:"1", NULL:"2", 1234567:"77"}
+{NULL:"12", 123:"7777"}
+
+-- !sql9 --
+{NULL:1, NULL:2, 1234567:77}
+{NULL:12, 123:7777}
+
+-- !sql10 --
+{NULL:NULL, NULL:NULL, 1234567:NULL}
+{NULL:NULL, 123:NULL}
+
diff --git
a/regression-test/data/query_p0/sql_functions/cast_function/test_cast_map_function.out
b/regression-test/data/query_p0/sql_functions/cast_function/test_cast_map_function.out
new file mode 100644
index 00000000000..c9da5a1c286
--- /dev/null
+++
b/regression-test/data/query_p0/sql_functions/cast_function/test_cast_map_function.out
@@ -0,0 +1,48 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !select --
+1 {"aa":1, "b":2, "1234567":77}
+2 {"b":12, "123":7777}
+
+-- !select --
+{}
+
+-- !select --
+{}
+
+-- !sql1 --
+\N
+
+-- !sql2 --
+{"":NULL}
+
+-- !sql3 --
+{"1":2}
+
+-- !sql4 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":7777}
+
+-- !sql5 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":7777}
+
+-- !sql6 --
+{"aa":1, "b":2, "1234567":77}
+{"b":12, "123":97}
+
+-- !sql7 --
+{"aa":"1", "b":"2", "1234567":"77"}
+{"b":"12", "123":"7777"}
+
+-- !sql8 --
+{NULL:"1", NULL:"2", 1234567:"77"}
+{NULL:"12", 123:"7777"}
+
+-- !sql9 --
+{NULL:1, NULL:2, 1234567:77}
+{NULL:12, 123:7777}
+
+-- !sql10 --
+{NULL:NULL, NULL:NULL, 1234567:NULL}
+{NULL:NULL, 123:NULL}
+
diff --git
a/regression-test/suites/nereids_function_p0/cast_function/test_cast_map_function.groovy
b/regression-test/suites/nereids_function_p0/cast_function/test_cast_map_function.groovy
new file mode 100644
index 00000000000..e3a4ccdaecc
--- /dev/null
+++
b/regression-test/suites/nereids_function_p0/cast_function/test_cast_map_function.groovy
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_cast_map_function", "query") {
+ sql """ set enable_nereids_planner = true; """
+ sql """ set enable_fallback_to_original_planner=false; """
+ def tableName = "tbl_test_cast_map_function_nereids"
+
+ sql """DROP TABLE IF EXISTS ${tableName}"""
+ sql """
+ CREATE TABLE IF NOT EXISTS ${tableName} (
+ `k1` int(11) NULL COMMENT "",
+ `k2` Map<char(7), int(11)> NOT NULL COMMENT "",
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`k1`)
+ DISTRIBUTED BY HASH(`k1`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "storage_format" = "V2"
+ )
+ """
+ // insert into with implicit cast
+ sql """ INSERT INTO ${tableName} VALUES(1, {"aa": 1, "b": 2, "1234567":
77}) """
+ sql """ INSERT INTO ${tableName} VALUES(2, {"b":12, "123":7777}) """
+
+ qt_select """ select * from ${tableName} order by k1; """
+
+ qt_select " select cast({} as MAP<INT,INT>);"
+ qt_select " select cast(map() as MAP<INT,INT>); "
+ qt_sql1 "select cast(NULL as MAP<string,int>)"
+
+ // literal NONSTRICT_SUPERTYPE_OF cast
+ qt_sql2 "select cast({'':''} as MAP<String,INT>);"
+ qt_sql3 "select cast({1:2} as MAP<String,INT>);"
+
+ // select SUPERTYPE_OF cast
+ qt_sql4 "select cast(k2 as map<varchar, bigint>) from ${tableName} order
by k1;"
+
+ // select NONSTRICT_SUPERTYPE_OF cast , this behavior is same with nested
scala type
+ qt_sql5 "select cast(k2 as map<char(2), smallint>) from ${tableName} order
by k1;"
+ qt_sql6 "select cast(k2 as map<char(1), tinyint>) from ${tableName} order
by k1;"
+ qt_sql7 "select cast(k2 as map<char, string>) from ${tableName} order by
k1;"
+ qt_sql8 "select cast(k2 as map<int, string>) from ${tableName} order by
k1;"
+ qt_sql9 "select cast(k2 as map<largeint, decimal>) from ${tableName} order
by k1;"
+ qt_sql10 "select cast(k2 as map<double, datetime>) from ${tableName} order
by k1;"
+}
diff --git
a/regression-test/suites/query_p0/sql_functions/cast_function/test_cast_map_function.groovy
b/regression-test/suites/query_p0/sql_functions/cast_function/test_cast_map_function.groovy
new file mode 100644
index 00000000000..021f8096b04
--- /dev/null
+++
b/regression-test/suites/query_p0/sql_functions/cast_function/test_cast_map_function.groovy
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_cast_map_function", "query") {
+ sql """set enable_nereids_planner = false """
+ def tableName = "tbl_test_cast_map_function"
+ // array functions only supported in vectorized engine
+
+ sql """DROP TABLE IF EXISTS ${tableName}"""
+ sql """
+ CREATE TABLE IF NOT EXISTS ${tableName} (
+ `k1` int(11) NULL COMMENT "",
+ `k2` Map<char(7), int(11)> NOT NULL COMMENT "",
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`k1`)
+ DISTRIBUTED BY HASH(`k1`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "storage_format" = "V2"
+ )
+ """
+ // insert into with implicit cast
+ sql """ INSERT INTO ${tableName} VALUES(1, {"aa": 1, "b": 2, "1234567":
77}) """
+ sql """ INSERT INTO ${tableName} VALUES(2, {"b":12, "123":7777}) """
+
+ qt_select """ select * from ${tableName} order by k1; """
+
+ qt_select " select cast({} as MAP<INT,INT>);"
+ qt_select " select cast(map() as MAP<INT,INT>); "
+ qt_sql1 "select cast(NULL as MAP<string,int>)"
+
+ // literal NONSTRICT_SUPERTYPE_OF cast
+ qt_sql2 "select cast({'':''} as MAP<String,INT>);"
+ qt_sql3 "select cast({1:2} as MAP<String,INT>);"
+
+ // select SUPERTYPE_OF cast
+ qt_sql4 "select cast(k2 as map<varchar, bigint>) from ${tableName} order
by k1;"
+
+ // select NONSTRICT_SUPERTYPE_OF cast , this behavior is same with nested
scala type
+ qt_sql5 "select cast(k2 as map<char(2), smallint>) from ${tableName} order
by k1;"
+ qt_sql6 "select cast(k2 as map<char(1), tinyint>) from ${tableName} order
by k1;"
+ qt_sql7 "select cast(k2 as map<char, string>) from ${tableName} order by
k1;"
+ qt_sql8 "select cast(k2 as map<int, string>) from ${tableName} order by
k1;"
+ qt_sql9 "select cast(k2 as map<largeint, decimal>) from ${tableName} order
by k1;"
+ qt_sql10 "select cast(k2 as map<double, datetime>) from ${tableName} order
by k1;"
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]